feat: add more payloads

This commit is contained in:
NiceAesth 2024-02-24 13:53:05 +02:00
parent bf92e37da6
commit bc19e07008
7 changed files with 96 additions and 233 deletions

View File

@ -9,3 +9,15 @@ from .version import *
class BaseModel(pydantic.BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
def model_dump(self, *args, **kwargs) -> dict:
by_alias = kwargs.pop("by_alias", True)
mode = kwargs.pop("mode", "json")
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
class VersionedModel(BaseModel):
version: LavalinkVersionType
def model_dump(self, *args, **kwargs) -> dict:
return super().model_dump(*args, **kwargs, exclude={"version"})

View File

@ -1,39 +1,66 @@
from __future__ import annotations
from typing import Optional
from typing import Union
from pydantic import AliasPath
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import TypeAdapter
from pomice.models import BaseModel
from pomice.models import LavalinkVersion
from pomice.models import LavalinkVersion3Type
from pomice.models import LavalinkVersion4Type
from pomice.models import VersionedModel
__all__ = (
"ResumePayload",
"ResumePayloadV3",
"ResumePayloadV4",
"VoiceUpdatePayload",
"TrackStartPayload",
"TrackUpdatePayload",
"ResumePayloadType",
"ResumePayloadTypeAdapter",
)
class ResumePayload(BaseModel):
version: LavalinkVersion
class VoiceUpdatePayload(BaseModel):
token: str = Field(validation_alias=AliasPath("event", "token"))
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
session_id: str = Field(alias="sessionId")
class TrackUpdatePayload(BaseModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
class TrackStartPayload(VersionedModel):
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
position: float
end_time: str = Field(default="0", alias="endTime")
@field_validator("end_time", mode="before")
@classmethod
def cast_end_time(cls, value: object) -> str:
return str(value)
@model_validator(mode="after")
def adjust_end_time(self) -> TrackStartPayload:
if self.version >= LavalinkVersion3Type(3, 7, 5):
self.end_time = None
class ResumePayload(VersionedModel):
timeout: int
def model_dump(self) -> dict:
return super().model_dump(by_alias=True, exclude={"version"})
class ResumePayloadV3(BaseModel):
class ResumePayloadV3(ResumePayload):
version: LavalinkVersion3Type
timeout: int
resuming_key: str = Field(alias="resumingKey")
class ResumePayloadV4(BaseModel):
class ResumePayloadV4(ResumePayload):
version: LavalinkVersion4Type
timeout: int
resuming: bool = True

View File

@ -1,167 +0,0 @@
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from discord import ClientUser
from discord import Member
from discord import User
from discord.ext import commands
from .enums import PlaylistType
from .enums import SearchType
from .enums import TrackType
from .filters import Filter
__all__ = (
"Track",
"Playlist",
)
class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track.
"""
__slots__ = (
"track_id",
"info",
"track_type",
"filters",
"timestamp",
"original",
"_search_type",
"playlist",
"title",
"author",
"uri",
"identifier",
"isrc",
"thumbnail",
"length",
"ctx",
"requester",
"is_stream",
"is_seekable",
"position",
)
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.YTSEARCH,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User, ClientUser]] = None,
):
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
self.filters: Optional[List[Filter]] = filters
self.timestamp: Optional[float] = timestamp
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
self.original: Optional[Track] = None
else:
self.original = self
self._search_type: SearchType = search_type
self.playlist: Optional[Playlist] = None
self.title: str = info.get("title", "Unknown Title")
self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier", "")
self.isrc: Optional[str] = info.get("isrc", None)
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri and self.track_type is TrackType.YOUTUBE:
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length: int = info.get("length", 0)
self.is_stream: bool = info.get("isStream", False)
self.is_seekable: bool = info.get("isSeekable", False)
self.position: int = info.get("position", 0)
self.ctx: Optional[commands.Context] = ctx
self.requester: Optional[Union[Member, User, ClientUser]] = requester
if not self.requester and self.ctx:
self.requester = self.ctx.author
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track):
return False
return other.track_id == self.track_id
def __str__(self) -> str:
return self.title
def __repr__(self) -> str:
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
class Playlist:
"""The base playlist object.
Returns critical playlist information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your tracks.
"""
__slots__ = (
"playlist_info",
"tracks",
"name",
"playlist_type",
"_thumbnail",
"_uri",
"selected_track",
"track_count",
)
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name", "Unknown Playlist")
self.playlist_type: PlaylistType = playlist_type
self._thumbnail: Optional[str] = thumbnail
self._uri: Optional[str] = uri
for track in self.tracks:
track.playlist = self
self.selected_track: Optional[Track] = None
if (index := playlist_info.get("selectedTrack", -1)) != -1:
self.selected_track = self.tracks[index]
self.track_count: int = len(self.tracks)
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
@property
def uri(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
return self._uri
@property
def thumbnail(self) -> Optional[str]:
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
return self._thumbnail

View File

@ -14,23 +14,24 @@ from discord import VoiceChannel
from discord import VoiceProtocol
from discord.ext import commands
from . import events
from .enums import SearchType
from .exceptions import FilterInvalidArgument
from .exceptions import FilterTagAlreadyInUse
from .exceptions import FilterTagInvalid
from .exceptions import TrackInvalidPosition
from .exceptions import TrackLoadError
from .filters import Filter
from .filters import Timescale
from .objects import Playlist
from .objects import Track
from .pool import Node
from .pool import NodePool
from pomice import events
from pomice.enums import SearchType
from pomice.exceptions import FilterInvalidArgument
from pomice.exceptions import FilterTagAlreadyInUse
from pomice.exceptions import FilterTagInvalid
from pomice.exceptions import TrackInvalidPosition
from pomice.exceptions import TrackLoadError
from pomice.filters import Filter
from pomice.filters import Timescale
from pomice.models.events import PomiceEvent
from pomice.models.events import TrackEndEvent
from pomice.models.events import TrackStartEvent
from pomice.models.version import LavalinkVersion
from pomice.models.music import Playlist
from pomice.models.music import Track
from pomice.models.payloads import TrackUpdatePayload
from pomice.models.payloads import VoiceUpdatePayload
from pomice.pool import Node
from pomice.pool import NodePool
if TYPE_CHECKING:
from discord.types.voice import VoiceServerUpdate
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
"""
return self.guild.id not in self._node._players
def _adjust_end_time(self) -> Optional[str]:
if self._node._version >= LavalinkVersion(3, 7, 5):
return None
return "0"
async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {})
self._last_update = int(state.get("time", 0))
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
if self._log:
self._log.debug(f"Got player update state with data {state}")
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys():
async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
state = voice_data or self._voice_state
if {"sessionId", "event"} != state.keys():
return
state = voice_data or self._voice_state
data = {
"token": state["event"]["token"],
"endpoint": state["event"]["endpoint"],
"sessionId": state["sessionId"],
}
data = VoiceUpdatePayload.model_validate(state)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"voice": data},
data={"voice": data.model_dump()},
)
if self._log:
@ -327,30 +317,26 @@ class Player(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data})
await self._dispatch_voice_update(self._voice_state)
await self._dispatch_voice_update()
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
self._voice_state.update({"sessionId": data.get("session_id")})
channel_id = data.get("channel_id")
if not channel_id:
await self.disconnect()
self._voice_state.clear()
return
self._voice_state.update({"sessionId": data["session_id"]})
channel_id = data["session_id"]
channel = self.guild.get_channel(int(channel_id))
if self.channel != channel:
self.channel = channel
if not channel:
await self.disconnect()
self._voice_state.clear()
return
if self.channel != channel:
self.channel = channel
if not data.get("token"):
return
self._voice_state.update({"event": data})
await self._dispatch_voice_update({**self._voice_state, "event": data})
async def _dispatch_event(self, data: dict) -> None:
@ -359,12 +345,11 @@ class Player(VoiceProtocol):
if isinstance(event, TrackEndEvent) and event.reason != "replaced":
self._current = None
event.dispatch(self._bot)
if isinstance(event, TrackStartEvent):
self._ending_track = self._current
event.dispatch(self._bot)
if self._log:
self._log.debug(f"Dispatched event {data['type']} to player.")
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
async def _swap_node(self, *, new_node: Node) -> None:
if self.current:
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
data: dict = TrackUpdatePayload(
encoded_track=self.current.track_id,
position=self.position,
).model_dump()
del self._node._players[self._guild.id]
self._node = new_node
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
if volume < 0 or volume > 500:
raise ValueError("Volume must be between 0 and 500")
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,

View File

@ -8,11 +8,11 @@ from typing import List
from typing import Optional
from typing import Union
from .enums import LoopMode
from .exceptions import QueueEmpty
from .exceptions import QueueException
from .exceptions import QueueFull
from .objects import Track
from pomice.enums import LoopMode
from pomice.exceptions import QueueEmpty
from pomice.exceptions import QueueException
from pomice.exceptions import QueueFull
from pomice.models.music import Track
__all__ = ("Queue",)

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .pool import Node
from .utils import RouteStats
from pomice.utils import RouteStats
__all__ = ("RoutePlanner",)

View File

@ -10,8 +10,8 @@ from typing import Dict
from typing import Iterable
from typing import Optional
from .enums import RouteIPType
from .enums import RouteStrategy
from pomice.enums import RouteIPType
from pomice.enums import RouteStrategy
__all__ = (
"ExponentialBackoff",