feat: add more payloads
This commit is contained in:
parent
bf92e37da6
commit
bc19e07008
|
|
@ -9,3 +9,15 @@ from .version import *
|
|||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict:
|
||||
by_alias = kwargs.pop("by_alias", True)
|
||||
mode = kwargs.pop("mode", "json")
|
||||
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
|
||||
|
||||
|
||||
class VersionedModel(BaseModel):
|
||||
version: LavalinkVersionType
|
||||
|
||||
def model_dump(self, *args, **kwargs) -> dict:
|
||||
return super().model_dump(*args, **kwargs, exclude={"version"})
|
||||
|
|
|
|||
|
|
@ -1,39 +1,66 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pydantic import AliasPath
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from pomice.models import BaseModel
|
||||
from pomice.models import LavalinkVersion
|
||||
from pomice.models import LavalinkVersion3Type
|
||||
from pomice.models import LavalinkVersion4Type
|
||||
from pomice.models import VersionedModel
|
||||
|
||||
__all__ = (
|
||||
"ResumePayload",
|
||||
"ResumePayloadV3",
|
||||
"ResumePayloadV4",
|
||||
"VoiceUpdatePayload",
|
||||
"TrackStartPayload",
|
||||
"TrackUpdatePayload",
|
||||
"ResumePayloadType",
|
||||
"ResumePayloadTypeAdapter",
|
||||
)
|
||||
|
||||
|
||||
class ResumePayload(BaseModel):
|
||||
version: LavalinkVersion
|
||||
class VoiceUpdatePayload(BaseModel):
|
||||
token: str = Field(validation_alias=AliasPath("event", "token"))
|
||||
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
|
||||
session_id: str = Field(alias="sessionId")
|
||||
|
||||
|
||||
class TrackUpdatePayload(BaseModel):
|
||||
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||
position: float
|
||||
|
||||
|
||||
class TrackStartPayload(VersionedModel):
|
||||
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||
position: float
|
||||
end_time: str = Field(default="0", alias="endTime")
|
||||
|
||||
@field_validator("end_time", mode="before")
|
||||
@classmethod
|
||||
def cast_end_time(cls, value: object) -> str:
|
||||
return str(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def adjust_end_time(self) -> TrackStartPayload:
|
||||
if self.version >= LavalinkVersion3Type(3, 7, 5):
|
||||
self.end_time = None
|
||||
|
||||
|
||||
class ResumePayload(VersionedModel):
|
||||
timeout: int
|
||||
|
||||
def model_dump(self) -> dict:
|
||||
return super().model_dump(by_alias=True, exclude={"version"})
|
||||
|
||||
|
||||
class ResumePayloadV3(BaseModel):
|
||||
class ResumePayloadV3(ResumePayload):
|
||||
version: LavalinkVersion3Type
|
||||
timeout: int
|
||||
resuming_key: str = Field(alias="resumingKey")
|
||||
|
||||
|
||||
class ResumePayloadV4(BaseModel):
|
||||
class ResumePayloadV4(ResumePayload):
|
||||
version: LavalinkVersion4Type
|
||||
timeout: int
|
||||
resuming: bool = True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,167 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from discord import ClientUser
|
||||
from discord import Member
|
||||
from discord import User
|
||||
from discord.ext import commands
|
||||
|
||||
from .enums import PlaylistType
|
||||
from .enums import SearchType
|
||||
from .enums import TrackType
|
||||
from .filters import Filter
|
||||
|
||||
__all__ = (
|
||||
"Track",
|
||||
"Playlist",
|
||||
)
|
||||
|
||||
|
||||
class Track:
|
||||
"""The base track object. Returns critical track information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your track.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"track_id",
|
||||
"info",
|
||||
"track_type",
|
||||
"filters",
|
||||
"timestamp",
|
||||
"original",
|
||||
"_search_type",
|
||||
"playlist",
|
||||
"title",
|
||||
"author",
|
||||
"uri",
|
||||
"identifier",
|
||||
"isrc",
|
||||
"thumbnail",
|
||||
"length",
|
||||
"ctx",
|
||||
"requester",
|
||||
"is_stream",
|
||||
"is_seekable",
|
||||
"position",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
track_id: str,
|
||||
info: dict,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
track_type: TrackType,
|
||||
search_type: SearchType = SearchType.YTSEARCH,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
requester: Optional[Union[Member, User, ClientUser]] = None,
|
||||
):
|
||||
self.track_id: str = track_id
|
||||
self.info: dict = info
|
||||
self.track_type: TrackType = track_type
|
||||
self.filters: Optional[List[Filter]] = filters
|
||||
self.timestamp: Optional[float] = timestamp
|
||||
|
||||
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
|
||||
self.original: Optional[Track] = None
|
||||
else:
|
||||
self.original = self
|
||||
self._search_type: SearchType = search_type
|
||||
|
||||
self.playlist: Optional[Playlist] = None
|
||||
|
||||
self.title: str = info.get("title", "Unknown Title")
|
||||
self.author: str = info.get("author", "Unknown Author")
|
||||
self.uri: str = info.get("uri", "")
|
||||
self.identifier: str = info.get("identifier", "")
|
||||
self.isrc: Optional[str] = info.get("isrc", None)
|
||||
self.thumbnail: Optional[str] = info.get("thumbnail")
|
||||
|
||||
if self.uri and self.track_type is TrackType.YOUTUBE:
|
||||
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
|
||||
|
||||
self.length: int = info.get("length", 0)
|
||||
self.is_stream: bool = info.get("isStream", False)
|
||||
self.is_seekable: bool = info.get("isSeekable", False)
|
||||
self.position: int = info.get("position", 0)
|
||||
|
||||
self.ctx: Optional[commands.Context] = ctx
|
||||
self.requester: Optional[Union[Member, User, ClientUser]] = requester
|
||||
if not self.requester and self.ctx:
|
||||
self.requester = self.ctx.author
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Track):
|
||||
return False
|
||||
|
||||
return other.track_id == self.track_id
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.title
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
|
||||
|
||||
|
||||
class Playlist:
|
||||
"""The base playlist object.
|
||||
Returns critical playlist information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your tracks.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"playlist_info",
|
||||
"tracks",
|
||||
"name",
|
||||
"playlist_type",
|
||||
"_thumbnail",
|
||||
"_uri",
|
||||
"selected_track",
|
||||
"track_count",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
playlist_info: dict,
|
||||
tracks: list,
|
||||
playlist_type: PlaylistType,
|
||||
thumbnail: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
):
|
||||
self.playlist_info: dict = playlist_info
|
||||
self.tracks: List[Track] = tracks
|
||||
self.name: str = playlist_info.get("name", "Unknown Playlist")
|
||||
self.playlist_type: PlaylistType = playlist_type
|
||||
|
||||
self._thumbnail: Optional[str] = thumbnail
|
||||
self._uri: Optional[str] = uri
|
||||
|
||||
for track in self.tracks:
|
||||
track.playlist = self
|
||||
|
||||
self.selected_track: Optional[Track] = None
|
||||
if (index := playlist_info.get("selectedTrack", -1)) != -1:
|
||||
self.selected_track = self.tracks[index]
|
||||
|
||||
self.track_count: int = len(self.tracks)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
|
||||
|
||||
@property
|
||||
def uri(self) -> Optional[str]:
|
||||
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def thumbnail(self) -> Optional[str]:
|
||||
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
|
||||
return self._thumbnail
|
||||
|
|
@ -14,23 +14,24 @@ from discord import VoiceChannel
|
|||
from discord import VoiceProtocol
|
||||
from discord.ext import commands
|
||||
|
||||
from . import events
|
||||
from .enums import SearchType
|
||||
from .exceptions import FilterInvalidArgument
|
||||
from .exceptions import FilterTagAlreadyInUse
|
||||
from .exceptions import FilterTagInvalid
|
||||
from .exceptions import TrackInvalidPosition
|
||||
from .exceptions import TrackLoadError
|
||||
from .filters import Filter
|
||||
from .filters import Timescale
|
||||
from .objects import Playlist
|
||||
from .objects import Track
|
||||
from .pool import Node
|
||||
from .pool import NodePool
|
||||
from pomice import events
|
||||
from pomice.enums import SearchType
|
||||
from pomice.exceptions import FilterInvalidArgument
|
||||
from pomice.exceptions import FilterTagAlreadyInUse
|
||||
from pomice.exceptions import FilterTagInvalid
|
||||
from pomice.exceptions import TrackInvalidPosition
|
||||
from pomice.exceptions import TrackLoadError
|
||||
from pomice.filters import Filter
|
||||
from pomice.filters import Timescale
|
||||
from pomice.models.events import PomiceEvent
|
||||
from pomice.models.events import TrackEndEvent
|
||||
from pomice.models.events import TrackStartEvent
|
||||
from pomice.models.version import LavalinkVersion
|
||||
from pomice.models.music import Playlist
|
||||
from pomice.models.music import Track
|
||||
from pomice.models.payloads import TrackUpdatePayload
|
||||
from pomice.models.payloads import VoiceUpdatePayload
|
||||
from pomice.pool import Node
|
||||
from pomice.pool import NodePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.types.voice import VoiceServerUpdate
|
||||
|
|
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
|
|||
"""
|
||||
return self.guild.id not in self._node._players
|
||||
|
||||
def _adjust_end_time(self) -> Optional[str]:
|
||||
if self._node._version >= LavalinkVersion(3, 7, 5):
|
||||
return None
|
||||
|
||||
return "0"
|
||||
|
||||
async def _update_state(self, data: dict) -> None:
|
||||
state: dict = data.get("state", {})
|
||||
self._last_update = int(state.get("time", 0))
|
||||
|
|
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
|
|||
if self._log:
|
||||
self._log.debug(f"Got player update state with data {state}")
|
||||
|
||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
if {"sessionId", "event"} != self._voice_state.keys():
|
||||
async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
|
||||
state = voice_data or self._voice_state
|
||||
if {"sessionId", "event"} != state.keys():
|
||||
return
|
||||
|
||||
state = voice_data or self._voice_state
|
||||
|
||||
data = {
|
||||
"token": state["event"]["token"],
|
||||
"endpoint": state["event"]["endpoint"],
|
||||
"sessionId": state["sessionId"],
|
||||
}
|
||||
data = VoiceUpdatePayload.model_validate(state)
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"voice": data},
|
||||
data={"voice": data.model_dump()},
|
||||
)
|
||||
|
||||
if self._log:
|
||||
|
|
@ -327,30 +317,26 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||
self._voice_state.update({"event": data})
|
||||
await self._dispatch_voice_update(self._voice_state)
|
||||
await self._dispatch_voice_update()
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
|
||||
self._voice_state.update({"sessionId": data.get("session_id")})
|
||||
|
||||
channel_id = data.get("channel_id")
|
||||
if not channel_id:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
self._voice_state.update({"sessionId": data["session_id"]})
|
||||
|
||||
channel_id = data["session_id"]
|
||||
channel = self.guild.get_channel(int(channel_id))
|
||||
|
||||
if self.channel != channel:
|
||||
self.channel = channel
|
||||
|
||||
if not channel:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
|
||||
if self.channel != channel:
|
||||
self.channel = channel
|
||||
|
||||
if not data.get("token"):
|
||||
return
|
||||
|
||||
self._voice_state.update({"event": data})
|
||||
await self._dispatch_voice_update({**self._voice_state, "event": data})
|
||||
|
||||
async def _dispatch_event(self, data: dict) -> None:
|
||||
|
|
@ -359,12 +345,11 @@ class Player(VoiceProtocol):
|
|||
|
||||
if isinstance(event, TrackEndEvent) and event.reason != "replaced":
|
||||
self._current = None
|
||||
|
||||
event.dispatch(self._bot)
|
||||
|
||||
if isinstance(event, TrackStartEvent):
|
||||
self._ending_track = self._current
|
||||
|
||||
event.dispatch(self._bot)
|
||||
|
||||
if self._log:
|
||||
self._log.debug(f"Dispatched event {data['type']} to player.")
|
||||
|
||||
|
|
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def _swap_node(self, *, new_node: Node) -> None:
|
||||
if self.current:
|
||||
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
|
||||
data: dict = TrackUpdatePayload(
|
||||
encoded_track=self.current.track_id,
|
||||
position=self.position,
|
||||
).model_dump()
|
||||
|
||||
del self._node._players[self._guild.id]
|
||||
self._node = new_node
|
||||
|
|
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def set_volume(self, volume: int) -> int:
|
||||
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
|
||||
if volume < 0 or volume > 500:
|
||||
raise ValueError("Volume must be between 0 and 500")
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ from typing import List
|
|||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from .enums import LoopMode
|
||||
from .exceptions import QueueEmpty
|
||||
from .exceptions import QueueException
|
||||
from .exceptions import QueueFull
|
||||
from .objects import Track
|
||||
from pomice.enums import LoopMode
|
||||
from pomice.exceptions import QueueEmpty
|
||||
from pomice.exceptions import QueueException
|
||||
from pomice.exceptions import QueueFull
|
||||
from pomice.models.music import Track
|
||||
|
||||
__all__ = ("Queue",)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|||
if TYPE_CHECKING:
|
||||
from .pool import Node
|
||||
|
||||
from .utils import RouteStats
|
||||
from pomice.utils import RouteStats
|
||||
|
||||
__all__ = ("RoutePlanner",)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from typing import Dict
|
|||
from typing import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from .enums import RouteIPType
|
||||
from .enums import RouteStrategy
|
||||
from pomice.enums import RouteIPType
|
||||
from pomice.enums import RouteStrategy
|
||||
|
||||
__all__ = (
|
||||
"ExponentialBackoff",
|
||||
|
|
|
|||
Loading…
Reference in New Issue