feat: add more payloads
This commit is contained in:
parent
bf92e37da6
commit
bc19e07008
|
|
@ -9,3 +9,15 @@ from .version import *
|
||||||
|
|
||||||
class BaseModel(pydantic.BaseModel):
|
class BaseModel(pydantic.BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs) -> dict:
|
||||||
|
by_alias = kwargs.pop("by_alias", True)
|
||||||
|
mode = kwargs.pop("mode", "json")
|
||||||
|
return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode)
|
||||||
|
|
||||||
|
|
||||||
|
class VersionedModel(BaseModel):
|
||||||
|
version: LavalinkVersionType
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs) -> dict:
|
||||||
|
return super().model_dump(*args, **kwargs, exclude={"version"})
|
||||||
|
|
|
||||||
|
|
@ -1,39 +1,66 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import AliasPath
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic import model_validator
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from pomice.models import BaseModel
|
from pomice.models import BaseModel
|
||||||
from pomice.models import LavalinkVersion
|
|
||||||
from pomice.models import LavalinkVersion3Type
|
from pomice.models import LavalinkVersion3Type
|
||||||
from pomice.models import LavalinkVersion4Type
|
from pomice.models import LavalinkVersion4Type
|
||||||
|
from pomice.models import VersionedModel
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"ResumePayload",
|
"VoiceUpdatePayload",
|
||||||
"ResumePayloadV3",
|
"TrackStartPayload",
|
||||||
"ResumePayloadV4",
|
"TrackUpdatePayload",
|
||||||
"ResumePayloadType",
|
"ResumePayloadType",
|
||||||
"ResumePayloadTypeAdapter",
|
"ResumePayloadTypeAdapter",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResumePayload(BaseModel):
|
class VoiceUpdatePayload(BaseModel):
|
||||||
version: LavalinkVersion
|
token: str = Field(validation_alias=AliasPath("event", "token"))
|
||||||
|
endpoint: str = Field(validation_alias=AliasPath("event", "endpoint"))
|
||||||
|
session_id: str = Field(alias="sessionId")
|
||||||
|
|
||||||
|
|
||||||
|
class TrackUpdatePayload(BaseModel):
|
||||||
|
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||||
|
position: float
|
||||||
|
|
||||||
|
|
||||||
|
class TrackStartPayload(VersionedModel):
|
||||||
|
encoded_track: Optional[str] = Field(default=None, alias="encodedTrack")
|
||||||
|
position: float
|
||||||
|
end_time: str = Field(default="0", alias="endTime")
|
||||||
|
|
||||||
|
@field_validator("end_time", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def cast_end_time(cls, value: object) -> str:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def adjust_end_time(self) -> TrackStartPayload:
|
||||||
|
if self.version >= LavalinkVersion3Type(3, 7, 5):
|
||||||
|
self.end_time = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResumePayload(VersionedModel):
|
||||||
timeout: int
|
timeout: int
|
||||||
|
|
||||||
def model_dump(self) -> dict:
|
|
||||||
return super().model_dump(by_alias=True, exclude={"version"})
|
|
||||||
|
|
||||||
|
class ResumePayloadV3(ResumePayload):
|
||||||
class ResumePayloadV3(BaseModel):
|
|
||||||
version: LavalinkVersion3Type
|
version: LavalinkVersion3Type
|
||||||
timeout: int
|
|
||||||
resuming_key: str = Field(alias="resumingKey")
|
resuming_key: str = Field(alias="resumingKey")
|
||||||
|
|
||||||
|
|
||||||
class ResumePayloadV4(BaseModel):
|
class ResumePayloadV4(ResumePayload):
|
||||||
version: LavalinkVersion4Type
|
version: LavalinkVersion4Type
|
||||||
timeout: int
|
|
||||||
resuming: bool = True
|
resuming: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,167 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from discord import ClientUser
|
|
||||||
from discord import Member
|
|
||||||
from discord import User
|
|
||||||
from discord.ext import commands
|
|
||||||
|
|
||||||
from .enums import PlaylistType
|
|
||||||
from .enums import SearchType
|
|
||||||
from .enums import TrackType
|
|
||||||
from .filters import Filter
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Track",
|
|
||||||
"Playlist",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Track:
|
|
||||||
"""The base track object. Returns critical track information needed for parsing by Lavalink.
|
|
||||||
You can also pass in commands.Context to get a discord.py Context object in your track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = (
|
|
||||||
"track_id",
|
|
||||||
"info",
|
|
||||||
"track_type",
|
|
||||||
"filters",
|
|
||||||
"timestamp",
|
|
||||||
"original",
|
|
||||||
"_search_type",
|
|
||||||
"playlist",
|
|
||||||
"title",
|
|
||||||
"author",
|
|
||||||
"uri",
|
|
||||||
"identifier",
|
|
||||||
"isrc",
|
|
||||||
"thumbnail",
|
|
||||||
"length",
|
|
||||||
"ctx",
|
|
||||||
"requester",
|
|
||||||
"is_stream",
|
|
||||||
"is_seekable",
|
|
||||||
"position",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
track_id: str,
|
|
||||||
info: dict,
|
|
||||||
ctx: Optional[commands.Context] = None,
|
|
||||||
track_type: TrackType,
|
|
||||||
search_type: SearchType = SearchType.YTSEARCH,
|
|
||||||
filters: Optional[List[Filter]] = None,
|
|
||||||
timestamp: Optional[float] = None,
|
|
||||||
requester: Optional[Union[Member, User, ClientUser]] = None,
|
|
||||||
):
|
|
||||||
self.track_id: str = track_id
|
|
||||||
self.info: dict = info
|
|
||||||
self.track_type: TrackType = track_type
|
|
||||||
self.filters: Optional[List[Filter]] = filters
|
|
||||||
self.timestamp: Optional[float] = timestamp
|
|
||||||
|
|
||||||
if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC:
|
|
||||||
self.original: Optional[Track] = None
|
|
||||||
else:
|
|
||||||
self.original = self
|
|
||||||
self._search_type: SearchType = search_type
|
|
||||||
|
|
||||||
self.playlist: Optional[Playlist] = None
|
|
||||||
|
|
||||||
self.title: str = info.get("title", "Unknown Title")
|
|
||||||
self.author: str = info.get("author", "Unknown Author")
|
|
||||||
self.uri: str = info.get("uri", "")
|
|
||||||
self.identifier: str = info.get("identifier", "")
|
|
||||||
self.isrc: Optional[str] = info.get("isrc", None)
|
|
||||||
self.thumbnail: Optional[str] = info.get("thumbnail")
|
|
||||||
|
|
||||||
if self.uri and self.track_type is TrackType.YOUTUBE:
|
|
||||||
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
|
|
||||||
|
|
||||||
self.length: int = info.get("length", 0)
|
|
||||||
self.is_stream: bool = info.get("isStream", False)
|
|
||||||
self.is_seekable: bool = info.get("isSeekable", False)
|
|
||||||
self.position: int = info.get("position", 0)
|
|
||||||
|
|
||||||
self.ctx: Optional[commands.Context] = ctx
|
|
||||||
self.requester: Optional[Union[Member, User, ClientUser]] = requester
|
|
||||||
if not self.requester and self.ctx:
|
|
||||||
self.requester = self.ctx.author
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, Track):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return other.track_id == self.track_id
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.title
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
|
|
||||||
|
|
||||||
|
|
||||||
class Playlist:
|
|
||||||
"""The base playlist object.
|
|
||||||
Returns critical playlist information needed for parsing by Lavalink.
|
|
||||||
You can also pass in commands.Context to get a discord.py Context object in your tracks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = (
|
|
||||||
"playlist_info",
|
|
||||||
"tracks",
|
|
||||||
"name",
|
|
||||||
"playlist_type",
|
|
||||||
"_thumbnail",
|
|
||||||
"_uri",
|
|
||||||
"selected_track",
|
|
||||||
"track_count",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
playlist_info: dict,
|
|
||||||
tracks: list,
|
|
||||||
playlist_type: PlaylistType,
|
|
||||||
thumbnail: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.playlist_info: dict = playlist_info
|
|
||||||
self.tracks: List[Track] = tracks
|
|
||||||
self.name: str = playlist_info.get("name", "Unknown Playlist")
|
|
||||||
self.playlist_type: PlaylistType = playlist_type
|
|
||||||
|
|
||||||
self._thumbnail: Optional[str] = thumbnail
|
|
||||||
self._uri: Optional[str] = uri
|
|
||||||
|
|
||||||
for track in self.tracks:
|
|
||||||
track.playlist = self
|
|
||||||
|
|
||||||
self.selected_track: Optional[Track] = None
|
|
||||||
if (index := playlist_info.get("selectedTrack", -1)) != -1:
|
|
||||||
self.selected_track = self.tracks[index]
|
|
||||||
|
|
||||||
self.track_count: int = len(self.tracks)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def uri(self) -> Optional[str]:
|
|
||||||
"""Returns either an Apple Music/Spotify URL/URI, or None if its neither of those."""
|
|
||||||
return self._uri
|
|
||||||
|
|
||||||
@property
|
|
||||||
def thumbnail(self) -> Optional[str]:
|
|
||||||
"""Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those."""
|
|
||||||
return self._thumbnail
|
|
||||||
|
|
@ -14,23 +14,24 @@ from discord import VoiceChannel
|
||||||
from discord import VoiceProtocol
|
from discord import VoiceProtocol
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from . import events
|
from pomice import events
|
||||||
from .enums import SearchType
|
from pomice.enums import SearchType
|
||||||
from .exceptions import FilterInvalidArgument
|
from pomice.exceptions import FilterInvalidArgument
|
||||||
from .exceptions import FilterTagAlreadyInUse
|
from pomice.exceptions import FilterTagAlreadyInUse
|
||||||
from .exceptions import FilterTagInvalid
|
from pomice.exceptions import FilterTagInvalid
|
||||||
from .exceptions import TrackInvalidPosition
|
from pomice.exceptions import TrackInvalidPosition
|
||||||
from .exceptions import TrackLoadError
|
from pomice.exceptions import TrackLoadError
|
||||||
from .filters import Filter
|
from pomice.filters import Filter
|
||||||
from .filters import Timescale
|
from pomice.filters import Timescale
|
||||||
from .objects import Playlist
|
|
||||||
from .objects import Track
|
|
||||||
from .pool import Node
|
|
||||||
from .pool import NodePool
|
|
||||||
from pomice.models.events import PomiceEvent
|
from pomice.models.events import PomiceEvent
|
||||||
from pomice.models.events import TrackEndEvent
|
from pomice.models.events import TrackEndEvent
|
||||||
from pomice.models.events import TrackStartEvent
|
from pomice.models.events import TrackStartEvent
|
||||||
from pomice.models.version import LavalinkVersion
|
from pomice.models.music import Playlist
|
||||||
|
from pomice.models.music import Track
|
||||||
|
from pomice.models.payloads import TrackUpdatePayload
|
||||||
|
from pomice.models.payloads import VoiceUpdatePayload
|
||||||
|
from pomice.pool import Node
|
||||||
|
from pomice.pool import NodePool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from discord.types.voice import VoiceServerUpdate
|
from discord.types.voice import VoiceServerUpdate
|
||||||
|
|
@ -287,12 +288,6 @@ class Player(VoiceProtocol):
|
||||||
"""
|
"""
|
||||||
return self.guild.id not in self._node._players
|
return self.guild.id not in self._node._players
|
||||||
|
|
||||||
def _adjust_end_time(self) -> Optional[str]:
|
|
||||||
if self._node._version >= LavalinkVersion(3, 7, 5):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "0"
|
|
||||||
|
|
||||||
async def _update_state(self, data: dict) -> None:
|
async def _update_state(self, data: dict) -> None:
|
||||||
state: dict = data.get("state", {})
|
state: dict = data.get("state", {})
|
||||||
self._last_update = int(state.get("time", 0))
|
self._last_update = int(state.get("time", 0))
|
||||||
|
|
@ -301,23 +296,18 @@ class Player(VoiceProtocol):
|
||||||
if self._log:
|
if self._log:
|
||||||
self._log.debug(f"Got player update state with data {state}")
|
self._log.debug(f"Got player update state with data {state}")
|
||||||
|
|
||||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None:
|
||||||
if {"sessionId", "event"} != self._voice_state.keys():
|
state = voice_data or self._voice_state
|
||||||
|
if {"sessionId", "event"} != state.keys():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = voice_data or self._voice_state
|
data = VoiceUpdatePayload.model_validate(state)
|
||||||
|
|
||||||
data = {
|
|
||||||
"token": state["event"]["token"],
|
|
||||||
"endpoint": state["event"]["endpoint"],
|
|
||||||
"sessionId": state["sessionId"],
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._node.send(
|
await self._node.send(
|
||||||
method="PATCH",
|
method="PATCH",
|
||||||
path=self._player_endpoint_uri,
|
path=self._player_endpoint_uri,
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"voice": data},
|
data={"voice": data.model_dump()},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
if self._log:
|
||||||
|
|
@ -327,30 +317,26 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||||
self._voice_state.update({"event": data})
|
self._voice_state.update({"event": data})
|
||||||
await self._dispatch_voice_update(self._voice_state)
|
await self._dispatch_voice_update()
|
||||||
|
|
||||||
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
|
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
|
||||||
self._voice_state.update({"sessionId": data.get("session_id")})
|
self._voice_state.update({"sessionId": data["session_id"]})
|
||||||
|
|
||||||
channel_id = data.get("channel_id")
|
|
||||||
if not channel_id:
|
|
||||||
await self.disconnect()
|
|
||||||
self._voice_state.clear()
|
|
||||||
return
|
|
||||||
|
|
||||||
|
channel_id = data["session_id"]
|
||||||
channel = self.guild.get_channel(int(channel_id))
|
channel = self.guild.get_channel(int(channel_id))
|
||||||
|
|
||||||
if self.channel != channel:
|
|
||||||
self.channel = channel
|
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
self._voice_state.clear()
|
self._voice_state.clear()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.channel != channel:
|
||||||
|
self.channel = channel
|
||||||
|
|
||||||
if not data.get("token"):
|
if not data.get("token"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._voice_state.update({"event": data})
|
||||||
await self._dispatch_voice_update({**self._voice_state, "event": data})
|
await self._dispatch_voice_update({**self._voice_state, "event": data})
|
||||||
|
|
||||||
async def _dispatch_event(self, data: dict) -> None:
|
async def _dispatch_event(self, data: dict) -> None:
|
||||||
|
|
@ -359,12 +345,11 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
if isinstance(event, TrackEndEvent) and event.reason != "replaced":
|
if isinstance(event, TrackEndEvent) and event.reason != "replaced":
|
||||||
self._current = None
|
self._current = None
|
||||||
|
|
||||||
event.dispatch(self._bot)
|
|
||||||
|
|
||||||
if isinstance(event, TrackStartEvent):
|
if isinstance(event, TrackStartEvent):
|
||||||
self._ending_track = self._current
|
self._ending_track = self._current
|
||||||
|
|
||||||
|
event.dispatch(self._bot)
|
||||||
|
|
||||||
if self._log:
|
if self._log:
|
||||||
self._log.debug(f"Dispatched event {data['type']} to player.")
|
self._log.debug(f"Dispatched event {data['type']} to player.")
|
||||||
|
|
||||||
|
|
@ -373,7 +358,10 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
async def _swap_node(self, *, new_node: Node) -> None:
|
async def _swap_node(self, *, new_node: Node) -> None:
|
||||||
if self.current:
|
if self.current:
|
||||||
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
|
data: dict = TrackUpdatePayload(
|
||||||
|
encoded_track=self.current.track_id,
|
||||||
|
position=self.position,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
del self._node._players[self._guild.id]
|
del self._node._players[self._guild.id]
|
||||||
self._node = new_node
|
self._node = new_node
|
||||||
|
|
@ -629,6 +617,9 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
async def set_volume(self, volume: int) -> int:
|
async def set_volume(self, volume: int) -> int:
|
||||||
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
|
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
|
||||||
|
if volume < 0 or volume > 500:
|
||||||
|
raise ValueError("Volume must be between 0 and 500")
|
||||||
|
|
||||||
await self._node.send(
|
await self._node.send(
|
||||||
method="PATCH",
|
method="PATCH",
|
||||||
path=self._player_endpoint_uri,
|
path=self._player_endpoint_uri,
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,11 @@ from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from .enums import LoopMode
|
from pomice.enums import LoopMode
|
||||||
from .exceptions import QueueEmpty
|
from pomice.exceptions import QueueEmpty
|
||||||
from .exceptions import QueueException
|
from pomice.exceptions import QueueException
|
||||||
from .exceptions import QueueFull
|
from pomice.exceptions import QueueFull
|
||||||
from .objects import Track
|
from pomice.models.music import Track
|
||||||
|
|
||||||
__all__ = ("Queue",)
|
__all__ = ("Queue",)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .pool import Node
|
from .pool import Node
|
||||||
|
|
||||||
from .utils import RouteStats
|
from pomice.utils import RouteStats
|
||||||
|
|
||||||
__all__ = ("RoutePlanner",)
|
__all__ = ("RoutePlanner",)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ from typing import Dict
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .enums import RouteIPType
|
from pomice.enums import RouteIPType
|
||||||
from .enums import RouteStrategy
|
from pomice.enums import RouteStrategy
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"ExponentialBackoff",
|
"ExponentialBackoff",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue