This commit is contained in:
Andrei Baciu 2024-04-12 19:44:29 +02:00 committed by GitHub
commit 1d162df60d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 712 additions and 646 deletions

View File

@ -5,7 +5,9 @@ name = "pypi"
[packages] [packages]
orjson = "*" orjson = "*"
pydantic = ">=2"
"discord.py" = {extras = ["voice"], version = "*"} "discord.py" = {extras = ["voice"], version = "*"}
websockets = "*"
[dev-packages] [dev-packages]
mypy = "*" mypy = "*"

View File

@ -125,7 +125,7 @@ class Music(commands.Cog):
return player.dj == ctx.author or ctx.author.guild_permissions.kick_members return player.dj == ctx.author or ctx.author.guild_permissions.kick_members
# The following are events from pomice.events # The following are events from pomice.models.events
# We are using these so that if the track either stops or errors, # We are using these so that if the track either stops or errors,
# we can just skip to the next track # we can just skip to the next track

View File

@ -27,7 +27,7 @@ __license__ = "GPL-3.0"
__copyright__ = "Copyright (c) 2023, cloudwithax" __copyright__ = "Copyright (c) 2023, cloudwithax"
from .enums import * from .enums import *
from .events import * from .models import *
from .exceptions import * from .exceptions import *
from .filters import * from .filters import *
from .objects import * from .objects import *

View File

