feat: add more payloads

This commit is contained in:
NiceAesth 2024-02-24 13:53:05 +02:00
parent bf92e37da6
commit bc19e07008
7 changed files with 96 additions and 233 deletions

View File

@ -9,3 +9,15 @@ from .version import *
class BaseModel(pydantic.BaseModel): class BaseModel(pydantic.BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
def model_dump(self, *args, **kwargs) -> dict:
by_alias = kwargs.pop("by_alias", True)
mode = kwargs.pop("mode", "json")
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
class VersionedModel(BaseModel):
version: LavalinkVersionType
def model_dump(self, *args, **kwargs) -> dict:
return super().model_dump(*args, **kwargs, exclude={"version"})

View File

@ -1,39 +1,66 @@
from __future__ import annotations
from typing import Optional
from typing import Union from typing import Union
from pydantic import AliasPath
from pydantic import Field from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import TypeAdapter from pydantic import TypeAdapter
from pomice.models import BaseModel from pomice.models import BaseModel
from pomice.models import LavalinkVersion
from pomice.models import LavalinkVersion3Type from pomice.models import LavalinkVersion3Type
from pomice.models import LavalinkVersion4Type from pomice.models import LavalinkVersion4Type
from pomice.models import VersionedModel
__all__ = ( __all__ = (
"ResumePayload", "VoiceUpdatePayload",
"ResumePayloadV3", "TrackStartPayload",
"ResumePayloadV4", "TrackUpdatePayload",
"ResumePayloadType", "ResumePayloadType",
"ResumePayloadTypeAdapter", "ResumePayloadTypeAdapter",
) )
class ResumePayload(BaseModel): class VoiceUpdatePayload(BaseModel):
version: LavalinkVersion token: str = Field(validation_alias=AliasPath("event", "token"))
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
session_id: str = Field(alias="sessionId")
class TrackUpdatePayload(BaseModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
class TrackStartPayload(VersionedModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
end_time: str = Field(default="0", alias="endTime")
@field_validator("end_time", mode="before")
@classmethod
def cast_end_time(cls, value: object) -> str:
return str(value)
@model_validator(mode="after")
def adjust_end_time(self) -> TrackStartPayload:
if self.version >= LavalinkVersion3Type(3, 7, 5):
self.end_time = None
class ResumePayload(VersionedModel):
timeout: int timeout: int
def model_dump(self) -> dict:
return super().model_dump(by_alias=True, exclude={"version"})
class ResumePayloadV3(ResumePayload):
class ResumePayloadV3(BaseModel):
version: LavalinkVersion3Type version: LavalinkVersion3Type
timeout: int
resuming_key: str = Field(alias="resumingKey") resuming_key: str = Field(alias="resumingKey")
class ResumePayloadV4(BaseModel): class ResumePayloadV4(ResumePayload):
version: LavalinkVersion4Type version: LavalinkVersion4Type
timeout: int
resuming: bool = True resuming: bool = True

View File

@ -1,167 +0,0 @@
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from discord import ClientUser
from discord import Member
from discord import User
from discord.ext import commands
from .enums import PlaylistType
from .enums import SearchType
from .enums import TrackType
from .filters import Filter
__all__ = (
"Track",
"Playlist",
)
class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track.
"""
__slots__ = (
"track_id",
"info",
"track_type",
"filters",
"timestamp",
"original",
"_search_type",
"playlist",
"title",
"author",
"uri",
"identifier",
"isrc",
"thumbnail",
"length",
"ctx",
"requester",
"is_stream",
"is_seekable",
"position",
)
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.YTSEARCH,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User, ClientUser]] = None,
):
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
self.filters: Optional[List[Filter]] = filters
self.timestamp: Optional[float] = timestamp
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
self.original: Optional[Track] = None
else:
self.original = self
self._search_type: SearchType = search_type
self.playlist: Optional[Playlist] = None
self.title: str = info.get("title", "Unknown Title")
self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier", "")
self.isrc: Optional[str] = info.get("isrc", None)
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri and self.track_type is TrackType.YOUTUBE:
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length: int = info.get("length", 0)
self.is_stream: bool = info.get("isStream", False)
self.is_seekable: bool = info.get("isSeekable", False)
self.position: int = info.get("position", 0)
self.ctx: Optional[commands.Context] = ctx
self.requester: Optional[Union[Member, User, ClientUser]] = requester
if not self.requester and self.ctx:
self.requester = self.ctx.author
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track):
return False
return other.track_id == self.track_id
def __str__(self) -> str:
return self.title
def __repr__(self) -> str:
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
class Playlist:
"""The base playlist object.
Returns critical playlist information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your tracks.
"""
__slots__ = (
"playlist_info",
"tracks",
"name",
"playlist_type",
"_thumbnail",
"_uri",
"selected_track",
"track_count",
)
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name", "Unknown Playlist")
self.playlist_type: PlaylistType = playlist_type
self._thumbnail: Optional[str] = thumbnail
self._uri: Optional[str] = uri
for track in self.tracks:
track.playlist = self
self.selected_track: Optional[Track] = None
if (index := playlist_info.get("selectedTrack", -1)) != -1:
self.selected_track = self.tracks[index]
self.track_count: int = len(self.tracks)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
@property
def uri(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
return self._uri
@property
def thumbnail(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
return self._thumbnail

View File

@ -14,23 +14,24 @@ from discord import VoiceChannel
from discord import VoiceProtocol from discord import VoiceProtocol
from discord.ext import commands from discord.ext import commands
from . import events from pomice import events
from .enums import SearchType from pomice.enums import SearchType
from .exceptions import FilterInvalidArgument from pomice.exceptions import FilterInvalidArgument
from .exceptions import FilterTagAlreadyInUse from pomice.exceptions import FilterTagAlreadyInUse
from .exceptions import FilterTagInvalid from pomice.exceptions import FilterTagInvalid
from .exceptions import TrackInvalidPosition from pomice.exceptions import TrackInvalidPosition
from .exceptions import TrackLoadError from pomice.exceptions import TrackLoadError
from .filters import Filter from pomice.filters import Filter
from .filters import Timescale from pomice.filters import Timescale
from .objects import Playlist
from .objects import Track
from .pool import Node
from .pool import NodePool
from pomice.models.events import PomiceEvent from pomice.models.events import PomiceEvent
from pomice.models.events import TrackEndEvent from pomice.models.events import TrackEndEvent
from pomice.models.events import TrackStartEvent from pomice.models.events import TrackStartEvent
from pomice.models.version import LavalinkVersion from pomice.models.music import Playlist
from pomice.models.music import Track
from pomice.models.payloads import TrackUpdatePayload
from pomice.models.payloads import VoiceUpdatePayload
from pomice.pool import Node
from pomice.pool import NodePool
if TYPE_CHECKING: if TYPE_CHECKING:
from discord.types.voice import VoiceServerUpdate from discord.types.voice import VoiceServerUpdate
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
def _adjust_end_time(self) -> Optional[str]:
if self._node._version >= LavalinkVersion(3, 7, 5):
return None
return "0"
async def _update_state(self, data: dict) -> None: async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {}) state: dict = data.get("state", {})
self._last_update = int(state.get("time", 0)) self._last_update = int(state.get("time", 0))
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
if self._log: if self._log:
self._log.debug(f"Got player update state with data {state}") self._log.debug(f"Got player update state with data {state}")
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
if {"sessionId", "event"} != self._voice_state.keys(): state = voice_data or self._voice_state
if {"sessionId", "event"} != state.keys():
return return
state = voice_data or self._voice_state data = VoiceUpdatePayload.model_validate(state)
data = {
"token": state["event"]["token"],
"endpoint": state["event"]["endpoint"],
"sessionId": state["sessionId"],
}
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,
guild_id=self._guild.id, guild_id=self._guild.id,
data={"voice": data}, data={"voice": data.model_dump()},
) )
if self._log: if self._log:
@ -327,30 +317,26 @@ class Player(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None: async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data}) self._voice_state.update({"event": data})
await self._dispatch_voice_update(self._voice_state) await self._dispatch_voice_update()
async def on_voice_state_update(self, data: GuildVoiceState) -> None: async def on_voice_state_update(self, data: GuildVoiceState) -> None:
self._voice_state.update({"sessionId": data.get("session_id")}) self._voice_state.update({"sessionId": data["session_id"]})
channel_id = data.get("channel_id")
if not channel_id:
await self.disconnect()
self._voice_state.clear()
return
channel_id = data["session_id"]
channel = self.guild.get_channel(int(channel_id)) channel = self.guild.get_channel(int(channel_id))
if self.channel != channel:
self.channel = channel
if not channel: if not channel:
await self.disconnect() await self.disconnect()
self._voice_state.clear() self._voice_state.clear()
return return
if self.channel != channel:
self.channel = channel
if not data.get("token"): if not data.get("token"):
return return
self._voice_state.update({"event": data})
await self._dispatch_voice_update({**self._voice_state, "event": data}) await self._dispatch_voice_update({**self._voice_state, "event": data})
async def _dispatch_event(self, data: dict) -> None: async def _dispatch_event(self, data: dict) -> None:
@ -359,12 +345,11 @@ class Player(VoiceProtocol):
if isinstance(event, TrackEndEvent) and event.reason != "replaced": if isinstance(event, TrackEndEvent) and event.reason != "replaced":
self._current = None self._current = None
event.dispatch(self._bot)
if isinstance(event, TrackStartEvent): if isinstance(event, TrackStartEvent):
self._ending_track = self._current self._ending_track = self._current
event.dispatch(self._bot)
if self._log: if self._log:
self._log.debug(f"Dispatched event {data['type']} to player.") self._log.debug(f"Dispatched event {data['type']} to player.")
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
async def _swap_node(self, *, new_node: Node) -> None: async def _swap_node(self, *, new_node: Node) -> None:
if self.current: if self.current:
data: dict = {"position": self.position, "encodedTrack": self.current.track_id} data: dict = TrackUpdatePayload(
encoded_track=self.current.track_id,
position=self.position,
).model_dump()
del self._node._players[self._guild.id] del self._node._players[self._guild.id]
self._node = new_node self._node = new_node
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
async def set_volume(self, volume: int) -> int: async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
if volume < 0 or volume > 500:
raise ValueError("Volume must be between 0 and 500")
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,

View File

@ -8,11 +8,11 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from .enums import LoopMode from pomice.enums import LoopMode
from .exceptions import QueueEmpty from pomice.exceptions import QueueEmpty
from .exceptions import QueueException from pomice.exceptions import QueueException
from .exceptions import QueueFull from pomice.exceptions import QueueFull
from .objects import Track from pomice.models.music import Track
__all__ = ("Queue",) __all__ = ("Queue",)

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from .pool import Node from .pool import Node
from .utils import RouteStats from pomice.utils import RouteStats
__all__ = ("RoutePlanner",) __all__ = ("RoutePlanner",)

View File

@ -10,8 +10,8 @@ from typing import Dict
from typing import Iterable from typing import Iterable
from typing import Optional from typing import Optional
from .enums import RouteIPType from pomice.enums import RouteIPType
from .enums import RouteStrategy from pomice.enums import RouteStrategy
__all__ = ( __all__ = (
"ExponentialBackoff", "ExponentialBackoff",