Merge pull request #34 from NiceAesth/close-clients

feat: add close to clients; style: formatting pass
This commit is contained in:
Clxud 2023-03-10 20:44:01 -05:00 committed by GitHub
commit 145634ce79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 218 additions and 186 deletions

View File

@ -18,7 +18,7 @@ class Player(pomice.Player):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.queue = pomice.Queue()
self.controller: discord.Message = None
# Set context here so we can send a now playing embed
@ -43,12 +43,12 @@ class Player(pomice.Player):
if self.controller:
with suppress(discord.HTTPException):
await self.controller.delete()
# Queue up the next track, else teardown the player
try:
track: pomice.Track = self.queue.get()
except pomice.QueueEmpty:
except pomice.QueueEmpty:
return await self.teardown()
await self.play(track)
@ -68,12 +68,12 @@ class Player(pomice.Player):
with suppress((discord.HTTPException), (KeyError)):
await self.destroy()
if self.controller:
await self.controller.delete()
await self.controller.delete()
async def set_context(self, ctx: commands.Context):
"""Set context for the player"""
self.context = ctx
self.dj = ctx.author
self.context = ctx
self.dj = ctx.author
@ -81,20 +81,20 @@ class Player(pomice.Player):
class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
# In order to initialize a node, or really do anything in this library,
# you need to make a node pool
self.pomice = pomice.NodePool()
# Start the node
bot.loop.create_task(self.start_nodes())
async def start_nodes(self):
# Waiting for the bot to get ready before connecting to nodes.
await self.bot.wait_until_ready()
# You can pass in Spotify credentials to enable Spotify querying.
# If you do not pass in valid Spotify credentials, Spotify querying will not work
# If you do not pass in valid Spotify credentials, Spotify querying will not work
await self.pomice.create_node(
bot=self.bot,
host="127.0.0.1",
@ -128,7 +128,7 @@ class Music(commands.Cog):
# we can just skip to the next track
# Of course, you can modify this to do whatever you like
@commands.Cog.listener()
async def on_pomice_track_end(self, player: Player, track, _):
await player.do_next()
@ -140,7 +140,7 @@ class Music(commands.Cog):
@commands.Cog.listener()
async def on_pomice_track_exception(self, player: Player, track, _):
await player.do_next()
@commands.command(aliases=['joi', 'j', 'summon', 'su', 'con', 'connect'])
async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None:
if not channel:
@ -165,14 +165,14 @@ class Music(commands.Cog):
await player.destroy()
await ctx.send("Player has left the channel.")
@commands.command(aliases=['pla', 'p'])
async def play(self, ctx: commands.Context, *, search: str) -> None:
# Checks if the player is in the channel before we play anything
if not (player := ctx.voice_client):
await ctx.author.voice.channel.connect(cls=Player)
player: Player = ctx.voice_client
await player.set_context(ctx=ctx)
await player.set_context(ctx=ctx)
# If you search a keyword, Pomice will automagically search the result using YouTube
# You can pass in "search_type=" as an argument to change the search type
@ -180,11 +180,11 @@ class Music(commands.Cog):
# will search up any keyword results on YouTube Music
# We will also set the context here to get special features, like a track.requester object
results = await player.get_tracks(search, ctx=ctx)
results = await player.get_tracks(search, ctx=ctx)
if not results:
return await ctx.send("No results were found for that search term", delete_after=7)
if isinstance(results, pomice.Playlist):
for track in results.tracks:
player.queue.put(track)

View File

@ -9,22 +9,22 @@ class MyBot(commands.Bot):
command_prefix="!",
activity=discord.Activity(type=discord.ActivityType.listening, name="to music!")
)
self.add_cog(Music(self))
self.loop.create_task(self.cogs["Music"].start_nodes())
async def on_ready(self) -> None:
print("I'm online!")
class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
# In order to initialize a node, or really do anything in this library,
# you need to make a node pool
self.pomice = pomice.NodePool()
async def start_nodes(self):
# You can pass in Spotify credentials to enable Spotify querying.
# If you do not pass in valid Spotify credentials, Spotify querying will not work
@ -36,7 +36,7 @@ class Music(commands.Cog):
identifier="MAIN"
)
print(f"Node is ready!")
@commands.command(aliases=["connect"])
async def join(self, ctx: commands.Context, *, channel: discord.VoiceChannel = None) -> None:
if not channel:
@ -62,24 +62,24 @@ class Music(commands.Cog):
await player.destroy()
await ctx.send("Player has left the channel.")
@commands.command(aliases=["p"])
async def play(self, ctx: commands.Context, *, search: str) -> None:
# Checks if the player is in the channel before we play anything
if not ctx.voice_client:
await ctx.invoke(self.join)
await ctx.invoke(self.join)
player: pomice.Player = ctx.voice_client
player: pomice.Player = ctx.voice_client
# If you search a keyword, Pomice will automagically search the result using YouTube
# You can pass in "search_type=" as an argument to change the search type
# i.e: player.get_tracks("query", search_type=SearchType.ytmsearch)
# will search up any keyword results on YouTube Music
results = await player.get_tracks(search)
results = await player.get_tracks(search)
if not results:
raise commands.CommandError("No results were found for that search term.")
if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0])
else:
@ -124,6 +124,6 @@ class Music(commands.Cog):
await player.stop()
await ctx.send("Player has been stopped")
bot = MyBot()
bot.run("token")