@ -11,17 +11,12 @@ from typing import Union
import aiohttp import aiohttp
import orjson as json import orjson as json
from .exceptions import * from pomice.applemusic.exceptions import *
from .objects import * from pomice.applemusic.objects import *
from pomice.enums import URLRegex
__all__ = ("Client",) __all__ = ("Client",)
AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
)
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"') AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
@ -103,7 +98,7 @@ class Client:
if not self.token or datetime.utcnow() > self.expiry: if not self.token or datetime.utcnow() > self.expiry:
await self.request_token() await self.request_token()
result = AM_URL_REGEX.match(query) result = URLRegex.AM_URL.match(query)
if not result: if not result:
raise InvalidAppleMusicURL( raise InvalidAppleMusicURL(
"The Apple Music link provided is not valid.", "The Apple Music link provided is not valid.",
@ -113,7 +108,7 @@ class Client:
type = result.group("type") type = result.group("type")
id = result.group("id") id = result.group("id")
if type == "album" and (sia_result := AM_SINGLE_IN_ALBUM_REGEX.match(query)): if type == "album" and (sia_result := URLRegex.AM_SINGLE_IN_ALBUM_REGEX.match(query)):
# apple music likes to generate links for singles off an album # apple music likes to generate links for singles off an album
# by adding a param at the end of the url # by adding a param at the end of the url
# so we're gonna scan for that and correct it # so we're gonna scan for that and correct it

View File

@ -1,6 +1,7 @@
import re import re
from enum import Enum from enum import Enum
from enum import IntEnum from enum import IntEnum
from enum import unique
__all__ = ( __all__ = (
"SearchType", "SearchType",
@ -15,7 +16,13 @@ __all__ = (
) )
class SearchType(Enum): class BaseStrEnum(str, Enum):
def __str__(self):
return self.value
@unique
class SearchType(BaseStrEnum):
""" """
The enum for the different search types for Pomice. The enum for the different search types for Pomice.
This feature is exclusively for the Spotify search feature of Pomice. This feature is exclusively for the Spotify search feature of Pomice.
@ -31,15 +38,13 @@ class SearchType(Enum):
which is an alternative to YouTube or YouTube Music. which is an alternative to YouTube or YouTube Music.
""" """
ytsearch = "ytsearch" YTSEARCH = "ytsearch"
ytmsearch = "ytmsearch" YTMSEARCH = "ytmsearch"
scsearch = "scsearch" SCSEARCH = "scsearch"
def __str__(self) -> str:
return self.value
class TrackType(Enum): @unique
class TrackType(BaseStrEnum):
""" """
The enum for the different track types for Pomice. The enum for the different track types for Pomice.
@ -64,11 +69,9 @@ class TrackType(Enum):
HTTP = "http" HTTP = "http"
LOCAL = "local" LOCAL = "local"
def __str__(self) -> str:
return self.value
@unique
class PlaylistType(Enum): class PlaylistType(BaseStrEnum):
""" """
The enum for the different playlist types for Pomice. The enum for the different playlist types for Pomice.
@ -87,11 +90,9 @@ class PlaylistType(Enum):
SPOTIFY = "spotify" SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music" APPLE_MUSIC = "apple_music"
def __str__(self) -> str:
return self.value
@unique
class NodeAlgorithm(Enum): class NodeAlgorithm(BaseStrEnum):
""" """
The enum for the different node algorithms in Pomice. The enum for the different node algorithms in Pomice.
@ -111,11 +112,9 @@ class NodeAlgorithm(Enum):
by_ping = "BY_PING" by_ping = "BY_PING"
by_players = "BY_PLAYERS" by_players = "BY_PLAYERS"
def __str__(self) -> str:
return self.value
@unique
class LoopMode(Enum): class LoopMode(BaseStrEnum):
""" """
The enum for the different loop modes. The enum for the different loop modes.
This feature is exclusively for the queue utility of pomice. This feature is exclusively for the queue utility of pomice.
@ -124,18 +123,15 @@ class LoopMode(Enum):
LoopMode.TRACK sets the queue loop to the current track. LoopMode.TRACK sets the queue loop to the current track.
LoopMode.QUEUE sets the queue loop to the whole queue. LoopMode.QUEUE sets the queue loop to the whole queue.
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
TRACK = "track" TRACK = "track"
QUEUE = "queue" QUEUE = "queue"
def __str__(self) -> str:
return self.value
@unique
class RouteStrategy(Enum): class RouteStrategy(BaseStrEnum):
""" """
The enum for specifying the route planner strategy for Lavalink. The enum for specifying the route planner strategy for Lavalink.
This feature is exclusively for the RoutePlanner class. This feature is exclusively for the RoutePlanner class.
@ -153,7 +149,6 @@ class RouteStrategy(Enum):
RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching
between IPs every CPU clock cycle and is rotating between IP blocks on between IPs every CPU clock cycle and is rotating between IP blocks on
ban. ban.
""" """
ROTATE_ON_BAN = "RotatingIpRoutePlanner" ROTATE_ON_BAN = "RotatingIpRoutePlanner"
@ -162,7 +157,8 @@ class RouteStrategy(Enum):
ROTATING_NANO_SWITCH = "RotatingNanoIpRoutePlanner" ROTATING_NANO_SWITCH = "RotatingNanoIpRoutePlanner"
class RouteIPType(Enum): @unique
class RouteIPType(BaseStrEnum):
""" """
The enum for specifying the route planner IP block type for Lavalink. The enum for specifying the route planner IP block type for Lavalink.
This feature is exclusively for the RoutePlanner class. This feature is exclusively for the RoutePlanner class.
@ -177,9 +173,43 @@ class RouteIPType(Enum):
IPV6 = "Inet6Address" IPV6 = "Inet6Address"
@unique
class LogLevel(IntEnum):
"""
The enum for specifying the logging level within Pomice.
This class serves as shorthand for logging.<level>
This enum is exclusively for the logging feature in Pomice.
If you are not using this feature, this class is not necessary.
LogLevel.DEBUG sets the logging level to "debug".
LogLevel.INFO sets the logging level to "info".
LogLevel.WARN sets the logging level to "warn".
LogLevel.ERROR sets the logging level to "error".
LogLevel.CRITICAL sets the logging level to "CRITICAL".
"""
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")
class URLRegex: class URLRegex:
""" """
The enum for all the URL Regexes in use by Pomice. The class for all the URL Regexes in use by Pomice.
URLRegex.SPOTIFY_URL returns the Spotify URL Regex. URLRegex.SPOTIFY_URL returns the Spotify URL Regex.
@ -196,7 +226,6 @@ class URLRegex:
URLRegex.SOUNDCLOUD_URL returns the SoundCloud URL Regex. URLRegex.SOUNDCLOUD_URL returns the SoundCloud URL Regex.
URLRegex.BASE_URL returns the standard URL Regex. URLRegex.BASE_URL returns the standard URL Regex.
""" """
SPOTIFY_URL = re.compile( SPOTIFY_URL = re.compile(
@ -246,37 +275,3 @@ class URLRegex:
LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:") LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")
BASE_URL = re.compile(r"https?://(?:www\.)?.+") BASE_URL = re.compile(r"https?://(?:www\.)?.+")
class LogLevel(IntEnum):
"""
The enum for specifying the logging level within Pomice.
This class serves as shorthand for logging.<level>
This enum is exclusively for the logging feature in Pomice.
If you are not using this feature, this class is not necessary.
LogLevel.DEBUG sets the logging level to "debug".
LogLevel.INFO sets the logging level to "info".
LogLevel.WARN sets the logging level to "warn".
LogLevel.ERROR sets the logging level to "error".
LogLevel.CRITICAL sets the logging level to "CRITICAL".
"""
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")

View File

@ -1,197 +0,0 @@
from __future__ import annotations
from abc import ABC
from typing import Any
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from discord import Client
from discord import Guild
from discord.ext import commands
from .objects import Track
from .pool import NodePool
if TYPE_CHECKING:
from .player import Player
__all__ = (
"PomiceEvent",
"TrackStartEvent",
"TrackEndEvent",
"TrackStuckEvent",
"TrackExceptionEvent",
"WebSocketClosedPayload",
"WebSocketClosedEvent",
"WebSocketOpenEvent",
)
class PomiceEvent(ABC):
"""The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be:
```py
@bot.listen
async def on_pomice_track_start(self, event):
```
"""
name = "event"
handler_args: Tuple
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", *self.handler_args)
class TrackStartEvent(PomiceEvent):
"""Fired when a track has successfully started.
Returns the player associated with the event and the pomice.Track object.
"""
name = "track_start"
__slots__ = (
"player",
"track",
)
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._current
# on_pomice_track_start(player, track)
self.handler_args = self.player, self.track
def __repr__(self) -> str:
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
class TrackEndEvent(PomiceEvent):
"""Fired when a track has successfully ended.
Returns the player associated with the event along with the pomice.Track object and reason.
"""
name = "track_end"
__slots__ = ("player", "track", "reason")
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
self.reason: str = data["reason"]
# on_pomice_track_end(player, track, reason)
self.handler_args = self.player, self.track, self.reason
def __repr__(self) -> str:
return (
f"<Pomice.TrackEndEvent player={self.player!r} track_id={self.track!r} "
f"reason={self.reason!r}>"
)
class TrackStuckEvent(PomiceEvent):
"""Fired when a track is stuck and cannot be played. Returns the player
associated with the event along with the pomice.Track object
to be further parsed by the end user.
"""
name = "track_stuck"
__slots__ = ("player", "track", "threshold")
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
self.threshold: float = data["thresholdMs"]
# on_pomice_track_stuck(player, track, threshold)
self.handler_args = self.player, self.track, self.threshold
def __repr__(self) -> str:
return (
f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} "
f"threshold={self.threshold!r}>"
)
class TrackExceptionEvent(PomiceEvent):
"""Fired when a track error has occured.
Returns the player associated with the event along with the error code and exception.
"""
name = "track_exception"
__slots__ = ("player", "track", "exception")
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
# Error is for Lavalink <= 3.3
self.exception: str = data.get(
"error",
"",
) or data.get("exception", "")
# on_pomice_track_exception(player, track, error)
self.handler_args = self.player, self.track, self.exception
def __repr__(self) -> str:
return f"<Pomice.TrackExceptionEvent player={self.player!r} exception={self.exception!r}>"
class WebSocketClosedPayload:
__slots__ = ("guild", "code", "reason", "by_remote")
def __init__(self, data: dict):
self.guild: Optional[Guild] = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.code: int = data["code"]
self.reason: str = data["code"]
self.by_remote: bool = data["byRemote"]
def __repr__(self) -> str:
return (
f"<Pomice.WebSocketClosedPayload guild={self.guild!r} code={self.code!r} "
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
)
class WebSocketClosedEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been closed.
Returns the reason and the error code.
"""
name = "websocket_closed"
__slots__ = ("payload",)
def __init__(self, data: dict, _: Any) -> None:
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload)
self.handler_args = (self.payload,)
def __repr__(self) -> str:
return f"<Pomice.WebsocketClosedEvent payload={self.payload!r}>"
class WebSocketOpenEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been initiated.
Returns the target and the session SSRC.
"""
name = "websocket_open"
__slots__ = ("target", "ssrc")
def __init__(self, data: dict, _: Any) -> None:
self.target: str = data["target"]
self.ssrc: int = data["ssrc"]
# on_pomice_websocket_open(target, ssrc)
self.handler_args = self.target, self.ssrc
def __repr__(self) -> str:
return f"<Pomice.WebsocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"

View File

@ -61,7 +61,7 @@ class NoNodesAvailable(PomiceException):
pass pass
class TrackInvalidPosition(PomiceException): class TrackInvalidPosition(PomiceException, ValueError):
"""An invalid position was chosen for a track.""" """An invalid position was chosen for a track."""
pass pass
@ -73,19 +73,19 @@ class TrackLoadError(PomiceException):
pass pass
class FilterInvalidArgument(PomiceException): class FilterInvalidArgument(PomiceException, ValueError):
"""An invalid argument was passed to a filter.""" """An invalid argument was passed to a filter."""
pass pass
class FilterTagInvalid(PomiceException): class FilterTagInvalid(PomiceException, ValueError):
"""An invalid tag was passed or Pomice was unable to find a filter tag""" """An invalid tag was passed or Pomice was unable to find a filter tag"""
pass pass
class FilterTagAlreadyInUse(PomiceException): class FilterTagAlreadyInUse(PomiceException, ValueError):
"""A filter with a tag is already in use by another filter""" """A filter with a tag is already in use by another filter"""
pass pass
@ -97,7 +97,7 @@ class InvalidSpotifyClientAuthorization(PomiceException):
pass pass
class AppleMusicNotEnabled(PomiceException): class AppleMusicNotEnabled(PomiceException, ValueError):
"""An Apple Music Link was passed in when Apple Music functionality was not enabled.""" """An Apple Music Link was passed in when Apple Music functionality was not enabled."""
pass pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import collections import collections
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -5,7 +7,7 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from .exceptions import FilterInvalidArgument from pomice.exceptions import FilterInvalidArgument
__all__ = ( __all__ = (
"Filter", "Filter",
@ -84,7 +86,7 @@ class Equalizer(Filter):
return self.raw == __value.raw return self.raw == __value.raw
@classmethod @classmethod
def flat(cls) -> "Equalizer": def flat(cls) -> Equalizer:
"""Equalizer preset which represents a flat EQ board, """Equalizer preset which represents a flat EQ board,
with all levels set to their default values. with all levels set to their default values.
""" """
@ -109,7 +111,7 @@ class Equalizer(Filter):
return cls(tag="flat", levels=levels) return cls(tag="flat", levels=levels)
@classmethod @classmethod
def boost(cls) -> "Equalizer": def boost(cls) -> Equalizer:
"""Equalizer preset which boosts the sound of a track, """Equalizer preset which boosts the sound of a track,
making it sound fun and energetic by increasing the bass making it sound fun and energetic by increasing the bass
and the highs. and the highs.
@ -135,7 +137,7 @@ class Equalizer(Filter):
return cls(tag="boost", levels=levels) return cls(tag="boost", levels=levels)
@classmethod @classmethod
def metal(cls) -> "Equalizer": def metal(cls) -> Equalizer:
"""Equalizer preset which increases the mids of a track, """Equalizer preset which increases the mids of a track,
preferably one of the metal genre, to make it sound preferably one of the metal genre, to make it sound
more full and concert-like. more full and concert-like.
@ -162,7 +164,7 @@ class Equalizer(Filter):
return cls(tag="metal", levels=levels) return cls(tag="metal", levels=levels)
@classmethod @classmethod
def piano(cls) -> "Equalizer": def piano(cls) -> Equalizer:
"""Equalizer preset which increases the mids and highs """Equalizer preset which increases the mids and highs
of a track, preferably a piano based one, to make it of a track, preferably a piano based one, to make it
stand out. stand out.
@ -215,7 +217,7 @@ class Timescale(Filter):
} }
@classmethod @classmethod
def vaporwave(cls) -> "Timescale": def vaporwave(cls) -> Timescale:
"""Timescale preset which slows down the currently playing track, """Timescale preset which slows down the currently playing track,
giving it the effect of a half-speed record/casette playing. giving it the effect of a half-speed record/casette playing.
@ -225,7 +227,7 @@ class Timescale(Filter):
return cls(tag="vaporwave", speed=0.8, pitch=0.8) return cls(tag="vaporwave", speed=0.8, pitch=0.8)
@classmethod @classmethod
def nightcore(cls) -> "Timescale": def nightcore(cls) -> Timescale:
"""Timescale preset which speeds up the currently playing track, """Timescale preset which speeds up the currently playing track,
which matches up to nightcore, a genre of sped-up music which matches up to nightcore, a genre of sped-up music

23
pomice/models/__init__.py Normal file
View File

@ -0,0 +1,23 @@
import pydantic
from pydantic import ConfigDict
from .events import *
from .music import *
from .payloads import *
from .version import *
class BaseModel(pydantic.BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
def model_dump(self, *args, **kwargs) -> dict:
by_alias = kwargs.pop("by_alias", True)
mode = kwargs.pop("mode", "json")
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
class VersionedModel(BaseModel):
version: LavalinkVersionType
def model_dump(self, *args, **kwargs) -> dict:
return super().model_dump(*args, **kwargs, exclude={"version"})

179
pomice/models/events.py Normal file
View File

@ -0,0 +1,179 @@
from __future__ import annotations
import abc
from enum import Enum
from enum import unique
from typing import Literal
from typing import TYPE_CHECKING
from discord import Guild
from pydantic import computed_field
from pydantic import Field
from pomice.models import BaseModel
from pomice.objects import Track
from pomice.player import Player
from pomice.pool import NodePool
if TYPE_CHECKING:
from discord import Client
__all__ = (
"PomiceEvent",
"TrackStartEvent",
"TrackEndEvent",
"TrackStuckEvent",
"TrackExceptionEvent",
"WebSocketClosedPayload",
"WebSocketClosedEvent",
"WebSocketOpenEvent",
)
class PomiceEvent(BaseModel, abc.ABC):
"""The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be:
```py
@bot.listen
async def on_pomice_track_start(self, event):
```
"""
name: str
@abc.abstractmethod
def dispatch(self, bot: Client) -> None:
...
class TrackStartEvent(PomiceEvent):
"""Fired when a track has successfully started.
Returns the player associated with the event and the pomice.Track object.
"""
name: Literal["track_start"]
player: Player
track: Track
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.player, self.track)
def __repr__(self) -> str:
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
@unique
class TrackEndEventReason(str, Enum):
FINISHED = "finished"
LOAD_FAILED = "loadfailed"
STOPPED = "stopped"
REPLACED = "replaced"
CLEANUP = "cleanup"
@classmethod
def _missing_(cls, value: object) -> TrackEndEventReason:
if isinstance(value, str):
return TrackEndEventReason(value.casefold())
class TrackEndEvent(PomiceEvent):
"""Fired when a track has successfully ended.
Returns the player associated with the event along with the pomice.Track object and reason.
"""
name: Literal["track_end"]
player: Player
track: Track
reason: TrackEndEventReason
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.reason)
def __repr__(self) -> str:
return f"<Pomice.TrackEndEvent player={self.player!r} track={self.track!r} reason={self.reason!r}>"
class TrackStuckEvent(PomiceEvent):
"""Fired when a track has been stuck for a while.
Returns the player associated with the event along with the pomice.Track object and threshold.
"""
name: Literal["track_stuck"]
player: Player
track: Track
threshold: float = Field(alias="thresholdMs")
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.threshold)
def __repr__(self) -> str:
return f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} threshold={self.threshold!r}>"
class TrackExceptionEvent(PomiceEvent):
"""Fired when there is an exception while playing a track.
Returns the player associated with the event along with the pomice.Track object and exception.
"""
name: Literal["track_exception"]
player: Player
track: Track
exception: str = Field(alias="error")
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.exception)
def __repr__(self) -> str:
return f"<Pomice.TrackExceptionEvent player={self.player!r} track={self.track!r} exception={self.exception!r}>"
class WebSocketClosedPayload(BaseModel):
"""The payload for the WebSocketClosedEvent."""
guild_id: int = Field(alias="guildId")
code: int
reason: str
by_remote: bool = Field(alias="byRemote")
@computed_field
@property
def guild(self) -> Guild:
return NodePool.get_node().bot.get_guild(self.guild_id)
def __repr__(self) -> str:
return (
f"<Pomice.WebSocketClosedPayload guild_id={self.guild_id!r} code={self.code!r} "
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
)
class WebSocketClosedEvent(PomiceEvent):
"""Fired when the websocket connection to the node is closed.
Returns the player associated with the event and the code and reason for the closure.
"""
name: Literal["websocket_closed"]
payload: WebSocketClosedPayload
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.payload)
def __repr__(self) -> str:
return f"<Pomice.WebSocketClosedEvent payload={self.payload!r}>"
class WebSocketOpenEvent(PomiceEvent):
"""Fired when the websocket connection to the node is opened.
Returns the player associated with the event.
"""
name: Literal["websocket_open"]
target: str
ssrc: str
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", self.target, self.ssrc)
def __repr__(self) -> str:
return f"<Pomice.WebSocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"

