feat: add close to clients; style: formatting pass

This commit is contained in:
NiceAesth 2023-03-10 15:35:41 +02:00
parent 458b686769
commit c5ca63b014
12 changed files with 212 additions and 185 deletions

View File

@ -18,7 +18,7 @@ class Player(pomice.Player):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.queue = pomice.Queue() self.queue = pomice.Queue()
self.controller: discord.Message = None self.controller: discord.Message = None
# Set context here so we can send a now playing embed # Set context here so we can send a now playing embed
@ -43,12 +43,12 @@ class Player(pomice.Player):
if self.controller: if self.controller:
with suppress(discord.HTTPException): with suppress(discord.HTTPException):
await self.controller.delete() await self.controller.delete()
# Queue up the next track, else teardown the player # Queue up the next track, else teardown the player
try: try:
track: pomice.Track = self.queue.get() track: pomice.Track = self.queue.get()
except pomice.QueueEmpty: except pomice.QueueEmpty:
return await self.teardown() return await self.teardown()
await self.play(track) await self.play(track)
@ -68,12 +68,12 @@ class Player(pomice.Player):
with suppress((discord.HTTPException), (KeyError)): with suppress((discord.HTTPException), (KeyError)):
await self.destroy() await self.destroy()
if self.controller: if self.controller:
await self.controller.delete() await self.controller.delete()
async def set_context(self, ctx: commands.Context): async def set_context(self, ctx: commands.Context):
"""Set context for the player""" """Set context for the player"""
self.context = ctx self.context = ctx
self.dj = ctx.author self.dj = ctx.author
@ -81,20 +81,20 @@ class Player(pomice.Player):
class Music(commands.Cog): class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None: def __init__(self, bot: commands.Bot) -> None:
self.bot = bot self.bot = bot
# In order to initialize a node, or really do anything in this library, # In order to initialize a node, or really do anything in this library,
# you need to make a node pool # you need to make a node pool
self.pomice = pomice.NodePool() self.pomice = pomice.NodePool()
# Start the node # Start the node
bot.loop.create_task(self.start_nodes()) bot.loop.create_task(self.start_nodes())
async def start_nodes(self): async def start_nodes(self):
# Waiting for the bot to get ready before connecting to nodes. # Waiting for the bot to get ready before connecting to nodes.
await self.bot.wait_until_ready() await self.bot.wait_until_ready()
# You can pass in Spotify credentials to enable Spotify querying. # You can pass in Spotify credentials to enable Spotify querying.
# If you do not pass in valid Spotify credentials, Spotify querying will not work # If you do not pass in valid Spotify credentials, Spotify querying will not work
await self.pomice.create_node( await self.pomice.create_node(
bot=self.bot, bot=self.bot,
host="127.0.0.1", host="127.0.0.1",
@ -128,7 +128,7 @@ class Music(commands.Cog):
# we can just skip to the next track # we can just skip to the next track
# Of course, you can modify this to do whatever you like # Of course, you can modify this to do whatever you like
@commands.Cog.listener() @commands.Cog.listener()
async def on_pomice_track_end(self, player: Player, track, _): async def on_pomice_track_end(self, player: Player, track, _):
await player.do_next() await player.do_next()
@ -140,7 +140,7 @@ class Music(commands.Cog):
@commands.Cog.listener() @commands.Cog.listener()
async def on_pomice_track_exception(self, player: Player, track, _): async def on_pomice_track_exception(self, player: Player, track, _):
await player.do_next() await player.do_next()
@commands.command(aliases=['joi', 'j', 'summon', 'su', 'con', 'connect']) @commands.command(aliases=['joi', 'j', 'summon', 'su', 'con', 'connect'])
async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None: async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None:
if not channel: if not channel:
@ -165,14 +165,14 @@ class Music(commands.Cog):
await player.destroy() await player.destroy()
await ctx.send("Player has left the channel.") await ctx.send("Player has left the channel.")
@commands.command(aliases=['pla', 'p']) @commands.command(aliases=['pla', 'p'])
async def play(self, ctx: commands.Context, *, search: str) -> None: async def play(self, ctx: commands.Context, *, search: str) -> None:
# Checks if the player is in the channel before we play anything # Checks if the player is in the channel before we play anything
if not (player := ctx.voice_client): if not (player := ctx.voice_client):
await ctx.author.voice.channel.connect(cls=Player) await ctx.author.voice.channel.connect(cls=Player)
player: Player = ctx.voice_client player: Player = ctx.voice_client
await player.set_context(ctx=ctx) await player.set_context(ctx=ctx)
# If you search a keyword, Pomice will automagically search the result using YouTube # If you search a keyword, Pomice will automagically search the result using YouTube
# You can pass in "search_type=" as an argument to change the search type # You can pass in "search_type=" as an argument to change the search type
@ -180,11 +180,11 @@ class Music(commands.Cog):
# will search up any keyword results on YouTube Music # will search up any keyword results on YouTube Music
# We will also set the context here to get special features, like a track.requester object # We will also set the context here to get special features, like a track.requester object
results = await player.get_tracks(search, ctx=ctx) results = await player.get_tracks(search, ctx=ctx)
if not results: if not results:
return await ctx.send("No results were found for that search term", delete_after=7) return await ctx.send("No results were found for that search term", delete_after=7)
if isinstance(results, pomice.Playlist): if isinstance(results, pomice.Playlist):
for track in results.tracks: for track in results.tracks:
player.queue.put(track) player.queue.put(track)

View File

@ -9,22 +9,22 @@ class MyBot(commands.Bot):
command_prefix="!", command_prefix="!",
activity=discord.Activity(type=discord.ActivityType.listening, name="to music!") activity=discord.Activity(type=discord.ActivityType.listening, name="to music!")
) )
self.add_cog(Music(self)) self.add_cog(Music(self))
self.loop.create_task(self.cogs["Music"].start_nodes()) self.loop.create_task(self.cogs["Music"].start_nodes())
async def on_ready(self) -> None: async def on_ready(self) -> None:
print("I'm online!") print("I'm online!")
class Music(commands.Cog): class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None: def __init__(self, bot: commands.Bot) -> None:
self.bot = bot self.bot = bot
# In order to initialize a node, or really do anything in this library, # In order to initialize a node, or really do anything in this library,
# you need to make a node pool # you need to make a node pool
self.pomice = pomice.NodePool() self.pomice = pomice.NodePool()
async def start_nodes(self): async def start_nodes(self):
# You can pass in Spotify credentials to enable Spotify querying. # You can pass in Spotify credentials to enable Spotify querying.
# If you do not pass in valid Spotify credentials, Spotify querying will not work # If you do not pass in valid Spotify credentials, Spotify querying will not work
@ -36,7 +36,7 @@ class Music(commands.Cog):
identifier="MAIN" identifier="MAIN"
) )
print(f"Node is ready!") print(f"Node is ready!")
@commands.command(aliases=["connect"]) @commands.command(aliases=["connect"])
async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None: async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None:
if not channel: if not channel:
@ -62,24 +62,24 @@ class Music(commands.Cog):
await player.destroy() await player.destroy()
await ctx.send("Player has left the channel.") await ctx.send("Player has left the channel.")
@commands.command(aliases=["p"]) @commands.command(aliases=["p"])
async def play(self, ctx: commands.Context, *, search: str) -> None: async def play(self, ctx: commands.Context, *, search: str) -> None:
# Checks if the player is in the channel before we play anything # Checks if the player is in the channel before we play anything
if not ctx.voice_client: if not ctx.voice_client:
await ctx.invoke(self.join) await ctx.invoke(self.join)
player: pomice.Player = ctx.voice_client player: pomice.Player = ctx.voice_client
# If you search a keyword, Pomice will automagically search the result using YouTube # If you search a keyword, Pomice will automagically search the result using YouTube
# You can pass in "search_type=" as an argument to change the search type # You can pass in "search_type=" as an argument to change the search type
# i.e: player.get_tracks("query", search_type=SearchType.ytmsearch) # i.e: player.get_tracks("query", search_type=SearchType.ytmsearch)
# will search up any keyword results on YouTube Music # will search up any keyword results on YouTube Music
results = await player.get_tracks(search) results = await player.get_tracks(search)
if not results: if not results:
raise commands.CommandError("No results were found for that search term.") raise commands.CommandError("No results were found for that search term.")
if isinstance(results, pomice.Playlist): if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0]) await player.play(track=results.tracks[0])
else: else:
@ -124,6 +124,6 @@ class Music(commands.Cog):
await player.stop() await player.stop()
await ctx.send("Player has been stopped") await ctx.send("Player has been stopped")
bot = MyBot() bot = MyBot()
bot.run("token") bot.run("token")

