From bc19e07008ab56f0f2275f4ff08db212e250659c Mon Sep 17 00:00:00 2001 From: NiceAesth Date: Sat, 24 Feb 2024 13:53:05 +0200 Subject: [PATCH] feat: add more payloads --- pomice/models/__init__.py | 12 +++ pomice/models/payloads.py | 53 +++++++++--- pomice/objects.py | 167 -------------------------------------- pomice/player.py | 81 ++++++++---------- pomice/queue.py | 10 +-- pomice/routeplanner.py | 2 +- pomice/utils.py | 4 +- 7 files changed, 96 insertions(+), 233 deletions(-) delete mode 100644 pomice/objects.py diff --git a/pomice/models/__init__.py b/pomice/models/__init__.py index 532d545..3959143 100644 --- a/pomice/models/__init__.py +++ b/pomice/models/__init__.py @@ -9,3 +9,15 @@ from .version import * class BaseModel(pydantic.BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + def model_dump(self, *args, **kwargs) -> dict: + by_alias = kwargs.pop("by_alias", True) + mode = kwargs.pop("mode", "json") + return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode) + + +class VersionedModel(BaseModel): + version: LavalinkVersionType + + def model_dump(self, *args, **kwargs) -> dict: + return super().model_dump(*args, **kwargs, exclude={"version"}) diff --git a/pomice/models/payloads.py b/pomice/models/payloads.py index 31ce66b..96d97ef 100644 --- a/pomice/models/payloads.py +++ b/pomice/models/payloads.py @@ -1,39 +1,66 @@ +from __future__ import annotations + +from typing import Optional from typing import Union +from pydantic import AliasPath from pydantic import Field +from pydantic import field_validator +from pydantic import model_validator from pydantic import TypeAdapter from pomice.models import BaseModel -from pomice.models import LavalinkVersion from pomice.models import LavalinkVersion3Type from pomice.models import LavalinkVersion4Type +from pomice.models import VersionedModel __all__ = ( - "ResumePayload", - "ResumePayloadV3", - "ResumePayloadV4", + "VoiceUpdatePayload", + "TrackStartPayload", + "TrackUpdatePayload", "ResumePayloadType", "ResumePayloadTypeAdapter", ) -class ResumePayload(BaseModel): - version: LavalinkVersion +class VoiceUpdatePayload(BaseModel): + token: str = Field(validation_alias=AliasPath("event", "token")) + endpoint: str = Field(validation_alias=AliasPath("event", "endpoint")) + session_id: str = Field(alias="sessionId") + + +class TrackUpdatePayload(BaseModel): + encoded_track: Optional[str] = Field(default=None, alias="encodedTrack") + position: float + + +class TrackStartPayload(VersionedModel): + encoded_track: Optional[str] = Field(default=None, alias="encodedTrack") + position: float + end_time: str = Field(default="0", alias="endTime") + + @field_validator("end_time", mode="before") + @classmethod + def cast_end_time(cls, value: object) -> str: + return str(value) + + @model_validator(mode="after") + def adjust_end_time(self) -> TrackStartPayload: + if self.version >= LavalinkVersion3Type(3, 7, 5): + self.end_time = None + + +class ResumePayload(VersionedModel): timeout: int - def model_dump(self) -> dict: - return super().model_dump(by_alias=True, exclude={"version"}) - -class ResumePayloadV3(BaseModel): +class ResumePayloadV3(ResumePayload): version: LavalinkVersion3Type - timeout: int resuming_key: str = Field(alias="resumingKey") -class ResumePayloadV4(BaseModel): +class ResumePayloadV4(ResumePayload): version: LavalinkVersion4Type - timeout: int resuming: bool = True diff --git a/pomice/objects.py b/pomice/objects.py deleted file mode 100644 index 1576308..0000000 --- a/pomice/objects.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -from typing import List -from typing import Optional -from typing import Union - -from discord import ClientUser -from discord import Member -from discord import User -from discord.ext import commands - -from .enums import PlaylistType -from .enums import SearchType -from .enums import TrackType -from .filters import Filter - -__all__ = ( - "Track", - "Playlist", -) - - -class Track: - """The base track object. Returns critical track information needed for parsing by Lavalink. - You can also pass in commands.Context to get a discord.py Context object in your track. - """ - - __slots__ = ( - "track_id", - "info", - "track_type", - "filters", - "timestamp", - "original", - "_search_type", - "playlist", - "title", - "author", - "uri", - "identifier", - "isrc", - "thumbnail", - "length", - "ctx", - "requester", - "is_stream", - "is_seekable", - "position", - ) - - def __init__( - self, - *, - track_id: str, - info: dict, - ctx: Optional[commands.Context] = None, - track_type: TrackType, - search_type: SearchType = SearchType.YTSEARCH, - filters: Optional[List[Filter]] = None, - timestamp: Optional[float] = None, - requester: Optional[Union[Member, User, ClientUser]] = None, - ): - self.track_id: str = track_id - self.info: dict = info - self.track_type: TrackType = track_type - self.filters: Optional[List[Filter]] = filters - self.timestamp: Optional[float] = timestamp - - if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC: - self.original: Optional[Track] = None - else: - self.original = self - self._search_type: SearchType = search_type - - self.playlist: Optional[Playlist] = None - - self.title: str = info.get("title", "Unknown Title") - self.author: str = info.get("author", "Unknown Author") - self.uri: str = info.get("uri", "") - self.identifier: str = info.get("identifier", "") - self.isrc: Optional[str] = info.get("isrc", None) - self.thumbnail: Optional[str] = info.get("thumbnail") - - if self.uri and self.track_type is TrackType.YOUTUBE: - self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg" - - self.length: int = info.get("length", 0) - self.is_stream: bool = info.get("isStream", False) - self.is_seekable: bool = info.get("isSeekable", False) - self.position: int = info.get("position", 0) - - self.ctx: Optional[commands.Context] = ctx - self.requester: Optional[Union[Member, User, ClientUser]] = requester - if not self.requester and self.ctx: - self.requester = self.ctx.author - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Track): - return False - - return other.track_id == self.track_id - - def __str__(self) -> str: - return self.title - - def __repr__(self) -> str: - return f" length={self.length}>" - - -class Playlist: - """The base playlist object. - Returns critical playlist information needed for parsing by Lavalink. - You can also pass in commands.Context to get a discord.py Context object in your tracks. - """ - - __slots__ = ( - "playlist_info", - "tracks", - "name", - "playlist_type", - "_thumbnail", - "_uri", - "selected_track", - "track_count", - ) - - def __init__( - self, - *, - playlist_info: dict, - tracks: list, - playlist_type: PlaylistType, - thumbnail: Optional[str] = None, - uri: Optional[str] = None, - ): - self.playlist_info: dict = playlist_info - self.tracks: List[Track] = tracks - self.name: str = playlist_info.get("name", "Unknown Playlist") - self.playlist_type: PlaylistType = playlist_type - - self._thumbnail: Optional[str] = thumbnail - self._uri: Optional[str] = uri - - for track in self.tracks: - track.playlist = self - - self.selected_track: Optional[Track] = None - if (index := playlist_info.get("selectedTrack", -1)) != -1: - self.selected_track = self.tracks[index] - - self.track_count: int = len(self.tracks) - - def __str__(self) -> str: - return self.name - - def __repr__(self) -> str: - return f"" - - @property - def uri(self) -> Optional[str]: - """Returns either an Apple Music/Spotify URL/URI, or None if its neither of those.""" - return self._uri - - @property - def thumbnail(self) -> Optional[str]: - """Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those.""" - return self._thumbnail diff --git a/pomice/player.py b/pomice/player.py index 4cd005b..d15db08 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -14,23 +14,24 @@ from discord import VoiceChannel from discord import VoiceProtocol from discord.ext import commands -from . import events -from .enums import SearchType -from .exceptions import FilterInvalidArgument -from .exceptions import FilterTagAlreadyInUse -from .exceptions import FilterTagInvalid -from .exceptions import TrackInvalidPosition -from .exceptions import TrackLoadError -from .filters import Filter -from .filters import Timescale -from .objects import Playlist -from .objects import Track -from .pool import Node -from .pool import NodePool +from pomice import events +from pomice.enums import SearchType +from pomice.exceptions import FilterInvalidArgument +from pomice.exceptions import FilterTagAlreadyInUse +from pomice.exceptions import FilterTagInvalid +from pomice.exceptions import TrackInvalidPosition +from pomice.exceptions import TrackLoadError +from pomice.filters import Filter +from pomice.filters import Timescale from pomice.models.events import PomiceEvent from pomice.models.events import TrackEndEvent from pomice.models.events import TrackStartEvent -from pomice.models.version import LavalinkVersion +from pomice.models.music import Playlist +from pomice.models.music import Track +from pomice.models.payloads import TrackUpdatePayload +from pomice.models.payloads import VoiceUpdatePayload +from pomice.pool import Node +from pomice.pool import NodePool if TYPE_CHECKING: from discord.types.voice import VoiceServerUpdate @@ -287,12 +288,6 @@ class Player(VoiceProtocol): """ return self.guild.id not in self._node._players - def _adjust_end_time(self) -> Optional[str]: - if self._node._version >= LavalinkVersion(3, 7, 5): - return None - - return "0" - async def _update_state(self, data: dict) -> None: state: dict = data.get("state", {}) self._last_update = int(state.get("time", 0)) @@ -301,23 +296,18 @@ class Player(VoiceProtocol): if self._log: self._log.debug(f"Got player update state with data {state}") - async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: - if {"sessionId", "event"} != self._voice_state.keys(): + async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None: + state = voice_data or self._voice_state + if {"sessionId", "event"} != state.keys(): return - state = voice_data or self._voice_state - - data = { - "token": state["event"]["token"], - "endpoint": state["event"]["endpoint"], - "sessionId": state["sessionId"], - } + data = VoiceUpdatePayload.model_validate(state) await self._node.send( method="PATCH", path=self._player_endpoint_uri, guild_id=self._guild.id, - data={"voice": data}, + data={"voice": data.model_dump()}, ) if self._log: @@ -327,30 +317,26 @@ class Player(VoiceProtocol): async def on_voice_server_update(self, data: VoiceServerUpdate) -> None: self._voice_state.update({"event": data}) - await self._dispatch_voice_update(self._voice_state) + await self._dispatch_voice_update() async def on_voice_state_update(self, data: GuildVoiceState) -> None: - self._voice_state.update({"sessionId": data.get("session_id")}) - - channel_id = data.get("channel_id") - if not channel_id: - await self.disconnect() - self._voice_state.clear() - return + self._voice_state.update({"sessionId": data["session_id"]}) + channel_id = data["session_id"] channel = self.guild.get_channel(int(channel_id)) - if self.channel != channel: - self.channel = channel - if not channel: await self.disconnect() self._voice_state.clear() return + if self.channel != channel: + self.channel = channel + if not data.get("token"): return + self._voice_state.update({"event": data}) await self._dispatch_voice_update({**self._voice_state, "event": data}) async def _dispatch_event(self, data: dict) -> None: @@ -359,12 +345,11 @@ class Player(VoiceProtocol): if isinstance(event, TrackEndEvent) and event.reason != "replaced": self._current = None - - event.dispatch(self._bot) - if isinstance(event, TrackStartEvent): self._ending_track = self._current + event.dispatch(self._bot) + if self._log: self._log.debug(f"Dispatched event {data['type']} to player.") @@ -373,7 +358,10 @@ class Player(VoiceProtocol): async def _swap_node(self, *, new_node: Node) -> None: if self.current: - data: dict = {"position": self.position, "encodedTrack": self.current.track_id} + data: dict = TrackUpdatePayload( + encoded_track=self.current.track_id, + position=self.position, + ).model_dump() del self._node._players[self._guild.id] self._node = new_node @@ -629,6 +617,9 @@ class Player(VoiceProtocol): async def set_volume(self, volume: int) -> int: """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" + if volume < 0 or volume > 500: + raise ValueError("Volume must be between 0 and 500") + await self._node.send( method="PATCH", path=self._player_endpoint_uri, diff --git a/pomice/queue.py b/pomice/queue.py index 3ea6e8b..0a54edf 100644 --- a/pomice/queue.py +++ b/pomice/queue.py @@ -8,11 +8,11 @@ from typing import List from typing import Optional from typing import Union -from .enums import LoopMode -from .exceptions import QueueEmpty -from .exceptions import QueueException -from .exceptions import QueueFull -from .objects import Track +from pomice.enums import LoopMode +from pomice.exceptions import QueueEmpty +from pomice.exceptions import QueueException +from pomice.exceptions import QueueFull +from pomice.models.music import Track __all__ = ("Queue",) diff --git a/pomice/routeplanner.py b/pomice/routeplanner.py index 9a3d06e..f7e77b4 100644 --- a/pomice/routeplanner.py +++ b/pomice/routeplanner.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from .pool import Node -from .utils import RouteStats +from pomice.utils import RouteStats __all__ = ("RoutePlanner",) diff --git a/pomice/utils.py b/pomice/utils.py index 9eed7eb..aa1a623 100644 --- a/pomice/utils.py +++ b/pomice/utils.py @@ -10,8 +10,8 @@ from typing import Dict from typing import Iterable from typing import Optional -from .enums import RouteIPType -from .enums import RouteStrategy +from pomice.enums import RouteIPType +from pomice.enums import RouteStrategy __all__ = ( "ExponentialBackoff",