fix a couple of outstanding bugs
This commit is contained in:
parent
9d831d3ecd
commit
de7385d8ff
|
|
@ -3,8 +3,9 @@ Pomice
|
|||
~~~~~~
|
||||
The modern Lavalink wrapper designed for discord.py.
|
||||
|
||||
:copyright: 2023, cloudwithax
|
||||
:license: GPL-3.0
|
||||
Copyright (c) 2023, cloudwithax
|
||||
|
||||
Licensed under GPL-3.0
|
||||
"""
|
||||
import discord
|
||||
|
||||
|
|
@ -18,9 +19,11 @@ if not discord.version_info.major >= 2:
|
|||
"using 'pip install discord.py'"
|
||||
)
|
||||
|
||||
__version__ = "2.1.1"
|
||||
__version__ = "2.2a"
|
||||
__title__ = "pomice"
|
||||
__author__ = "cloudwithax"
|
||||
__license__ = "GPL-3.0"
|
||||
__copyright__ = "Copyright (c) 2023, cloudwithax"
|
||||
|
||||
from .enums import *
|
||||
from .events import *
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import aiohttp
|
||||
import orjson as json
|
||||
|
|
@ -6,6 +8,10 @@ import base64
|
|||
from datetime import datetime
|
||||
from .objects import *
|
||||
from .exceptions import *
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pool import Node
|
||||
|
||||
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
|
||||
|
|
@ -18,10 +24,11 @@ class Client:
|
|||
and translating it to a valid Lavalink track. No client auth is required here.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, node: Node) -> None:
|
||||
self.token: str = None
|
||||
self.expiry: datetime = None
|
||||
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
|
||||
self.node = node
|
||||
self.session = self.node._session
|
||||
self.headers = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@ class TrackType(Enum):
|
|||
"""
|
||||
|
||||
# We don't have to define anything special for these, since these just serve as flags
|
||||
YOUTUBE = "youtube_track"
|
||||
SOUNDCLOUD = "soundcloud_track"
|
||||
SPOTIFY = "spotify_track"
|
||||
APPLE_MUSIC = "apple_music_track"
|
||||
HTTP = "http_source"
|
||||
YOUTUBE = "youtube"
|
||||
SOUNDCLOUD = "soundcloud"
|
||||
SPOTIFY = "spotify"
|
||||
APPLE_MUSIC = "apple_music"
|
||||
HTTP = "http"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -65,10 +65,10 @@ class PlaylistType(Enum):
|
|||
"""
|
||||
|
||||
# We don't have to define anything special for these, since these just serve as flags
|
||||
YOUTUBE = "youtube_playlist"
|
||||
SOUNDCLOUD = "soundcloud_playlist"
|
||||
SPOTIFY = "spotify_playlist"
|
||||
APPLE_MUSIC = "apple_music_list"
|
||||
YOUTUBE = "youtube"
|
||||
SOUNDCLOUD = "soundcloud"
|
||||
SPOTIFY = "spotify"
|
||||
APPLE_MUSIC = "apple_music"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -114,27 +114,6 @@ class LoopMode(Enum):
|
|||
QUEUE = "queue"
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
class PlatformRecommendation(Enum):
|
||||
|
||||
"""
|
||||
The enum for choosing what platform you want for recommendations.
|
||||
This feature is exclusively for the recommendations function.
|
||||
If you are not using this feature, this class is not necessary.
|
||||
|
||||
PlatformRecommendation.SPOTIFY sets the recommendations to come from Spotify
|
||||
|
||||
PlatformRecommendation.YOUTUBE sets the recommendations to come from YouTube
|
||||
|
||||
"""
|
||||
|
||||
# We don't have to define anything special for these, since these just serve as flags
|
||||
SPOTIFY = "spotify"
|
||||
YOUTUBE = "youtube"
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,6 @@ from discord.ext import commands
|
|||
from .enums import SearchType, TrackType, PlaylistType
|
||||
from .filters import Filter
|
||||
|
||||
from . import (
|
||||
spotify,
|
||||
applemusic
|
||||
)
|
||||
|
||||
|
||||
class Track:
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ from discord import (
|
|||
from discord.ext import commands
|
||||
|
||||
from . import events
|
||||
from .enums import SearchType, PlatformRecommendation
|
||||
from .enums import SearchType
|
||||
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
|
||||
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError
|
||||
from .filters import Filter
|
||||
from .objects import Track, Playlist
|
||||
from .objects import Track
|
||||
from .pool import Node, NodePool
|
||||
|
||||
class Filters:
|
||||
|
|
@ -111,24 +111,24 @@ class Player(VoiceProtocol):
|
|||
*,
|
||||
node: Node = None
|
||||
):
|
||||
self.client = client
|
||||
self.client: Optional[Client] = client
|
||||
self._bot: Union[Client, commands.Bot] = client
|
||||
self.channel = channel
|
||||
self._guild = channel.guild if channel else None
|
||||
self.channel: Optional[VoiceChannel] = channel
|
||||
self._guild: Guild = channel.guild if channel else None
|
||||
|
||||
self._node = node if node else NodePool.get_node()
|
||||
self._node: Node = node if node else NodePool.get_node()
|
||||
self._current: Track = None
|
||||
self._filters: Filters = Filters()
|
||||
self._volume = 100
|
||||
self._paused = False
|
||||
self._is_connected = False
|
||||
self._volume: int = 100
|
||||
self._paused: bool = False
|
||||
self._is_connected: bool = False
|
||||
|
||||
self._position = 0
|
||||
self._last_position = 0
|
||||
self._last_update = 0
|
||||
self._position: int = 0
|
||||
self._last_position: int = 0
|
||||
self._last_update: int = 0
|
||||
self._ending_track: Optional[Track] = None
|
||||
|
||||
self._voice_state = {}
|
||||
self._voice_state: dict = {}
|
||||
|
||||
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def _update_state(self, data: dict):
|
||||
state: dict = data.get("state")
|
||||
self._last_update = time.time() * 1000
|
||||
self._last_update = int(state.get("time")) * 1000
|
||||
self._is_connected = state.get("connected")
|
||||
self._last_position = state.get("position")
|
||||
|
||||
|
|
@ -366,7 +366,7 @@ class Player(VoiceProtocol):
|
|||
data = {
|
||||
"encodedTrack": search.track_id,
|
||||
"position": str(start),
|
||||
"endTime": str(end)
|
||||
"endTime": str(track.length)
|
||||
}
|
||||
track.original = search
|
||||
track.track_id = search.track_id
|
||||
|
|
@ -375,7 +375,7 @@ class Player(VoiceProtocol):
|
|||
data = {
|
||||
"encodedTrack": track.track_id,
|
||||
"position": str(start),
|
||||
"endTime": str(end)
|
||||
"endTime": str(track.length)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -400,7 +400,12 @@ class Player(VoiceProtocol):
|
|||
# Now apply all filters
|
||||
for filter in track.filters:
|
||||
await self.add_filter(filter=filter)
|
||||
|
||||
|
||||
# Lavalink v4 changed the way the end time parameter works
|
||||
# so now the end time cannot be zero.
|
||||
# If it isnt zero, it'll match the length of the track,
|
||||
# otherwise itll be set here:
|
||||
|
||||
if end > 0:
|
||||
data["endTime"] = str(end)
|
||||
|
||||
|
|
|
|||
131
pomice/pool.py
131
pomice/pool.py
|
|
@ -3,13 +3,13 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
|
||||
from discord import Client
|
||||
from discord.ext import commands
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from . import (
|
||||
__version__,
|
||||
|
|
@ -37,6 +37,8 @@ from .routeplanner import RoutePlanner
|
|||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
"""The base class for a node.
|
||||
|
|
@ -48,7 +50,7 @@ class Node:
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pool,
|
||||
pool: NodePool,
|
||||
bot: Union[Client, commands.Bot],
|
||||
host: str,
|
||||
port: int,
|
||||
|
|
@ -59,7 +61,8 @@ class Node:
|
|||
session: Optional[aiohttp.ClientSession] = None,
|
||||
spotify_client_id: Optional[str] = None,
|
||||
spotify_client_secret: Optional[str] = None,
|
||||
apple_music: bool = False
|
||||
apple_music: bool = False,
|
||||
fallback: bool = False
|
||||
|
||||
):
|
||||
self._bot = bot
|
||||
|
|
@ -70,6 +73,7 @@ class Node:
|
|||
self._identifier = identifier
|
||||
self._heartbeat = heartbeat
|
||||
self._secure = secure
|
||||
self.fallback = fallback
|
||||
|
||||
|
||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
|
||||
|
|
@ -100,11 +104,11 @@ class Node:
|
|||
|
||||
if self._spotify_client_id and self._spotify_client_secret:
|
||||
self._spotify_client = spotify.Client(
|
||||
self._spotify_client_id, self._spotify_client_secret
|
||||
self, self._spotify_client_id, self._spotify_client_secret
|
||||
)
|
||||
|
||||
if apple_music:
|
||||
self._apple_music_client = applemusic.Client()
|
||||
self._apple_music_client = applemusic.Client(self)
|
||||
|
||||
self._bot.add_listener(self._update_handler, "on_socket_response")
|
||||
|
||||
|
|
@ -165,6 +169,7 @@ class Node:
|
|||
|
||||
if data["t"] == "VOICE_SERVER_UPDATE":
|
||||
guild_id = int(data["d"]["guild_id"])
|
||||
_log.debug(f"Recieved voice server update message from guild ID: {guild_id}")
|
||||
try:
|
||||
player = self._players[guild_id]
|
||||
await player.on_voice_server_update(data["d"])
|
||||
|
|
@ -176,6 +181,7 @@ class Node:
|
|||
return
|
||||
|
||||
guild_id = int(data["d"]["guild_id"])
|
||||
_log.debug(f"Recieved voice state update message from guild ID: {guild_id}")
|
||||
try:
|
||||
player = self._players[guild_id]
|
||||
await player.on_voice_state_update(data["d"])
|
||||
|
|
@ -187,13 +193,13 @@ class Node:
|
|||
|
||||
while True:
|
||||
msg = await self._websocket.receive()
|
||||
if msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
retry = backoff.delay()
|
||||
await asyncio.sleep(retry)
|
||||
if not self.is_connected:
|
||||
self._bot.loop.create_task(self.connect())
|
||||
if not self.is_connected:
|
||||
asyncio.create_task(self.connect())
|
||||
else:
|
||||
self._bot.loop.create_task(self._handle_payload(msg.json()))
|
||||
asyncio.create_task(self._handle_payload(msg.json()))
|
||||
|
||||
async def _handle_payload(self, data: dict):
|
||||
op = data.get("op", None)
|
||||
|
|
@ -216,32 +222,6 @@ class Node:
|
|||
elif op == "playerUpdate":
|
||||
await player._update_state(data)
|
||||
|
||||
def _get_type(self, query: str):
|
||||
if match := URLRegex.LAVALINK_SEARCH.match(query):
|
||||
type = match.group("type")
|
||||
if type == "sc":
|
||||
return TrackType.SOUNDCLOUD
|
||||
|
||||
return TrackType.YOUTUBE
|
||||
|
||||
|
||||
elif URLRegex.YOUTUBE_URL.match(query):
|
||||
if URLRegex.YOUTUBE_PLAYLIST_URL.match(query):
|
||||
return PlaylistType.YOUTUBE
|
||||
|
||||
return TrackType.YOUTUBE
|
||||
|
||||
elif URLRegex.SOUNDCLOUD_URL.match(query):
|
||||
if URLRegex.SOUNDCLOUD_TRACK_IN_SET_URL.match(query):
|
||||
return TrackType.SOUNDCLOUD
|
||||
if URLRegex.SOUNDCLOUD_PLAYLIST_URL.match(query):
|
||||
return PlaylistType.SOUNDCLOUD
|
||||
|
||||
return TrackType.SOUNDCLOUD
|
||||
|
||||
else:
|
||||
return TrackType.HTTP
|
||||
|
||||
async def send(
|
||||
self,
|
||||
method: str,
|
||||
|
|
@ -249,9 +229,10 @@ class Node:
|
|||
include_version: bool = True,
|
||||
guild_id: Optional[Union[int, str]] = None,
|
||||
query: Optional[str] = None,
|
||||
data: Optional[Union[dict, str]] = None
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
ignore_if_available: bool = False,
|
||||
):
|
||||
if not self._available:
|
||||
if not ignore_if_available and not self._available:
|
||||
raise NodeNotAvailable(
|
||||
f"The node '{self._identifier}' is unavailable."
|
||||
)
|
||||
|
|
@ -264,7 +245,8 @@ class Node:
|
|||
|
||||
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp:
|
||||
if resp.status >= 300:
|
||||
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
|
||||
data: dict = await resp.json()
|
||||
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}')
|
||||
|
||||
if method == "DELETE" or resp.status == 204:
|
||||
return await resp.json(content_type=None)
|
||||
|
|
@ -285,34 +267,42 @@ class Node:
|
|||
await self._bot.wait_until_ready()
|
||||
|
||||
try:
|
||||
version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False)
|
||||
version = version.replace(".", "")
|
||||
if not version.endswith('-SNAPSHOT') and int(version) < 370:
|
||||
self._available = False
|
||||
raise LavalinkVersionIncompatible(
|
||||
"The Lavalink version you're using is incompatible. "
|
||||
"Lavalink version 3.7.0 or above is required to use this library."
|
||||
)
|
||||
|
||||
self._websocket = await self._session.ws_connect(
|
||||
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
|
||||
)
|
||||
self._task = self._bot.loop.create_task(self._listen())
|
||||
self._available = True
|
||||
version = await self.send(method="GET", path="version", include_version=False)
|
||||
version = version.replace(".", "")
|
||||
if int(version) < 370:
|
||||
raise LavalinkVersionIncompatible(
|
||||
"The Lavalink version you're using is incompatible."
|
||||
"Lavalink version 3.7.0 or above is required to use this library."
|
||||
)
|
||||
|
||||
self._version = version[:1]
|
||||
if not self._task:
|
||||
self._task = asyncio.create_task(self._listen())
|
||||
|
||||
self._available = True
|
||||
if version.endswith('-SNAPSHOT'):
|
||||
# we're just gonna assume all snapshot versions correlate with v4
|
||||
self._version = 4
|
||||
else:
|
||||
self._version = version[:1]
|
||||
return self
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
|
||||
raise NodeConnectionFailure(
|
||||
f"The connection to node '{self._identifier}' failed."
|
||||
)
|
||||
) from None
|
||||
except aiohttp.WSServerHandshakeError:
|
||||
raise NodeConnectionFailure(
|
||||
f"The password for node '{self._identifier}' is invalid."
|
||||
)
|
||||
) from None
|
||||
except aiohttp.InvalidURL:
|
||||
raise NodeConnectionFailure(
|
||||
f"The URI for node '{self._identifier}' is invalid."
|
||||
)
|
||||
) from None
|
||||
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnects a connected Lavalink node and removes it from the node pool.
|
||||
|
|
@ -322,7 +312,8 @@ class Node:
|
|||
await player.destroy()
|
||||
|
||||
await self._websocket.close()
|
||||
del self._pool.nodes[self._identifier]
|
||||
await self._session.close()
|
||||
del self._pool._nodes[self._identifier]
|
||||
self.available = False
|
||||
self._task.cancel()
|
||||
|
||||
|
|
@ -338,8 +329,6 @@ class Node:
|
|||
Context object on the track it builds.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}")
|
||||
return Track(track_id=identifier, ctx=ctx, info=data)
|
||||
|
||||
|
|
@ -537,8 +526,6 @@ class Node:
|
|||
|
||||
load_type = data.get("loadType")
|
||||
|
||||
query_type = self._get_type(query)
|
||||
|
||||
if not load_type:
|
||||
raise TrackLoadError("There was an error while trying to load this track.")
|
||||
|
||||
|
|
@ -550,19 +537,14 @@ class Node:
|
|||
return None
|
||||
|
||||
elif load_type == "PLAYLIST_LOADED":
|
||||
if query_type == PlaylistType.SOUNDCLOUD:
|
||||
track_type = TrackType.SOUNDCLOUD
|
||||
else:
|
||||
track_type = TrackType.YOUTUBE
|
||||
|
||||
tracks = [
|
||||
Track(track_id=track["track"], info=track["info"], ctx=ctx, track_type=track_type)
|
||||
Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]))
|
||||
for track in data["tracks"]
|
||||
]
|
||||
return Playlist(
|
||||
playlist_info=data["playlistInfo"],
|
||||
tracks=tracks,
|
||||
playlist_type=query_type,
|
||||
playlist_type=PlaylistType(tracks[0].track_type.value),
|
||||
thumbnail=tracks[0].thumbnail,
|
||||
uri=query
|
||||
)
|
||||
|
|
@ -570,10 +552,10 @@ class Node:
|
|||
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
|
||||
return [
|
||||
Track(
|
||||
track_id=track["track"],
|
||||
track_id=track["encoded"],
|
||||
info=track["info"],
|
||||
ctx=ctx,
|
||||
track_type=query_type,
|
||||
track_type=TrackType(track["info"]["sourceName"]),
|
||||
filters=filters,
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
|
@ -629,7 +611,7 @@ class NodePool:
|
|||
This holds all the nodes that are to be used by the bot.
|
||||
"""
|
||||
|
||||
_nodes: dict = {}
|
||||
_nodes: Dict[str, Node] = {}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Pomice.NodePool node_count={self.node_count}>"
|
||||
|
|
@ -658,7 +640,7 @@ class NodePool:
|
|||
based on how players it has. This method will return a node with
|
||||
the least amount of players
|
||||
"""
|
||||
available_nodes = [node for node in cls._nodes.values() if node._available]
|
||||
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
||||
|
||||
if not available_nodes:
|
||||
raise NoNodesAvailable("There are no nodes available.")
|
||||
|
|
@ -704,7 +686,8 @@ class NodePool:
|
|||
spotify_client_id: Optional[str] = None,
|
||||
spotify_client_secret: Optional[str] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
apple_music: bool = False
|
||||
apple_music: bool = False,
|
||||
fallback: bool = False
|
||||
|
||||
) -> Node:
|
||||
"""Creates a Node object to be then added into the node pool.
|
||||
|
|
@ -718,7 +701,7 @@ class NodePool:
|
|||
identifier=identifier, secure=secure, heartbeat=heartbeat,
|
||||
spotify_client_id=spotify_client_id,
|
||||
session=session, spotify_client_secret=spotify_client_secret,
|
||||
apple_music=apple_music
|
||||
apple_music=apple_music, fallback=fallback
|
||||
)
|
||||
|
||||
await node.connect()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from base64 import b64encode
|
||||
|
||||
import aiohttp
|
||||
import orjson as json
|
||||
|
||||
|
||||
from base64 import b64encode
|
||||
from typing import TYPE_CHECKING
|
||||
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
|
||||
from .objects import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pool import Node
|
||||
|
||||
|
||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||
SPOTIFY_URL_REGEX = re.compile(
|
||||
|
|
@ -22,11 +26,12 @@ class Client:
|
|||
for any Spotify URL you throw at it.
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
def __init__(self, node: Node, client_id: str, client_secret: str) -> None:
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
self.node = node
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.session = self.node._session
|
||||
|
||||
self._bearer_token: str = None
|
||||
self._expiry = 0
|
||||
|
|
|
|||
|
|
@ -120,8 +120,6 @@ class RouteStats:
|
|||
self.block_index = details.get("blockIndex")
|
||||
self.address_index = details.get("currentAddressIndex")
|
||||
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue