diff --git a/pomice/__init__.py b/pomice/__init__.py index d0fd69c..0d7a0f1 100644 --- a/pomice/__init__.py +++ b/pomice/__init__.py @@ -3,8 +3,9 @@ Pomice ~~~~~~ The modern Lavalink wrapper designed for discord.py. -:copyright: 2023, cloudwithax -:license: GPL-3.0 +Copyright (c) 2023, cloudwithax + +Licensed under GPL-3.0 """ import discord @@ -18,9 +19,11 @@ if not discord.version_info.major >= 2: "using 'pip install discord.py'" ) -__version__ = "2.1.1" +__version__ = "2.2a" __title__ = "pomice" __author__ = "cloudwithax" +__license__ = "GPL-3.0" +__copyright__ = "Copyright (c) 2023, cloudwithax" from .enums import * from .events import * diff --git a/pomice/applemusic/client.py b/pomice/applemusic/client.py index e0b00b3..655f9af 100644 --- a/pomice/applemusic/client.py +++ b/pomice/applemusic/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import aiohttp import orjson as json @@ -6,6 +8,10 @@ import base64 from datetime import datetime from .objects import * from .exceptions import * +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..pool import Node AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P[^?]+)") AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P.+)(\?i=)(?P.+)") @@ -18,10 +24,11 @@ class Client: and translating it to a valid Lavalink track. No client auth is required here. """ - def __init__(self) -> None: + def __init__(self, node: Node) -> None: self.token: str = None self.expiry: datetime = None - self.session: aiohttp.ClientSession = aiohttp.ClientSession() + self.node = node + self.session = self.node._session self.headers = None diff --git a/pomice/enums.py b/pomice/enums.py index 54217b3..309fc14 100644 --- a/pomice/enums.py +++ b/pomice/enums.py @@ -42,11 +42,11 @@ class TrackType(Enum): """ # We don't have to define anything special for these, since these just serve as flags - YOUTUBE = "youtube_track" - SOUNDCLOUD = "soundcloud_track" - SPOTIFY = "spotify_track" - APPLE_MUSIC = "apple_music_track" - HTTP = "http_source" + YOUTUBE = "youtube" + SOUNDCLOUD = "soundcloud" + SPOTIFY = "spotify" + APPLE_MUSIC = "apple_music" + HTTP = "http" def __str__(self) -> str: return self.value @@ -65,10 +65,10 @@ class PlaylistType(Enum): """ # We don't have to define anything special for these, since these just serve as flags - YOUTUBE = "youtube_playlist" - SOUNDCLOUD = "soundcloud_playlist" - SPOTIFY = "spotify_playlist" - APPLE_MUSIC = "apple_music_list" + YOUTUBE = "youtube" + SOUNDCLOUD = "soundcloud" + SPOTIFY = "spotify" + APPLE_MUSIC = "apple_music" def __str__(self) -> str: return self.value @@ -114,27 +114,6 @@ class LoopMode(Enum): QUEUE = "queue" - def __str__(self) -> str: - return self.value - -class PlatformRecommendation(Enum): - - """ - The enum for choosing what platform you want for recommendations. - This feature is exclusively for the recommendations function. - If you are not using this feature, this class is not necessary. - - PlatformRecommendation.SPOTIFY sets the recommendations to come from Spotify - - PlatformRecommendation.YOUTUBE sets the recommendations to come from YouTube - - """ - - # We don't have to define anything special for these, since these just serve as flags - SPOTIFY = "spotify" - YOUTUBE = "youtube" - - def __str__(self) -> str: return self.value diff --git a/pomice/objects.py b/pomice/objects.py index 78fea40..7cfe269 100644 --- a/pomice/objects.py +++ b/pomice/objects.py @@ -7,10 +7,6 @@ from discord.ext import commands from .enums import SearchType, TrackType, PlaylistType from .filters import Filter -from . import ( - spotify, - applemusic -) class Track: diff --git a/pomice/player.py b/pomice/player.py index d9b28e2..71626e4 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -16,11 +16,11 @@ from discord import ( from discord.ext import commands from . import events -from .enums import SearchType, PlatformRecommendation +from .enums import SearchType from .events import PomiceEvent, TrackEndEvent, TrackStartEvent from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError from .filters import Filter -from .objects import Track, Playlist +from .objects import Track from .pool import Node, NodePool class Filters: @@ -111,24 +111,24 @@ class Player(VoiceProtocol): *, node: Node = None ): - self.client = client + self.client: Optional[Client] = client self._bot: Union[Client, commands.Bot] = client - self.channel = channel - self._guild = channel.guild if channel else None + self.channel: Optional[VoiceChannel] = channel + self._guild: Guild = channel.guild if channel else None - self._node = node if node else NodePool.get_node() + self._node: Node = node if node else NodePool.get_node() self._current: Track = None self._filters: Filters = Filters() - self._volume = 100 - self._paused = False - self._is_connected = False + self._volume: int = 100 + self._paused: bool = False + self._is_connected: bool = False - self._position = 0 - self._last_position = 0 - self._last_update = 0 + self._position: int = 0 + self._last_position: int = 0 + self._last_update: int = 0 self._ending_track: Optional[Track] = None - self._voice_state = {} + self._voice_state: dict = {} self._player_endpoint_uri = f'sessions/{self._node._session_id}/players' @@ -211,7 +211,7 @@ class Player(VoiceProtocol): async def _update_state(self, data: dict): state: dict = data.get("state") - self._last_update = time.time() * 1000 + self._last_update = int(state.get("time")) * 1000 self._is_connected = state.get("connected") self._last_position = state.get("position") @@ -366,7 +366,7 @@ class Player(VoiceProtocol): data = { "encodedTrack": search.track_id, "position": str(start), - "endTime": str(end) + "endTime": str(track.length) } track.original = search track.track_id = search.track_id @@ -375,7 +375,7 @@ class Player(VoiceProtocol): data = { "encodedTrack": track.track_id, "position": str(start), - "endTime": str(end) + "endTime": str(track.length) } @@ -400,7 +400,12 @@ class Player(VoiceProtocol): # Now apply all filters for filter in track.filters: await self.add_filter(filter=filter) - + + # Lavalink v4 changed the way the end time parameter works + # so now the end time cannot be zero. + # If it isnt zero, it'll match the length of the track, + # otherwise itll be set here: + if end > 0: data["endTime"] = str(end) diff --git a/pomice/pool.py b/pomice/pool.py index 67d8bcc..480b625 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -3,13 +3,13 @@ from __future__ import annotations import asyncio import random import re -from typing import Dict, List, Optional, TYPE_CHECKING, Union -from urllib.parse import quote - +import logging import aiohttp + from discord import Client from discord.ext import commands - +from typing import Dict, List, Optional, TYPE_CHECKING, Union +from urllib.parse import quote from . import ( __version__, @@ -37,6 +37,8 @@ from .routeplanner import RoutePlanner if TYPE_CHECKING: from .player import Player +_log = logging.getLogger(__name__) + class Node: """The base class for a node. @@ -48,7 +50,7 @@ class Node: def __init__( self, *, - pool, + pool: NodePool, bot: Union[Client, commands.Bot], host: str, port: int, @@ -59,7 +61,8 @@ class Node: session: Optional[aiohttp.ClientSession] = None, spotify_client_id: Optional[str] = None, spotify_client_secret: Optional[str] = None, - apple_music: bool = False + apple_music: bool = False, + fallback: bool = False ): self._bot = bot @@ -70,6 +73,7 @@ class Node: self._identifier = identifier self._heartbeat = heartbeat self._secure = secure + self.fallback = fallback self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket" @@ -100,11 +104,11 @@ class Node: if self._spotify_client_id and self._spotify_client_secret: self._spotify_client = spotify.Client( - self._spotify_client_id, self._spotify_client_secret + self, self._spotify_client_id, self._spotify_client_secret ) if apple_music: - self._apple_music_client = applemusic.Client() + self._apple_music_client = applemusic.Client(self) self._bot.add_listener(self._update_handler, "on_socket_response") @@ -165,6 +169,7 @@ class Node: if data["t"] == "VOICE_SERVER_UPDATE": guild_id = int(data["d"]["guild_id"]) + _log.debug(f"Recieved voice server update message from guild ID: {guild_id}") try: player = self._players[guild_id] await player.on_voice_server_update(data["d"]) @@ -176,6 +181,7 @@ class Node: return guild_id = int(data["d"]["guild_id"]) + _log.debug(f"Recieved voice state update message from guild ID: {guild_id}") try: player = self._players[guild_id] await player.on_voice_state_update(data["d"]) @@ -187,13 +193,13 @@ class Node: while True: msg = await self._websocket.receive() - if msg.type == aiohttp.WSMsgType.CLOSED: + if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): retry = backoff.delay() await asyncio.sleep(retry) - if not self.is_connected: - self._bot.loop.create_task(self.connect()) + if not self.is_connected: + asyncio.create_task(self.connect()) else: - self._bot.loop.create_task(self._handle_payload(msg.json())) + asyncio.create_task(self._handle_payload(msg.json())) async def _handle_payload(self, data: dict): op = data.get("op", None) @@ -216,32 +222,6 @@ class Node: elif op == "playerUpdate": await player._update_state(data) - def _get_type(self, query: str): - if match := URLRegex.LAVALINK_SEARCH.match(query): - type = match.group("type") - if type == "sc": - return TrackType.SOUNDCLOUD - - return TrackType.YOUTUBE - - - elif URLRegex.YOUTUBE_URL.match(query): - if URLRegex.YOUTUBE_PLAYLIST_URL.match(query): - return PlaylistType.YOUTUBE - - return TrackType.YOUTUBE - - elif URLRegex.SOUNDCLOUD_URL.match(query): - if URLRegex.SOUNDCLOUD_TRACK_IN_SET_URL.match(query): - return TrackType.SOUNDCLOUD - if URLRegex.SOUNDCLOUD_PLAYLIST_URL.match(query): - return PlaylistType.SOUNDCLOUD - - return TrackType.SOUNDCLOUD - - else: - return TrackType.HTTP - async def send( self, method: str, @@ -249,9 +229,10 @@ class Node: include_version: bool = True, guild_id: Optional[Union[int, str]] = None, query: Optional[str] = None, - data: Optional[Union[dict, str]] = None + data: Optional[Union[dict, str]] = None, + ignore_if_available: bool = False, ): - if not self._available: + if not ignore_if_available and not self._available: raise NodeNotAvailable( f"The node '{self._identifier}' is unavailable." ) @@ -264,7 +245,8 @@ class Node: async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp: if resp.status >= 300: - raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}') + data: dict = await resp.json() + raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}') if method == "DELETE" or resp.status == 204: return await resp.json(content_type=None) @@ -285,34 +267,42 @@ class Node: await self._bot.wait_until_ready() try: + version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False) + version = version.replace(".", "") + if not version.endswith('-SNAPSHOT') and int(version) < 370: + self._available = False + raise LavalinkVersionIncompatible( + "The Lavalink version you're using is incompatible. " + "Lavalink version 3.7.0 or above is required to use this library." + ) + self._websocket = await self._session.ws_connect( self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat ) - self._task = self._bot.loop.create_task(self._listen()) - self._available = True - version = await self.send(method="GET", path="version", include_version=False) - version = version.replace(".", "") - if int(version) < 370: - raise LavalinkVersionIncompatible( - "The Lavalink version you're using is incompatible." - "Lavalink version 3.7.0 or above is required to use this library." - ) - - self._version = version[:1] + if not self._task: + self._task = asyncio.create_task(self._listen()) + + self._available = True + if version.endswith('-SNAPSHOT'): + # we're just gonna assume all snapshot versions correlate with v4 + self._version = 4 + else: + self._version = version[:1] return self - except aiohttp.ClientConnectorError: + except (aiohttp.ClientConnectorError, ConnectionRefusedError): raise NodeConnectionFailure( f"The connection to node '{self._identifier}' failed." - ) + ) from None except aiohttp.WSServerHandshakeError: raise NodeConnectionFailure( f"The password for node '{self._identifier}' is invalid." - ) + ) from None except aiohttp.InvalidURL: raise NodeConnectionFailure( f"The URI for node '{self._identifier}' is invalid." - ) + ) from None + async def disconnect(self): """Disconnects a connected Lavalink node and removes it from the node pool. @@ -322,7 +312,8 @@ class Node: await player.destroy() await self._websocket.close() - del self._pool.nodes[self._identifier] + await self._session.close() + del self._pool._nodes[self._identifier] self.available = False self._task.cancel() @@ -338,8 +329,6 @@ class Node: Context object on the track it builds. """ - - data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}") return Track(track_id=identifier, ctx=ctx, info=data) @@ -537,8 +526,6 @@ class Node: load_type = data.get("loadType") - query_type = self._get_type(query) - if not load_type: raise TrackLoadError("There was an error while trying to load this track.") @@ -550,19 +537,14 @@ class Node: return None elif load_type == "PLAYLIST_LOADED": - if query_type == PlaylistType.SOUNDCLOUD: - track_type = TrackType.SOUNDCLOUD - else: - track_type = TrackType.YOUTUBE - tracks = [ - Track(track_id=track["track"], info=track["info"], ctx=ctx, track_type=track_type) + Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"])) for track in data["tracks"] ] return Playlist( playlist_info=data["playlistInfo"], tracks=tracks, - playlist_type=query_type, + playlist_type=PlaylistType(tracks[0].track_type.value), thumbnail=tracks[0].thumbnail, uri=query ) @@ -570,10 +552,10 @@ class Node: elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED": return [ Track( - track_id=track["track"], + track_id=track["encoded"], info=track["info"], ctx=ctx, - track_type=query_type, + track_type=TrackType(track["info"]["sourceName"]), filters=filters, timestamp=timestamp ) @@ -629,7 +611,7 @@ class NodePool: This holds all the nodes that are to be used by the bot. """ - _nodes: dict = {} + _nodes: Dict[str, Node] = {} def __repr__(self): return f"" @@ -658,7 +640,7 @@ class NodePool: based on how players it has. This method will return a node with the least amount of players """ - available_nodes = [node for node in cls._nodes.values() if node._available] + available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] if not available_nodes: raise NoNodesAvailable("There are no nodes available.") @@ -704,7 +686,8 @@ class NodePool: spotify_client_id: Optional[str] = None, spotify_client_secret: Optional[str] = None, session: Optional[aiohttp.ClientSession] = None, - apple_music: bool = False + apple_music: bool = False, + fallback: bool = False ) -> Node: """Creates a Node object to be then added into the node pool. @@ -718,7 +701,7 @@ class NodePool: identifier=identifier, secure=secure, heartbeat=heartbeat, spotify_client_id=spotify_client_id, session=session, spotify_client_secret=spotify_client_secret, - apple_music=apple_music + apple_music=apple_music, fallback=fallback ) await node.connect() diff --git a/pomice/spotify/client.py b/pomice/spotify/client.py index 460c3e3..063c262 100644 --- a/pomice/spotify/client.py +++ b/pomice/spotify/client.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import re import time -from base64 import b64encode - -import aiohttp import orjson as json - +from base64 import b64encode +from typing import TYPE_CHECKING from .exceptions import InvalidSpotifyURL, SpotifyRequestException from .objects import * +if TYPE_CHECKING: + from ..pool import Node + + GRANT_URL = "https://accounts.spotify.com/api/token" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" SPOTIFY_URL_REGEX = re.compile( @@ -22,11 +26,12 @@ class Client: for any Spotify URL you throw at it. """ - def __init__(self, client_id: str, client_secret: str) -> None: + def __init__(self, node: Node, client_id: str, client_secret: str) -> None: self._client_id = client_id self._client_secret = client_secret + self.node = node - self.session = aiohttp.ClientSession() + self.session = self.node._session self._bearer_token: str = None self._expiry = 0 diff --git a/pomice/utils.py b/pomice/utils.py index 947b086..36486a9 100644 --- a/pomice/utils.py +++ b/pomice/utils.py @@ -120,8 +120,6 @@ class RouteStats: self.block_index = details.get("blockIndex") self.address_index = details.get("currentAddressIndex") - - def __repr__(self) -> str: return f""