switch formatting to black

This commit is contained in:
cloudwithax 2023-03-11 00:22:38 -05:00
parent a4a49c249e
commit 481e616414
20 changed files with 537 additions and 568 deletions

View File

@ -10,6 +10,7 @@ Licensed under GPL-3.0
import discord import discord
if not discord.version_info.major >= 2: if not discord.version_info.major >= 2:
class DiscordPyOutdated(Exception): class DiscordPyOutdated(Exception):
pass pass
@ -34,4 +35,3 @@ from .queue import *
from .player import * from .player import *
from .pool import * from .pool import *
from .routeplanner import * from .routeplanner import *

View File

@ -35,9 +35,7 @@ class Client:
if not self.session: if not self.session:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
async with self.session.get( async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
"https://music.apple.com/assets/index.919fe17f.js"
) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
@ -50,9 +48,7 @@ class Client:
"Origin": "https://apple.com", "Origin": "https://apple.com",
} }
token_split = self.token.split(".")[1] token_split = self.token.split(".")[1]
token_json = base64.b64decode( token_json = base64.b64decode(token_split + "=" * (-len(token_split) % 4)).decode()
token_split + "=" * (-len(token_split) % 4)
).decode()
token_data = json.loads(token_json) token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"]) self.expiry = datetime.fromtimestamp(token_data["exp"])
@ -105,7 +101,6 @@ class Client:
return Artist(data, tracks=tracks) return Artist(data, tracks=tracks)
else: else:
track_data: dict = data["relationships"]["tracks"] track_data: dict = data["relationships"]["tracks"]
tracks = [Song(track) for track in track_data.get("data")] tracks = [Song(track) for track in track_data.get("data")]
@ -119,9 +114,7 @@ class Client:
next_page_url = AM_BASE_URL + track_data.get("next") next_page_url = AM_BASE_URL + track_data.get("next")
while next_page_url is not None: while next_page_url is not None:
async with self.session.get( async with self.session.get(next_page_url, headers=self.headers) as resp:
next_page_url, headers=self.headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"

View File

@ -1,8 +1,10 @@
class AppleMusicRequestException(Exception): class AppleMusicRequestException(Exception):
"""An error occurred when making a request to the Apple Music API""" """An error occurred when making a request to the Apple Music API"""
pass pass
class InvalidAppleMusicURL(Exception): class InvalidAppleMusicURL(Exception):
"""An invalid Apple Music URL was passed""" """An invalid Apple Music URL was passed"""
pass pass

View File

@ -5,8 +5,8 @@ from typing import List
class Song: class Song:
"""The base class for an Apple Music song""" """The base class for an Apple Music song"""
def __init__(self, data: dict) -> None:
def __init__(self, data: dict) -> None:
self.name: str = data["attributes"]["name"] self.name: str = data["attributes"]["name"]
self.url: str = data["attributes"]["url"] self.url: str = data["attributes"]["url"]
self.isrc: str = data["attributes"]["isrc"] self.isrc: str = data["attributes"]["isrc"]
@ -15,7 +15,7 @@ class Song:
self.artists: str = data["attributes"]["artistName"] self.artists: str = data["attributes"]["artistName"]
self.image: str = data["attributes"]["artwork"]["url"].replace( self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}", "{w}x{h}",
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}' f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@ -27,6 +27,7 @@ class Song:
class Playlist: class Playlist:
"""The base class for an Apple Music playlist""" """The base class for an Apple Music playlist"""
def __init__(self, data: dict, tracks: List[Song]) -> None: def __init__(self, data: dict, tracks: List[Song]) -> None:
self.name: str = data["attributes"]["name"] self.name: str = data["attributes"]["name"]
self.owner: str = data["attributes"]["curatorName"] self.owner: str = data["attributes"]["curatorName"]
@ -47,6 +48,7 @@ class Playlist:
class Album: class Album:
"""The base class for an Apple Music album""" """The base class for an Apple Music album"""
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
self.name: str = data["attributes"]["name"] self.name: str = data["attributes"]["name"]
self.url: str = data["attributes"]["url"] self.url: str = data["attributes"]["url"]
@ -56,7 +58,7 @@ class Album:
self.tracks: List[Song] = [Song(track) for track in data["relationships"]["tracks"]["data"]] self.tracks: List[Song] = [Song(track) for track in data["relationships"]["tracks"]["data"]]
self.image: str = data["attributes"]["artwork"]["url"].replace( self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}", "{w}x{h}",
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}' f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@ -66,9 +68,9 @@ class Album:
) )
class Artist: class Artist:
"""The base class for an Apple Music artist""" """The base class for an Apple Music artist"""
def __init__(self, data: dict, tracks: dict) -> None: def __init__(self, data: dict, tracks: dict) -> None:
self.name: str = f'Top tracks for {data["attributes"]["name"]}' self.name: str = f'Top tracks for {data["attributes"]["name"]}'
self.url: str = data["attributes"]["url"] self.url: str = data["attributes"]["url"]
@ -77,11 +79,8 @@ class Artist:
self.tracks: List[Song] = [Song(track) for track in tracks] self.tracks: List[Song] = [Song(track) for track in tracks]
self.image: str = data["attributes"]["artwork"]["url"].replace( self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}", "{w}x{h}",
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}' f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"<Pomice.applemusic.Artist name={self.name} id={self.id} " f"tracks={self.tracks}>"
f"<Pomice.applemusic.Artist name={self.name} id={self.id} "
f"tracks={self.tracks}>"
)

View File

@ -18,6 +18,7 @@ class SearchType(Enum):
SearchType.scsearch searches using SoundCloud, SearchType.scsearch searches using SoundCloud,
which is an alternative to YouTube or YouTube Music. which is an alternative to YouTube or YouTube Music.
""" """
ytsearch = "ytsearch" ytsearch = "ytsearch"
ytmsearch = "ytmsearch" ytmsearch = "ytmsearch"
scsearch = "scsearch" scsearch = "scsearch"
@ -51,6 +52,7 @@ class TrackType(Enum):
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
class PlaylistType(Enum): class PlaylistType(Enum):
""" """
The enum for the different playlist types for Pomice. The enum for the different playlist types for Pomice.
@ -74,7 +76,6 @@ class PlaylistType(Enum):
return self.value return self.value
class NodeAlgorithm(Enum): class NodeAlgorithm(Enum):
""" """
The enum for the different node algorithms in Pomice. The enum for the different node algorithms in Pomice.
@ -98,6 +99,7 @@ class NodeAlgorithm(Enum):
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
class LoopMode(Enum): class LoopMode(Enum):
""" """
The enum for the different loop modes. The enum for the different loop modes.
@ -109,11 +111,11 @@ class LoopMode(Enum):
LoopMode.QUEUE sets the queue loop to the whole queue. LoopMode.QUEUE sets the queue loop to the whole queue.
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
TRACK = "track" TRACK = "track"
QUEUE = "queue" QUEUE = "queue"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -160,7 +162,7 @@ class RouteIPType(Enum):
IPV6 = "Inet6Address" IPV6 = "Inet6Address"
class URLRegex(): class URLRegex:
""" """
The enums for all the URL Regexes in use by Pomice. The enums for all the URL Regexes in use by Pomice.
@ -181,6 +183,7 @@ class URLRegex():
URLRegex.BASE_URL returns the standard URL Regex. URLRegex.BASE_URL returns the standard URL Regex.
""" """
SPOTIFY_URL = re.compile( SPOTIFY_URL = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)" r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)"
) )
@ -199,13 +202,9 @@ class URLRegex():
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*" r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*"
) )
YOUTUBE_VID_IN_PLAYLIST = re.compile( YOUTUBE_VID_IN_PLAYLIST = re.compile(r"(?P<video>^.*?v.*?)(?P<list>&list.*)")
r"(?P<video>^.*?v.*?)(?P<list>&list.*)"
)
YOUTUBE_TIMESTAMP = re.compile( YOUTUBE_TIMESTAMP = re.compile(r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*")
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*"
)
AM_URL = re.compile( AM_URL = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
@ -217,23 +216,14 @@ class URLRegex():
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)" r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
) )
SOUNDCLOUD_URL = re.compile( SOUNDCLOUD_URL = re.compile(r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*")
r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*"
)
SOUNDCLOUD_PLAYLIST_URL = re.compile( SOUNDCLOUD_PLAYLIST_URL = re.compile(r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*")
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*"
)
SOUNDCLOUD_TRACK_IN_SET_URL = re.compile( SOUNDCLOUD_TRACK_IN_SET_URL = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)" r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)"
) )
LAVALINK_SEARCH = re.compile( LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")
r"(?P<type>ytm?|sc)search:"
)
BASE_URL = re.compile(
r"https?://(?:www\.)?.+"
)
BASE_URL = re.compile(r"https?://(?:www\.)?.+")

View File