160
pomice/models/music.py Normal file
View File

@ -0,0 +1,160 @@
from __future__ import annotations
from typing import List
from typing import Literal
from typing import Optional
from typing import Union
from discord.ext.commands import Context
from discord.user import _UserTag
from pydantic import Field
from pydantic import model_validator
from pydantic import TypeAdapter
from pomice.enums import PlaylistType
from pomice.enums import SearchType
from pomice.enums import TrackType
from pomice.filters import Filter
from pomice.models import BaseModel
__all__ = (
"Track",
"TrackInfo",
"Playlist",
"PlaylistInfo",
"PlaylistExtended",
"PlaylistModelAdapter",
)
class TrackInfo(BaseModel):
identifier: str
title: str
author: str
length: int
position: int = 0
is_stream: bool = Field(default=False, alias="isStream")
is_seekable: bool = Field(default=False, alias="isSeekable")
uri: Optional[str] = None
isrc: Optional[str] = None
source_name: Optional[str] = Field(default=None, alias="sourceName")
artwork_url: Optional[str] = Field(default=None, alias="artworkUrl")
class Track(BaseModel):
"""The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track.
"""
track_id: str = Field(alias="encoded")
track_type: TrackType
info: TrackInfo
search_type: SearchType = SearchType.YTSEARCH
filters: List[Filter] = Field(default_factory=list)
timestamp: Optional[float] = None
playlist: Optional[Playlist] = None
original: Optional[Track] = None
ctx: Optional[Context] = None
requester: Optional[_UserTag] = None
@property
def title(self) -> str:
return self.info.title
@property
def author(self) -> str:
return self.info.author
@property
def uri(self) -> Optional[str]:
return self.info.uri
@property
def identifier(self) -> str:
return self.info.identifier
@property
def isrc(self) -> Optional[str]:
return self.info.isrc
@property
def thumbnail(self) -> Optional[str]:
return self.info.artwork_url
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track):
return False
return self.track_id == other.track_id
def __str__(self) -> str:
return self.info.title
def __repr__(self) -> str:
return f"<Pomice.Track title={self.info.title!r} uri=<{self.info.uri!r}> length={self.info.length}>"
@model_validator(mode="after")
def _set_thumbnail_url(self) -> Track:
if self.track_type is TrackType.YOUTUBE and not self.info.artwork_url:
self.info.artwork_url = (
f"https://img.youtube.com/vi/{self.info.identifier}/mqdefault.jpg"
)
return self
class PlaylistInfo(BaseModel):
name: str
selected_track: int = Field(default=0, alias="selectedTrack")
class Playlist(BaseModel):
"""The base playlist object.
Returns critical playlist information needed for parsing by Lavalink.
"""
info: PlaylistInfo
tracks: List[Track]
playlist_type: PlaylistType
@property
def name(self) -> str:
return self.info.name
@property
def selected_track(self) -> Optional[Track]:
if self.track_count <= 0:
return None
return self.tracks[self.info.selected_track]
@property
def track_count(self) -> int:
return len(self.tracks)
def __str__(self) -> str:
return self.info.name
def __repr__(self) -> str:
return f"<Pomice.Playlist name={self.info.name!r} total_tracks={self.track_count}>"
@model_validator(mode="after")
def _set_playlist(self) -> Playlist:
for track in self.tracks:
track.playlist = self
return self
class PlaylistExtended(Playlist):
"""Playlist object with additional information for external services."""
playlist_type: Union[Literal[PlaylistType.APPLE_MUSIC, PlaylistType.SPOTIFY]]
uri: str
artwork_url: str
@property
def thumbnail(self) -> Optional[str]:
return self.artwork_url
PlaylistModelType = Union[Playlist, PlaylistExtended]
PlaylistModelAdapter = lambda **kwargs: TypeAdapter(PlaylistModelType).validate_python(kwargs)

