Merge e48ed8fd4b into f1609f7049
This commit is contained in:
commit
1d162df60d
2
Pipfile
2
Pipfile
|
|
@ -5,7 +5,9 @@ name = "pypi"
|
|||
|
||||
[packages]
|
||||
orjson = "*"
|
||||
pydantic = ">=2"
|
||||
"discord.py" = {extras = ["voice"], version = "*"}
|
||||
websockets = "*"
|
||||
|
||||
[dev-packages]
|
||||
mypy = "*"
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class Music(commands.Cog):
|
|||
|
||||
return player.dj == ctx.author or ctx.author.guild_permissions.kick_members
|
||||
|
||||
# The following are events from pomice.events
|
||||
# The following are events from pomice.models.events
|
||||
# We are using these so that if the track either stops or errors,
|
||||
# we can just skip to the next track
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ __license__ = "GPL-3.0"
|
|||
__copyright__ = "Copyright (c) 2023, cloudwithax"
|
||||
|
||||
from .enums import *
|
||||
from .events import *
|
||||
from .models import *
|
||||
from .exceptions import *
|
||||
from .filters import *
|
||||
from .objects import *
|
||||
|
|
|
|||
|
|
@ -11,17 +11,12 @@ from typing import Union
|
|||
import aiohttp
|
||||
import orjson as json
|
||||
|
||||
from .exceptions import *
|
||||
from .objects import *
|
||||
from pomice.applemusic.exceptions import *
|
||||
from pomice.applemusic.objects import *
|
||||
from pomice.enums import URLRegex
|
||||
|
||||
__all__ = ("Client",)
|
||||
|
||||
AM_URL_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||
)
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||
)
|
||||
|
||||
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
|
||||
|
||||
|
|
@ -103,7 +98,7 @@ class Client:
|
|||
if not self.token or datetime.utcnow() > self.expiry:
|
||||
await self.request_token()
|
||||
|
||||
result = AM_URL_REGEX.match(query)
|
||||
result = URLRegex.AM_URL.match(query)
|
||||
if not result:
|
||||
raise InvalidAppleMusicURL(
|
||||
"The Apple Music link provided is not valid.",
|
||||
|
|
@ -113,7 +108,7 @@ class Client:
|
|||
type = result.group("type")
|
||||
id = result.group("id")
|
||||
|
||||
if type == "album" and (sia_result := AM_SINGLE_IN_ALBUM_REGEX.match(query)):
|
||||
if type == "album" and (sia_result := URLRegex.AM_SINGLE_IN_ALBUM_REGEX.match(query)):
|
||||
# apple music likes to generate links for singles off an album
|
||||
# by adding a param at the end of the url
|
||||
# so we're gonna scan for that and correct it
|
||||
|
|
|
|||
121
pomice/enums.py
121
pomice/enums.py
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
from enum import unique
|
||||
|
||||
__all__ = (
|
||||
"SearchType",
|
||||
|
|
@ -15,7 +16,13 @@ __all__ = (
|
|||
)
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
class BaseStrEnum(str, Enum):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
@unique
|
||||
class SearchType(BaseStrEnum):
|
||||
"""
|
||||
The enum for the different search types for Pomice.
|
||||
This feature is exclusively for the Spotify search feature of Pomice.
|
||||
|
|
@ -31,15 +38,13 @@ class SearchType(Enum):
|
|||
which is an alternative to YouTube or YouTube Music.
|
||||
"""
|
||||
|
||||
ytsearch = "ytsearch"
|
||||
ytmsearch = "ytmsearch"
|
||||
scsearch = "scsearch"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
YTSEARCH = "ytsearch"
|
||||
YTMSEARCH = "ytmsearch"
|
||||
SCSEARCH = "scsearch"
|
||||
|
||||
|
||||
class TrackType(Enum):
|
||||
@unique
|
||||
class TrackType(BaseStrEnum):
|
||||
"""
|
||||
The enum for the different track types for Pomice.
|
||||
|
||||
|
|
@ -64,11 +69,9 @@ class TrackType(Enum):
|
|||
HTTP = "http"
|
||||
LOCAL = "local"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class PlaylistType(Enum):
|
||||
@unique
|
||||
class PlaylistType(BaseStrEnum):
|
||||
"""
|
||||
The enum for the different playlist types for Pomice.
|
||||
|
||||
|
|
@ -87,11 +90,9 @@ class PlaylistType(Enum):
|
|||
SPOTIFY = "spotify"
|
||||
APPLE_MUSIC = "apple_music"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class NodeAlgorithm(Enum):
|
||||
@unique
|
||||
class NodeAlgorithm(BaseStrEnum):
|
||||
"""
|
||||
The enum for the different node algorithms in Pomice.
|
||||
|
||||
|
|
@ -111,11 +112,9 @@ class NodeAlgorithm(Enum):
|
|||
by_ping = "BY_PING"
|
||||
by_players = "BY_PLAYERS"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class LoopMode(Enum):
|
||||
@unique
|
||||
class LoopMode(BaseStrEnum):
|
||||
"""
|
||||
The enum for the different loop modes.
|
||||
This feature is exclusively for the queue utility of pomice.
|
||||
|
|
@ -124,18 +123,15 @@ class LoopMode(Enum):
|
|||
LoopMode.TRACK sets the queue loop to the current track.
|
||||
|
||||
LoopMode.QUEUE sets the queue loop to the whole queue.
|
||||
|
||||
"""
|
||||
|
||||
# We don't have to define anything special for these, since these just serve as flags
|
||||
TRACK = "track"
|
||||
QUEUE = "queue"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class RouteStrategy(Enum):
|
||||
@unique
|
||||
class RouteStrategy(BaseStrEnum):
|
||||
"""
|
||||
The enum for specifying the route planner strategy for Lavalink.
|
||||
This feature is exclusively for the RoutePlanner class.
|
||||
|
|
@ -153,7 +149,6 @@ class RouteStrategy(Enum):
|
|||
RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching
|
||||
between IPs every CPU clock cycle and is rotating between IP blocks on
|
||||
ban.
|
||||
|
||||
"""
|
||||
|
||||
ROTATE_ON_BAN = "RotatingIpRoutePlanner"
|
||||
|
|
@ -162,7 +157,8 @@ class RouteStrategy(Enum):
|
|||
ROTATING_NANO_SWITCH = "RotatingNanoIpRoutePlanner"
|
||||
|
||||
|
||||
class RouteIPType(Enum):
|
||||
@unique
|
||||
class RouteIPType(BaseStrEnum):
|
||||
"""
|
||||
The enum for specifying the route planner IP block type for Lavalink.
|
||||
This feature is exclusively for the RoutePlanner class.
|
||||
|
|
@ -177,9 +173,43 @@ class RouteIPType(Enum):
|
|||
IPV6 = "Inet6Address"
|
||||
|
||||
|
||||
@unique
|
||||
class LogLevel(IntEnum):
|
||||
"""
|
||||
The enum for specifying the logging level within Pomice.
|
||||
This class serves as shorthand for logging.<level>
|
||||
This enum is exclusively for the logging feature in Pomice.
|
||||
If you are not using this feature, this class is not necessary.
|
||||
|
||||
|
||||
LogLevel.DEBUG sets the logging level to "debug".
|
||||
|
||||
LogLevel.INFO sets the logging level to "info".
|
||||
|
||||
LogLevel.WARN sets the logging level to "warn".
|
||||
|
||||
LogLevel.ERROR sets the logging level to "error".
|
||||
|
||||
LogLevel.CRITICAL sets the logging level to "CRITICAL".
|
||||
"""
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, level_str):
|
||||
try:
|
||||
return cls[level_str.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(f"No such log level: {level_str}")
|
||||
|
||||
|
||||
class URLRegex:
|
||||
"""
|
||||
The enum for all the URL Regexes in use by Pomice.
|
||||
The class for all the URL Regexes in use by Pomice.
|
||||
|
||||
URLRegex.SPOTIFY_URL returns the Spotify URL Regex.
|
||||
|
||||
|
|
@ -196,7 +226,6 @@ class URLRegex:
|
|||
URLRegex.SOUNDCLOUD_URL returns the SoundCloud URL Regex.
|
||||
|
||||
URLRegex.BASE_URL returns the standard URL Regex.
|
||||
|
||||
"""
|
||||
|
||||
SPOTIFY_URL = re.compile(
|
||||
|
|
@ -246,37 +275,3 @@ class URLRegex:
|
|||
LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")
|
||||
|
||||
BASE_URL = re.compile(r"https?://(?:www\.)?.+")
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
"""
|
||||
The enum for specifying the logging level within Pomice.
|
||||
This class serves as shorthand for logging.<level>
|
||||
This enum is exclusively for the logging feature in Pomice.
|
||||
If you are not using this feature, this class is not necessary.
|
||||
|
||||
|
||||
LogLevel.DEBUG sets the logging level to "debug".
|
||||
|
||||
LogLevel.INFO sets the logging level to "info".
|
||||
|
||||
LogLevel.WARN sets the logging level to "warn".
|
||||
|
||||
LogLevel.ERROR sets the logging level to "error".
|
||||
|
||||
LogLevel.CRITICAL sets the logging level to "CRITICAL".
|
||||
|
||||
"""
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, level_str):
|
||||
try:
|
||||
return cls[level_str.upper()]
|
||||
except KeyError:
|
||||
raise ValueError(f"No such log level: {level_str}")
|
||||
|
|
|
|||
197
pomice/events.py
197
pomice/events.py
|
|
@ -1,197 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from discord import Client
|
||||
from discord import Guild
|
||||
from discord.ext import commands
|
||||
|
||||
from .objects import Track
|
||||
from .pool import NodePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
|
||||
__all__ = (
|
||||
"PomiceEvent",
|
||||
"TrackStartEvent",
|
||||
"TrackEndEvent",
|
||||
"TrackStuckEvent",
|
||||
"TrackExceptionEvent",
|
||||
"WebSocketClosedPayload",
|
||||
"WebSocketClosedEvent",
|
||||
"WebSocketOpenEvent",
|
||||
)
|
||||
|
||||
|
||||
class PomiceEvent(ABC):
|
||||
"""The base class for all events dispatched by a node.
|
||||
Every event must be formatted within your bot's code as a listener.
|
||||
i.e: If you want to listen for when a track starts, the event would be:
|
||||
```py
|
||||
@bot.listen
|
||||
async def on_pomice_track_start(self, event):
|
||||
```
|
||||
"""
|
||||
|
||||
name = "event"
|
||||
handler_args: Tuple
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", *self.handler_args)
|
||||
|
||||
|
||||
class TrackStartEvent(PomiceEvent):
|
||||
"""Fired when a track has successfully started.
|
||||
Returns the player associated with the event and the pomice.Track object.
|
||||
"""
|
||||
|
||||
name = "track_start"
|
||||
|
||||
__slots__ = (
|
||||
"player",
|
||||
"track",
|
||||
)
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
self.player: Player = player
|
||||
self.track: Optional[Track] = self.player._current
|
||||
|
||||
# on_pomice_track_start(player, track)
|
||||
self.handler_args = self.player, self.track
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
|
||||
|
||||
|
||||
class TrackEndEvent(PomiceEvent):
|
||||
"""Fired when a track has successfully ended.
|
||||
Returns the player associated with the event along with the pomice.Track object and reason.
|
||||
"""
|
||||
|
||||
name = "track_end"
|
||||
|
||||
__slots__ = ("player", "track", "reason")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
self.player: Player = player
|
||||
self.track: Optional[Track] = self.player._ending_track
|
||||
self.reason: str = data["reason"]
|
||||
|
||||
# on_pomice_track_end(player, track, reason)
|
||||
self.handler_args = self.player, self.track, self.reason
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.TrackEndEvent player={self.player!r} track_id={self.track!r} "
|
||||
f"reason={self.reason!r}>"
|
||||
)
|
||||
|
||||
|
||||
class TrackStuckEvent(PomiceEvent):
|
||||
"""Fired when a track is stuck and cannot be played. Returns the player
|
||||
associated with the event along with the pomice.Track object
|
||||
to be further parsed by the end user.
|
||||
"""
|
||||
|
||||
name = "track_stuck"
|
||||
|
||||
__slots__ = ("player", "track", "threshold")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
self.player: Player = player
|
||||
self.track: Optional[Track] = self.player._ending_track
|
||||
self.threshold: float = data["thresholdMs"]
|
||||
|
||||
# on_pomice_track_stuck(player, track, threshold)
|
||||
self.handler_args = self.player, self.track, self.threshold
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} "
|
||||
f"threshold={self.threshold!r}>"
|
||||
)
|
||||
|
||||
|
||||
class TrackExceptionEvent(PomiceEvent):
|
||||
"""Fired when a track error has occured.
|
||||
Returns the player associated with the event along with the error code and exception.
|
||||
"""
|
||||
|
||||
name = "track_exception"
|
||||
|
||||
__slots__ = ("player", "track", "exception")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
self.player: Player = player
|
||||
self.track: Optional[Track] = self.player._ending_track
|
||||
# Error is for Lavalink <= 3.3
|
||||
self.exception: str = data.get(
|
||||
"error",
|
||||
"",
|
||||
) or data.get("exception", "")
|
||||
|
||||
# on_pomice_track_exception(player, track, error)
|
||||
self.handler_args = self.player, self.track, self.exception
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackExceptionEvent player={self.player!r} exception={self.exception!r}>"
|
||||
|
||||
|
||||
class WebSocketClosedPayload:
|
||||
__slots__ = ("guild", "code", "reason", "by_remote")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.guild: Optional[Guild] = NodePool.get_node().bot.get_guild(int(data["guildId"]))
|
||||
self.code: int = data["code"]
|
||||
self.reason: str = data["code"]
|
||||
self.by_remote: bool = data["byRemote"]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.WebSocketClosedPayload guild={self.guild!r} code={self.code!r} "
|
||||
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
|
||||
)
|
||||
|
||||
|
||||
class WebSocketClosedEvent(PomiceEvent):
|
||||
"""Fired when a websocket connection to a node has been closed.
|
||||
Returns the reason and the error code.
|
||||
"""
|
||||
|
||||
name = "websocket_closed"
|
||||
|
||||
__slots__ = ("payload",)
|
||||
|
||||
def __init__(self, data: dict, _: Any) -> None:
|
||||
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
|
||||
|
||||
# on_pomice_websocket_closed(payload)
|
||||
self.handler_args = (self.payload,)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.WebsocketClosedEvent payload={self.payload!r}>"
|
||||
|
||||
|
||||
class WebSocketOpenEvent(PomiceEvent):
|
||||
"""Fired when a websocket connection to a node has been initiated.
|
||||
Returns the target and the session SSRC.
|
||||
"""
|
||||
|
||||
name = "websocket_open"
|
||||
|
||||
__slots__ = ("target", "ssrc")
|
||||
|
||||
def __init__(self, data: dict, _: Any) -> None:
|
||||
self.target: str = data["target"]
|
||||
self.ssrc: int = data["ssrc"]
|
||||
|
||||
# on_pomice_websocket_open(target, ssrc)
|
||||
self.handler_args = self.target, self.ssrc
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.WebsocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"
|
||||
|
|
@ -61,7 +61,7 @@ class NoNodesAvailable(PomiceException):
|
|||
pass
|
||||
|
||||
|
||||
class TrackInvalidPosition(PomiceException):
|
||||
class TrackInvalidPosition(PomiceException, ValueError):
|
||||
"""An invalid position was chosen for a track."""
|
||||
|
||||
pass
|
||||
|
|
@ -73,19 +73,19 @@ class TrackLoadError(PomiceException):
|
|||
pass
|
||||
|
||||
|
||||
class FilterInvalidArgument(PomiceException):
|
||||
class FilterInvalidArgument(PomiceException, ValueError):
|
||||
"""An invalid argument was passed to a filter."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FilterTagInvalid(PomiceException):
|
||||
class FilterTagInvalid(PomiceException, ValueError):
|
||||
"""An invalid tag was passed or Pomice was unable to find a filter tag"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FilterTagAlreadyInUse(PomiceException):
|
||||
class FilterTagAlreadyInUse(PomiceException, ValueError):
|
||||
"""A filter with a tag is already in use by another filter"""
|
||||
|
||||
pass
|
||||
|
|
@ -97,7 +97,7 @@ class InvalidSpotifyClientAuthorization(PomiceException):
|
|||
pass
|
||||
|
||||
|
||||
class AppleMusicNotEnabled(PomiceException):
|
||||
class AppleMusicNotEnabled(PomiceException, ValueError):
|
||||
"""An Apple Music Link was passed in when Apple Music functionality was not enabled."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
|
@ -5,7 +7,7 @@ from typing import List
|
|||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from .exceptions import FilterInvalidArgument
|
||||
from pomice.exceptions import FilterInvalidArgument
|
||||
|
||||
__all__ = (
|
||||
"Filter",
|
||||
|
|
@ -84,7 +86,7 @@ class Equalizer(Filter):
|
|||
return self.raw == __value.raw
|
||||
|
||||
@classmethod
|
||||
def flat(cls) -> "Equalizer":
|
||||
def flat(cls) -> Equalizer:
|
||||
"""Equalizer preset which represents a flat EQ board,
|
||||
with all levels set to their default values.
|
||||
"""
|
||||
|
|
@ -109,7 +111,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="flat", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def boost(cls) -> "Equalizer":
|
||||
def boost(cls) -> Equalizer:
|
||||
"""Equalizer preset which boosts the sound of a track,
|
||||
making it sound fun and energetic by increasing the bass
|
||||
and the highs.
|
||||
|
|
@ -135,7 +137,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="boost", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def metal(cls) -> "Equalizer":
|
||||
def metal(cls) -> Equalizer:
|
||||
"""Equalizer preset which increases the mids of a track,
|
||||
preferably one of the metal genre, to make it sound
|
||||
more full and concert-like.
|
||||
|
|
@ -162,7 +164,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="metal", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def piano(cls) -> "Equalizer":
|
||||
def piano(cls) -> Equalizer:
|
||||
"""Equalizer preset which increases the mids and highs
|
||||
of a track, preferably a piano based one, to make it
|
||||
stand out.
|
||||
|
|
@ -215,7 +217,7 @@ class Timescale(Filter):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def vaporwave(cls) -> "Timescale":
|
||||
def vaporwave(cls) -> Timescale:
|
||||
"""Timescale preset which slows down the currently playing track,
|
||||
giving it the effect of a half-speed record/casette playing.
|
||||
|
||||
|
|
@ -225,7 +227,7 @@ class Timescale(Filter):
|
|||
return cls(tag="vaporwave", speed=0.8, pitch=0.8)
|
||||
|
||||
@classmethod
|
||||
def nightcore(cls) -> "Timescale":
|
||||
def nightcore(cls) -> Timescale:
|
||||
"""Timescale preset which speeds up the currently playing track,
|
||||
which matches up to nightcore, a genre of sped-up music
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
import pydantic
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .events import *
|
||||
from .music import *
|
||||
from .payloads import *
|
||||
from .version import *
|
||||
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict:
|
||||
by_alias = kwargs.pop("by_alias", True)
|
||||
mode = kwargs.pop("mode", "json")
|
||||
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
|
||||
|
||||
|
||||
class VersionedModel(BaseModel):
|
||||
version: LavalinkVersionType
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict:
|
||||
return super().model_dump(*args, **kwargs, exclude={"version"})
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from enum import Enum
|
||||
from enum import unique
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from discord import Guild
|
||||
from pydantic import computed_field
|
||||
from pydantic import Field
|
||||
|
||||
from pomice.models import BaseModel
|
||||
from pomice.objects import Track
|
||||
from pomice.player import Player
|
||||
from pomice.pool import NodePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord import Client
|
||||
|
||||
__all__ = (
|
||||
"PomiceEvent",
|
||||
"TrackStartEvent",
|
||||
"TrackEndEvent",
|
||||
"TrackStuckEvent",
|
||||
"TrackExceptionEvent",
|
||||
"WebSocketClosedPayload",
|
||||
"WebSocketClosedEvent",
|
||||
"WebSocketOpenEvent",
|
||||
)
|
||||
|
||||
|
||||
class PomiceEvent(BaseModel, abc.ABC):
|
||||
"""The base class for all events dispatched by a node.
|
||||
Every event must be formatted within your bot's code as a listener.
|
||||
i.e: If you want to listen for when a track starts, the event would be:
|
||||
```py
|
||||
@bot.listen
|
||||
async def on_pomice_track_start(self, event):
|
||||
```
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
@abc.abstractmethod
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
...
|
||||
|
||||
|
||||
class TrackStartEvent(PomiceEvent):
|
||||
"""Fired when a track has successfully started.
|
||||
Returns the player associated with the event and the pomice.Track object.
|
||||
"""
|
||||
|
||||
name: Literal["track_start"]
|
||||
player: Player
|
||||
track: Track
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.player, self.track)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
|
||||
|
||||
|
||||
@unique
|
||||
class TrackEndEventReason(str, Enum):
|
||||
FINISHED = "finished"
|
||||
LOAD_FAILED = "loadfailed"
|
||||
STOPPED = "stopped"
|
||||
REPLACED = "replaced"
|
||||
CLEANUP = "cleanup"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> TrackEndEventReason:
|
||||
if isinstance(value, str):
|
||||
return TrackEndEventReason(value.casefold())
|
||||
|
||||
|
||||
class TrackEndEvent(PomiceEvent):
|
||||
"""Fired when a track has successfully ended.
|
||||
Returns the player associated with the event along with the pomice.Track object and reason.
|
||||
"""
|
||||
|
||||
name: Literal["track_end"]
|
||||
player: Player
|
||||
track: Track
|
||||
reason: TrackEndEventReason
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.reason)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackEndEvent player={self.player!r} track={self.track!r} reason={self.reason!r}>"
|
||||
|
||||
|
||||
class TrackStuckEvent(PomiceEvent):
|
||||
"""Fired when a track has been stuck for a while.
|
||||
Returns the player associated with the event along with the pomice.Track object and threshold.
|
||||
"""
|
||||
|
||||
name: Literal["track_stuck"]
|
||||
player: Player
|
||||
track: Track
|
||||
threshold: float = Field(alias="thresholdMs")
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.threshold)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} threshold={self.threshold!r}>"
|
||||
|
||||
|
||||
class TrackExceptionEvent(PomiceEvent):
|
||||
"""Fired when there is an exception while playing a track.
|
||||
Returns the player associated with the event along with the pomice.Track object and exception.
|
||||
"""
|
||||
|
||||
name: Literal["track_exception"]
|
||||
player: Player
|
||||
track: Track
|
||||
exception: str = Field(alias="error")
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.exception)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TrackExceptionEvent player={self.player!r} track={self.track!r} exception={self.exception!r}>"
|
||||
|
||||
|
||||
class WebSocketClosedPayload(BaseModel):
|
||||
"""The payload for the WebSocketClosedEvent."""
|
||||
|
||||
guild_id: int = Field(alias="guildId")
|
||||
code: int
|
||||
reason: str
|
||||
by_remote: bool = Field(alias="byRemote")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def guild(self) -> Guild:
|
||||
return NodePool.get_node().bot.get_guild(self.guild_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.WebSocketClosedPayload guild_id={self.guild_id!r} code={self.code!r} "
|
||||
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
|
||||
)
|
||||
|
||||
|
||||
class WebSocketClosedEvent(PomiceEvent):
|
||||
"""Fired when the websocket connection to the node is closed.
|
||||
Returns the player associated with the event and the code and reason for the closure.
|
||||
"""
|
||||
|
||||
name: Literal["websocket_closed"]
|
||||
payload: WebSocketClosedPayload
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.payload)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.WebSocketClosedEvent payload={self.payload!r}>"
|
||||
|
||||
|
||||
class WebSocketOpenEvent(PomiceEvent):
|
||||
"""Fired when the websocket connection to the node is opened.
|
||||
Returns the player associated with the event.
|
||||
"""
|
||||
|
||||
name: Literal["websocket_open"]
|
||||
target: str
|
||||
ssrc: str
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", self.target, self.ssrc)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.WebSocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from discord.ext.commands import Context
|
||||
from discord.user import _UserTag
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from pomice.enums import PlaylistType
|
||||
from pomice.enums import SearchType
|
||||
from pomice.enums import TrackType
|
||||
from pomice.filters import Filter
|
||||
from pomice.models import BaseModel
|
||||
|
||||
__all__ = (
|
||||
"Track",
|
||||
"TrackInfo",
|
||||
"Playlist",
|
||||
"PlaylistInfo",
|
||||
"PlaylistExtended",
|
||||
"PlaylistModelAdapter",
|
||||
)
|
||||
|
||||
|
||||
class TrackInfo(BaseModel):
|
||||
identifier: str
|
||||
title: str
|
||||
author: str
|
||||
length: int
|
||||
position: int = 0
|
||||
is_stream: bool = Field(default=False, alias="isStream")
|
||||
is_seekable: bool = Field(default=False, alias="isSeekable")
|
||||
uri: Optional[str] = None
|
||||
isrc: Optional[str] = None
|
||||
source_name: Optional[str] = Field(default=None, alias="sourceName")
|
||||
artwork_url: Optional[str] = Field(default=None, alias="artworkUrl")
|
||||
|
||||
|
||||
class Track(BaseModel):
|
||||
"""The base track object. Returns critical track information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your track.
|
||||
"""
|
||||
|
||||
track_id: str = Field(alias="encoded")
|
||||
track_type: TrackType
|
||||
info: TrackInfo
|
||||
search_type: SearchType = SearchType.YTSEARCH
|
||||
filters: List[Filter] = Field(default_factory=list)
|
||||
timestamp: Optional[float] = None
|
||||
playlist: Optional[Playlist] = None
|
||||
original: Optional[Track] = None
|
||||
ctx: Optional[Context] = None
|
||||
requester: Optional[_UserTag] = None
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
return self.info.title
|
||||
|
||||
@property
|
||||
def author(self) -> str:
|
||||
return self.info.author
|
||||
|
||||
@property
|
||||
def uri(self) -> Optional[str]:
|
||||
return self.info.uri
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return self.info.identifier
|
||||
|
||||
@property
|
||||
def isrc(self) -> Optional[str]:
|
||||
return self.info.isrc
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Optional[str]:
|
||||
return self.info.artwork_url
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Track):
|
||||
return False
|
||||
|
||||
return self.track_id == other.track_id
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info.title
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.Track title={self.info.title!r} uri=<{self.info.uri!r}> length={self.info.length}>"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_thumbnail_url(self) -> Track:
|
||||
if self.track_type is TrackType.YOUTUBE and not self.info.artwork_url:
|
||||
self.info.artwork_url = (
|
||||
f"https://img.youtube.com/vi/{self.info.identifier}/mqdefault.jpg"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class PlaylistInfo(BaseModel):
|
||||
name: str
|
||||
selected_track: int = Field(default=0, alias="selectedTrack")
|
||||
|
||||
|
||||
class Playlist(BaseModel):
|
||||
"""The base playlist object.
|
||||
Returns critical playlist information needed for parsing by Lavalink.
|
||||
"""
|
||||
|
||||
info: PlaylistInfo
|
||||
tracks: List[Track]
|
||||
playlist_type: PlaylistType
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.info.name
|
||||
|
||||
@property
|
||||
def selected_track(self) -> Optional[Track]:
|
||||
if self.track_count <= 0:
|
||||
return None
|
||||
|
||||
return self.tracks[self.info.selected_track]
|
||||
|
||||
@property
|
||||
def track_count(self) -> int:
|
||||
return len(self.tracks)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.info.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.Playlist name={self.info.name!r} total_tracks={self.track_count}>"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_playlist(self) -> Playlist:
|
||||
for track in self.tracks:
|
||||
track.playlist = self
|
||||
return self
|
||||
|
||||
|
||||
class PlaylistExtended(Playlist):
|
||||
"""Playlist object with additional information for external services."""
|
||||
|
||||
playlist_type: Union[Literal[PlaylistType.APPLE_MUSIC, PlaylistType.SPOTIFY]]
|
||||
uri: str
|
||||
artwork_url: str
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Optional[str]:
|
||||
return self.artwork_url
|
||||
|
||||
|
||||
PlaylistModelType = Union[Playlist, PlaylistExtended]
|
||||
PlaylistModelAdapter = lambda **kwargs: TypeAdapter(PlaylistModelType).validate_python(kwargs)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pydantic import AliasPath
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from pomice.models import BaseModel
|
||||
from pomice.models import LavalinkVersion3Type
|
||||
from pomice.models import LavalinkVersion4Type
|
||||
from pomice.models import VersionedModel
|
||||
|
||||
__all__ = (
|
||||
"VoiceUpdatePayload",
|
||||
"TrackStartPayload",
|
||||
"TrackUpdatePayload",
|
||||
"ResumePayloadType",
|
||||
"ResumePayloadTypeAdapter",
|
||||
)
|
||||
|
||||
|
||||
class VoiceUpdatePayload(BaseModel):
|
||||
token: str = Field(validation_alias=AliasPath("event", "token"))
|
||||
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
|
||||
session_id: str = Field(alias="sessionId")
|
||||
|
||||
|
||||
class TrackUpdatePayload(BaseModel):
|
||||
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||
position: float
|
||||
|
||||
|
||||
class TrackStartPayload(VersionedModel):
|
||||
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||
position: float
|
||||
end_time: str = Field(default="0", alias="endTime")
|
||||
|
||||
@field_validator("end_time", mode="before")
|
||||
@classmethod
|
||||
def cast_end_time(cls, value: object) -> str:
|
||||
return str(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def adjust_end_time(self) -> TrackStartPayload:
|
||||
if self.version >= LavalinkVersion3Type(3, 7, 5):
|
||||
self.end_time = None
|
||||
|
||||
|
||||
class ResumePayload(VersionedModel):
|
||||
timeout: int
|
||||
|
||||
|
||||
class ResumePayloadV3(ResumePayload):
|
||||
version: LavalinkVersion3Type
|
||||
resuming_key: str = Field(alias="resumingKey")
|
||||
|
||||
|
||||
class ResumePayloadV4(ResumePayload):
|
||||
version: LavalinkVersion4Type
|
||||
resuming: bool = True
|
||||
|
||||
|
||||
ResumePayloadType = Union[ResumePayloadV3, ResumePayloadV4]
|
||||
ResumePayloadTypeAdapter = lambda **kwargs: TypeAdapter(ResumePayloadType).validate_python(kwargs)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Literal
|
||||
from typing import NamedTuple
|
||||
from typing import Union
|
||||
|
||||
__all__ = (
|
||||
"LavalinkVersion",
|
||||
"LavalinkVersion3Type",
|
||||
"LavalinkVersion4Type",
|
||||
"LavalinkVersionType",
|
||||
)
|
||||
|
||||
|
||||
class LavalinkVersion(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
fix: int
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return (
|
||||
(self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix)
|
||||
)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
if self.major > other.major:
|
||||
return False
|
||||
if self.minor > other.minor:
|
||||
return False
|
||||
if self.fix > other.fix:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class LavalinkVersion3Type(LavalinkVersion):
|
||||
major: Literal[3]
|
||||
minor: int
|
||||
fix: int
|
||||
|
||||
|
||||
class LavalinkVersion4Type(LavalinkVersion):
|
||||
major: Literal[4]
|
||||
minor: int
|
||||
fix: int
|
||||
|
||||
|
||||
LavalinkVersionType = Union[LavalinkVersion3Type, LavalinkVersion4Type, LavalinkVersion]
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from discord import ClientUser
|
||||
from discord import Member
|
||||
from discord import User
|
||||
from discord.ext import commands
|
||||
|
||||
from .enums import PlaylistType
|
||||
from .enums import SearchType
|
||||
from .enums import TrackType
|
||||
from .filters import Filter
|
||||
|
||||
__all__ = (
|
||||
"Track",
|
||||
"Playlist",
|
||||
)
|
||||
|
||||
|
||||
class Track:
|
||||
"""The base track object. Returns critical track information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your track.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"track_id",
|
||||
"info",
|
||||
"track_type",
|
||||
"filters",
|
||||
"timestamp",
|
||||
"original",
|
||||
"_search_type",
|
||||
"playlist",
|
||||
"title",
|
||||
"author",
|
||||
"uri",
|
||||
"identifier",
|
||||
"isrc",
|
||||
"thumbnail",
|
||||
"length",
|
||||
"ctx",
|
||||
"requester",
|
||||
"is_stream",
|
||||
"is_seekable",
|
||||
"position",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
track_id: str,
|
||||
info: dict,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
track_type: TrackType,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
requester: Optional[Union[Member, User, ClientUser]] = None,
|
||||
):
|
||||
self.track_id: str = track_id
|
||||
self.info: dict = info
|
||||
self.track_type: TrackType = track_type
|
||||
self.filters: Optional[List[Filter]] = filters
|
||||
self.timestamp: Optional[float] = timestamp
|
||||
|
||||
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
|
||||
self.original: Optional[Track] = None
|
||||
else:
|
||||
self.original = self
|
||||
self._search_type: SearchType = search_type
|
||||
|
||||
self.playlist: Optional[Playlist] = None
|
||||
|
||||
self.title: str = info.get("title", "Unknown Title")
|
||||
self.author: str = info.get("author", "Unknown Author")
|
||||
self.uri: str = info.get("uri", "")
|
||||
self.identifier: str = info.get("identifier", "")
|
||||
self.isrc: Optional[str] = info.get("isrc", None)
|
||||
self.thumbnail: Optional[str] = info.get("thumbnail")
|
||||
|
||||
if self.uri and self.track_type is TrackType.YOUTUBE:
|
||||
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
|
||||
|
||||
self.length: int = info.get("length", 0)
|
||||
self.is_stream: bool = info.get("isStream", False)
|
||||
self.is_seekable: bool = info.get("isSeekable", False)
|
||||
self.position: int = info.get("position", 0)
|
||||
|
||||
self.ctx: Optional[commands.Context] = ctx
|
||||
self.requester: Optional[Union[Member, User, ClientUser]] = requester
|
||||
if not self.requester and self.ctx:
|
||||
self.requester = self.ctx.author
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Track):
|
||||
return False
|
||||
|
||||
return other.track_id == self.track_id
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.title
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
|
||||
|
||||
|
||||
class Playlist:
|
||||
"""The base playlist object.
|
||||
Returns critical playlist information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your tracks.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"playlist_info",
|
||||
"tracks",
|
||||
"name",
|
||||
"playlist_type",
|
||||
"_thumbnail",
|
||||
"_uri",
|
||||
"selected_track",
|
||||
"track_count",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
playlist_info: dict,
|
||||
tracks: list,
|
||||
playlist_type: PlaylistType,
|
||||
thumbnail: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
):
|
||||
self.playlist_info: dict = playlist_info
|
||||
self.tracks: List[Track] = tracks
|
||||
self.name: str = playlist_info.get("name", "Unknown Playlist")
|
||||
self.playlist_type: PlaylistType = playlist_type
|
||||
|
||||
self._thumbnail: Optional[str] = thumbnail
|
||||
self._uri: Optional[str] = uri
|
||||
|
||||
for track in self.tracks:
|
||||
track.playlist = self
|
||||
|
||||
self.selected_track: Optional[Track] = None
|
||||
if (index := playlist_info.get("selectedTrack", -1)) != -1:
|
||||
self.selected_track = self.tracks[index]
|
||||
|
||||
self.track_count: int = len(self.tracks)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
|
||||
|
||||
@property
|
||||
def uri(self) -> Optional[str]:
|
||||
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Optional[str]:
|
||||
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
|
||||
return self._thumbnail
|
||||
101
pomice/player.py
101
pomice/player.py
|
|
@ -14,23 +14,24 @@ from discord import VoiceChannel
|
|||
from discord import VoiceProtocol
|
||||
from discord.ext import commands
|
||||
|
||||
from . import events
|
||||
from .enums import SearchType
|
||||
from .events import PomiceEvent
|
||||
from .events import TrackEndEvent
|
||||
from .events import TrackStartEvent
|
||||
from .exceptions import FilterInvalidArgument
|
||||
from .exceptions import FilterTagAlreadyInUse
|
||||
from .exceptions import FilterTagInvalid
|
||||
from .exceptions import TrackInvalidPosition
|
||||
from .exceptions import TrackLoadError
|
||||
from .filters import Filter
|
||||
from .filters import Timescale
|
||||
from .objects import Playlist
|
||||
from .objects import Track
|
||||
from .pool import Node
|
||||
from .pool import NodePool
|
||||
from pomice.utils import LavalinkVersion
|
||||
from pomice import events
|
||||
from pomice.enums import SearchType
|
||||
from pomice.exceptions import FilterInvalidArgument
|
||||
from pomice.exceptions import FilterTagAlreadyInUse
|
||||
from pomice.exceptions import FilterTagInvalid
|
||||
from pomice.exceptions import TrackInvalidPosition
|
||||
from pomice.exceptions import TrackLoadError
|
||||
from pomice.filters import Filter
|
||||
from pomice.filters import Timescale
|
||||
from pomice.models.events import PomiceEvent
|
||||
from pomice.models.events import TrackEndEvent
|
||||
from pomice.models.events import TrackStartEvent
|
||||
from pomice.models.music import Playlist
|
||||
from pomice.models.music import Track
|
||||
from pomice.models.payloads import TrackUpdatePayload
|
||||
from pomice.models.payloads import VoiceUpdatePayload
|
||||
from pomice.pool import Node
|
||||
from pomice.pool import NodePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.types.voice import VoiceServerUpdate
|
||||
|
|
@ -200,10 +201,10 @@ class Player(VoiceProtocol):
|
|||
@property
|
||||
def position(self) -> float:
|
||||
"""Property which returns the player's position in a track in milliseconds"""
|
||||
if not self.is_playing:
|
||||
if not self.is_playing or not self._current:
|
||||
return 0
|
||||
|
||||
current: Track = self._current # type: ignore
|
||||
current: Track = self._current
|
||||
if current.original:
|
||||
current = current.original
|
||||
|
||||
|
|
@ -230,10 +231,10 @@ class Player(VoiceProtocol):
|
|||
@property
|
||||
def adjusted_length(self) -> float:
|
||||
"""Property which returns the player's track length in milliseconds adjusted for rate"""
|
||||
if not self.is_playing:
|
||||
if not self.is_playing or not self._current:
|
||||
return 0
|
||||
|
||||
return self.current.length / self.rate # type: ignore
|
||||
return self.current.length / self.rate
|
||||
|
||||
@property
|
||||
def is_playing(self) -> bool:
|
||||
|
|
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
|
|||
"""
|
||||
return self.guild.id not in self._node._players
|
||||
|
||||
def _adjust_end_time(self) -> Optional[str]:
|
||||
if self._node._version >= LavalinkVersion(3, 7, 5):
|
||||
return None
|
||||
|
||||
return "0"
|
||||
|
||||
async def _update_state(self, data: dict) -> None:
|
||||
state: dict = data.get("state", {})
|
||||
self._last_update = int(state.get("time", 0))
|
||||
|
|
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
|
|||
if self._log:
|
||||
self._log.debug(f"Got player update state with data {state}")
|
||||
|
||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
if {"sessionId", "event"} != self._voice_state.keys():
|
||||
async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
|
||||
state = voice_data or self._voice_state
|
||||
if {"sessionId", "event"} != state.keys():
|
||||
return
|
||||
|
||||
state = voice_data or self._voice_state
|
||||
|
||||
data = {
|
||||
"token": state["event"]["token"],
|
||||
"endpoint": state["event"]["endpoint"],
|
||||
"sessionId": state["sessionId"],
|
||||
}
|
||||
data = VoiceUpdatePayload.model_validate(state)
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"voice": data},
|
||||
data={"voice": data.model_dump()},
|
||||
)
|
||||
|
||||
if self._log:
|
||||
|
|
@ -327,44 +317,39 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||
self._voice_state.update({"event": data})
|
||||
await self._dispatch_voice_update(self._voice_state)
|
||||
await self._dispatch_voice_update()
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
|
||||
self._voice_state.update({"sessionId": data.get("session_id")})
|
||||
|
||||
channel_id = data.get("channel_id")
|
||||
if not channel_id:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
self._voice_state.update({"sessionId": data["session_id"]})
|
||||
|
||||
channel_id = data["session_id"]
|
||||
channel = self.guild.get_channel(int(channel_id))
|
||||
|
||||
if self.channel != channel:
|
||||
self.channel = channel
|
||||
|
||||
if not channel:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
|
||||
if self.channel != channel:
|
||||
self.channel = channel
|
||||
|
||||
if not data.get("token"):
|
||||
return
|
||||
|
||||
self._voice_state.update({"event": data})
|
||||
await self._dispatch_voice_update({**self._voice_state, "event": data})
|
||||
|
||||
async def _dispatch_event(self, data: dict) -> None:
|
||||
event_type: str = data["type"]
|
||||
event: PomiceEvent = getattr(events, event_type)(data, self)
|
||||
event: PomiceEvent = getattr(events, event_type)(player=self, **data)
|
||||
|
||||
if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"):
|
||||
if isinstance(event, TrackEndEvent) and event.reason != "replaced":
|
||||
self._current = None
|
||||
|
||||
event.dispatch(self._bot)
|
||||
|
||||
if isinstance(event, TrackStartEvent):
|
||||
self._ending_track = self._current
|
||||
|
||||
event.dispatch(self._bot)
|
||||
|
||||
if self._log:
|
||||
self._log.debug(f"Dispatched event {data['type']} to player.")
|
||||
|
||||
|
|
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def _swap_node(self, *, new_node: Node) -> None:
|
||||
if self.current:
|
||||
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
|
||||
data: dict = TrackUpdatePayload(
|
||||
encoded_track=self.current.track_id,
|
||||
position=self.position,
|
||||
).model_dump()
|
||||
|
||||
del self._node._players[self._guild.id]
|
||||
self._node = new_node
|
||||
|
|
@ -396,7 +384,7 @@ class Player(VoiceProtocol):
|
|||
query: str,
|
||||
*,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
search_type: SearchType = SearchType.YTSEARCH,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
) -> Optional[Union[List[Track], Playlist]]:
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
|
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def set_volume(self, volume: int) -> int:
|
||||
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
|
||||
if volume < 0 or volume > 500:
|
||||
raise ValueError("Volume must be between 0 and 500")
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
|
|
|
|||
|
|
@ -20,35 +20,34 @@ import aiohttp
|
|||
import orjson as json
|
||||
from discord import Client
|
||||
from discord.ext import commands
|
||||
from discord.utils import MISSING
|
||||
from websockets import client
|
||||
from websockets import exceptions
|
||||
from websockets import typing as wstype
|
||||
|
||||
from . import __version__
|
||||
from . import applemusic
|
||||
from . import spotify
|
||||
from .enums import *
|
||||
from .enums import LogLevel
|
||||
from .exceptions import InvalidSpotifyClientAuthorization
|
||||
from .exceptions import LavalinkVersionIncompatible
|
||||
from .exceptions import NodeConnectionFailure
|
||||
from .exceptions import NodeCreationError
|
||||
from .exceptions import NodeNotAvailable
|
||||
from .exceptions import NodeRestException
|
||||
from .exceptions import NoNodesAvailable
|
||||
from .exceptions import TrackLoadError
|
||||
from .filters import Filter
|
||||
from .objects import Playlist
|
||||
from .objects import Track
|
||||
from .routeplanner import RoutePlanner
|
||||
from .utils import ExponentialBackoff
|
||||
from .utils import LavalinkVersion
|
||||
from .utils import NodeStats
|
||||
from .utils import Ping
|
||||
from pomice import __version__
|
||||
from pomice import applemusic
|
||||
from pomice import spotify
|
||||
from pomice.enums import *
|
||||
from pomice.exceptions import InvalidSpotifyClientAuthorization
|
||||
from pomice.exceptions import LavalinkVersionIncompatible
|
||||
from pomice.exceptions import NodeConnectionFailure
|
||||
from pomice.exceptions import NodeCreationError
|
||||
from pomice.exceptions import NodeNotAvailable
|
||||
from pomice.exceptions import NodeRestException
|
||||
from pomice.exceptions import NoNodesAvailable
|
||||
from pomice.exceptions import TrackLoadError
|
||||
from pomice.filters import Filter
|
||||
from pomice.models.music import Playlist
|
||||
from pomice.models.music import Track
|
||||
from pomice.models.payloads import ResumePayloadTypeAdapter
|
||||
from pomice.models.payloads import ResumePayloadV4
|
||||
from pomice.models.version import LavalinkVersion
|
||||
from pomice.routeplanner import RoutePlanner
|
||||
from pomice.utils import ExponentialBackoff
|
||||
from pomice.utils import NodeStats
|
||||
from pomice.utils import Ping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
from pomice.player import Player
|
||||
|
||||
__all__ = (
|
||||
"Node",
|
||||
|
|
@ -167,20 +166,14 @@ class Node:
|
|||
self._spotify_client: Optional[spotify.Client] = None
|
||||
self._apple_music_client: Optional[applemusic.Client] = None
|
||||
|
||||
self._spotify_client_id: Optional[str] = spotify_client_id
|
||||
self._spotify_client_secret: Optional[str] = spotify_client_secret
|
||||
|
||||
if self._spotify_client_id and self._spotify_client_secret:
|
||||
if spotify_client_id and spotify_client_secret:
|
||||
self._spotify_client = spotify.Client(
|
||||
self._spotify_client_id,
|
||||
self._spotify_client_secret,
|
||||
spotify_client_id,
|
||||
spotify_client_secret,
|
||||
)
|
||||
|
||||
if apple_music:
|
||||
self._apple_music_client = applemusic.Client()
|
||||
|
||||
self._bot.add_listener(self._update_handler, "on_socket_response")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
|
||||
|
|
@ -265,31 +258,6 @@ class Node:
|
|||
if self._apple_music_client:
|
||||
await self._apple_music_client._set_session(session=session)
|
||||
|
||||
async def _update_handler(self, data: dict) -> None:
|
||||
await self._bot.wait_until_ready()
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
if data["t"] == "VOICE_SERVER_UPDATE":
|
||||
guild_id = int(data["d"]["guild_id"])
|
||||
try:
|
||||
player = self._players[guild_id]
|
||||
await player.on_voice_server_update(data["d"])
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
elif data["t"] == "VOICE_STATE_UPDATE":
|
||||
if int(data["d"]["user_id"]) != self._bot_user.id:
|
||||
return
|
||||
|
||||
guild_id = int(data["d"]["guild_id"])
|
||||
try:
|
||||
player = self._players[guild_id]
|
||||
await player.on_voice_state_update(data["d"])
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
async def _handle_node_switch(self) -> None:
|
||||
nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected]
|
||||
new_node = random.choice(nodes)
|
||||
|
|
@ -303,14 +271,15 @@ class Node:
|
|||
if not self._resume_key:
|
||||
return
|
||||
|
||||
data = {"timeout": self._resume_timeout}
|
||||
data = ResumePayloadTypeAdapter(
|
||||
version=self._version,
|
||||
timeout=self._resume_timeout,
|
||||
resuming_key=self._resume_key,
|
||||
).model_dump()
|
||||
|
||||
if self._version.major == 3:
|
||||
data["resumingKey"] = self._resume_key
|
||||
elif self._version.major == 4:
|
||||
if isinstance(data, ResumePayloadV4):
|
||||
if self._log:
|
||||
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
|
||||
data["resuming"] = True
|
||||
|
||||
await self.send(
|
||||
method="PATCH",
|
||||
|
|
@ -560,7 +529,7 @@ class Node:
|
|||
query: str,
|
||||
*,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
search_type: SearchType = SearchType.YTSEARCH,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
) -> Optional[Union[Playlist, List[Track]]]:
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ from typing import List
|
|||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from .enums import LoopMode
|
||||
from .exceptions import QueueEmpty
|
||||
from .exceptions import QueueException
|
||||
from .exceptions import QueueFull
|
||||
from .objects import Track
|
||||
from pomice.enums import LoopMode
|
||||
from pomice.exceptions import QueueEmpty
|
||||
from pomice.exceptions import QueueException
|
||||
from pomice.exceptions import QueueFull
|
||||
from pomice.models.music import Track
|
||||
|
||||
__all__ = ("Queue",)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pool import Node
|
||||
from pomice.pool import Node
|
||||
|
||||
from .utils import RouteStats
|
||||
from pomice.utils import RouteStats
|
||||
|
||||
__all__ = ("RoutePlanner",)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Spotify module for Pomice, made possible by cloudwithax 2023"""
|
||||
from .client import Client
|
||||
from .client import *
|
||||
from .exceptions import *
|
||||
from .models import *
|
||||
from .objects import *
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from typing import Dict
|
||||
|
|
@ -13,18 +12,15 @@ from urllib.parse import quote
|
|||
import aiohttp
|
||||
import orjson as json
|
||||
|
||||
from .exceptions import InvalidSpotifyURL
|
||||
from .exceptions import SpotifyRequestException
|
||||
from .objects import *
|
||||
from pomice.enums import URLRegex
|
||||
from pomice.spotify.exceptions import *
|
||||
from pomice.spotify.models import *
|
||||
|
||||
__all__ = ("Client",)
|
||||
|
||||
|
||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||
SPOTIFY_URL_REGEX = re.compile(
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||
)
|
||||
|
||||
|
||||
class Client:
|
||||
|
|
@ -34,15 +30,12 @@ class Client:
|
|||
"""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
self._client_id: str = client_id
|
||||
self._client_secret: str = client_secret
|
||||
|
||||
self.session: aiohttp.ClientSession = None # type: ignore
|
||||
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._expiry: float = 0.0
|
||||
self._auth_token = b64encode(
|
||||
f"{self._client_id}:{self._client_secret}".encode(),
|
||||
f"{client_id}:{client_secret}".encode(),
|
||||
)
|
||||
self._grant_headers = {
|
||||
"Authorization": f"Basic {self._auth_token.decode()}",
|
||||
|
|
@ -77,7 +70,7 @@ class Client:
|
|||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
await self._fetch_bearer_token()
|
||||
|
||||
result = SPOTIFY_URL_REGEX.match(query)
|
||||
result = URLRegex.SPOTIFY_URL.match(query)
|
||||
if not result:
|
||||
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
|
||||
|
||||
|
|
@ -151,7 +144,7 @@ class Client:
|
|||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
await self._fetch_bearer_token()
|
||||
|
||||
result = SPOTIFY_URL_REGEX.match(query)
|
||||
result = URLRegex.SPOTIFY_URL.match(query)
|
||||
if not result:
|
||||
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from discord.ext.commands import Context
|
||||
from discord.user import _UserTag
|
||||
from pydantic import Field
|
||||
|
||||
from pomice.enums import SearchType
|
||||
from pomice.enums import TrackType
|
||||
from pomice.filters import Filter
|
||||
from pomice.models import BaseModel
|
||||
from pomice.models.music import Track
|
||||
from pomice.models.music import TrackInfo
|
||||
|
||||
|
||||
class SpotifyTrackRaw(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
artists: List[Dict[str, str]]
|
||||
duration_ms: float
|
||||
external_ids: Dict[str, str] = Field(default_factory=dict)
|
||||
external_urls: Dict[str, str] = Field(default_factory=dict)
|
||||
album: Dict[str, List[Dict[str, str]]] = Field(default_factory=dict)
|
||||
|
||||
def build_track(
|
||||
self,
|
||||
image: Optional[str] = None,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
ctx: Optional[Context] = None,
|
||||
requester: Optional[_UserTag] = None,
|
||||
) -> Track:
|
||||
if self.album:
|
||||
image = self.album["images"][0]["url"]
|
||||
|
||||
return Track(
|
||||
track_id=self.id,
|
||||
track_type=TrackType.SPOTIFY,
|
||||
search_type=SearchType.YTMSEARCH,
|
||||
filters=filters,
|
||||
ctx=ctx,
|
||||
requester=requester,
|
||||
info=TrackInfo(
|
||||
identifier=self.id,
|
||||
title=self.name,
|
||||
author=", ".join(artist["name"] for artist in self.artists),
|
||||
length=self.duration_ms,
|
||||
is_seekable=True,
|
||||
uri=self.external_urls.get("spotify", ""),
|
||||
artwork_url=image,
|
||||
isrc=self.external_ids.get("isrc"),
|
||||
),
|
||||
)
|
||||
|
|
@ -8,11 +8,10 @@ from typing import Any
|
|||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
|
||||
from .enums import RouteIPType
|
||||
from .enums import RouteStrategy
|
||||
from pomice.enums import RouteIPType
|
||||
from pomice.enums import RouteStrategy
|
||||
|
||||
__all__ = (
|
||||
"ExponentialBackoff",
|
||||
|
|
@ -20,7 +19,6 @@ __all__ = (
|
|||
"FailingIPBlock",
|
||||
"RouteStats",
|
||||
"Ping",
|
||||
"LavalinkVersion",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -226,53 +224,3 @@ class Ping:
|
|||
s_runtime = 1000 * (cost_time)
|
||||
|
||||
return s_runtime
|
||||
|
||||
|
||||
class LavalinkVersion(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
fix: int
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return (
|
||||
(self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix)
|
||||
)
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return not (self == other)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
if self.major > other.major:
|
||||
return False
|
||||
if self.minor > other.minor:
|
||||
return False
|
||||
if self.fix > other.fix:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return not (self < other)
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return (self < other) or (self == other)
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if not isinstance(other, LavalinkVersion):
|
||||
return False
|
||||
|
||||
return (self > other) or (self == other)
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -4,7 +4,7 @@ import re
|
|||
import setuptools
|
||||
|
||||
version = ""
|
||||
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
|
||||
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets", "pydantic>=2"]
|
||||
with open("pomice/__init__.py") as f:
|
||||
version = re.search(
|
||||
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
|
||||
|
|
|
|||
Loading…
Reference in New Issue