@ -7,19 +7,21 @@ from .objects import Track
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
class PomiceEvent: class PomiceEvent:
"""The base class for all events dispatched by a node. """The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener. Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be: i.e: If you want to listen for when a track starts, the event would be:
```py ```py
@bot.listen @bot.listen
async def on_pomice_track_start(self, event): async def on_pomice_track_start(self, event):
``` ```
""" """
name = "event" name = "event"
handler_args = () handler_args = ()
@ -29,16 +31,13 @@ class PomiceEvent:
class TrackStartEvent(PomiceEvent): class TrackStartEvent(PomiceEvent):
"""Fired when a track has successfully started. """Fired when a track has successfully started.
Returns the player associated with the event and the pomice.Track object. Returns the player associated with the event and the pomice.Track object.
""" """
name = "track_start" name = "track_start"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track")
__slots__ = (
"player",
"track"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._current self.track: Track = self.player._current
@ -52,17 +51,13 @@ class TrackStartEvent(PomiceEvent):
class TrackEndEvent(PomiceEvent): class TrackEndEvent(PomiceEvent):
"""Fired when a track has successfully ended. """Fired when a track has successfully ended.
Returns the player associated with the event along with the pomice.Track object and reason. Returns the player associated with the event along with the pomice.Track object and reason.
""" """
name = "track_end" name = "track_end"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "reason")
__slots__ = (
"player",
"track",
"reason"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
@ -80,18 +75,14 @@ class TrackEndEvent(PomiceEvent):
class TrackStuckEvent(PomiceEvent): class TrackStuckEvent(PomiceEvent):
"""Fired when a track is stuck and cannot be played. Returns the player """Fired when a track is stuck and cannot be played. Returns the player
associated with the event along with the pomice.Track object associated with the event along with the pomice.Track object
to be further parsed by the end user. to be further parsed by the end user.
""" """
name = "track_stuck" name = "track_stuck"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "threshold")
__slots__ = (
"player",
"track",
"threshold"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
@ -101,27 +92,25 @@ class TrackStuckEvent(PomiceEvent):
self.handler_args = self.player, self.track, self.threshold self.handler_args = self.player, self.track, self.threshold
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} " \ return (
f"threshold={self.threshold!r}>" f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} "
f"threshold={self.threshold!r}>"
)
class TrackExceptionEvent(PomiceEvent): class TrackExceptionEvent(PomiceEvent):
"""Fired when a track error has occured. """Fired when a track error has occured.
Returns the player associated with the event along with the error code and exception. Returns the player associated with the event along with the error code and exception.
""" """
name = "track_exception" name = "track_exception"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "exception")
__slots__ = (
"player",
"track",
"exception"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
if data.get('error'): if data.get("error"):
# User is running Lavalink <= 3.3 # User is running Lavalink <= 3.3
self.exception: str = data["error"] self.exception: str = data["error"]
else: else:
@ -137,13 +126,7 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload: class WebSocketClosedPayload:
def __init__(self, data: dict): def __init__(self, data: dict):
__slots__ = ("guild", "code", "reason", "by_remote")
__slots__ = (
"guild",
"code",
"reason",
"by_remote"
)
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"])) self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.code: int = data["code"] self.code: int = data["code"]
@ -151,21 +134,24 @@ class WebSocketClosedPayload:
self.by_remote: bool = data["byRemote"] self.by_remote: bool = data["byRemote"]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.WebSocketClosedPayload guild={self.guild!r} code={self.code!r} " \ return (
f"reason={self.reason!r} by_remote={self.by_remote!r}>" f"<Pomice.WebSocketClosedPayload guild={self.guild!r} code={self.code!r} "
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
)
class WebSocketClosedEvent(PomiceEvent): class WebSocketClosedEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been closed. """Fired when a websocket connection to a node has been closed.
Returns the reason and the error code. Returns the reason and the error code.
""" """
name = "websocket_closed" name = "websocket_closed"
def __init__(self, data: dict, _): def __init__(self, data: dict, _):
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data) self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload) # on_pomice_websocket_closed(payload)
self.handler_args = self.payload, self.handler_args = (self.payload,)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.WebsocketClosedEvent payload={self.payload!r}>" return f"<Pomice.WebsocketClosedEvent payload={self.payload!r}>"
@ -173,16 +159,13 @@ class WebSocketClosedEvent(PomiceEvent):
class WebSocketOpenEvent(PomiceEvent): class WebSocketOpenEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been initiated. """Fired when a websocket connection to a node has been initiated.
Returns the target and the session SSRC. Returns the target and the session SSRC.
""" """
name = "websocket_open" name = "websocket_open"
def __init__(self, data: dict, _): def __init__(self, data: dict, _):
__slots__ = ("target", "ssrc")
__slots__ = (
"target",
"ssrc"
)
self.target: str = data["target"] self.target: str = data["target"]
self.ssrc: int = data["ssrc"] self.ssrc: int = data["ssrc"]
@ -192,4 +175,3 @@ class WebSocketOpenEvent(PomiceEvent):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.WebsocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>" return f"<Pomice.WebsocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"

View File

@ -16,68 +16,89 @@ class NodeConnectionFailure(NodeException):
class NodeConnectionClosed(NodeException): class NodeConnectionClosed(NodeException):
"""The node's connection is closed.""" """The node's connection is closed."""
pass pass
class NodeRestException(NodeException): class NodeRestException(NodeException):
"""A request made using the node's REST uri failed""" """A request made using the node's REST uri failed"""
pass pass
class NodeNotAvailable(PomiceException): class NodeNotAvailable(PomiceException):
"""The node is currently unavailable.""" """The node is currently unavailable."""
pass pass
class NoNodesAvailable(PomiceException): class NoNodesAvailable(PomiceException):
"""There are no nodes currently available.""" """There are no nodes currently available."""
pass pass
class TrackInvalidPosition(PomiceException): class TrackInvalidPosition(PomiceException):
"""An invalid position was chosen for a track.""" """An invalid position was chosen for a track."""
pass pass
class TrackLoadError(PomiceException): class TrackLoadError(PomiceException):
"""There was an error while loading a track.""" """There was an error while loading a track."""
pass pass
class FilterInvalidArgument(PomiceException): class FilterInvalidArgument(PomiceException):
"""An invalid argument was passed to a filter.""" """An invalid argument was passed to a filter."""
pass pass
class FilterTagInvalid(PomiceException): class FilterTagInvalid(PomiceException):
"""An invalid tag was passed or Pomice was unable to find a filter tag""" """An invalid tag was passed or Pomice was unable to find a filter tag"""
pass pass
class FilterTagAlreadyInUse(PomiceException): class FilterTagAlreadyInUse(PomiceException):
"""A filter with a tag is already in use by another filter""" """A filter with a tag is already in use by another filter"""
pass pass
class InvalidSpotifyClientAuthorization(PomiceException): class InvalidSpotifyClientAuthorization(PomiceException):
"""No Spotify client authorization was provided for track searching.""" """No Spotify client authorization was provided for track searching."""
pass pass
class AppleMusicNotEnabled(PomiceException): class AppleMusicNotEnabled(PomiceException):
"""An Apple Music Link was passed in when Apple Music functionality was not enabled.""" """An Apple Music Link was passed in when Apple Music functionality was not enabled."""
pass pass
class QueueException(Exception): class QueueException(Exception):
"""Base Pomice queue exception.""" """Base Pomice queue exception."""
pass pass
class QueueFull(QueueException): class QueueFull(QueueException):
"""Exception raised when attempting to add to a full Queue.""" """Exception raised when attempting to add to a full Queue."""
pass pass
class QueueEmpty(QueueException): class QueueEmpty(QueueException):
"""Exception raised when attempting to retrieve from an empty Queue.""" """Exception raised when attempting to retrieve from an empty Queue."""
pass pass
class LavalinkVersionIncompatible(PomiceException): class LavalinkVersionIncompatible(PomiceException):
"""Lavalink version is incompatible. Must be using Lavalink > 3.7.0 to avoid this error.""" """Lavalink version is incompatible. Must be using Lavalink > 3.7.0 to avoid this error."""
pass pass

View File