View File

@ -8,18 +8,19 @@ import base64
from datetime import datetime
from .objects import *
from .exceptions import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..pool import Node
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
)
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
AM_BASE_URL = "https://api.music.apple.com"
class Client:
"""The base Apple Music client for Pomice.
"""The base Apple Music client for Pomice.
This will do all the heavy lifting of getting tracks from Apple Music
and translating it to a valid Lavalink track. No client auth is required here.
"""
@ -30,28 +31,30 @@ class Client:
self.session: aiohttp.ClientSession = None
self.headers = None
async def request_token(self):
if not self.session:
self.session = aiohttp.ClientSession()
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
async with self.session.get(
"https://music.apple.com/assets/index.919fe17f.js"
) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
)
text = await resp.text()
result = re.search("\"(eyJ.+?)\"", text).group(1)
result = re.search('"(eyJ.+?)"', text).group(1)
self.token = result
self.headers = {
'Authorization': f"Bearer {result}",
'Origin': 'https://apple.com',
"Authorization": f"Bearer {result}",
"Origin": "https://apple.com",
}
token_split = self.token.split(".")[1]
token_json = base64.b64decode(token_split + '=' * (-len(token_split) % 4)).decode()
token_json = base64.b64decode(
token_split + "=" * (-len(token_split) % 4)
).decode()
token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"])
async def search(self, query: str):
if not self.token or datetime.utcnow() > self.expiry:
@ -72,7 +75,6 @@ class Client:
request_url = AM_REQ_URL.format(country=country, type=type, id=id)
else:
request_url = AM_REQ_URL.format(country=country, type=type, id=id)
async with self.session.get(request_url, headers=self.headers) as resp:
if resp.status != 200:
@ -83,15 +85,16 @@ class Client:
data = data["data"][0]
if type == "song":
return Song(data)
elif type == "album":
return Album(data)
elif type == "artist":
async with self.session.get(f"{request_url}/view/top-songs", headers=self.headers) as resp:
async with self.session.get(
f"{request_url}/view/top-songs", headers=self.headers
) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
@ -101,20 +104,24 @@ class Client:
return Artist(data, tracks=tracks)
else:
else:
track_data: dict = data["relationships"]["tracks"]
tracks = [Song(track) for track in track_data.get("data")]
if not len(tracks):
raise AppleMusicRequestException("This playlist is empty and therefore cannot be queued.")
raise AppleMusicRequestException(
"This playlist is empty and therefore cannot be queued."
)
if track_data.get("next"):
if track_data.get("next"):
next_page_url = AM_BASE_URL + track_data.get("next")
while next_page_url is not None:
async with self.session.get(next_page_url, headers=self.headers) as resp:
async with self.session.get(
next_page_url, headers=self.headers
) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
@ -128,6 +135,9 @@ class Client:
else:
next_page_url = None
return Playlist(data, tracks)
return Playlist(data, tracks)
async def close(self):
if self.session:
await self.session.close()
self.session = None

View File

@ -89,7 +89,7 @@ class PlaylistType(Enum):
class NodeAlgorithm(Enum):
"""
The enum for the different node algorithms in Pomice.
The enums in this class are to only differentiate different
methods, since the actual method is handled in the
get_best_node() method.
@ -123,7 +123,7 @@ class LoopMode(Enum):
# We don't have to define anything special for these, since these just serve as flags
TRACK = "track"
QUEUE = "queue"
def __str__(self) -> str:
return self.value
@ -135,16 +135,16 @@ class RouteStrategy(Enum):
This feature is exclusively for the RoutePlanner class.
If you are not using this feature, this class is not necessary.
RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs
RouteStrategy.ROTATE_ON_BAN specifies that the node is rotating IPs
whenever they get banned by Youtube.
RouteStrategy.LOAD_BALANCE specifies that the node is selecting
random IPs to balance out requests between them.
RouteStrategy.NANO_SWITCH specifies that the node is switching
RouteStrategy.NANO_SWITCH specifies that the node is switching
between IPs every CPU clock cycle.
RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching
RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching
between IPs every CPU clock cycle and is rotating between IP blocks on
ban.