View File

@ -8,18 +8,19 @@ import base64
from datetime import datetime from datetime import datetime
from .objects import * from .objects import *
from .exceptions import * from .exceptions import *
from typing import TYPE_CHECKING
if TYPE_CHECKING: AM_URL_REGEX = re.compile(
from ..pool import Node r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
)
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)") AM_SINGLE_IN_ALBUM_REGEX = re.compile(
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)") r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
)
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}" AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
AM_BASE_URL = "https://api.music.apple.com" AM_BASE_URL = "https://api.music.apple.com"
class Client: class Client:
"""The base Apple Music client for Pomice. """The base Apple Music client for Pomice.
This will do all the heavy lifting of getting tracks from Apple Music This will do all the heavy lifting of getting tracks from Apple Music
and translating it to a valid Lavalink track. No client auth is required here. and translating it to a valid Lavalink track. No client auth is required here.
""" """
@ -30,28 +31,30 @@ class Client:
self.session: aiohttp.ClientSession = None self.session: aiohttp.ClientSession = None
self.headers = None self.headers = None
async def request_token(self): async def request_token(self):
if not self.session: if not self.session:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp: async with self.session.get(
"https://music.apple.com/assets/index.919fe17f.js"
) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
) )
text = await resp.text() text = await resp.text()
result = re.search("\"(eyJ.+?)\"", text).group(1) result = re.search('"(eyJ.+?)"', text).group(1)
self.token = result self.token = result
self.headers = { self.headers = {
'Authorization': f"Bearer {result}", "Authorization": f"Bearer {result}",
'Origin': 'https://apple.com', "Origin": "https://apple.com",
} }
token_split = self.token.split(".")[1] token_split = self.token.split(".")[1]
token_json = base64.b64decode(token_split + '=' * (-len(token_split) % 4)).decode() token_json = base64.b64decode(
token_split + "=" * (-len(token_split) % 4)
).decode()
token_data = json.loads(token_json) token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"]) self.expiry = datetime.fromtimestamp(token_data["exp"])
async def search(self, query: str): async def search(self, query: str):
if not self.token or datetime.utcnow() > self.expiry: if not self.token or datetime.utcnow() > self.expiry:
@ -72,7 +75,6 @@ class Client:
request_url = AM_REQ_URL.format(country=country, type=type, id=id) request_url = AM_REQ_URL.format(country=country, type=type, id=id)
else: else:
request_url = AM_REQ_URL.format(country=country, type=type, id=id) request_url = AM_REQ_URL.format(country=country, type=type, id=id)
async with self.session.get(request_url, headers=self.headers) as resp: async with self.session.get(request_url, headers=self.headers) as resp:
if resp.status != 200: if resp.status != 200:
@ -83,15 +85,16 @@ class Client:
data = data["data"][0] data = data["data"][0]
if type == "song": if type == "song":
return Song(data) return Song(data)
elif type == "album": elif type == "album":
return Album(data) return Album(data)
elif type == "artist": elif type == "artist":
async with self.session.get(f"{request_url}/view/top-songs", headers=self.headers) as resp: async with self.session.get(
f"{request_url}/view/top-songs", headers=self.headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
@ -101,20 +104,24 @@ class Client:
return Artist(data, tracks=tracks) return Artist(data, tracks=tracks)
else: else:
track_data: dict = data["relationships"]["tracks"] track_data: dict = data["relationships"]["tracks"]
tracks = [Song(track) for track in track_data.get("data")] tracks = [Song(track) for track in track_data.get("data")]
if not len(tracks): if not len(tracks):
raise AppleMusicRequestException("This playlist is empty and therefore cannot be queued.") raise AppleMusicRequestException(
"This playlist is empty and therefore cannot be queued."
)
if track_data.get("next"): if track_data.get("next"):
next_page_url = AM_BASE_URL + track_data.get("next") next_page_url = AM_BASE_URL + track_data.get("next")
while next_page_url is not None: while next_page_url is not None:
async with self.session.get(next_page_url, headers=self.headers) as resp: async with self.session.get(
next_page_url, headers=self.headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
@ -128,6 +135,9 @@ class Client:
else: else:
next_page_url = None next_page_url = None
return Playlist(data, tracks)
return Playlist(data, tracks) async def close(self):
if self.session:
await self.session.close()
self.session = None

View File

@ -89,7 +89,7 @@ class PlaylistType(Enum):
class NodeAlgorithm(Enum): class NodeAlgorithm(Enum):
""" """
The enum for the different node algorithms in Pomice. The enum for the different node algorithms in Pomice.
The enums in this class are to only differentiate different The enums in this class are to only differentiate different
methods, since the actual method is handled in the methods, since the actual method is handled in the
get_best_node() method. get_best_node() method.
@ -123,7 +123,7 @@ class LoopMode(Enum):
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
TRACK = "track" TRACK = "track"
QUEUE = "queue" QUEUE = "queue"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -135,16 +135,16 @@ class RouteStrategy(Enum):
This feature is exclusively for the RoutePlanner class. This feature is exclusively for the RoutePlanner class.
If you are not using this feature, this class is not necessary. If you are not using this feature, this class is not necessary.
RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs
whenever they get banned by Youtube. whenever they get banned by Youtube.
RouteStrategy.LOAD_BALANCE specifies that the node is selecting RouteStrategy.LOAD_BALANCE specifies that the node is selecting
random IPs to balance out requests between them. random IPs to balance out requests between them.
RouteStrategy.NANO_SWITCH specifies that the node is switching RouteStrategy.NANO_SWITCH specifies that the node is switching
between IPs every CPU clock cycle. between IPs every CPU clock cycle.
RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching
between IPs every CPU clock cycle and is rotating between IP blocks on between IPs every CPU clock cycle and is rotating between IP blocks on
ban. ban.

View File

@ -23,7 +23,7 @@ __all__ = (
class PomiceEvent: class PomiceEvent:
"""The base class for all events dispatched by a node. """The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener. Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be: i.e: If you want to listen for when a track starts, the event would be:
```py ```py

View File

@ -34,7 +34,7 @@ class Track:
self.timestamp: Optional[float] = timestamp self.timestamp: Optional[float] = timestamp
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC: if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
self.original: Optional[Track] = None self.original: Optional[Track] = None
else: else:
self.original = self self.original = self
self._search_type: SearchType = search_type self._search_type: SearchType = search_type
@ -46,10 +46,10 @@ class Track:
self.uri: str = info.get("uri") self.uri: str = info.get("uri")
self.identifier: str = info.get("identifier") self.identifier: str = info.get("identifier")
self.isrc: str = info.get("isrc") self.isrc: str = info.get("isrc")
if self.uri: if self.uri:
if info.get("thumbnail"): if info.get("thumbnail"):
self.thumbnail: str = info.get("thumbnail") self.thumbnail: str = info.get("thumbnail")
elif self.track_type == TrackType.SOUNDCLOUD: elif self.track_type == TrackType.SOUNDCLOUD:
# ok so theres no feasible way of getting a Soundcloud image URL # ok so theres no feasible way of getting a Soundcloud image URL
# so we're just gonna leave it blank for brevity # so we're just gonna leave it blank for brevity

View File

@ -78,7 +78,7 @@ class Filters:
def get_preload_filters(self): def get_preload_filters(self):
"""Get all preloaded filters""" """Get all preloaded filters"""
return [f for f in self._filters if f.preload == True] return [f for f in self._filters if f.preload == True]
def get_all_payloads(self): def get_all_payloads(self):
"""Returns a formatted dict of all the filter payloads""" """Returns a formatted dict of all the filter payloads"""
@ -127,10 +127,10 @@ class Player(VoiceProtocol):
return self return self
def __init__( def __init__(
self, self,
client: Optional[Client] = None, client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None, channel: Optional[VoiceChannel] = None,
*, *,
node: Node = None node: Node = None
): ):
self.client: Optional[Client] = client self.client: Optional[Client] = client
@ -240,7 +240,7 @@ class Player(VoiceProtocol):
async def _dispatch_voice_update(self, voice_data: Dict[str, Any]): async def _dispatch_voice_update(self, voice_data: Dict[str, Any]):
if {"sessionId", "event"} != self._voice_state.keys(): if {"sessionId", "event"} != self._voice_state.keys():
return return
data = { data = {
"token": voice_data['event']['token'], "token": voice_data['event']['token'],
"endpoint": voice_data['event']['endpoint'], "endpoint": voice_data['event']['endpoint'],
@ -248,9 +248,9 @@ class Player(VoiceProtocol):
} }
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"voice": data} data={"voice": data}
) )
@ -302,15 +302,15 @@ class Player(VoiceProtocol):
You can pass in a discord.py Context object to get a You can pass in a discord.py Context object to get a
Context object on any track you search. Context object on any track you search.
You may also pass in a List of filters You may also pass in a List of filters
to be applied to your track once it plays. to be applied to your track once it plays.
""" """
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters) return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def get_recommendations( async def get_recommendations(
self, self,
*, *,
track: Track, track: Track,
ctx: Optional[commands.Context] = None ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Union[List[Track], None]:
""" """
@ -329,9 +329,9 @@ class Player(VoiceProtocol):
"""Stops the currently playing track.""" """Stops the currently playing track."""
self._current = None self._current = None
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={'encodedTrack': None} data={'encodedTrack': None}
) )
@ -371,8 +371,8 @@ class Player(VoiceProtocol):
# First lets try using the tracks ISRC, every track has one (hopefully) # First lets try using the tracks ISRC, every track has one (hopefully)
try: try:
if not track.isrc: if not track.isrc:
# We have to bare raise here because theres no other way to skip this block feasibly # We have to bare raise here because theres no other way to skip this block feasibly
raise raise
search: Track = (await self._node.get_tracks( search: Track = (await self._node.get_tracks(
f"{track._search_type}:{track.isrc}", ctx=track.ctx))[0] f"{track._search_type}:{track.isrc}", ctx=track.ctx))[0]
except Exception: except Exception:
@ -389,7 +389,7 @@ class Player(VoiceProtocol):
"encodedTrack": search.track_id, "encodedTrack": search.track_id,
"position": str(start), "position": str(start),
"endTime": str(track.length) "endTime": str(track.length)
} }
track.original = search track.original = search
track.track_id = search.track_id track.track_id = search.track_id
# Set track_id for later lavalink searches # Set track_id for later lavalink searches
@ -412,8 +412,8 @@ class Player(VoiceProtocol):
await self.remove_filter(filter_tag=filter.tag) await self.remove_filter(filter_tag=filter.tag)
# Global filters take precedence over track filters # Global filters take precedence over track filters
# So if no global filters are detected, lets apply any # So if no global filters are detected, lets apply any
# necessary track filters # necessary track filters
# Check if theres no global filters and if the track has any filters # Check if theres no global filters and if the track has any filters
# that need to be applied # that need to be applied
@ -427,15 +427,15 @@ class Player(VoiceProtocol):
# so now the end time cannot be zero. # so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track, # If it isnt zero, it'll match the length of the track,
# otherwise itll be set here: # otherwise itll be set here:
if end > 0: if end > 0:
data["endTime"] = str(end) data["endTime"] = str(end)
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data=data, data=data,
query=f"noReplace={ignore_if_playing}" query=f"noReplace={ignore_if_playing}"
) )
@ -449,9 +449,9 @@ class Player(VoiceProtocol):
) )
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"position": position} data={"position": position}
) )
return self._position return self._position
@ -459,9 +459,9 @@ class Player(VoiceProtocol):
async def set_pause(self, pause: bool) -> bool: async def set_pause(self, pause: bool) -> bool:
"""Sets the pause state of the currently playing track.""" """Sets the pause state of the currently playing track."""
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"paused": pause} data={"paused": pause}
) )
self._paused = pause self._paused = pause
@ -470,9 +470,9 @@ class Player(VoiceProtocol):
async def set_volume(self, volume: int) -> int: async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"volume": volume} data={"volume": volume}
) )
self._volume = volume self._volume = volume
@ -485,18 +485,18 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
self._filters.add_filter(filter=filter) self._filters.add_filter(filter=filter)
payload = self._filters.get_all_payloads() payload = self._filters.get_all_payloads()
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": payload} data={"filters": payload}
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
return self._filters return self._filters
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter: async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter:
@ -506,18 +506,18 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
self._filters.remove_filter(filter_tag=filter_tag) self._filters.remove_filter(filter_tag=filter_tag)
payload = self._filters.get_all_payloads() payload = self._filters.get_all_payloads()
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": payload} data={"filters": payload}
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
return self._filters return self._filters
async def reset_filters(self, *, fast_apply: bool = False): async def reset_filters(self, *, fast_apply: bool = False):
@ -534,14 +534,14 @@ class Player(VoiceProtocol):
) )
self._filters.reset_filters() self._filters.reset_filters()
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": {}} data={"filters": {}}
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)