@ -1,6 +1,7 @@
import collections import collections
from .exceptions import FilterInvalidArgument from .exceptions import FilterInvalidArgument
class Filter: class Filter:
""" """
The base class for all filters. The base class for all filters.
@ -11,12 +12,9 @@ class Filter:
You must specify a tag for each filter you put on. You must specify a tag for each filter you put on.
This is necessary for the removal of filters. This is necessary for the removal of filters.
""" """
def __init__(self, *, tag: str): def __init__(self, *, tag: str):
__slots__ = ( __slots__ = ("payload", "tag", "preload")
"payload",
"tag",
"preload"
)
self.payload: dict = None self.payload: dict = None
self.tag: str = tag self.tag: str = tag
@ -63,41 +61,77 @@ class Equalizer(Filter):
@classmethod @classmethod
def flat(cls): def flat(cls):
"""Equalizer preset which represents a flat EQ board, """Equalizer preset which represents a flat EQ board,
with all levels set to their default values. with all levels set to their default values.
""" """
levels = [ levels = [
(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (0, 0.0),
(5, 0.0), (6, 0.0), (7, 0.0), (8, 0.0), (9, 0.0), (1, 0.0),
(10, 0.0), (11, 0.0), (12, 0.0), (13, 0.0), (14, 0.0) (2, 0.0),
(3, 0.0),
(4, 0.0),
(5, 0.0),
(6, 0.0),
(7, 0.0),
(8, 0.0),
(9, 0.0),
(10, 0.0),
(11, 0.0),
(12, 0.0),
(13, 0.0),
(14, 0.0),
] ]
return cls(tag="flat", levels=levels) return cls(tag="flat", levels=levels)
@classmethod @classmethod
def boost(cls): def boost(cls):
"""Equalizer preset which boosts the sound of a track, """Equalizer preset which boosts the sound of a track,
making it sound fun and energetic by increasing the bass making it sound fun and energetic by increasing the bass
and the highs. and the highs.
""" """
levels = [ levels = [
(0, -0.075), (1, 0.125), (2, 0.125), (3, 0.1), (4, 0.1), (0, -0.075),
(5, .05), (6, 0.075), (7, 0.0), (8, 0.0), (9, 0.0), (1, 0.125),
(10, 0.0), (11, 0.0), (12, 0.125), (13, 0.15), (14, 0.05) (2, 0.125),
(3, 0.1),
(4, 0.1),
(5, 0.05),
(6, 0.075),
(7, 0.0),
(8, 0.0),
(9, 0.0),
(10, 0.0),
(11, 0.0),
(12, 0.125),
(13, 0.15),
(14, 0.05),
] ]
return cls(tag="boost", levels=levels) return cls(tag="boost", levels=levels)
@classmethod @classmethod
def metal(cls): def metal(cls):
"""Equalizer preset which increases the mids of a track, """Equalizer preset which increases the mids of a track,
preferably one of the metal genre, to make it sound preferably one of the metal genre, to make it sound
more full and concert-like. more full and concert-like.
""" """
levels = [ levels = [
(0, 0.0), (1, 0.1), (2, 0.1), (3, 0.15), (4, 0.13), (0, 0.0),
(5, 0.1), (6, 0.0), (7, 0.125), (8, 0.175), (9, 0.175), (1, 0.1),
(10, 0.125), (11, 0.125), (12, 0.1), (13, 0.075), (14, 0.0) (2, 0.1),
(3, 0.15),
(4, 0.13),
(5, 0.1),
(6, 0.0),
(7, 0.125),
(8, 0.175),
(9, 0.175),
(10, 0.125),
(11, 0.125),
(12, 0.1),
(13, 0.075),
(14, 0.0),
] ]
return cls(tag="metal", levels=levels) return cls(tag="metal", levels=levels)
@ -105,40 +139,40 @@ class Equalizer(Filter):
@classmethod @classmethod
def piano(cls): def piano(cls):
"""Equalizer preset which increases the mids and highs """Equalizer preset which increases the mids and highs
of a track, preferably a piano based one, to make it of a track, preferably a piano based one, to make it
stand out. stand out.
""" """
levels = [ levels = [
(0, -0.25), (1, -0.25), (2, -0.125), (3, 0.0), (0, -0.25),
(4, 0.25), (5, 0.25), (6, 0.0), (7, -0.25), (8, -0.25), (1, -0.25),
(9, 0.0), (10, 0.0), (11, 0.5), (12, 0.25), (13, -0.025) (2, -0.125),
(3, 0.0),
(4, 0.25),
(5, 0.25),
(6, 0.0),
(7, -0.25),
(8, -0.25),
(9, 0.0),
(10, 0.0),
(11, 0.5),
(12, 0.25),
(13, -0.025),
] ]
return cls(tag="piano", levels=levels) return cls(tag="piano", levels=levels)
class Timescale(Filter): class Timescale(Filter):
"""Filter which changes the speed and pitch of a track. """Filter which changes the speed and pitch of a track.
You can make some very nice effects with this filter, You can make some very nice effects with this filter,
i.e: a vaporwave-esque filter which slows the track down i.e: a vaporwave-esque filter which slows the track down
a certain amount to produce said effect. a certain amount to produce said effect.
""" """
def __init__( def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: float = 1.0):
self,
*,
tag: str,
speed: float = 1.0,
pitch: float = 1.0,
rate: float = 1.0
):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ( __slots__ = ("speed", "pitch", "rate")
"speed",
"pitch",
"rate"
)
if speed < 0: if speed < 0:
raise FilterInvalidArgument("Timescale speed must be more than 0.") raise FilterInvalidArgument("Timescale speed must be more than 0.")
@ -151,9 +185,9 @@ class Timescale(Filter):
self.pitch: float = pitch self.pitch: float = pitch
self.rate: float = rate self.rate: float = rate
self.payload: dict = {"timescale": {"speed": self.speed, self.payload: dict = {
"pitch": self.pitch, "timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate}
"rate": self.rate}} }
@classmethod @classmethod
def vaporwave(cls): def vaporwave(cls):
@ -181,7 +215,7 @@ class Timescale(Filter):
class Karaoke(Filter): class Karaoke(Filter):
"""Filter which filters the vocal track from any song and leaves the instrumental. """Filter which filters the vocal track from any song and leaves the instrumental.
Best for karaoke as the filter implies. Best for karaoke as the filter implies.
""" """
def __init__( def __init__(
@ -191,26 +225,25 @@ class Karaoke(Filter):
level: float = 1.0, level: float = 1.0,
mono_level: float = 1.0, mono_level: float = 1.0,
filter_band: float = 220.0, filter_band: float = 220.0,
filter_width: float = 100.0 filter_width: float = 100.0,
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ( __slots__ = ("level", "mono_level", "filter_band", "filter_width")
"level",
"mono_level",
"filter_band",
"filter_width"
)
self.level: float = level self.level: float = level
self.mono_level: float = mono_level self.mono_level: float = mono_level
self.filter_band: float = filter_band self.filter_band: float = filter_band
self.filter_width: float = filter_width self.filter_width: float = filter_width
self.payload: dict = {"karaoke": {"level": self.level, self.payload: dict = {
"monoLevel": self.mono_level, "karaoke": {
"filterBand": self.filter_band, "level": self.level,
"filterWidth": self.filter_width}} "monoLevel": self.mono_level,
"filterBand": self.filter_band,
"filterWidth": self.filter_width,
}
}
def __repr__(self): def __repr__(self):
return ( return (
@ -221,74 +254,54 @@ class Karaoke(Filter):
class Tremolo(Filter): class Tremolo(Filter):
"""Filter which produces a wavering tone in the music, """Filter which produces a wavering tone in the music,
causing it to sound like the music is changing in volume rapidly. causing it to sound like the music is changing in volume rapidly.
""" """
def __init__( def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
self,
*,
tag: str,
frequency: float = 2.0,
depth: float = 0.5
):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ( __slots__ = ("frequency", "depth")
"frequency",
"depth"
)
if frequency < 0: if frequency < 0:
raise FilterInvalidArgument( raise FilterInvalidArgument("Tremolo frequency must be more than 0.")
"Tremolo frequency must be more than 0.")
if depth < 0 or depth > 1: if depth < 0 or depth > 1:
raise FilterInvalidArgument( raise FilterInvalidArgument("Tremolo depth must be between 0 and 1.")
"Tremolo depth must be between 0 and 1.")
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.payload: dict = {"tremolo": {"frequency": self.frequency, self.payload: dict = {"tremolo": {"frequency": self.frequency, "depth": self.depth}}
"depth": self.depth}}
def __repr__(self): def __repr__(self):
return f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" return (
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
class Vibrato(Filter): class Vibrato(Filter):
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter, """Filter which produces a wavering tone in the music, similar to the Tremolo filter,
but changes in pitch rather than volume. but changes in pitch rather than volume.
""" """
def __init__( def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
self,
*,
tag: str,
frequency: float = 2.0,
depth: float = 0.5
):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ( __slots__ = ("frequency", "depth")
"frequency",
"depth"
)
if frequency < 0 or frequency > 14: if frequency < 0 or frequency > 14:
raise FilterInvalidArgument( raise FilterInvalidArgument("Vibrato frequency must be between 0 and 14.")
"Vibrato frequency must be between 0 and 14.")
if depth < 0 or depth > 1: if depth < 0 or depth > 1:
raise FilterInvalidArgument( raise FilterInvalidArgument("Vibrato depth must be between 0 and 1.")
"Vibrato depth must be between 0 and 1.")
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.payload: dict = {"vibrato": {"frequency": self.frequency, self.payload: dict = {"vibrato": {"frequency": self.frequency, "depth": self.depth}}
"depth": self.depth}}
def __repr__(self): def __repr__(self):
return f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" return (
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
class Rotation(Filter): class Rotation(Filter):
@ -299,7 +312,7 @@ class Rotation(Filter):
def __init__(self, *, tag: str, rotation_hertz: float = 5): def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("rotation_hertz") __slots__ = "rotation_hertz"
self.rotation_hertz: float = rotation_hertz self.rotation_hertz: float = rotation_hertz
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}} self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
@ -320,48 +333,50 @@ class ChannelMix(Filter):
left_to_left: float = 1, left_to_left: float = 1,
right_to_right: float = 1, right_to_right: float = 1,
left_to_right: float = 0, left_to_right: float = 0,
right_to_left: float = 0 right_to_left: float = 0,
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ( __slots__ = ("left_to_left", "right_to_right", "left_to_right", "right_to_left")
"left_to_left",
"right_to_right",
"left_to_right",
"right_to_left"
)
if 0 > left_to_left > 1: if 0 > left_to_left > 1:
raise ValueError( raise ValueError(
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1.") "'left_to_left' value must be more than or equal to 0 or less than or equal to 1."
)
if 0 > right_to_right > 1: if 0 > right_to_right > 1:
raise ValueError( raise ValueError(
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1.") "'right_to_right' value must be more than or equal to 0 or less than or equal to 1."
)
if 0 > left_to_right > 1: if 0 > left_to_right > 1:
raise ValueError( raise ValueError(
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1.") "'left_to_right' value must be more than or equal to 0 or less than or equal to 1."
)
if 0 > right_to_left > 1: if 0 > right_to_left > 1:
raise ValueError( raise ValueError(
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1.") "'right_to_left' value must be more than or equal to 0 or less than or equal to 1."
)
self.left_to_left: float = left_to_left self.left_to_left: float = left_to_left
self.left_to_right: float = left_to_right self.left_to_right: float = left_to_right
self.right_to_left: float = right_to_left self.right_to_left: float = right_to_left
self.right_to_right: float = right_to_right self.right_to_right: float = right_to_right
self.payload: dict = {"channelMix": {"leftToLeft": self.left_to_left, self.payload: dict = {
"leftToRight": self.left_to_right, "channelMix": {
"rightToLeft": self.right_to_left, "leftToLeft": self.left_to_left,
"rightToRight": self.right_to_right} "leftToRight": self.left_to_right,
} "rightToLeft": self.right_to_left,
"rightToRight": self.right_to_right,
}
}
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Pomice.ChannelMix tag={self.tag} left_to_left={self.left_to_left} left_to_right={self.left_to_right} " f"<Pomice.ChannelMix tag={self.tag} left_to_left={self.left_to_left} left_to_right={self.left_to_right} "
f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>" f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>"
) )
class Distortion(Filter): class Distortion(Filter):
"""Filter which generates a distortion effect. Useful for certain filter implementations where """Filter which generates a distortion effect. Useful for certain filter implementations where
distortion is needed. distortion is needed.
@ -371,14 +386,14 @@ class Distortion(Filter):
self, self,
*, *,
tag: str, tag: str,
sin_offset: float = 0, sin_offset: float = 0,
sin_scale: float = 1, sin_scale: float = 1,
cos_offset: float = 0, cos_offset: float = 0,
cos_scale: float = 1, cos_scale: float = 1,
tan_offset: float = 0, tan_offset: float = 0,
tan_scale: float = 1, tan_scale: float = 1,
offset: float = 0, offset: float = 0,
scale: float = 1 scale: float = 1,
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
@ -388,9 +403,8 @@ class Distortion(Filter):
"cos_offset", "cos_offset",
"cos_scale", "cos_scale",
"tan_offset", "tan_offset",
"tan_scale" "tan_scale" "offset",
"offset", "scale",
"scale"
) )
self.sin_offset: float = sin_offset self.sin_offset: float = sin_offset
@ -402,22 +416,24 @@ class Distortion(Filter):
self.offset: float = offset self.offset: float = offset
self.scale: float = scale self.scale: float = scale
self.payload: dict = {"distortion": { self.payload: dict = {
"sinOffset": self.sin_offset, "distortion": {
"sinScale": self.sin_scale, "sinOffset": self.sin_offset,
"cosOffset": self.cos_offset, "sinScale": self.sin_scale,
"cosScale": self.cos_scale, "cosOffset": self.cos_offset,
"tanOffset": self.tan_offset, "cosScale": self.cos_scale,
"tanScale": self.tan_scale, "tanOffset": self.tan_offset,
"offset": self.offset, "tanScale": self.tan_scale,
"scale": self.scale "offset": self.offset,
}} "scale": self.scale,
}
}
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Pomice.Distortion tag={self.tag} sin_offset={self.sin_offset} sin_scale={self.sin_scale}> " f"<Pomice.Distortion tag={self.tag} sin_offset={self.sin_offset} sin_scale={self.sin_scale}> "
f"cos_offset={self.cos_offset} cos_scale={self.cos_scale} tan_offset={self.tan_offset} " f"cos_offset={self.cos_offset} cos_scale={self.cos_scale} tan_offset={self.tan_offset} "
f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}" f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}"
) )
@ -425,15 +441,14 @@ class LowPass(Filter):
"""Filter which supresses higher frequencies and allows lower frequencies to pass. """Filter which supresses higher frequencies and allows lower frequencies to pass.
You can also do this with the Equalizer filter, but this is an easier way to do it. You can also do this with the Equalizer filter, but this is an easier way to do it.
""" """
def __init__(self, *, tag: str, smoothing: float = 20): def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ('smoothing') __slots__ = "smoothing"
self.smoothing: float = smoothing self.smoothing: float = smoothing
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}} self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>" return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"

View File

@ -10,7 +10,7 @@ from .filters import Filter
class Track: class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink. """The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track. You can also pass in commands.Context to get a discord.py Context object in your track.
""" """
def __init__( def __init__(
@ -45,7 +45,7 @@ class Track:
"requester", "requester",
"is_stream", "is_stream",
"is_seekable", "is_seekable",
"position" "position",
) )
self.track_id: str = track_id self.track_id: str = track_id
@ -106,8 +106,8 @@ class Track:
class Playlist: class Playlist:
"""The base playlist object. """The base playlist object.
Returns critical playlist information needed for parsing by Lavalink. Returns critical playlist information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your tracks. You can also pass in commands.Context to get a discord.py Context object in your tracks.
""" """
def __init__( def __init__(
@ -117,9 +117,8 @@ class Playlist:
tracks: list, tracks: list,
playlist_type: PlaylistType, playlist_type: PlaylistType,
thumbnail: Optional[str] = None, thumbnail: Optional[str] = None,
uri: Optional[str] = None uri: Optional[str] = None,
): ):
__slots__ = ( __slots__ = (
"playlist_info", "playlist_info",
"tracks", "tracks",
@ -128,7 +127,7 @@ class Playlist:
"_thumbnail", "_thumbnail",
"_uri", "_uri",
"selected_track", "selected_track",
"track_count" "track_count",
) )
self.playlist_info: dict = playlist_info self.playlist_info: dict = playlist_info

View File

@ -1,24 +1,19 @@
import time import time
from typing import ( from typing import Any, Dict, List, Optional, Union
Any,
Dict,
List,
Optional,
Union
)
from discord import ( from discord import Client, Guild, VoiceChannel, VoiceProtocol
Client,
Guild,
VoiceChannel,
VoiceProtocol
)
from discord.ext import commands from discord.ext import commands
from . import events from . import events
from .enums import SearchType from .enums import SearchType
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError from .exceptions import (
FilterInvalidArgument,
FilterTagAlreadyInUse,
FilterTagInvalid,
TrackInvalidPosition,
TrackLoadError,
)
from .filters import Filter from .filters import Filter
from .objects import Track from .objects import Track
from .pool import Node, NodePool from .pool import Node, NodePool
@ -26,7 +21,8 @@ from .pool import Node, NodePool
class Filters: class Filters:
"""Helper class for filters""" """Helper class for filters"""
__slots__ = ('_filters')
__slots__ = "_filters"
def __init__(self): def __init__(self):
self._filters: List[Filter] = [] self._filters: List[Filter] = []
@ -41,27 +37,21 @@ class Filters:
"""Property which checks if any applied filters are global""" """Property which checks if any applied filters are global"""
return any(f for f in self._filters if f.preload == False) return any(f for f in self._filters if f.preload == False)
@property @property
def empty(self): def empty(self):
"""Property which checks if the filter list is empty""" """Property which checks if the filter list is empty"""
return len(self._filters) == 0 return len(self._filters) == 0
def add_filter(self, *, filter: Filter): def add_filter(self, *, filter: Filter):
"""Adds a filter to the list of filters applied""" """Adds a filter to the list of filters applied"""
if any(f for f in self._filters if f.tag == filter.tag): if any(f for f in self._filters if f.tag == filter.tag):
raise FilterTagAlreadyInUse( raise FilterTagAlreadyInUse("A filter with that tag is already in use.")
"A filter with that tag is already in use."
)
self._filters.append(filter) self._filters.append(filter)
def remove_filter(self, *, filter_tag: str): def remove_filter(self, *, filter_tag: str):
"""Removes a filter from the list of filters applied using its filter tag""" """Removes a filter from the list of filters applied using its filter tag"""
if not any(f for f in self._filters if f.tag == filter_tag): if not any(f for f in self._filters if f.tag == filter_tag):
raise FilterTagInvalid( raise FilterTagInvalid("A filter with that tag was not found.")
"A filter with that tag was not found."
)
for index, filter in enumerate(self._filters): for index, filter in enumerate(self._filters):
if filter.tag == filter_tag: if filter.tag == filter_tag:
@ -91,13 +81,12 @@ class Filters:
return self._filters return self._filters
class Player(VoiceProtocol): class Player(VoiceProtocol):
"""The base player class for Pomice. """The base player class for Pomice.
In order to initiate a player, you must pass it in as a cls when you connect to a channel. In order to initiate a player, you must pass it in as a cls when you connect to a channel.
i.e: ```py i.e: ```py
await ctx.author.voice.channel.connect(cls=pomice.Player) await ctx.author.voice.channel.connect(cls=pomice.Player)
``` ```
""" """
def __call__(self, client: Client, channel: VoiceChannel): def __call__(self, client: Client, channel: VoiceChannel):
@ -112,26 +101,26 @@ class Player(VoiceProtocol):
client: Optional[Client] = None, client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None, channel: Optional[VoiceChannel] = None,
*, *,
node: Node = None node: Node = None,
): ):
__slots__ = ( __slots__ = (
'client', "client",
'channel', "channel",
'_bot', "_bot",
'_guild', "_guild",
'_node', "_node",
'_current', "_current",
'_filters', "_filters",
'_volume', "_volume",
'_paused', "_paused",
'_is_connected', "_is_connected",
'_position', "_position",
'_last_position', "_last_position",
'_last_update', "_last_update",
'_ending_track', "_ending_track",
'_voice_state', "_voice_state",
'_player_endpoint_uri', "_player_endpoint_uri",
'__dict__' "__dict__",
) )
self.client: Optional[Client] = client self.client: Optional[Client] = client
@ -153,7 +142,7 @@ class Player(VoiceProtocol):
self._voice_state: dict = {} self._voice_state: dict = {}
self._player_endpoint_uri: str = f'sessions/{self._node._session_id}/players' self._player_endpoint_uri: str = f"sessions/{self._node._session_id}/players"
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -228,7 +217,7 @@ class Player(VoiceProtocol):
@property @property
def is_dead(self) -> bool: def is_dead(self) -> bool:
"""Returns a bool representing whether the player is dead or not. """Returns a bool representing whether the player is dead or not.
A player is considered dead if it has been destroyed and removed from stored players. A player is considered dead if it has been destroyed and removed from stored players.
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
@ -245,16 +234,16 @@ class Player(VoiceProtocol):
state = voice_data or self._voice_state state = voice_data or self._voice_state
data = { data = {
"token": state['event']['token'], "token": state["event"]["token"],
"endpoint": state['event']['endpoint'], "endpoint": state["event"]["endpoint"],
"sessionId": state['sessionId'], "sessionId": state["sessionId"],
} }
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"voice": data} data={"voice": data},
) )
async def on_voice_server_update(self, data: dict): async def on_voice_server_update(self, data: dict):
@ -290,15 +279,15 @@ class Player(VoiceProtocol):
async def _swap_node(self, *, new_node: Node): async def _swap_node(self, *, new_node: Node):
data: dict = { data: dict = {
'encodedTrack': self.current.track_id, "encodedTrack": self.current.track_id,
'position': self.position, "position": self.position,
} }
del self._node._players[self._guild.id] del self._node._players[self._guild.id]
self._node = new_node self._node = new_node
self._node._players[self._guild.id] = self self._node._players[self._guild.id] = self
# reassign uri to update session id # reassign uri to update session id
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players' self._player_endpoint_uri = f"sessions/{self._node._session_id}/players"
await self._dispatch_voice_update() await self._dispatch_voice_update()
await self._node.send( await self._node.send(
@ -314,7 +303,7 @@ class Player(VoiceProtocol):
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None filters: Optional[List[Filter]] = None,
): ):
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
@ -331,10 +320,7 @@ class Player(VoiceProtocol):
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters) return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def get_recommendations( async def get_recommendations(
self, self, *, track: Track, ctx: Optional[commands.Context] = None
*,
track: Track,
ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Union[List[Track], None]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
@ -343,8 +329,12 @@ class Player(VoiceProtocol):
""" """
return await self._node.get_recommendations(track=track, ctx=ctx) return await self._node.get_recommendations(track=track, ctx=ctx)
async def connect(self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False): async def connect(
await self.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute) self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
):
await self.guild.change_voice_state(
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute
)
self._node._players[self.guild.id] = self self._node._players[self.guild.id] = self
self._is_connected = True self._is_connected = True
@ -355,7 +345,7 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={'encodedTrack': None} data={"encodedTrack": None},
) )
async def disconnect(self, *, force: bool = False): async def disconnect(self, *, force: bool = False):
@ -377,15 +367,12 @@ class Player(VoiceProtocol):
assert self.channel is None and not self.is_connected assert self.channel is None and not self.is_connected
self._node._players.pop(self.guild.id) self._node._players.pop(self.guild.id)
await self._node.send(method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id) await self._node.send(
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id
)
async def play( async def play(
self, self, track: Track, *, start: int = 0, end: int = 0, ignore_if_playing: bool = False
track: Track,
*,
start: int = 0,
end: int = 0,
ignore_if_playing: bool = False
) -> Track: ) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly.""" """Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
@ -396,22 +383,24 @@ class Player(VoiceProtocol):
if not track.isrc: if not track.isrc:
# We have to bare raise here because theres no other way to skip this block feasibly # We have to bare raise here because theres no other way to skip this block feasibly
raise raise
search: Track = (await self._node.get_tracks( search: Track = (
f"{track._search_type}:{track.isrc}", ctx=track.ctx))[0] await self._node.get_tracks(f"{track._search_type}:{track.isrc}", ctx=track.ctx)
)[0]
except Exception: except Exception:
# First method didn't work, lets try just searching it up # First method didn't work, lets try just searching it up
try: try:
search: Track = (await self._node.get_tracks( search: Track = (
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx))[0] await self._node.get_tracks(
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx
)
)[0]
except: except:
# The song wasn't able to be found, raise error # The song wasn't able to be found, raise error
raise TrackLoadError ( raise TrackLoadError("No equivalent track was able to be found.")
"No equivalent track was able to be found."
)
data = { data = {
"encodedTrack": search.track_id, "encodedTrack": search.track_id,
"position": str(start), "position": str(start),
"endTime": str(track.length) "endTime": str(track.length),
} }
track.original = search track.original = search
track.track_id = search.track_id track.track_id = search.track_id
@ -420,10 +409,9 @@ class Player(VoiceProtocol):
data = { data = {
"encodedTrack": track.track_id, "encodedTrack": track.track_id,
"position": str(start), "position": str(start),
"endTime": str(track.length) "endTime": str(track.length),
} }
# Lets set the current track before we play it so any # Lets set the current track before we play it so any
# corresponding events can capture it correctly # corresponding events can capture it correctly
@ -459,7 +447,7 @@ class Player(VoiceProtocol):
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data=data, data=data,
query=f"noReplace={ignore_if_playing}" query=f"noReplace={ignore_if_playing}",
) )
return self._current return self._current
@ -467,15 +455,13 @@ class Player(VoiceProtocol):
async def seek(self, position: float) -> float: async def seek(self, position: float) -> float:
"""Seeks to a position in the currently playing track milliseconds""" """Seeks to a position in the currently playing track milliseconds"""
if position < 0 or position > self._current.original.length: if position < 0 or position > self._current.original.length:
raise TrackInvalidPosition( raise TrackInvalidPosition("Seek position must be between 0 and the track length")
"Seek position must be between 0 and the track length"
)
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"position": position} data={"position": position},
) )
return self._position return self._position
@ -485,7 +471,7 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"paused": pause} data={"paused": pause},
) )
self._paused = pause self._paused = pause
return self._paused return self._paused
@ -496,17 +482,17 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"volume": volume} data={"volume": volume},
) )
self._volume = volume self._volume = volume
return self._volume return self._volume
async def add_filter(self, filter: Filter, fast_apply: bool = False) -> Filter: async def add_filter(self, filter: Filter, fast_apply: bool = False) -> Filter:
"""Adds a filter to the player. Takes a pomice.Filter object. """Adds a filter to the player. Takes a pomice.Filter object.
This will only work if you are using a version of Lavalink that supports filters. This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`. If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
self._filters.add_filter(filter=filter) self._filters.add_filter(filter=filter)
@ -515,7 +501,7 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": payload} data={"filters": payload},
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
@ -524,10 +510,10 @@ class Player(VoiceProtocol):
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter: async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter:
"""Removes a filter from the player. Takes a filter tag. """Removes a filter from the player. Takes a filter tag.
This will only work if you are using a version of Lavalink that supports filters. This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`. If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
self._filters.remove_filter(filter_tag=filter_tag) self._filters.remove_filter(filter_tag=filter_tag)
@ -536,7 +522,7 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": payload} data={"filters": payload},
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
@ -545,10 +531,10 @@ class Player(VoiceProtocol):
async def reset_filters(self, *, fast_apply: bool = False): async def reset_filters(self, *, fast_apply: bool = False):
"""Resets all currently applied filters to their default parameters. """Resets all currently applied filters to their default parameters.
You must have filters applied in order for this to work. You must have filters applied in order for this to work.
If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`. If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`.
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
if not self._filters: if not self._filters:
@ -560,11 +546,8 @@ class Player(VoiceProtocol):
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": {}} data={"filters": {}},
) )
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)

