From c5ca63b014244ac044460f8613dae528830d67f0 Mon Sep 17 00:00:00 2001 From: NiceAesth Date: Fri, 10 Mar 2023 15:35:41 +0200 Subject: [PATCH 1/2] feat: add close to clients; style: formatting pass --- examples/advanced.py | 34 ++++++------ examples/basic.py | 26 +++++----- pomice/applemusic/client.py | 60 +++++++++++++--------- pomice/enums.py | 10 ++-- pomice/events.py | 2 +- pomice/objects.py | 6 +-- pomice/player.py | 100 ++++++++++++++++++------------------ pomice/pool.py | 76 +++++++++++++-------------- pomice/queue.py | 14 ++--- pomice/spotify/client.py | 61 ++++++++++++++-------- pomice/spotify/objects.py | 2 +- pomice/utils.py | 6 +-- 12 files changed, 212 insertions(+), 185 deletions(-) diff --git a/examples/advanced.py b/examples/advanced.py index 2ae9a03..4f78ff1 100644 --- a/examples/advanced.py +++ b/examples/advanced.py @@ -18,7 +18,7 @@ class Player(pomice.Player): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + self.queue = pomice.Queue() self.controller: discord.Message = None # Set context here so we can send a now playing embed @@ -43,12 +43,12 @@ class Player(pomice.Player): if self.controller: with suppress(discord.HTTPException): await self.controller.delete() - + # Queue up the next track, else teardown the player try: track: pomice.Track = self.queue.get() - except pomice.QueueEmpty: + except pomice.QueueEmpty: return await self.teardown() await self.play(track) @@ -68,12 +68,12 @@ class Player(pomice.Player): with suppress((discord.HTTPException), (KeyError)): await self.destroy() if self.controller: - await self.controller.delete() + await self.controller.delete() async def set_context(self, ctx: commands.Context): """Set context for the player""" - self.context = ctx - self.dj = ctx.author + self.context = ctx + self.dj = ctx.author @@ -81,20 +81,20 @@ class Player(pomice.Player): class Music(commands.Cog): def __init__(self, bot: commands.Bot) -> None: self.bot = bot - + # In order to initialize a node, or really do anything in this library, # you need to make a node pool self.pomice = pomice.NodePool() # Start the node bot.loop.create_task(self.start_nodes()) - + async def start_nodes(self): # Waiting for the bot to get ready before connecting to nodes. await self.bot.wait_until_ready() - + # You can pass in Spotify credentials to enable Spotify querying. - # If you do not pass in valid Spotify credentials, Spotify querying will not work + # If you do not pass in valid Spotify credentials, Spotify querying will not work await self.pomice.create_node( bot=self.bot, host="127.0.0.1", @@ -128,7 +128,7 @@ class Music(commands.Cog): # we can just skip to the next track # Of course, you can modify this to do whatever you like - + @commands.Cog.listener() async def on_pomice_track_end(self, player: Player, track, _): await player.do_next() @@ -140,7 +140,7 @@ class Music(commands.Cog): @commands.Cog.listener() async def on_pomice_track_exception(self, player: Player, track, _): await player.do_next() - + @commands.command(aliases=['joi', 'j', 'summon', 'su', 'con', 'connect']) async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None: if not channel: @@ -165,14 +165,14 @@ class Music(commands.Cog): await player.destroy() await ctx.send("Player has left the channel.") - + @commands.command(aliases=['pla', 'p']) async def play(self, ctx: commands.Context, *, search: str) -> None: # Checks if the player is in the channel before we play anything if not (player := ctx.voice_client): await ctx.author.voice.channel.connect(cls=Player) player: Player = ctx.voice_client - await player.set_context(ctx=ctx) + await player.set_context(ctx=ctx) # If you search a keyword, Pomice will automagically search the result using YouTube # You can pass in "search_type=" as an argument to change the search type @@ -180,11 +180,11 @@ class Music(commands.Cog): # will search up any keyword results on YouTube Music # We will also set the context here to get special features, like a track.requester object - results = await player.get_tracks(search, ctx=ctx) - + results = await player.get_tracks(search, ctx=ctx) + if not results: return await ctx.send("No results were found for that search term", delete_after=7) - + if isinstance(results, pomice.Playlist): for track in results.tracks: player.queue.put(track) diff --git a/examples/basic.py b/examples/basic.py index 5eb33b3..7268870 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -9,22 +9,22 @@ class MyBot(commands.Bot): command_prefix="!", activity=discord.Activity(type=discord.ActivityType.listening, name="to music!") ) - + self.add_cog(Music(self)) self.loop.create_task(self.cogs["Music"].start_nodes()) async def on_ready(self) -> None: print("I'm online!") - - + + class Music(commands.Cog): def __init__(self, bot: commands.Bot) -> None: self.bot = bot - + # In order to initialize a node, or really do anything in this library, # you need to make a node pool self.pomice = pomice.NodePool() - + async def start_nodes(self): # You can pass in Spotify credentials to enable Spotify querying. # If you do not pass in valid Spotify credentials, Spotify querying will not work @@ -36,7 +36,7 @@ class Music(commands.Cog): identifier="MAIN" ) print(f"Node is ready!") - + @commands.command(aliases=["connect"]) async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None: if not channel: @@ -62,24 +62,24 @@ class Music(commands.Cog): await player.destroy() await ctx.send("Player has left the channel.") - + @commands.command(aliases=["p"]) async def play(self, ctx: commands.Context, *, search: str) -> None: # Checks if the player is in the channel before we play anything if not ctx.voice_client: - await ctx.invoke(self.join) + await ctx.invoke(self.join) - player: pomice.Player = ctx.voice_client + player: pomice.Player = ctx.voice_client # If you search a keyword, Pomice will automagically search the result using YouTube # You can pass in "search_type=" as an argument to change the search type # i.e: player.get_tracks("query", search_type=SearchType.ytmsearch) # will search up any keyword results on YouTube Music - results = await player.get_tracks(search) - + results = await player.get_tracks(search) + if not results: raise commands.CommandError("No results were found for that search term.") - + if isinstance(results, pomice.Playlist): await player.play(track=results.tracks[0]) else: @@ -124,6 +124,6 @@ class Music(commands.Cog): await player.stop() await ctx.send("Player has been stopped") - + bot = MyBot() bot.run("token") diff --git a/pomice/applemusic/client.py b/pomice/applemusic/client.py index 643076d..9023326 100644 --- a/pomice/applemusic/client.py +++ b/pomice/applemusic/client.py @@ -8,18 +8,19 @@ import base64 from datetime import datetime from .objects import * from .exceptions import * -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ..pool import Node - -AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P[^?]+)") -AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P.+)(\?i=)(?P.+)") +AM_URL_REGEX = re.compile( + r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P[^?]+)" +) +AM_SINGLE_IN_ALBUM_REGEX = re.compile( + r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P.+)(\?i=)(?P.+)" +) AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}" AM_BASE_URL = "https://api.music.apple.com" + class Client: - """The base Apple Music client for Pomice. + """The base Apple Music client for Pomice. This will do all the heavy lifting of getting tracks from Apple Music and translating it to a valid Lavalink track. No client auth is required here. """ @@ -30,28 +31,30 @@ class Client: self.session: aiohttp.ClientSession = None self.headers = None - async def request_token(self): if not self.session: self.session = aiohttp.ClientSession() - async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp: + async with self.session.get( + "https://music.apple.com/assets/index.919fe17f.js" + ) as resp: if resp.status != 200: raise AppleMusicRequestException( f"Error while fetching results: {resp.status} {resp.reason}" ) text = await resp.text() - result = re.search("\"(eyJ.+?)\"", text).group(1) + result = re.search('"(eyJ.+?)"', text).group(1) self.token = result self.headers = { - 'Authorization': f"Bearer {result}", - 'Origin': 'https://apple.com', + "Authorization": f"Bearer {result}", + "Origin": "https://apple.com", } token_split = self.token.split(".")[1] - token_json = base64.b64decode(token_split + '=' * (-len(token_split) % 4)).decode() + token_json = base64.b64decode( + token_split + "=" * (-len(token_split) % 4) + ).decode() token_data = json.loads(token_json) self.expiry = datetime.fromtimestamp(token_data["exp"]) - async def search(self, query: str): if not self.token or datetime.utcnow() > self.expiry: @@ -72,7 +75,6 @@ class Client: request_url = AM_REQ_URL.format(country=country, type=type, id=id) else: request_url = AM_REQ_URL.format(country=country, type=type, id=id) - async with self.session.get(request_url, headers=self.headers) as resp: if resp.status != 200: @@ -83,15 +85,16 @@ class Client: data = data["data"][0] - if type == "song": return Song(data) - + elif type == "album": return Album(data) elif type == "artist": - async with self.session.get(f"{request_url}/view/top-songs", headers=self.headers) as resp: + async with self.session.get( + f"{request_url}/view/top-songs", headers=self.headers + ) as resp: if resp.status != 200: raise AppleMusicRequestException( f"Error while fetching results: {resp.status} {resp.reason}" @@ -101,20 +104,24 @@ class Client: return Artist(data, tracks=tracks) - else: + else: track_data: dict = data["relationships"]["tracks"] - + tracks = [Song(track) for track in track_data.get("data")] if not len(tracks): - raise AppleMusicRequestException("This playlist is empty and therefore cannot be queued.") + raise AppleMusicRequestException( + "This playlist is empty and therefore cannot be queued." + ) - if track_data.get("next"): + if track_data.get("next"): next_page_url = AM_BASE_URL + track_data.get("next") while next_page_url is not None: - async with self.session.get(next_page_url, headers=self.headers) as resp: + async with self.session.get( + next_page_url, headers=self.headers + ) as resp: if resp.status != 200: raise AppleMusicRequestException( f"Error while fetching results: {resp.status} {resp.reason}" @@ -128,6 +135,9 @@ class Client: else: next_page_url = None - + return Playlist(data, tracks) - return Playlist(data, tracks) \ No newline at end of file + async def close(self): + if self.session: + await self.session.close() + self.session = None diff --git a/pomice/enums.py b/pomice/enums.py index 1f42de8..ae94d25 100644 --- a/pomice/enums.py +++ b/pomice/enums.py @@ -89,7 +89,7 @@ class PlaylistType(Enum): class NodeAlgorithm(Enum): """ The enum for the different node algorithms in Pomice. - + The enums in this class are to only differentiate different methods, since the actual method is handled in the get_best_node() method. @@ -123,7 +123,7 @@ class LoopMode(Enum): # We don't have to define anything special for these, since these just serve as flags TRACK = "track" QUEUE = "queue" - + def __str__(self) -> str: return self.value @@ -135,16 +135,16 @@ class RouteStrategy(Enum): This feature is exclusively for the RoutePlanner class. If you are not using this feature, this class is not necessary. - RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs + RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs whenever they get banned by Youtube. RouteStrategy.LOAD_BALANCE specifies that the node is selecting random IPs to balance out requests between them. - RouteStrategy.NANO_SWITCH specifies that the node is switching + RouteStrategy.NANO_SWITCH specifies that the node is switching between IPs every CPU clock cycle. - RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching + RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching between IPs every CPU clock cycle and is rotating between IP blocks on ban. diff --git a/pomice/events.py b/pomice/events.py index 1f8f3d2..75da529 100644 --- a/pomice/events.py +++ b/pomice/events.py @@ -23,7 +23,7 @@ __all__ = ( class PomiceEvent: - """The base class for all events dispatched by a node. + """The base class for all events dispatched by a node. Every event must be formatted within your bot's code as a listener. i.e: If you want to listen for when a track starts, the event would be: ```py diff --git a/pomice/objects.py b/pomice/objects.py index 3d13c61..4e03376 100644 --- a/pomice/objects.py +++ b/pomice/objects.py @@ -34,7 +34,7 @@ class Track: self.timestamp: Optional[float] = timestamp if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC: - self.original: Optional[Track] = None + self.original: Optional[Track] = None else: self.original = self self._search_type: SearchType = search_type @@ -46,10 +46,10 @@ class Track: self.uri: str = info.get("uri") self.identifier: str = info.get("identifier") self.isrc: str = info.get("isrc") - + if self.uri: if info.get("thumbnail"): - self.thumbnail: str = info.get("thumbnail") + self.thumbnail: str = info.get("thumbnail") elif self.track_type == TrackType.SOUNDCLOUD: # ok so theres no feasible way of getting a Soundcloud image URL # so we're just gonna leave it blank for brevity diff --git a/pomice/player.py b/pomice/player.py index 37621df..c84a3ad 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -78,7 +78,7 @@ class Filters: def get_preload_filters(self): """Get all preloaded filters""" - return [f for f in self._filters if f.preload == True] + return [f for f in self._filters if f.preload == True] def get_all_payloads(self): """Returns a formatted dict of all the filter payloads""" @@ -127,10 +127,10 @@ class Player(VoiceProtocol): return self def __init__( - self, - client: Optional[Client] = None, - channel: Optional[VoiceChannel] = None, - *, + self, + client: Optional[Client] = None, + channel: Optional[VoiceChannel] = None, + *, node: Node = None ): self.client: Optional[Client] = client @@ -240,7 +240,7 @@ class Player(VoiceProtocol): async def _dispatch_voice_update(self, voice_data: Dict[str, Any]): if {"sessionId", "event"} != self._voice_state.keys(): return - + data = { "token": voice_data['event']['token'], "endpoint": voice_data['event']['endpoint'], @@ -248,9 +248,9 @@ class Player(VoiceProtocol): } await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"voice": data} ) @@ -302,15 +302,15 @@ class Player(VoiceProtocol): You can pass in a discord.py Context object to get a Context object on any track you search. - You may also pass in a List of filters + You may also pass in a List of filters to be applied to your track once it plays. """ return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters) async def get_recommendations( - self, - *, - track: Track, + self, + *, + track: Track, ctx: Optional[commands.Context] = None ) -> Union[List[Track], None]: """ @@ -329,9 +329,9 @@ class Player(VoiceProtocol): """Stops the currently playing track.""" self._current = None await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={'encodedTrack': None} ) @@ -371,8 +371,8 @@ class Player(VoiceProtocol): # First lets try using the tracks ISRC, every track has one (hopefully) try: if not track.isrc: - # We have to bare raise here because theres no other way to skip this block feasibly - raise + # We have to bare raise here because theres no other way to skip this block feasibly + raise search: Track = (await self._node.get_tracks( f"{track._search_type}:{track.isrc}", ctx=track.ctx))[0] except Exception: @@ -389,7 +389,7 @@ class Player(VoiceProtocol): "encodedTrack": search.track_id, "position": str(start), "endTime": str(track.length) - } + } track.original = search track.track_id = search.track_id # Set track_id for later lavalink searches @@ -412,8 +412,8 @@ class Player(VoiceProtocol): await self.remove_filter(filter_tag=filter.tag) # Global filters take precedence over track filters - # So if no global filters are detected, lets apply any - # necessary track filters + # So if no global filters are detected, lets apply any + # necessary track filters # Check if theres no global filters and if the track has any filters # that need to be applied @@ -427,15 +427,15 @@ class Player(VoiceProtocol): # so now the end time cannot be zero. # If it isnt zero, it'll match the length of the track, # otherwise itll be set here: - + if end > 0: data["endTime"] = str(end) await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, - data=data, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data=data, query=f"noReplace={ignore_if_playing}" ) @@ -449,9 +449,9 @@ class Player(VoiceProtocol): ) await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"position": position} ) return self._position @@ -459,9 +459,9 @@ class Player(VoiceProtocol): async def set_pause(self, pause: bool) -> bool: """Sets the pause state of the currently playing track.""" await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"paused": pause} ) self._paused = pause @@ -470,9 +470,9 @@ class Player(VoiceProtocol): async def set_volume(self, volume: int) -> int: """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"volume": volume} ) self._volume = volume @@ -485,18 +485,18 @@ class Player(VoiceProtocol): (You must have a song playing in order for `fast_apply` to work.) """ - + self._filters.add_filter(filter=filter) payload = self._filters.get_all_payloads() await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"filters": payload} ) if fast_apply: await self.seek(self.position) - + return self._filters async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter: @@ -506,18 +506,18 @@ class Player(VoiceProtocol): (You must have a song playing in order for `fast_apply` to work.) """ - + self._filters.remove_filter(filter_tag=filter_tag) payload = self._filters.get_all_payloads() await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"filters": payload} ) if fast_apply: await self.seek(self.position) - + return self._filters async def reset_filters(self, *, fast_apply: bool = False): @@ -534,14 +534,14 @@ class Player(VoiceProtocol): ) self._filters.reset_filters() await self._node.send( - method="PATCH", - path=self._player_endpoint_uri, - guild_id=self._guild.id, + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, data={"filters": {}} ) if fast_apply: await self.seek(self.position) - - + + diff --git a/pomice/pool.py b/pomice/pool.py index 3122755..8f72d31 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -12,7 +12,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union from urllib.parse import quote from . import ( - __version__, + __version__, spotify, applemusic ) @@ -40,8 +40,8 @@ if TYPE_CHECKING: __all__ = ('Node', 'NodePool') class Node: - """The base class for a node. - This node object represents a Lavalink node. + """The base class for a node. + This node object represents a Lavalink node. To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret To enable Apple music, set the "apple_music" parameter to "True" """ @@ -74,10 +74,10 @@ class Node: self._heartbeat: int = heartbeat self._secure: bool = secure self.fallback: bool = fallback - - - self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" + + + self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._session: Optional[aiohttp.ClientSession] = session @@ -88,7 +88,7 @@ class Node: self._session_id: str = None self._available: bool = False self._version: str = None - + self._route_planner = RoutePlanner(self) self._headers = { @@ -196,8 +196,8 @@ class Node: if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): retry = backoff.delay() await asyncio.sleep(retry) - if not self.is_connected: - self._loop.create_task(self.connect()) + if not self.is_connected: + self._loop.create_task(self.connect()) else: self._loop.create_task(self._handle_payload(msg.json())) @@ -223,12 +223,12 @@ class Node: await player._update_state(data) async def send( - self, + self, method: str, - path: str, - include_version: bool = True, - guild_id: Optional[Union[int, str]] = None, - query: Optional[str] = None, + path: str, + include_version: bool = True, + guild_id: Optional[Union[int, str]] = None, + query: Optional[str] = None, data: Optional[Union[dict, str]] = None, ignore_if_available: bool = False, ): @@ -253,10 +253,10 @@ class Node: if resp.content_type == "text/plain": return await resp.text() - + return await resp.json() - + def get_player(self, guild_id: int): """Takes a guild ID as a parameter. Returns a pomice Player object.""" @@ -278,24 +278,24 @@ class Node: "The Lavalink version you're using is incompatible. " "Lavalink version 3.7.0 or above is required to use this library." ) - + if version.endswith('-SNAPSHOT'): # we're just gonna assume all snapshot versions correlate with v4 self._version = 4 else: - self._version = version[:1] + self._version = version[:1] self._websocket = await self._session.ws_connect( f"{self._websocket_uri}/v{self._version}/websocket", - headers=self._headers, + headers=self._headers, heartbeat=self._heartbeat ) if not self._task: self._task = self._loop.create_task(self._listen()) - self._available = True + self._available = True return self except (aiohttp.ClientConnectorError, ConnectionRefusedError): @@ -322,11 +322,11 @@ class Node: await self._websocket.close() await self._session.close() if self._spotify_client: - await self._spotify_client.session.close() + await self._spotify_client.close() if self._apple_music_client: - await self._apple_music_client.session.close() - + await self._apple_music_client.close() + del self._pool._nodes[self._identifier] self.available = False self._task.cancel() @@ -362,11 +362,11 @@ class Node: You can pass in a discord.py Context object to get a Context object on any track you search. - You may also pass in a List of filters + You may also pass in a List of filters to be applied to your track once it plays. """ - timestamp = None + timestamp = None if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): query = f"{search_type}:{query}" @@ -374,7 +374,7 @@ class Node: if filters: for filter in filters: filter.set_preload() - + if URLRegex.AM_URL.match(query): if not self._apple_music_client: raise AppleMusicNotEnabled( @@ -382,7 +382,7 @@ class Node: "Please set apple_music to True in your Node class." ) - apple_music_results = await self._apple_music_client.search(query=query) + apple_music_results = await self._apple_music_client.search(query=query) if isinstance(apple_music_results, applemusic.Song): return [ Track( @@ -501,7 +501,7 @@ class Node: ) elif discord_url := URLRegex.DISCORD_MP3_URL.match(query): - + data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") track: dict = data["tracks"][0] @@ -533,9 +533,9 @@ class Node: # If query is a video thats part of a playlist, get the video and queue that instead # (I can't tell you how much i've wanted to implement this in here) - if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)): + if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)): query = match.group("video") - + data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") load_type = data.get("loadType") @@ -577,14 +577,14 @@ class Node: ] async def get_recommendations( - self, - *, - track: Track, + self, + *, + track: Track, ctx: Optional[commands.Context] = None ) -> Union[List[Track], None]: """ Gets recommendations from either YouTube or Spotify. - The track that is passed in must be either from + The track that is passed in must be either from YouTube or Spotify or else this will not work. You can pass in a discord.py Context object to get a Context object on all tracks that get recommended. @@ -613,12 +613,12 @@ class Node: ] return tracks - elif track.track_type == TrackType.YOUTUBE: + elif track.track_type == TrackType.YOUTUBE: tracks = await self.get_tracks(query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx) return tracks else: raise TrackLoadError("The specfied track must be either a YouTube or Spotify track to recieve recommendations.") - + class NodePool: """The base class for the node pool. @@ -666,7 +666,7 @@ class NodePool: elif algorithm == NodeAlgorithm.by_players: tested_nodes = {node: len(node.players.keys()) for node in available_nodes} return min(tested_nodes, key=tested_nodes.get) - + @classmethod def get_node(cls, *, identifier: str = None) -> Node: @@ -714,7 +714,7 @@ class NodePool: node = Node( pool=cls, bot=bot, host=host, port=port, password=password, identifier=identifier, secure=secure, heartbeat=heartbeat, - loop=loop, spotify_client_id=spotify_client_id, + loop=loop, spotify_client_id=spotify_client_id, session=session, spotify_client_secret=spotify_client_secret, apple_music=apple_music, fallback=fallback ) diff --git a/pomice/queue.py b/pomice/queue.py index 72d18b1..97a1f8d 100644 --- a/pomice/queue.py +++ b/pomice/queue.py @@ -107,7 +107,7 @@ class Queue(Iterable[Track]): raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.") - def _get(self) -> Track: + def _get(self) -> Track: return self._queue.pop(0) def _drop(self) -> Track: @@ -298,7 +298,7 @@ class Queue(Iterable[Track]): def set_loop_mode(self, mode: LoopMode) -> None: """ - Sets the loop mode of the queue. + Sets the loop mode of the queue. Takes the LoopMode enum as an argument. """ self._loop_mode = mode @@ -306,11 +306,11 @@ class Queue(Iterable[Track]): try: index = self._index(self._current_item) except ValueError: - index = 0 + index = 0 if self._current_item not in self._queue: self._queue.insert(index, self._current_item) self._current_item = self._queue[index] - + def disable_loop(self) -> None: """ @@ -320,12 +320,12 @@ class Queue(Iterable[Track]): if not self._loop_mode: raise QueueException("Queue loop is already disabled.") - if self._loop_mode == LoopMode.QUEUE: - index = self.find_position(self._current_item) + 1 + if self._loop_mode == LoopMode.QUEUE: + index = self.find_position(self._current_item) + 1 self._queue = self._queue[index:] self._loop_mode = None - + def shuffle(self) -> None: """Shuffles the queue.""" diff --git a/pomice/spotify/client.py b/pomice/spotify/client.py index f3669a2..94a5771 100644 --- a/pomice/spotify/client.py +++ b/pomice/spotify/client.py @@ -8,9 +8,7 @@ import orjson as json from base64 import b64encode from typing import TYPE_CHECKING from .exceptions import InvalidSpotifyURL, SpotifyRequestException -from .objects import * - - +from .objects import * GRANT_URL = "https://accounts.spotify.com/api/token" @@ -22,8 +20,8 @@ SPOTIFY_URL_REGEX = re.compile( class Client: """The base client for the Spotify module of Pomice. - This class will do all the heavy lifting of getting all the metadata - for any Spotify URL you throw at it. + This class will do all the heavy lifting of getting all the metadata + for any Spotify URL you throw at it. """ def __init__(self, client_id: str, client_secret: str) -> None: @@ -34,7 +32,9 @@ class Client: self._bearer_token: str = None self._expiry = 0 - self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode()) + self._auth_token = b64encode( + f"{self._client_id}:{self._client_secret}".encode() + ) self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"} self._bearer_headers = None @@ -44,7 +44,9 @@ class Client: if not self.session: self.session = aiohttp.ClientSession() - async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp: + async with self.session.post( + GRANT_URL, data=_data, headers=self._grant_headers + ) as resp: if resp.status != 200: raise SpotifyRequestException( f"Error fetching bearer token: {resp.status} {resp.reason}" @@ -82,28 +84,35 @@ class Client: elif spotify_type == "album": return Album(data) elif spotify_type == "artist": - async with self.session.get(f"{request_url}/top-tracks?market=US", headers=self._bearer_headers) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error while fetching results: {resp.status} {resp.reason}" - ) + async with self.session.get( + f"{request_url}/top-tracks?market=US", headers=self._bearer_headers + ) as resp: + if resp.status != 200: + raise SpotifyRequestException( + f"Error while fetching results: {resp.status} {resp.reason}" + ) - track_data: dict = await resp.json(loads=json.loads) - tracks = track_data['tracks'] - return Artist(data, tracks) + track_data: dict = await resp.json(loads=json.loads) + tracks = track_data["tracks"] + return Artist(data, tracks) else: tracks = [ Track(track["track"]) - for track in data["tracks"]["items"] if track["track"] is not None + for track in data["tracks"]["items"] + if track["track"] is not None ] if not len(tracks): - raise SpotifyRequestException("This playlist is empty and therefore cannot be queued.") - + raise SpotifyRequestException( + "This playlist is empty and therefore cannot be queued." + ) + next_page_url = data["tracks"]["next"] while next_page_url is not None: - async with self.session.get(next_page_url, headers=self._bearer_headers) as resp: + async with self.session.get( + next_page_url, headers=self._bearer_headers + ) as resp: if resp.status != 200: raise SpotifyRequestException( f"Error while fetching results: {resp.status} {resp.reason}" @@ -113,7 +122,8 @@ class Client: tracks += [ Track(track["track"]) - for track in next_data["items"] if track["track"] is not None + for track in next_data["items"] + if track["track"] is not None ] next_page_url = next_data["next"] @@ -133,7 +143,9 @@ class Client: if not spotify_type == "track": raise InvalidSpotifyURL("The provided query is not a Spotify track.") - request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}") + request_url = REQUEST_URL.format( + type="recommendation", id=f"?seed_tracks={spotify_id}" + ) async with self.session.get(request_url, headers=self._bearer_headers) as resp: if resp.status != 200: @@ -145,4 +157,9 @@ class Client: tracks = [Track(track) for track in data["tracks"]] - return tracks \ No newline at end of file + return tracks + + async def close(self) -> None: + if self.session: + await self.session.close() + self.session = None diff --git a/pomice/spotify/objects.py b/pomice/spotify/objects.py index ec7b48c..d4cc652 100644 --- a/pomice/spotify/objects.py +++ b/pomice/spotify/objects.py @@ -43,7 +43,7 @@ class Playlist: if data.get("images") and len(data["images"]): self.image: str = data["images"][0]["url"] else: - self.image = self.tracks[0].image + self.image = self.tracks[0].image self.uri = data["external_urls"]["spotify"] def __repr__(self) -> str: diff --git a/pomice/utils.py b/pomice/utils.py index 31f4a3a..9c7c1ea 100644 --- a/pomice/utils.py +++ b/pomice/utils.py @@ -93,7 +93,7 @@ class NodeStats: class FailingIPBlock: """ The base class for the failing IP block object from the route planner stats. - Gives critical information about any failing addresses on the block + Gives critical information about any failing addresses on the block and the time they failed. """ def __init__(self, data: dict) -> None: @@ -102,7 +102,7 @@ class FailingIPBlock: def __repr__(self) -> str: return f"" - + class RouteStats: """ @@ -182,7 +182,7 @@ class Ping: def get_ping(self): s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM) - + cost_time = self.timer.cost( (s.connect, s.shutdown), ((self._host, self._port), None)) From 4564e89b4e0acdf484689d86315a8c941c0ceca6 Mon Sep 17 00:00:00 2001 From: NiceAesth Date: Fri, 10 Mar 2023 15:57:33 +0200 Subject: [PATCH 2/2] feat: add `disconnect()` to node pool --- pomice/pool.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pomice/pool.py b/pomice/pool.py index 8f72d31..cabb5e2 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -721,4 +721,9 @@ class NodePool: await node.connect() cls._nodes[node._identifier] = node - return node \ No newline at end of file + return node + + async def disconnect(self) -> None: + """Disconnects all nodes from the node pool.""" + for node in self._nodes.copy().values(): + await node.disconnect() \ No newline at end of file