View File

@ -12,7 +12,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote from urllib.parse import quote
from . import ( from . import (
__version__, __version__,
spotify, spotify,
applemusic applemusic
) )
@ -40,8 +40,8 @@ if TYPE_CHECKING:
__all__ = ('Node', 'NodePool') __all__ = ('Node', 'NodePool')
class Node: class Node:
"""The base class for a node. """The base class for a node.
This node object represents a Lavalink node. This node object represents a Lavalink node.
To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret
To enable Apple music, set the "apple_music" parameter to "True" To enable Apple music, set the "apple_music" parameter to "True"
""" """
@ -74,10 +74,10 @@ class Node:
self._heartbeat: int = heartbeat self._heartbeat: int = heartbeat
self._secure: bool = secure self._secure: bool = secure
self.fallback: bool = fallback self.fallback: bool = fallback
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: Optional[aiohttp.ClientSession] = session self._session: Optional[aiohttp.ClientSession] = session
@ -88,7 +88,7 @@ class Node:
self._session_id: str = None self._session_id: str = None
self._available: bool = False self._available: bool = False
self._version: str = None self._version: str = None
self._route_planner = RoutePlanner(self) self._route_planner = RoutePlanner(self)
self._headers = { self._headers = {
@ -196,8 +196,8 @@ class Node:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
retry = backoff.delay() retry = backoff.delay()
await asyncio.sleep(retry) await asyncio.sleep(retry)
if not self.is_connected: if not self.is_connected:
self._loop.create_task(self.connect()) self._loop.create_task(self.connect())
else: else:
self._loop.create_task(self._handle_payload(msg.json())) self._loop.create_task(self._handle_payload(msg.json()))
@ -223,12 +223,12 @@ class Node:
await player._update_state(data) await player._update_state(data)
async def send( async def send(
self, self,
method: str, method: str,
path: str, path: str,
include_version: bool = True, include_version: bool = True,
guild_id: Optional[Union[int, str]] = None, guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None, query: Optional[str] = None,
data: Optional[Union[dict, str]] = None, data: Optional[Union[dict, str]] = None,
ignore_if_available: bool = False, ignore_if_available: bool = False,
): ):
@ -253,10 +253,10 @@ class Node:
if resp.content_type == "text/plain": if resp.content_type == "text/plain":
return await resp.text() return await resp.text()
return await resp.json() return await resp.json()
def get_player(self, guild_id: int): def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object.""" """Takes a guild ID as a parameter. Returns a pomice Player object."""
@ -278,24 +278,24 @@ class Node:
"The Lavalink version you're using is incompatible. " "The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library." "Lavalink version 3.7.0 or above is required to use this library."
) )
if version.endswith('-SNAPSHOT'): if version.endswith('-SNAPSHOT'):
# we're just gonna assume all snapshot versions correlate with v4 # we're just gonna assume all snapshot versions correlate with v4
self._version = 4 self._version = 4
else: else:
self._version = version[:1] self._version = version[:1]
self._websocket = await self._session.ws_connect( self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket", f"{self._websocket_uri}/v{self._version}/websocket",
headers=self._headers, headers=self._headers,
heartbeat=self._heartbeat heartbeat=self._heartbeat
) )
if not self._task: if not self._task:
self._task = self._loop.create_task(self._listen()) self._task = self._loop.create_task(self._listen())
self._available = True self._available = True
return self return self
except (aiohttp.ClientConnectorError, ConnectionRefusedError): except (aiohttp.ClientConnectorError, ConnectionRefusedError):
@ -322,11 +322,11 @@ class Node:
await self._websocket.close() await self._websocket.close()
await self._session.close() await self._session.close()
if self._spotify_client: if self._spotify_client:
await self._spotify_client.session.close() await self._spotify_client.close()
if self._apple_music_client: if self._apple_music_client:
await self._apple_music_client.session.close() await self._apple_music_client.close()
del self._pool._nodes[self._identifier] del self._pool._nodes[self._identifier]
self.available = False self.available = False
self._task.cancel() self._task.cancel()
@ -362,11 +362,11 @@ class Node:
You can pass in a discord.py Context object to get a You can pass in a discord.py Context object to get a
Context object on any track you search. Context object on any track you search.
You may also pass in a List of filters You may also pass in a List of filters
to be applied to your track once it plays. to be applied to your track once it plays.
""" """
timestamp = None timestamp = None
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}" query = f"{search_type}:{query}"
@ -374,7 +374,7 @@ class Node:
if filters: if filters:
for filter in filters: for filter in filters:
filter.set_preload() filter.set_preload()
if URLRegex.AM_URL.match(query): if URLRegex.AM_URL.match(query):
if not self._apple_music_client: if not self._apple_music_client:
raise AppleMusicNotEnabled( raise AppleMusicNotEnabled(
@ -382,7 +382,7 @@ class Node:
"Please set apple_music to True in your Node class." "Please set apple_music to True in your Node class."
) )
apple_music_results = await self._apple_music_client.search(query=query) apple_music_results = await self._apple_music_client.search(query=query)
if isinstance(apple_music_results, applemusic.Song): if isinstance(apple_music_results, applemusic.Song):
return [ return [
Track( Track(
@ -501,7 +501,7 @@ class Node:
) )
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query): elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}")
track: dict = data["tracks"][0] track: dict = data["tracks"][0]
@ -533,9 +533,9 @@ class Node:
# If query is a video thats part of a playlist, get the video and queue that instead # If query is a video thats part of a playlist, get the video and queue that instead
# (I can't tell you how much i've wanted to implement this in here) # (I can't tell you how much i've wanted to implement this in here)
if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)): if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)):
query = match.group("video") query = match.group("video")
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}")
load_type = data.get("loadType") load_type = data.get("loadType")
@ -577,14 +577,14 @@ class Node:
] ]
async def get_recommendations( async def get_recommendations(
self, self,
*, *,
track: Track, track: Track,
ctx: Optional[commands.Context] = None ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Union[List[Track], None]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
The track that is passed in must be either from The track that is passed in must be either from
YouTube or Spotify or else this will not work. YouTube or Spotify or else this will not work.
You can pass in a discord.py Context object to get a You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended. Context object on all tracks that get recommended.
@ -613,12 +613,12 @@ class Node:
] ]
return tracks return tracks
elif track.track_type == TrackType.YOUTUBE: elif track.track_type == TrackType.YOUTUBE:
tracks = await self.get_tracks(query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx) tracks = await self.get_tracks(query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx)
return tracks return tracks
else: else:
raise TrackLoadError("The specfied track must be either a YouTube or Spotify track to recieve recommendations.") raise TrackLoadError("The specfied track must be either a YouTube or Spotify track to recieve recommendations.")
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.
@ -666,7 +666,7 @@ class NodePool:
elif algorithm == NodeAlgorithm.by_players: elif algorithm == NodeAlgorithm.by_players:
tested_nodes = {node: len(node.players.keys()) for node in available_nodes} tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get) return min(tested_nodes, key=tested_nodes.get)
@classmethod @classmethod
def get_node(cls, *, identifier: str = None) -> Node: def get_node(cls, *, identifier: str = None) -> Node:
@ -714,7 +714,7 @@ class NodePool:
node = Node( node = Node(
pool=cls, bot=bot, host=host, port=port, password=password, pool=cls, bot=bot, host=host, port=port, password=password,
identifier=identifier, secure=secure, heartbeat=heartbeat, identifier=identifier, secure=secure, heartbeat=heartbeat,
loop=loop, spotify_client_id=spotify_client_id, loop=loop, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret, session=session, spotify_client_secret=spotify_client_secret,
apple_music=apple_music, fallback=fallback apple_music=apple_music, fallback=fallback
) )

View File

@ -107,7 +107,7 @@ class Queue(Iterable[Track]):
raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.") raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.")
def _get(self) -> Track: def _get(self) -> Track:
return self._queue.pop(0) return self._queue.pop(0)
def _drop(self) -> Track: def _drop(self) -> Track:
@ -298,7 +298,7 @@ class Queue(Iterable[Track]):
def set_loop_mode(self, mode: LoopMode) -> None: def set_loop_mode(self, mode: LoopMode) -> None:
""" """
Sets the loop mode of the queue. Sets the loop mode of the queue.
Takes the LoopMode enum as an argument. Takes the LoopMode enum as an argument.
""" """
self._loop_mode = mode self._loop_mode = mode
@ -306,11 +306,11 @@ class Queue(Iterable[Track]):
try: try:
index = self._index(self._current_item) index = self._index(self._current_item)
except ValueError: except ValueError:
index = 0 index = 0
if self._current_item not in self._queue: if self._current_item not in self._queue:
self._queue.insert(index, self._current_item) self._queue.insert(index, self._current_item)
self._current_item = self._queue[index] self._current_item = self._queue[index]
def disable_loop(self) -> None: def disable_loop(self) -> None:
""" """
@ -320,12 +320,12 @@ class Queue(Iterable[Track]):
if not self._loop_mode: if not self._loop_mode:
raise QueueException("Queue loop is already disabled.") raise QueueException("Queue loop is already disabled.")
if self._loop_mode == LoopMode.QUEUE: if self._loop_mode == LoopMode.QUEUE:
index = self.find_position(self._current_item) + 1 index = self.find_position(self._current_item) + 1
self._queue = self._queue[index:] self._queue = self._queue[index:]
self._loop_mode = None self._loop_mode = None
def shuffle(self) -> None: def shuffle(self) -> None:
"""Shuffles the queue.""" """Shuffles the queue."""

View File

@ -8,9 +8,7 @@ import orjson as json
from base64 import b64encode from base64 import b64encode
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from .exceptions import InvalidSpotifyURL, SpotifyRequestException from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .objects import * from .objects import *
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
@ -22,8 +20,8 @@ SPOTIFY_URL_REGEX = re.compile(
class Client: class Client:
"""The base client for the Spotify module of Pomice. """The base client for the Spotify module of Pomice.
This class will do all the heavy lifting of getting all the metadata This class will do all the heavy lifting of getting all the metadata
for any Spotify URL you throw at it. for any Spotify URL you throw at it.
""" """
def __init__(self, client_id: str, client_secret: str) -> None: def __init__(self, client_id: str, client_secret: str) -> None:
@ -34,7 +32,9 @@ class Client:
self._bearer_token: str = None self._bearer_token: str = None
self._expiry = 0 self._expiry = 0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode()) self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode()
)
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"} self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._bearer_headers = None self._bearer_headers = None
@ -44,7 +44,9 @@ class Client:
if not self.session: if not self.session:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp: async with self.session.post(
GRANT_URL, data=_data, headers=self._grant_headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}" f"Error fetching bearer token: {resp.status} {resp.reason}"
@ -82,28 +84,35 @@ class Client:
elif spotify_type == "album": elif spotify_type == "album":
return Album(data) return Album(data)
elif spotify_type == "artist": elif spotify_type == "artist":
async with self.session.get(f"{request_url}/top-tracks?market=US", headers=self._bearer_headers) as resp: async with self.session.get(
if resp.status != 200: f"{request_url}/top-tracks?market=US", headers=self._bearer_headers
raise SpotifyRequestException( ) as resp:
f"Error while fetching results: {resp.status} {resp.reason}" if resp.status != 200:
) raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
)
track_data: dict = await resp.json(loads=json.loads) track_data: dict = await resp.json(loads=json.loads)
tracks = track_data['tracks'] tracks = track_data["tracks"]
return Artist(data, tracks) return Artist(data, tracks)
else: else:
tracks = [ tracks = [
Track(track["track"]) Track(track["track"])
for track in data["tracks"]["items"] if track["track"] is not None for track in data["tracks"]["items"]
if track["track"] is not None
] ]
if not len(tracks): if not len(tracks):
raise SpotifyRequestException("This playlist is empty and therefore cannot be queued.") raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued."
)
next_page_url = data["tracks"]["next"] next_page_url = data["tracks"]["next"]
while next_page_url is not None: while next_page_url is not None:
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp: async with self.session.get(
next_page_url, headers=self._bearer_headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
@ -113,7 +122,8 @@ class Client:
tracks += [ tracks += [
Track(track["track"]) Track(track["track"])
for track in next_data["items"] if track["track"] is not None for track in next_data["items"]
if track["track"] is not None
] ]
next_page_url = next_data["next"] next_page_url = next_data["next"]
@ -133,7 +143,9 @@ class Client:
if not spotify_type == "track": if not spotify_type == "track":
raise InvalidSpotifyURL("The provided query is not a Spotify track.") raise InvalidSpotifyURL("The provided query is not a Spotify track.")
request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}") request_url = REQUEST_URL.format(
type="recommendation", id=f"?seed_tracks={spotify_id}"
)
async with self.session.get(request_url, headers=self._bearer_headers) as resp: async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200: if resp.status != 200:
@ -145,4 +157,9 @@ class Client:
tracks = [Track(track) for track in data["tracks"]] tracks = [Track(track) for track in data["tracks"]]
return tracks return tracks
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None

View File

@ -43,7 +43,7 @@ class Playlist:
if data.get("images") and len(data["images"]): if data.get("images") and len(data["images"]):
self.image: str = data["images"][0]["url"] self.image: str = data["images"][0]["url"]
else: else:
self.image = self.tracks[0].image self.image = self.tracks[0].image
self.uri = data["external_urls"]["spotify"] self.uri = data["external_urls"]["spotify"]
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -93,7 +93,7 @@ class NodeStats:
class FailingIPBlock: class FailingIPBlock:
""" """
The base class for the failing IP block object from the route planner stats. The base class for the failing IP block object from the route planner stats.
Gives critical information about any failing addresses on the block Gives critical information about any failing addresses on the block
and the time they failed. and the time they failed.
""" """
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
@ -102,7 +102,7 @@ class FailingIPBlock:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>" return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>"
class RouteStats: class RouteStats:
""" """
@ -182,7 +182,7 @@ class Ping:
def get_ping(self): def get_ping(self):
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM) s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost( cost_time = self.timer.cost(
(s.connect, s.shutdown), (s.connect, s.shutdown),
((self._host, self._port), None)) ((self._host, self._port), None))