View File

@ -10,11 +10,7 @@ from discord.ext import commands
from typing import Dict, List, Optional, TYPE_CHECKING, Union from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote from urllib.parse import quote
from . import ( from . import __version__, spotify, applemusic
__version__,
spotify,
applemusic
)
from .enums import * from .enums import *
from .exceptions import ( from .exceptions import (
@ -26,7 +22,7 @@ from .exceptions import (
NodeNotAvailable, NodeNotAvailable,
NoNodesAvailable, NoNodesAvailable,
NodeRestException, NodeRestException,
TrackLoadError TrackLoadError,
) )
from .filters import Filter from .filters import Filter
from .objects import Playlist, Track from .objects import Playlist, Track
@ -36,11 +32,12 @@ from .routeplanner import RoutePlanner
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
class Node: class Node:
"""The base class for a node. """The base class for a node.
This node object represents a Lavalink node. This node object represents a Lavalink node.
To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret
To enable Apple music, set the "apple_music" parameter to "True" To enable Apple music, set the "apple_music" parameter to "True"
""" """
def __init__( def __init__(
@ -59,8 +56,7 @@ class Node:
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False fallback: bool = False,
): ):
__slots__ = ( __slots__ = (
"_bot", "_bot",
@ -86,7 +82,7 @@ class Node:
"_spotify_client_id", "_spotify_client_id",
"_spotify_client_secret", "_spotify_client_secret",
"_spotify_client", "_spotify_client",
"_apple_music_client" "_apple_music_client",
) )
self._bot: Union[Client, commands.Bot] = bot self._bot: Union[Client, commands.Bot] = bot
@ -116,7 +112,7 @@ class Node:
self._headers = { self._headers = {
"Authorization": self._password, "Authorization": self._password,
"User-Id": str(self._bot.user.id), "User-Id": str(self._bot.user.id),
"Client-Name": f"Pomice/{__version__}" "Client-Name": f"Pomice/{__version__}",
} }
self._players: Dict[int, Player] = {} self._players: Dict[int, Player] = {}
@ -147,7 +143,6 @@ class Node:
"""Property which returns whether this node is connected or not""" """Property which returns whether this node is connected or not"""
return self._websocket is not None and not self._websocket.closed return self._websocket is not None and not self._websocket.closed
@property @property
def stats(self) -> NodeStats: def stats(self) -> NodeStats:
"""Property which returns the node stats.""" """Property which returns the node stats."""
@ -158,7 +153,6 @@ class Node:
"""Property which returns a dict containing the guild ID and the player object.""" """Property which returns a dict containing the guild ID and the player object."""
return self._players return self._players
@property @property
def bot(self) -> Union[Client, commands.Bot]: def bot(self) -> Union[Client, commands.Bot]:
"""Property which returns the discord.py client linked to this node""" """Property which returns the discord.py client linked to this node"""
@ -184,7 +178,6 @@ class Node:
"""Alias for `Node.latency`, returns the latency of the node""" """Alias for `Node.latency`, returns the latency of the node"""
return self.latency return self.latency
async def _update_handler(self, data: dict): async def _update_handler(self, data: dict):
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
@ -211,10 +204,10 @@ class Node:
return return
async def _handle_node_switch(self): async def _handle_node_switch(self):
nodes = [node for node in self._pool._nodes.values() if node.is_connected] nodes = [node for node in self.pool.nodes.copy().values() if node.is_connected]
new_node = random.choice(nodes) new_node = random.choice(nodes)
for player in self._players.values(): for player in self.players.copy().values():
await player._swap_node(new_node=new_node) await player._swap_node(new_node=new_node)
await self.disconnect() await self.disconnect()
@ -266,20 +259,24 @@ class Node:
ignore_if_available: bool = False, ignore_if_available: bool = False,
): ):
if not ignore_if_available and not self._available: if not ignore_if_available and not self._available:
raise NodeNotAvailable( raise NodeNotAvailable(f"The node '{self._identifier}' is unavailable.")
f"The node '{self._identifier}' is unavailable."
)
uri: str = f'{self._rest_uri}/' \ uri: str = (
f'{f"v{self._version}/" if include_version else ""}' \ f"{self._rest_uri}/"
f'{path}' \ f'{f"v{self._version}/" if include_version else ""}'
f'{f"/{guild_id}" if guild_id else ""}' \ f"{path}"
f'{f"?{query}" if query else ""}' f'{f"/{guild_id}" if guild_id else ""}'
f'{f"?{query}" if query else ""}'
)
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp: async with self._session.request(
method=method, url=uri, headers=self._headers, json=data or {}
) as resp:
if resp.status >= 300: if resp.status >= 300:
data: dict = await resp.json() data: dict = await resp.json()
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}') raise NodeRestException(
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}'
)
if method == "DELETE" or resp.status == 204: if method == "DELETE" or resp.status == 204:
return await resp.json(content_type=None) return await resp.json(content_type=None)
@ -289,8 +286,6 @@ class Node:
return await resp.json() return await resp.json()
def get_player(self, guild_id: int): def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object.""" """Takes a guild ID as a parameter. Returns a pomice Player object."""
return self._players.get(guild_id, None) return self._players.get(guild_id, None)
@ -303,26 +298,30 @@ class Node:
self._session = aiohttp.ClientSession() self._session = aiohttp.ClientSession()
try: try:
version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False) version = await self.send(
method="GET",
path="version",
ignore_if_available=True,
include_version=False,
)
version = version.replace(".", "") version = version.replace(".", "")
if not version.endswith('-SNAPSHOT') and int(version) < 370: if not version.endswith("-SNAPSHOT") and int(version) < 370:
self._available = False self._available = False
raise LavalinkVersionIncompatible( raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. " "The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library." "Lavalink version 3.7.0 or above is required to use this library."
) )
if version.endswith('-SNAPSHOT'): if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4 # we're just gonna assume all snapshot versions correlate with v4
self._version = 4 self._version = 4
else: else:
self._version = version[:1] self._version = version[:1]
self._websocket = await self._session.ws_connect( self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket", f"{self._websocket_uri}/v{self._version}/websocket",
headers=self._headers, headers=self._headers,
heartbeat=self._heartbeat heartbeat=self._heartbeat,
) )
if not self._task: if not self._task:
@ -344,10 +343,9 @@ class Node:
f"The URI for node '{self._identifier}' is invalid." f"The URI for node '{self._identifier}' is invalid."
) from None ) from None
async def disconnect(self): async def disconnect(self):
"""Disconnects a connected Lavalink node and removes it from the node pool. """Disconnects a connected Lavalink node and removes it from the node pool.
This also destroys any players connected to the node. This also destroys any players connected to the node.
""" """
for player in self.players.copy().values(): for player in self.players.copy().values():
await player.destroy() await player.destroy()
@ -364,11 +362,7 @@ class Node:
self.available = False self.available = False
self._task.cancel() self._task.cancel()
async def build_track( async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
self,
identifier: str,
ctx: Optional[commands.Context] = None
) -> Track:
""" """
Builds a track using a valid track identifier Builds a track using a valid track identifier
@ -376,8 +370,15 @@ class Node:
Context object on the track it builds. Context object on the track it builds.
""" """
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}") data: dict = await self.send(
return Track(track_id=identifier, ctx=ctx, info=data, track_type=TrackType(data['sourceName'])) method="GET", path="decodetrack", query=f"encodedTrack={identifier}"
)
return Track(
track_id=identifier,
ctx=ctx,
info=data,
track_type=TrackType(data["sourceName"]),
)
async def get_tracks( async def get_tracks(
self, self,
@ -385,18 +386,18 @@ class Node:
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None filters: Optional[List[Filter]] = None,
): ):
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a If you passed in Spotify API credentials, you can also pass in a
Spotify URL of a playlist, album or track and it will be parsed accordingly. Spotify URL of a playlist, album or track and it will be parsed accordingly.
You can pass in a discord.py Context object to get a You can pass in a discord.py Context object to get a
Context object on any track you search. Context object on any track you search.
You may also pass in a List of filters You may also pass in a List of filters
to be applied to your track once it plays. to be applied to your track once it plays.
""" """
timestamp = None timestamp = None
@ -434,8 +435,8 @@ class Node:
"isSeekable": True, "isSeekable": True,
"position": 0, "position": 0,
"thumbnail": apple_music_results.image, "thumbnail": apple_music_results.image,
"isrc": apple_music_results.isrc "isrc": apple_music_results.isrc,
} },
) )
] ]
@ -456,9 +457,10 @@ class Node:
"isSeekable": True, "isSeekable": True,
"position": 0, "position": 0,
"thumbnail": track.image, "thumbnail": track.image,
"isrc": track.isrc "isrc": track.isrc,
} },
) for track in apple_music_results.tracks )
for track in apple_music_results.tracks
] ]
return Playlist( return Playlist(
@ -466,10 +468,9 @@ class Node:
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType.APPLE_MUSIC, playlist_type=PlaylistType.APPLE_MUSIC,
thumbnail=apple_music_results.image, thumbnail=apple_music_results.image,
uri=apple_music_results.url uri=apple_music_results.url,
) )
elif URLRegex.SPOTIFY_URL.match(query): elif URLRegex.SPOTIFY_URL.match(query):
if not self._spotify_client_id and not self._spotify_client_secret: if not self._spotify_client_id and not self._spotify_client_secret:
raise InvalidSpotifyClientAuthorization( raise InvalidSpotifyClientAuthorization(
@ -498,8 +499,8 @@ class Node:
"isSeekable": True, "isSeekable": True,
"position": 0, "position": 0,
"thumbnail": spotify_results.image, "thumbnail": spotify_results.image,
"isrc": spotify_results.isrc "isrc": spotify_results.isrc,
} },
) )
] ]
@ -520,9 +521,10 @@ class Node:
"isSeekable": True, "isSeekable": True,
"position": 0, "position": 0,
"thumbnail": track.image, "thumbnail": track.image,
"isrc": track.isrc "isrc": track.isrc,
} },
) for track in spotify_results.tracks )
for track in spotify_results.tracks
] ]
return Playlist( return Playlist(
@ -530,12 +532,13 @@ class Node:
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType.SPOTIFY, playlist_type=PlaylistType.SPOTIFY,
thumbnail=spotify_results.image, thumbnail=spotify_results.image,
uri=spotify_results.uri uri=spotify_results.uri,
) )
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query): elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") method="GET", path="loadtracks", query=f"identifier={quote(query)}"
)
track: dict = data["tracks"][0] track: dict = data["tracks"][0]
info: dict = track.get("info") info: dict = track.get("info")
@ -549,27 +552,29 @@ class Node:
"length": info.get("length"), "length": info.get("length"),
"uri": info.get("uri"), "uri": info.get("uri"),
"position": info.get("position"), "position": info.get("position"),
"identifier": info.get("identifier") "identifier": info.get("identifier"),
}, },
ctx=ctx, ctx=ctx,
track_type=TrackType.HTTP, track_type=TrackType.HTTP,
filters=filters filters=filters,
) )
] ]
else: else:
# If YouTube url contains a timestamp, capture it for use later. # If YouTube url contains a timestamp, capture it for use later.
if (match := URLRegex.YOUTUBE_TIMESTAMP.match(query)): if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):
timestamp = float(match.group("time")) timestamp = float(match.group("time"))
# If query is a video thats part of a playlist, get the video and queue that instead # If query is a video thats part of a playlist, get the video and queue that instead
# (I can't tell you how much i've wanted to implement this in here) # (I can't tell you how much i've wanted to implement this in here)
if (match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query)): if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
query = match.group("video") query = match.group("video")
data: dict = await self.send(method="GET", path="loadtracks", query=f"identifier={quote(query)}") data: dict = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}"
)
load_type = data.get("loadType") load_type = data.get("loadType")
@ -585,15 +590,20 @@ class Node:
elif load_type == "PLAYLIST_LOADED": elif load_type == "PLAYLIST_LOADED":
tracks = [ tracks = [
Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"])) Track(
for track in data["tracks"] track_id=track["encoded"],
info=track["info"],
ctx=ctx,
track_type=TrackType(track["info"]["sourceName"]),
)
for track in data["tracks"]
] ]
return Playlist( return Playlist(
playlist_info=data["playlistInfo"], playlist_info=data["playlistInfo"],
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType(tracks[0].track_type.value), playlist_type=PlaylistType(tracks[0].track_type.value),
thumbnail=tracks[0].thumbnail, thumbnail=tracks[0].thumbnail,
uri=query uri=query,
) )
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED": elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
@ -604,16 +614,13 @@ class Node:
ctx=ctx, ctx=ctx,
track_type=TrackType(track["info"]["sourceName"]), track_type=TrackType(track["info"]["sourceName"]),
filters=filters, filters=filters,
timestamp=timestamp timestamp=timestamp,
) )
for track in data["tracks"] for track in data["tracks"]
] ]
async def get_recommendations( async def get_recommendations(
self, self, *, track: Track, ctx: Optional[commands.Context] = None
*,
track: Track,
ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Union[List[Track], None]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
@ -625,37 +632,43 @@ class Node:
if track.track_type == TrackType.SPOTIFY: if track.track_type == TrackType.SPOTIFY:
results = await self._spotify_client.get_recommendations(query=track.uri) results = await self._spotify_client.get_recommendations(query=track.uri)
tracks = [ tracks = [
Track( Track(
track_id=track.id, track_id=track.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.SPOTIFY, track_type=TrackType.SPOTIFY,
info={ info={
"title": track.name, "title": track.name,
"author": track.artists, "author": track.artists,
"length": track.length, "length": track.length,
"identifier": track.id, "identifier": track.id,
"uri": track.uri, "uri": track.uri,
"isStream": False, "isStream": False,
"isSeekable": True, "isSeekable": True,
"position": 0, "position": 0,
"thumbnail": track.image, "thumbnail": track.image,
"isrc": track.isrc "isrc": track.isrc,
}, },
requester=self.bot.user requester=self.bot.user,
) for track in results )
] for track in results
]
return tracks return tracks
elif track.track_type == TrackType.YOUTUBE: elif track.track_type == TrackType.YOUTUBE:
tracks = await self.get_tracks(query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx) tracks = await self.get_tracks(
query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}",
ctx=ctx,
)
return tracks return tracks
else: else:
raise TrackLoadError("The specfied track must be either a YouTube or Spotify track to recieve recommendations.") raise TrackLoadError(
"The specfied track must be either a YouTube or Spotify track to recieve recommendations."
)
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.
This holds all the nodes that are to be used by the bot. This holds all the nodes that are to be used by the bot.
""" """
_nodes: Dict[str, Node] = {} _nodes: Dict[str, Node] = {}
@ -675,17 +688,17 @@ class NodePool:
@classmethod @classmethod
def get_best_node(cls, *, algorithm: NodeAlgorithm) -> Node: def get_best_node(cls, *, algorithm: NodeAlgorithm) -> Node:
"""Fetches the best node based on an NodeAlgorithm. """Fetches the best node based on an NodeAlgorithm.
This option is preferred if you want to choose the best node This option is preferred if you want to choose the best node
from a multi-node setup using either the node's latency from a multi-node setup using either the node's latency
or the node's voice region. or the node's voice region.
Use NodeAlgorithm.by_ping if you want to get the best node Use NodeAlgorithm.by_ping if you want to get the best node
based on the node's latency. based on the node's latency.
Use NodeAlgorithm.by_players if you want to get the best node Use NodeAlgorithm.by_players if you want to get the best node
based on how players it has. This method will return a node with based on how players it has. This method will return a node with
the least amount of players the least amount of players
""" """
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
@ -700,15 +713,13 @@ class NodePool:
tested_nodes = {node: len(node.players.keys()) for node in available_nodes} tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get) return min(tested_nodes, key=tested_nodes.get)
@classmethod @classmethod
def get_node(cls, *, identifier: str = None) -> Node: def get_node(cls, *, identifier: str = None) -> Node:
"""Fetches a node from the node pool using it's identifier. """Fetches a node from the node pool using it's identifier.
If no identifier is provided, it will choose a node at random. If no identifier is provided, it will choose a node at random.
""" """
available_nodes = { available_nodes = {
identifier: node identifier: node for identifier, node in cls._nodes.items() if node._available
for identifier, node in cls._nodes.items() if node._available
} }
if not available_nodes: if not available_nodes:
@ -735,21 +746,29 @@ class NodePool:
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False fallback: bool = False,
) -> Node: ) -> Node:
"""Creates a Node object to be then added into the node pool. """Creates a Node object to be then added into the node pool.
For Spotify searching capabilites, pass in valid Spotify API credentials. For Spotify searching capabilites, pass in valid Spotify API credentials.
""" """
if identifier in cls._nodes.keys(): if identifier in cls._nodes.keys():
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.") raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
node = Node( node = Node(
pool=cls, bot=bot, host=host, port=port, password=password, pool=cls,
identifier=identifier, secure=secure, heartbeat=heartbeat, bot=bot,
loop=loop, spotify_client_id=spotify_client_id, host=host,
session=session, spotify_client_secret=spotify_client_secret, port=port,
apple_music=apple_music, fallback=fallback password=password,
identifier=identifier,
secure=secure,
heartbeat=heartbeat,
loop=loop,
spotify_client_id=spotify_client_id,
session=session,
spotify_client_secret=spotify_client_secret,
apple_music=apple_music,
fallback=fallback,
) )
await node.connect() await node.connect()