68
pomice/models/payloads.py Normal file
View File

@ -0,0 +1,68 @@
from __future__ import annotations
from typing import Optional
from typing import Union
from pydantic import AliasPath
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import TypeAdapter
from pomice.models import BaseModel
from pomice.models import LavalinkVersion3Type
from pomice.models import LavalinkVersion4Type
from pomice.models import VersionedModel
__all__ = (
"VoiceUpdatePayload",
"TrackStartPayload",
"TrackUpdatePayload",
"ResumePayloadType",
"ResumePayloadTypeAdapter",
)
class VoiceUpdatePayload(BaseModel):
token: str = Field(validation_alias=AliasPath("event", "token"))
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
session_id: str = Field(alias="sessionId")
class TrackUpdatePayload(BaseModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
class TrackStartPayload(VersionedModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
end_time: str = Field(default="0", alias="endTime")
@field_validator("end_time", mode="before")
@classmethod
def cast_end_time(cls, value: object) -> str:
return str(value)
@model_validator(mode="after")
def adjust_end_time(self) -> TrackStartPayload:
if self.version >= LavalinkVersion3Type(3, 7, 5):
self.end_time = None
class ResumePayload(VersionedModel):
timeout: int
class ResumePayloadV3(ResumePayload):
version: LavalinkVersion3Type
resuming_key: str = Field(alias="resumingKey")
class ResumePayloadV4(ResumePayload):
version: LavalinkVersion4Type
resuming: bool = True
ResumePayloadType = Union[ResumePayloadV3, ResumePayloadV4]
ResumePayloadTypeAdapter = lambda **kwargs: TypeAdapter(ResumePayloadType).validate_python(kwargs)

51
pomice/models/version.py Normal file
View File

@ -0,0 +1,51 @@
from typing import Literal
from typing import NamedTuple
from typing import Union
__all__ = (
"LavalinkVersion",
"LavalinkVersion3Type",
"LavalinkVersion4Type",
"LavalinkVersionType",
)
class LavalinkVersion(NamedTuple):
major: int
minor: int
fix: int
def __eq__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return (
(self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix)
)
def __lt__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
if self.major > other.major:
return False
if self.minor > other.minor:
return False
if self.fix > other.fix:
return False
return True
class LavalinkVersion3Type(LavalinkVersion):
major: Literal[3]
minor: int
fix: int
class LavalinkVersion4Type(LavalinkVersion):
major: Literal[4]
minor: int
fix: int
LavalinkVersionType = Union[LavalinkVersion3Type, LavalinkVersion4Type, LavalinkVersion]

View File

@ -1,167 +0,0 @@
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from discord import ClientUser
from discord import Member
from discord import User
from discord.ext import commands
from .enums import PlaylistType
from .enums import SearchType
from .enums import TrackType
from .filters import Filter
__all__ = (
"Track",
"Playlist",
)
class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track.
"""
__slots__ = (
"track_id",
"info",
"track_type",
"filters",
"timestamp",
"original",
"_search_type",
"playlist",
"title",
"author",
"uri",
"identifier",
"isrc",
"thumbnail",
"length",
"ctx",
"requester",
"is_stream",
"is_seekable",
"position",
)
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User, ClientUser]] = None,
):
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
self.filters: Optional[List[Filter]] = filters
self.timestamp: Optional[float] = timestamp
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
self.original: Optional[Track] = None
else:
self.original = self
self._search_type: SearchType = search_type
self.playlist: Optional[Playlist] = None
self.title: str = info.get("title", "Unknown Title")
self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier", "")
self.isrc: Optional[str] = info.get("isrc", None)
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri and self.track_type is TrackType.YOUTUBE:
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length: int = info.get("length", 0)
self.is_stream: bool = info.get("isStream", False)
self.is_seekable: bool = info.get("isSeekable", False)
self.position: int = info.get("position", 0)
self.ctx: Optional[commands.Context] = ctx
self.requester: Optional[Union[Member, User, ClientUser]] = requester
if not self.requester and self.ctx:
self.requester = self.ctx.author
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track):
return False
return other.track_id == self.track_id
def __str__(self) -> str:
return self.title
def __repr__(self) -> str:
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
class Playlist:
"""The base playlist object.
Returns critical playlist information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your tracks.
"""
__slots__ = (
"playlist_info",
"tracks",
"name",
"playlist_type",
"_thumbnail",
"_uri",
"selected_track",
"track_count",
)
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name", "Unknown Playlist")
self.playlist_type: PlaylistType = playlist_type
self._thumbnail: Optional[str] = thumbnail
self._uri: Optional[str] = uri
for track in self.tracks:
track.playlist = self
self.selected_track: Optional[Track] = None
if (index := playlist_info.get("selectedTrack", -1)) != -1:
self.selected_track = self.tracks[index]
self.track_count: int = len(self.tracks)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
@property
def uri(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
return self._uri
@property
def thumbnail(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
return self._thumbnail

View File

@ -14,23 +14,24 @@ from discord import VoiceChannel
from discord import VoiceProtocol from discord import VoiceProtocol
from discord.ext import commands from discord.ext import commands
from . import events from pomice import events
from .enums import SearchType from pomice.enums import SearchType
from .events import PomiceEvent from pomice.exceptions import FilterInvalidArgument
from .events import TrackEndEvent from pomice.exceptions import FilterTagAlreadyInUse
from .events import TrackStartEvent from pomice.exceptions import FilterTagInvalid
from .exceptions import FilterInvalidArgument from pomice.exceptions import TrackInvalidPosition
from .exceptions import FilterTagAlreadyInUse from pomice.exceptions import TrackLoadError
from .exceptions import FilterTagInvalid from pomice.filters import Filter
from .exceptions import TrackInvalidPosition from pomice.filters import Timescale
from .exceptions import TrackLoadError from pomice.models.events import PomiceEvent
from .filters import Filter from pomice.models.events import TrackEndEvent
from .filters import Timescale from pomice.models.events import TrackStartEvent
from .objects import Playlist from pomice.models.music import Playlist
from .objects import Track from pomice.models.music import Track
from .pool import Node from pomice.models.payloads import TrackUpdatePayload
from .pool import NodePool from pomice.models.payloads import VoiceUpdatePayload
from pomice.utils import LavalinkVersion from pomice.pool import Node
from pomice.pool import NodePool
if TYPE_CHECKING: if TYPE_CHECKING:
from discord.types.voice import VoiceServerUpdate from discord.types.voice import VoiceServerUpdate
@ -200,10 +201,10 @@ class Player(VoiceProtocol):
@property @property
def position(self) -> float: def position(self) -> float:
"""Property which returns the player's position in a track in milliseconds""" """Property which returns the player's position in a track in milliseconds"""
if not self.is_playing: if not self.is_playing or not self._current:
return 0 return 0
current: Track = self._current # type: ignore current: Track = self._current
if current.original: if current.original:
current = current.original current = current.original
@ -230,10 +231,10 @@ class Player(VoiceProtocol):
@property @property
def adjusted_length(self) -> float: def adjusted_length(self) -> float:
"""Property which returns the player's track length in milliseconds adjusted for rate""" """Property which returns the player's track length in milliseconds adjusted for rate"""
if not self.is_playing: if not self.is_playing or not self._current:
return 0 return 0
return self.current.length / self.rate # type: ignore return self.current.length / self.rate
@property @property
def is_playing(self) -> bool: def is_playing(self) -> bool:
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
def _adjust_end_time(self) -> Optional[str]:
if self._node._version >= LavalinkVersion(3, 7, 5):
return None
return "0"
async def _update_state(self, data: dict) -> None: async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {}) state: dict = data.get("state", {})
self._last_update = int(state.get("time", 0)) self._last_update = int(state.get("time", 0))
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
if self._log: if self._log:
self._log.debug(f"Got player update state with data {state}") self._log.debug(f"Got player update state with data {state}")
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
if {"sessionId", "event"} != self._voice_state.keys(): state = voice_data or self._voice_state
if {"sessionId", "event"} != state.keys():
return return
state = voice_data or self._voice_state data = VoiceUpdatePayload.model_validate(state)
data = {
"token": state["event"]["token"],
"endpoint": state["event"]["endpoint"],
"sessionId": state["sessionId"],
}
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"voice": data}, data={"voice": data.model_dump()},
) )
if self._log: if self._log:
@ -327,44 +317,39 @@ class Player(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None: async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data}) self._voice_state.update({"event": data})
await self._dispatch_voice_update(self._voice_state) await self._dispatch_voice_update()
async def on_voice_state_update(self, data: GuildVoiceState) -> None: async def on_voice_state_update(self, data: GuildVoiceState) -> None:
self._voice_state.update({"sessionId": data.get("session_id")}) self._voice_state.update({"sessionId": data["session_id"]})
channel_id = data.get("channel_id")
if not channel_id:
await self.disconnect()
self._voice_state.clear()
return
channel_id = data["session_id"]
channel = self.guild.get_channel(int(channel_id)) channel = self.guild.get_channel(int(channel_id))
if self.channel != channel:
self.channel = channel
if not channel: if not channel:
await self.disconnect() await self.disconnect()
self._voice_state.clear() self._voice_state.clear()
return return
if self.channel != channel:
self.channel = channel
if not data.get("token"): if not data.get("token"):
return return
self._voice_state.update({"event": data})
await self._dispatch_voice_update({**self._voice_state, "event": data}) await self._dispatch_voice_update({**self._voice_state, "event": data})
async def _dispatch_event(self, data: dict) -> None: async def _dispatch_event(self, data: dict) -> None:
event_type: str = data["type"] event_type: str = data["type"]
event: PomiceEvent = getattr(events, event_type)(data, self) event: PomiceEvent = getattr(events, event_type)(player=self, **data)
if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"): if isinstance(event, TrackEndEvent) and event.reason != "replaced":
self._current = None self._current = None
event.dispatch(self._bot)
if isinstance(event, TrackStartEvent): if isinstance(event, TrackStartEvent):
self._ending_track = self._current self._ending_track = self._current
event.dispatch(self._bot)
if self._log: if self._log:
self._log.debug(f"Dispatched event {data['type']} to player.") self._log.debug(f"Dispatched event {data['type']} to player.")
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
async def _swap_node(self, *, new_node: Node) -> None: async def _swap_node(self, *, new_node: Node) -> None:
if self.current: if self.current:
data: dict = {"position": self.position, "encodedTrack": self.current.track_id} data: dict = TrackUpdatePayload(
encoded_track=self.current.track_id,
position=self.position,
).model_dump()
del self._node._players[self._guild.id] del self._node._players[self._guild.id]
self._node = new_node self._node = new_node
@ -396,7 +384,7 @@ class Player(VoiceProtocol):
query: str, query: str,
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.YTSEARCH,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]: ) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
async def set_volume(self, volume: int) -> int: async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
if volume < 0 or volume > 500:
raise ValueError("Volume must be between 0 and 500")
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,

View File

@ -20,35 +20,34 @@ import aiohttp
import orjson as json import orjson as json
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from discord.utils import MISSING
from websockets import client from websockets import client
from websockets import exceptions from websockets import exceptions
from websockets import typing as wstype
from . import __version__ from pomice import __version__
from . import applemusic from pomice import applemusic
from . import spotify from pomice import spotify
from .enums import * from pomice.enums import *
from .enums import LogLevel from pomice.exceptions import InvalidSpotifyClientAuthorization
from .exceptions import InvalidSpotifyClientAuthorization from pomice.exceptions import LavalinkVersionIncompatible
from .exceptions import LavalinkVersionIncompatible from pomice.exceptions import NodeConnectionFailure
from .exceptions import NodeConnectionFailure from pomice.exceptions import NodeCreationError
from .exceptions import NodeCreationError from pomice.exceptions import NodeNotAvailable
from .exceptions import NodeNotAvailable from pomice.exceptions import NodeRestException
from .exceptions import NodeRestException from pomice.exceptions import NoNodesAvailable
from .exceptions import NoNodesAvailable from pomice.exceptions import TrackLoadError
from .exceptions import TrackLoadError from pomice.filters import Filter
from .filters import Filter from pomice.models.music import Playlist
from .objects import Playlist from pomice.models.music import Track
from .objects import Track from pomice.models.payloads import ResumePayloadTypeAdapter
from .routeplanner import RoutePlanner from pomice.models.payloads import ResumePayloadV4
from .utils import ExponentialBackoff from pomice.models.version import LavalinkVersion
from .utils import LavalinkVersion from pomice.routeplanner import RoutePlanner
from .utils import NodeStats from pomice.utils import ExponentialBackoff
from .utils import Ping from pomice.utils import NodeStats
from pomice.utils import Ping
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from pomice.player import Player
__all__ = ( __all__ = (
"Node", "Node",
@ -167,20 +166,14 @@ class Node:
self._spotify_client: Optional[spotify.Client] = None self._spotify_client: Optional[spotify.Client] = None
self._apple_music_client: Optional[applemusic.Client] = None self._apple_music_client: Optional[applemusic.Client] = None
self._spotify_client_id: Optional[str] = spotify_client_id if spotify_client_id and spotify_client_secret:
self._spotify_client_secret: Optional[str] = spotify_client_secret
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client( self._spotify_client = spotify.Client(
self._spotify_client_id, spotify_client_id,
self._spotify_client_secret, spotify_client_secret,
) )
if apple_music: if apple_music:
self._apple_music_client = applemusic.Client() self._apple_music_client = applemusic.Client()
self._bot.add_listener(self._update_handler, "on_socket_response")
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} " f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
@ -265,31 +258,6 @@ class Node:
if self._apple_music_client: if self._apple_music_client:
await self._apple_music_client._set_session(session=session) await self._apple_music_client._set_session(session=session)
async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready()
if not data:
return
if data["t"] == "VOICE_SERVER_UPDATE":
guild_id = int(data["d"]["guild_id"])
try:
player = self._players[guild_id]
await player.on_voice_server_update(data["d"])
except KeyError:
return
elif data["t"] == "VOICE_STATE_UPDATE":
if int(data["d"]["user_id"]) != self._bot_user.id:
return
guild_id = int(data["d"]["guild_id"])
try:
player = self._players[guild_id]
await player.on_voice_state_update(data["d"])
except KeyError:
return
async def _handle_node_switch(self) -> None: async def _handle_node_switch(self) -> None:
nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected] nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected]
new_node = random.choice(nodes) new_node = random.choice(nodes)
@ -303,14 +271,15 @@ class Node:
if not self._resume_key: if not self._resume_key:
return return
data = {"timeout": self._resume_timeout} data = ResumePayloadTypeAdapter(
version=self._version,
timeout=self._resume_timeout,
resuming_key=self._resume_key,
).model_dump()
if self._version.major == 3: if isinstance(data, ResumePayloadV4):
data["resumingKey"] = self._resume_key
elif self._version.major == 4:
if self._log: if self._log:
self._log.warning("Using a resume key with Lavalink v4 is deprecated.") self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
data["resuming"] = True
await self.send( await self.send(
method="PATCH", method="PATCH",
@ -560,7 +529,7 @@ class Node:
query: str, query: str,
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.YTSEARCH,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
) -> Optional[Union[Playlist, List[Track]]]: ) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.

View File

@ -8,11 +8,11 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from .enums import LoopMode from pomice.enums import LoopMode
from .exceptions import QueueEmpty from pomice.exceptions import QueueEmpty
from .exceptions import QueueException from pomice.exceptions import QueueException
from .exceptions import QueueFull from pomice.exceptions import QueueFull
from .objects import Track from pomice.models.music import Track
__all__ = ("Queue",) __all__ = ("Queue",)

View File

@ -3,9 +3,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from .pool import Node from pomice.pool import Node
from .utils import RouteStats from pomice.utils import RouteStats
__all__ = ("RoutePlanner",) __all__ = ("RoutePlanner",)

View File

@ -1,4 +1,5 @@
"""Spotify module for Pomice, made possible by cloudwithax 2023""" """Spotify module for Pomice, made possible by cloudwithax 2023"""
from .client import Client from .client import *
from .exceptions import * from .exceptions import *
from .models import *
from .objects import * from .objects import *

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import time import time
from base64 import b64encode from base64 import b64encode
from typing import Dict from typing import Dict
@ -13,18 +12,15 @@ from urllib.parse import quote
import aiohttp import aiohttp
import orjson as json import orjson as json
from .exceptions import InvalidSpotifyURL from pomice.enums import URLRegex
from .exceptions import SpotifyRequestException from pomice.spotify.exceptions import *
from .objects import * from pomice.spotify.models import *
__all__ = ("Client",) __all__ = ("Client",)
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
)
class Client: class Client:
@ -34,15 +30,12 @@ class Client:
""" """
def __init__(self, client_id: str, client_secret: str) -> None: def __init__(self, client_id: str, client_secret: str) -> None:
self._client_id: str = client_id
self._client_secret: str = client_secret
self.session: aiohttp.ClientSession = None # type: ignore self.session: aiohttp.ClientSession = None # type: ignore
self._bearer_token: Optional[str] = None self._bearer_token: Optional[str] = None
self._expiry: float = 0.0 self._expiry: float = 0.0
self._auth_token = b64encode( self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode(), f"{client_id}:{client_secret}".encode(),
) )
self._grant_headers = { self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}", "Authorization": f"Basic {self._auth_token.decode()}",
@ -77,7 +70,7 @@ class Client:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token() await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query) result = URLRegex.SPOTIFY_URL.match(query)
if not result: if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.") raise InvalidSpotifyURL("The Spotify link provided is not valid.")
@ -151,7 +144,7 @@ class Client:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token() await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query) result = URLRegex.SPOTIFY_URL.match(query)
if not result: if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.") raise InvalidSpotifyURL("The Spotify link provided is not valid.")

53
pomice/spotify/models.py Normal file
View File

@ -0,0 +1,53 @@
from typing import Dict
from typing import List
from typing import Optional
from discord.ext.commands import Context
from discord.user import _UserTag
from pydantic import Field
from pomice.enums import SearchType
from pomice.enums import TrackType
from pomice.filters import Filter
from pomice.models import BaseModel
from pomice.models.music import Track
from pomice.models.music import TrackInfo
class SpotifyTrackRaw(BaseModel):
id: str
name: str
artists: List[Dict[str, str]]
duration_ms: float
external_ids: Dict[str, str] = Field(default_factory=dict)
external_urls: Dict[str, str] = Field(default_factory=dict)
album: Dict[str, List[Dict[str, str]]] = Field(default_factory=dict)
def build_track(
self,
image: Optional[str] = None,
filters: Optional[List[Filter]] = None,
ctx: Optional[Context] = None,
requester: Optional[_UserTag] = None,
) -> Track:
if self.album:
image = self.album["images"][0]["url"]
return Track(
track_id=self.id,
track_type=TrackType.SPOTIFY,
search_type=SearchType.YTMSEARCH,
filters=filters,
ctx=ctx,
requester=requester,
info=TrackInfo(
identifier=self.id,
title=self.name,
author=", ".join(artist["name"] for artist in self.artists),
length=self.duration_ms,
is_seekable=True,
uri=self.external_urls.get("spotify", ""),
artwork_url=image,
isrc=self.external_ids.get("isrc"),
),
)

View File

@ -8,11 +8,10 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import NamedTuple
from typing import Optional from typing import Optional
from .enums import RouteIPType from pomice.enums import RouteIPType
from .enums import RouteStrategy from pomice.enums import RouteStrategy
__all__ = ( __all__ = (
"ExponentialBackoff", "ExponentialBackoff",
@ -20,7 +19,6 @@ __all__ = (
"FailingIPBlock", "FailingIPBlock",
"RouteStats", "RouteStats",
"Ping", "Ping",
"LavalinkVersion",
) )
@ -226,53 +224,3 @@ class Ping:
s_runtime = 1000 * (cost_time) s_runtime = 1000 * (cost_time)
return s_runtime return s_runtime
class LavalinkVersion(NamedTuple):
major: int
minor: int
fix: int
def __eq__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return (
(self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix)
)
def __ne__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return not (self == other)
def __lt__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
if self.major > other.major:
return False
if self.minor > other.minor:
return False
if self.fix > other.fix:
return False
return True
def __gt__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return not (self < other)
def __le__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return (self < other) or (self == other)
def __ge__(self, other: object) -> bool:
if not isinstance(other, LavalinkVersion):
return False
return (self > other) or (self == other)

View File

@ -4,7 +4,7 @@ import re
import setuptools import setuptools
version = "" version = ""
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"] requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets", "pydantic>=2"]
with open("pomice/__init__.py") as f: with open("pomice/__init__.py") as f:
version = re.search( version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',