View File

@ -23,7 +23,7 @@ __all__ = (
class PomiceEvent:
"""The base class for all events dispatched by a node.
"""The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be:
```py

View File

@ -34,7 +34,7 @@ class Track:
self.timestamp: Optional[float] = timestamp
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
self.original: Optional[Track] = None
self.original: Optional[Track] = None
else:
self.original = self
self._search_type: SearchType = search_type
@ -46,10 +46,10 @@ class Track:
self.uri: str = info.get("uri")
self.identifier: str = info.get("identifier")
self.isrc: str = info.get("isrc")
if self.uri:
if info.get("thumbnail"):
self.thumbnail: str = info.get("thumbnail")
self.thumbnail: str = info.get("thumbnail")
elif self.track_type == TrackType.SOUNDCLOUD:
# ok so theres no feasible way of getting a Soundcloud image URL
# so we're just gonna leave it blank for brevity

View File

@ -78,7 +78,7 @@ class Filters:
def get_preload_filters(self):
"""Get all preloaded filters"""
return [f for f in self._filters if f.preload == True]
return [f for f in self._filters if f.preload == True]
def get_all_payloads(self):
"""Returns a formatted dict of all the filter payloads"""
@ -127,10 +127,10 @@ class Player(VoiceProtocol):
return self
def __init__(
self,
client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None,
*,
self,
client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None,
*,
node: Node = None
):
self.client: Optional[Client] = client
@ -240,7 +240,7 @@ class Player(VoiceProtocol):
async def _dispatch_voice_update(self, voice_data: Dict[str, Any]):
if {"sessionId", "event"} != self._voice_state.keys():
return
data = {
"token": voice_data['event']['token'],
"endpoint": voice_data['event']['endpoint'],
@ -248,9 +248,9 @@ class Player(VoiceProtocol):
}
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"voice": data}
)
@ -302,15 +302,15 @@ class Player(VoiceProtocol):
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
You may also pass in a List of filters
to be applied to your track once it plays.
"""
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def get_recommendations(
self,
*,
track: Track,
self,
*,
track: Track,
ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]:
"""
@ -329,9 +329,9 @@ class Player(VoiceProtocol):
"""Stops the currently playing track."""
self._current = None
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={'encodedTrack': None}
)
@ -371,8 +371,8 @@ class Player(VoiceProtocol):
# First lets try using the tracks ISRC, every track has one (hopefully)
try:
if not track.isrc:
# We have to bare raise here because theres no other way to skip this block feasibly
raise
# We have to bare raise here because theres no other way to skip this block feasibly
raise
search: Track = (await self._node.get_tracks(
f"{track._search_type}:{track.isrc}", ctx=track.ctx))[0]
except Exception:
@ -389,7 +389,7 @@ class Player(VoiceProtocol):
"encodedTrack": search.track_id,
"position": str(start),
"endTime": str(track.length)
}
}
track.original = search
track.track_id = search.track_id
# Set track_id for later lavalink searches
@ -412,8 +412,8 @@ class Player(VoiceProtocol):
await self.remove_filter(filter_tag=filter.tag)
# Global filters take precedence over track filters
# So if no global filters are detected, lets apply any
# necessary track filters
# So if no global filters are detected, lets apply any
# necessary track filters
# Check if theres no global filters and if the track has any filters
# that need to be applied
@ -427,15 +427,15 @@ class Player(VoiceProtocol):
# so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track,
# otherwise itll be set here:
if end > 0:
data["endTime"] = str(end)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data,
query=f"noReplace={ignore_if_playing}"
)
@ -449,9 +449,9 @@ class Player(VoiceProtocol):
)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"position": position}
)
return self._position
@ -459,9 +459,9 @@ class Player(VoiceProtocol):
async def set_pause(self, pause: bool) -> bool:
"""Sets the pause state of the currently playing track."""
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"paused": pause}
)
self._paused = pause
@ -470,9 +470,9 @@ class Player(VoiceProtocol):
async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"volume": volume}
)
self._volume = volume
@ -485,18 +485,18 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.)
"""
self._filters.add_filter(filter=filter)
payload = self._filters.get_all_payloads()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply:
await self.seek(self.position)
return self._filters
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter:
@ -506,18 +506,18 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.)
"""
self._filters.remove_filter(filter_tag=filter_tag)
payload = self._filters.get_all_payloads()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply:
await self.seek(self.position)
return self._filters
async def reset_filters(self, *, fast_apply: bool = False):
@ -534,14 +534,14 @@ class Player(VoiceProtocol):
)
self._filters.reset_filters()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": {}}
)
if fast_apply:
await self.seek(self.position)

View File

@ -12,7 +12,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote
from . import (
__version__,
__version__,
spotify,
applemusic
)
@ -40,8 +40,8 @@ if TYPE_CHECKING:
__all__ = ('Node', 'NodePool')
class Node:
"""The base class for a node.
This node object represents a Lavalink node.
"""The base class for a node.
This node object represents a Lavalink node.
To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret
To enable Apple music, set the "apple_music" parameter to "True"
"""
@ -74,10 +74,10 @@ class Node:
self._heartbeat: int = heartbeat
self._secure: bool = secure
self.fallback: bool = fallback
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: Optional[aiohttp.ClientSession] = session
@ -88,7 +88,7 @@ class Node:
self._session_id: str = None
self._available: bool = False
self._version: str = None
self._route_planner = RoutePlanner(self)
self._headers = {
@ -196,8 +196,8 @@ class Node:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
retry = backoff.delay()
await asyncio.sleep(retry)
if not self.is_connected:
self._loop.create_task(self.connect())
if not self.is_connected:
self._loop.create_task(self.connect())
else:
self._loop.create_task(self._handle_payload(msg.json()))
@ -223,12 +223,12 @@ class Node:
await player._update_state(data)
async def send(
self,
self,
method: str,
path: str,
include_version: bool = True,
guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None,
path: str,
include_version: bool = True,
guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None,
data: Optional[Union[dict, str]] = None,
ignore_if_available: bool = False,
):
@ -253,10 +253,10 @@ class Node:
if resp.content_type == "text/plain":
return await resp.text()
return await resp.json()
def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
@ -278,24 +278,24 @@ class Node:
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library."
)
if version.endswith('-SNAPSHOT'):
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = version[:1]
self._version = version[:1]
self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket",
headers=self._headers,
headers=self._headers,
heartbeat=self._heartbeat
)
if not self._task:
self._task = self._loop.create_task(self._listen())
self._available = True
self._available = True
return self
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
@ -322,11 +322,11 @@ class Node:
await self._websocket.close()
await self._session.close()
if self._spotify_client:
await self._spotify_client.session.close()
await self._spotify_client.close()
if self._apple_music_client:
await self._apple_music_client.session.close()
await self._apple_music_client.close()
del self._pool._nodes[self._identifier]
self.available = False
self._task.cancel()
@ -362,11 +362,11 @@ class Node:
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
You may also pass in a List of filters
to be applied to your track once it plays.
"""
timestamp = None
timestamp = None
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
@ -374,7 +374,7 @@ class Node:
if filters:
for filter in filters:
filter.set_preload()
if URLRegex.AM_URL.match(query):
if not self._apple_music_client:
raise AppleMusicNotEnabled(
@ -382,7 +382,7 @@ class Node:
"Please set apple_music to True in your Node class."
)
apple_music_results = await self._apple_music_client.search(query=query)
apple_music_results = await self._apple_music_client.search(query=query)
if isinstance(apple_music_results, applemusic.Song):
return [
Track(
@ -501,7 +501,7 @@ class Node:
)
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}")
track: dict = data["tracks"][0]
@ -533,9 +533,9 @@ class Node:
# If query is a video thats part of a playlist, get the video and queue that instead
# (I can't tell you how much i've wanted to implement this in here)
if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)):
if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)):
query = match.group("video")
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}")
load_type = data.get("loadType")
@ -577,14 +577,14 @@ class Node:
]
async def get_recommendations(
self,
*,
track: Track,
self,
*,
track: Track,
ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]:
"""
Gets recommendations from either YouTube or Spotify.
The track that is passed in must be either from
The track that is passed in must be either from
YouTube or Spotify or else this will not work.
You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended.
@ -613,12 +613,12 @@ class Node:
]
return tracks
elif track.track_type == TrackType.YOUTUBE:
elif track.track_type == TrackType.YOUTUBE:
tracks = await self.get_tracks(query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx)
return tracks
else:
raise TrackLoadError("The specfied track must be either a YouTube or Spotify track to recieve recommendations.")
class NodePool:
"""The base class for the node pool.
@ -666,7 +666,7 @@ class NodePool:
elif algorithm == NodeAlgorithm.by_players:
tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get)
@classmethod
def get_node(cls, *, identifier: str = None) -> Node:
@ -714,11 +714,16 @@ class NodePool:
node = Node(
pool=cls, bot=bot, host=host, port=port, password=password,
identifier=identifier, secure=secure, heartbeat=heartbeat,
loop=loop, spotify_client_id=spotify_client_id,
loop=loop, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret,
apple_music=apple_music, fallback=fallback
)
await node.connect()
cls._nodes[node._identifier] = node
return node
return node
async def disconnect(self) -> None:
"""Disconnects all nodes from the node pool."""
for node in self._nodes.copy().values():
await node.disconnect()

View File

@ -107,7 +107,7 @@ class Queue(Iterable[Track]):
raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.")
def _get(self) -> Track:
def _get(self) -> Track:
return self._queue.pop(0)
def _drop(self) -> Track:
@ -298,7 +298,7 @@ class Queue(Iterable[Track]):
def set_loop_mode(self, mode: LoopMode) -> None:
"""
Sets the loop mode of the queue.
Sets the loop mode of the queue.
Takes the LoopMode enum as an argument.
"""
self._loop_mode = mode
@ -306,11 +306,11 @@ class Queue(Iterable[Track]):
try:
index = self._index(self._current_item)
except ValueError:
index = 0
index = 0
if self._current_item not in self._queue:
self._queue.insert(index, self._current_item)
self._current_item = self._queue[index]
def disable_loop(self) -> None:
"""
@ -320,12 +320,12 @@ class Queue(Iterable[Track]):
if not self._loop_mode:
raise QueueException("Queue loop is already disabled.")
if self._loop_mode == LoopMode.QUEUE:
index = self.find_position(self._current_item) + 1
if self._loop_mode == LoopMode.QUEUE:
index = self.find_position(self._current_item) + 1
self._queue = self._queue[index:]
self._loop_mode = None
def shuffle(self) -> None:
"""Shuffles the queue."""

View File

@ -8,9 +8,7 @@ import orjson as json
from base64 import b64encode
from typing import TYPE_CHECKING
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .objects import *
from .objects import *
GRANT_URL = "https://accounts.spotify.com/api/token"
@ -22,8 +20,8 @@ SPOTIFY_URL_REGEX = re.compile(
class Client:
"""The base client for the Spotify module of Pomice.
This class will do all the heavy lifting of getting all the metadata
for any Spotify URL you throw at it.
This class will do all the heavy lifting of getting all the metadata
for any Spotify URL you throw at it.
"""
def __init__(self, client_id: str, client_secret: str) -> None:
@ -34,7 +32,9 @@ class Client:
self._bearer_token: str = None
self._expiry = 0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode()
)
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._bearer_headers = None
@ -44,7 +44,9 @@ class Client:
if not self.session:
self.session = aiohttp.ClientSession()
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
async with self.session.post(
GRANT_URL, data=_data, headers=self._grant_headers
) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}"
@ -82,28 +84,35 @@ class Client:
elif spotify_type == "album":
return Album(data)
elif spotify_type == "artist":
async with self.session.get(f"{request_url}/top-tracks?market=US", headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
)
async with self.session.get(
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers
) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
)
track_data: dict = await resp.json(loads=json.loads)
tracks = track_data['tracks']
return Artist(data, tracks)
track_data: dict = await resp.json(loads=json.loads)
tracks = track_data["tracks"]
return Artist(data, tracks)
else:
tracks = [
Track(track["track"])
for track in data["tracks"]["items"] if track["track"] is not None
for track in data["tracks"]["items"]
if track["track"] is not None
]
if not len(tracks):
raise SpotifyRequestException("This playlist is empty and therefore cannot be queued.")
raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued."
)
next_page_url = data["tracks"]["next"]
while next_page_url is not None:
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
async with self.session.get(
next_page_url, headers=self._bearer_headers
) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
@ -113,7 +122,8 @@ class Client:
tracks += [
Track(track["track"])
for track in next_data["items"] if track["track"] is not None
for track in next_data["items"]
if track["track"] is not None
]
next_page_url = next_data["next"]
@ -133,7 +143,9 @@ class Client:
if not spotify_type == "track":
raise InvalidSpotifyURL("The provided query is not a Spotify track.")
request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}")
request_url = REQUEST_URL.format(
type="recommendation", id=f"?seed_tracks={spotify_id}"
)
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
@ -145,4 +157,9 @@ class Client:
tracks = [Track(track) for track in data["tracks"]]
return tracks
return tracks
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None

View File

@ -43,7 +43,7 @@ class Playlist:
if data.get("images") and len(data["images"]):
self.image: str = data["images"][0]["url"]
else:
self.image = self.tracks[0].image
self.image = self.tracks[0].image
self.uri = data["external_urls"]["spotify"]
def __repr__(self) -> str:

View File

@ -93,7 +93,7 @@ class NodeStats:
class FailingIPBlock:
"""
The base class for the failing IP block object from the route planner stats.
Gives critical information about any failing addresses on the block
Gives critical information about any failing addresses on the block
and the time they failed.
"""
def __init__(self, data: dict) -> None:
@ -102,7 +102,7 @@ class FailingIPBlock:
def __repr__(self) -> str:
return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>"
class RouteStats:
"""
@ -182,7 +182,7 @@ class Ping:
def get_ping(self):
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost(
(s.connect, s.shutdown),
((self._host, self._port), None))