View File

@ -16,23 +16,17 @@ from .exceptions import QueueEmpty, QueueException, QueueFull
class Queue(Iterable[Track]): class Queue(Iterable[Track]):
"""Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling.""" """Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling."""
def __init__( def __init__(
self, self,
max_size: Optional[int] = None, max_size: Optional[int] = None,
*, *,
overflow: bool = True, overflow: bool = True,
): ):
__slots__ = ("max_size", "_queue", "_overflow", "_loop_mode", "_current_item")
__slots__ = (
"max_size",
"_queue",
"_overflow",
"_loop_mode",
"_current_item"
)
self.max_size: Optional[int] = max_size self.max_size: Optional[int] = max_size
self._queue: List[Track] = [] # type: ignore self._queue: List[Track] = [] # type: ignore
self._overflow: bool = overflow self._overflow: bool = overflow
self._loop_mode: Optional[LoopMode] = None self._loop_mode: Optional[LoopMode] = None
self._current_item: Optional[Track] = None self._current_item: Optional[Track] = None
@ -43,9 +37,7 @@ class Queue(Iterable[Track]):
def __repr__(self) -> str: def __repr__(self) -> str:
"""Official representation with max_size and member count.""" """Official representation with max_size and member count."""
return ( return f"<{self.__class__.__name__} max_size={self.max_size} members={self.count}>"
f"<{self.__class__.__name__} max_size={self.max_size} members={self.count}>"
)
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Treats the queue as a bool, with it evaluating True when it contains members.""" """Treats the queue as a bool, with it evaluating True when it contains members."""
@ -125,7 +117,6 @@ class Queue(Iterable[Track]):
def _index(self, item: Track) -> int: def _index(self, item: Track) -> int:
return self._queue.index(item) return self._queue.index(item)
def _put(self, item: Track) -> None: def _put(self, item: Track) -> None:
self._queue.append(item) self._queue.append(item)
@ -183,13 +174,10 @@ class Queue(Iterable[Track]):
"""Returns the amount of items in the queue""" """Returns the amount of items in the queue"""
return len(self._queue) return len(self._queue)
def get_queue(self) -> List: def get_queue(self) -> List:
"""Returns the queue as a List""" """Returns the queue as a List"""
return self._queue return self._queue
def get(self): def get(self):
"""Return next immediately available item in queue if any. """Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue. Raises QueueEmpty if no items in queue.
@ -202,7 +190,6 @@ class Queue(Iterable[Track]):
raise QueueEmpty("No items in the queue.") raise QueueEmpty("No items in the queue.")
if self._loop_mode == LoopMode.QUEUE: if self._loop_mode == LoopMode.QUEUE:
# recurse if the item isnt in the queue # recurse if the item isnt in the queue
if self._current_item not in self._queue: if self._current_item not in self._queue:
self.get() self.get()
@ -242,7 +229,6 @@ class Queue(Iterable[Track]):
""" """
return self._remove(self._check_track(item)) return self._remove(self._check_track(item))
def find_position(self, item: Track) -> int: def find_position(self, item: Track) -> int:
"""Find the position a given item within the queue. """Find the position a given item within the queue.
Raises ValueError if item is not in queue. Raises ValueError if item is not in queue.
@ -321,7 +307,6 @@ class Queue(Iterable[Track]):
self._queue.insert(index, self._current_item) self._queue.insert(index, self._current_item)
self._current_item = self._queue[index] self._current_item = self._queue[index]
def disable_loop(self): def disable_loop(self):
""" """
Disables loop mode if set. Disables loop mode if set.
@ -336,7 +321,6 @@ class Queue(Iterable[Track]):
self._loop_mode = None self._loop_mode = None
def shuffle(self): def shuffle(self):
"""Shuffles the queue.""" """Shuffles the queue."""
return random.shuffle(self._queue) return random.shuffle(self._queue)
@ -349,5 +333,5 @@ class Queue(Iterable[Track]):
def jump(self, item: Track): def jump(self, item: Track):
"""Removes all tracks before the.""" """Removes all tracks before the."""
index = self.find_position(item) index = self.find_position(item)
new_queue = self._queue[index:self.size] new_queue = self._queue[index : self.size]
self._queue = new_queue self._queue = new_queue

View File

@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from .pool import Node from .pool import Node
from .utils import RouteStats from .utils import RouteStats
from aiohttp import ClientSession from aiohttp import ClientSession
class RoutePlanner: class RoutePlanner:
""" """
The base route planner class for Pomice. The base route planner class for Pomice.

View File

@ -32,9 +32,7 @@ class Client:
self._bearer_token: str = None self._bearer_token: str = None
self._expiry = 0 self._expiry = 0
self._auth_token = b64encode( self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
f"{self._client_id}:{self._client_secret}".encode()
)
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"} self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._bearer_headers = None self._bearer_headers = None
@ -44,9 +42,7 @@ class Client:
if not self.session: if not self.session:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
async with self.session.post( async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
GRANT_URL, data=_data, headers=self._grant_headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}" f"Error fetching bearer token: {resp.status} {resp.reason}"
@ -110,9 +106,7 @@ class Client:
next_page_url = data["tracks"]["next"] next_page_url = data["tracks"]["next"]
while next_page_url is not None: while next_page_url is not None:
async with self.session.get( async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
next_page_url, headers=self._bearer_headers
) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}"
@ -143,9 +137,7 @@ class Client:
if not spotify_type == "track": if not spotify_type == "track":
raise InvalidSpotifyURL("The provided query is not a Spotify track.") raise InvalidSpotifyURL("The provided query is not a Spotify track.")
request_url = REQUEST_URL.format( request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}")
type="recommendation", id=f"?seed_tracks={spotify_id}"
)
async with self.session.get(request_url, headers=self._bearer_headers) as resp: async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200: if resp.status != 200:

View File

@ -1,8 +1,10 @@
class SpotifyRequestException(Exception): class SpotifyRequestException(Exception):
"""An error occurred when making a request to the Spotify API""" """An error occurred when making a request to the Spotify API"""
pass pass
class InvalidSpotifyURL(Exception): class InvalidSpotifyURL(Exception):
"""An invalid Spotify URL was passed""" """An invalid Spotify URL was passed"""
pass pass

View File

@ -4,7 +4,7 @@ from typing import List
class Track: class Track:
"""The base class for a Spotify Track""" """The base class for a Spotify Track"""
def __init__(self, data: dict, image = None) -> None: def __init__(self, data: dict, image=None) -> None:
self.name: str = data["name"] self.name: str = data["name"]
self.artists: str = ", ".join(artist["name"] for artist in data["artists"]) self.artists: str = ", ".join(artist["name"] for artist in data["artists"])
self.length: float = data["duration_ms"] self.length: float = data["duration_ms"]
@ -31,6 +31,7 @@ class Track:
f"length={self.length} id={self.id} isrc={self.isrc}>" f"length={self.length} id={self.id} isrc={self.isrc}>"
) )
class Playlist: class Playlist:
"""The base class for a Spotify playlist""" """The base class for a Spotify playlist"""
@ -52,6 +53,7 @@ class Playlist:
f"total_tracks={self.total_tracks} tracks={self.tracks}>" f"total_tracks={self.total_tracks} tracks={self.tracks}>"
) )
class Album: class Album:
"""The base class for a Spotify album""" """The base class for a Spotify album"""
@ -70,11 +72,14 @@ class Album:
f"total_tracks={self.total_tracks} tracks={self.tracks}>" f"total_tracks={self.total_tracks} tracks={self.tracks}>"
) )
class Artist: class Artist:
"""The base class for a Spotify artist""" """The base class for a Spotify artist"""
def __init__(self, data: dict, tracks: dict) -> None: def __init__(self, data: dict, tracks: dict) -> None:
self.name: str = f"Top tracks for {data['name']}" # Setting that because its only playing top tracks self.name: str = (
f"Top tracks for {data['name']}" # Setting that because its only playing top tracks
)
self.genres: str = ", ".join(genre for genre in data["genres"]) self.genres: str = ", ".join(genre for genre in data["genres"])
self.followers: int = data["followers"]["total"] self.followers: int = data["followers"]["total"]
self.image: str = data["images"][0]["url"] self.image: str = data["images"][0]["url"]
@ -83,7 +88,4 @@ class Artist:
self.uri: str = data["external_urls"]["spotify"] self.uri: str = data["external_urls"]["spotify"]
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return f"<Pomice.spotify.Artist name={self.name} id={self.id} " f"tracks={self.tracks}>"
f"<Pomice.spotify.Artist name={self.name} id={self.id} "
f"tracks={self.tracks}>"
)

View File

@ -30,12 +30,11 @@ class ExponentialBackoff:
""" """
def __init__(self, base: int = 1, *, integral: bool = False) -> None: def __init__(self, base: int = 1, *, integral: bool = False) -> None:
self._base = base self._base = base
self._exp = 0 self._exp = 0
self._max = 10 self._max = 10
self._reset_time = base * 2 ** 11 self._reset_time = base * 2**11
self._last_invocation = time.monotonic() self._last_invocation = time.monotonic()
rand = random.Random() rand = random.Random()
@ -44,7 +43,6 @@ class ExponentialBackoff:
self._randfunc = rand.randrange if integral else rand.uniform self._randfunc = rand.randrange if integral else rand.uniform
def delay(self) -> float: def delay(self) -> float:
invocation = time.monotonic() invocation = time.monotonic()
interval = invocation - self._last_invocation interval = invocation - self._last_invocation
self._last_invocation = invocation self._last_invocation = invocation
@ -53,16 +51,15 @@ class ExponentialBackoff:
self._exp = 0 self._exp = 0
self._exp = min(self._exp + 1, self._max) self._exp = min(self._exp + 1, self._max)
return self._randfunc(0, self._base * 2 ** self._exp) return self._randfunc(0, self._base * 2**self._exp)
class NodeStats: class NodeStats:
"""The base class for the node stats object. """The base class for the node stats object.
Gives critical information on the node, which is updated every minute. Gives critical information on the node, which is updated every minute.
""" """
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
__slots__ = ( __slots__ = (
"used", "used",
"free", "free",
@ -73,7 +70,7 @@ class NodeStats:
"cpu_process_load", "cpu_process_load",
"players_active", "players_active",
"players_total", "players_total",
"uptime" "uptime",
) )
memory: dict = data.get("memory") memory: dict = data.get("memory")
@ -94,18 +91,16 @@ class NodeStats:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>" return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
class FailingIPBlock: class FailingIPBlock:
""" """
The base class for the failing IP block object from the route planner stats. The base class for the failing IP block object from the route planner stats.
Gives critical information about any failing addresses on the block Gives critical information about any failing addresses on the block
and the time they failed. and the time they failed.
""" """
def __init__(self, data: dict) -> None:
__slots__ = ( def __init__(self, data: dict) -> None:
"address", __slots__ = ("address", "failing_time")
"failing_time"
)
self.address = data.get("address") self.address = data.get("address")
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp"))) self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp")))
@ -121,13 +116,7 @@ class RouteStats:
""" """
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
__slots__ = ("strategy", "ip_block_type", "ip_block_size", "failing_addresses")
__slots__ = (
"strategy",
"ip_block_type",
"ip_block_size",
"failing_addresses"
)
self.strategy = RouteStrategy(data.get("class")) self.strategy = RouteStrategy(data.get("class"))
@ -172,7 +161,6 @@ class Ping:
def close(self): def close(self):
self._s.close() self._s.close()
class Timer(object): class Timer(object):
def __init__(self): def __init__(self):
self._start = 0 self._start = 0
@ -201,10 +189,7 @@ class Ping:
def get_ping(self): def get_ping(self):
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM) s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost( cost_time = self.timer.cost((s.connect, s.shutdown), ((self._host, self._port), None))
(s.connect, s.shutdown),
((self._host, self._port), None))
s_runtime = 1000 * (cost_time) s_runtime = 1000 * (cost_time)
return s_runtime return s_runtime

View File

@ -5,8 +5,5 @@ requires = [
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.autopep8] [tool.black]
max_line_length = 100 line-length = 100
in-place = true
recursive = true
aggressive = 3

View File

@ -1,33 +1,35 @@
import setuptools import setuptools
import re import re
version = '' version = ""
requirements = ['discord.py>=2.0.0', 'aiohttp>=3.7.4,<4', 'orjson'] requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
with open('pomice/__init__.py') as f: with open("pomice/__init__.py") as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1) version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
if not version: if not version:
raise RuntimeError('version is not set') raise RuntimeError("version is not set")
if version.endswith(('a', 'b', 'rc')): if version.endswith(("a", "b", "rc")):
# append version identifier based on commit count # append version identifier based on commit count
try: try:
import subprocess import subprocess
p = subprocess.Popen(['git', 'rev-list', '--count', 'HEAD'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE) p = subprocess.Popen(
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out, err = p.communicate() out, err = p.communicate()
if out: if out:
version += out.decode('utf-8').strip() version += out.decode("utf-8").strip()
p = subprocess.Popen(['git', 'rev-parse', '--short', 'HEAD'], p = subprocess.Popen(
stdout=subprocess.PIPE, stderr=subprocess.PIPE) ["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out, err = p.communicate() out, err = p.communicate()
if out: if out:
version += '+g' + out.decode('utf-8').strip() version += "+g" + out.decode("utf-8").strip()
except Exception: except Exception:
pass pass
with open("README.md") as f: with open("README.md") as f:
readme = f.read() readme = f.read()
@ -47,15 +49,15 @@ setuptools.setup(
extra_require=None, extra_require=None,
classifiers=[ classifiers=[
"Framework :: AsyncIO", "Framework :: AsyncIO",
'Operating System :: OS Independent', "Operating System :: OS Independent",
'Natural Language :: English', "Natural Language :: English",
'Intended Audience :: Developers', "Intended Audience :: Developers",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
'Topic :: Software Development :: Libraries :: Python Modules', "Topic :: Software Development :: Libraries :: Python Modules",
'Topic :: Software Development :: Libraries', "Topic :: Software Development :: Libraries",
"Topic :: Internet" "Topic :: Internet",
], ],
python_requires='>=3.8', python_requires=">=3.8",
keywords=['pomice', 'lavalink', "discord.py"], keywords=["pomice", "lavalink", "discord.py"],
) )