diff --git a/pomice/events.py b/pomice/events.py index 4102871..90cd4ed 100644 --- a/pomice/events.py +++ b/pomice/events.py @@ -1,8 +1,5 @@ -import asyncio - -from pomice import exceptions from .pool import NodePool - +from .utils import ClientType class PomiceEvent: """The base class for all events dispatched by a node. @@ -14,20 +11,24 @@ class PomiceEvent: ``` """ name = "event" - + handler_args = () + + def dispatch(self, bot: ClientType): + bot.dispatch(f"pomice_{self.name}", *self.handler_args) class TrackStartEvent(PomiceEvent): """Fired when a track has successfully started. Returns the player associated with the event and the pomice.Track object. """ + name = "track_start" - def __init__(self, player, track): - super().__init__() + def __init__(self, data: dict): + self.player = NodePool.get_node().get_player(int(data["guildId"])) + self.track = self.player._current - self.name = "track_start" - self.player = player - self.track = track + # on_pomice_track_start(player, track) + self.handler_args = self.player, self.track def __repr__(self) -> str: return f"" @@ -37,14 +38,15 @@ class TrackEndEvent(PomiceEvent): """Fired when a track has successfully ended. Returns the player associated with the event along with the pomice.Track object and reason. """ + name = "track_end" - def __init__(self, player, track, reason): - super().__init__() + def __init__(self, data: dict): + self.player = NodePool.get_node().get_player(int(data["guildId"])) + self.track = self.player._ending_track + self.reason: str = data["reason"] - self.name = "track_end" - self.player = player - self.track = track - self.reason = reason + # on_pomice_track_end(player, track, reason) + self.handler_args = self.player, self.track, self.reason def __repr__(self) -> str: return f"" @@ -54,66 +56,78 @@ class TrackStuckEvent(PomiceEvent): """Fired when a track is stuck and cannot be played. Returns the player associated with the event along with the pomice.Track object to be further parsed by the end user. """ + name = "track_stuck" - def __init__(self, player, track, threshold): - super().__init__() + def __init__(self, data: dict): + self.player = NodePool.get_node().get_player(int(data["guildId"])) + self.track = self.player._ending_track + self.threshold: float = data["thresholdMs"] - self.name = "track_stuck" - self.player = player - - self.track = track - self.threshold = threshold + # on_pomice_track_stuck(player, track, threshold) + self.handler_args = self.player, self.track, self.threshold def __repr__(self) -> str: - return f"" + return f"" class TrackExceptionEvent(PomiceEvent): """Fired when a track error has occured. Returns the player associated with the event along with the error code and exception. """ + name = "track_exception" - def __init__(self, player, track, error): - super().__init__() + def __init__(self, data: dict): + self.player = NodePool.get_node().get_player(int(data["guildId"])) + self.track = self.player._ending_track + self.error: str = data["error"] - self.name = "track_exception" - self.player = player - self.track = track - self.error = error + # on_pomice_track_exception(player, track, error) + self.handler_args = self.player, self.track, self.error def __repr__(self) -> str: - return f"" + return f"" +class WebSocketClosedPayload: + def __init__(self, data: dict): + self.guild = NodePool.get_node().get_player(int(data["guildId"]))._guild + self.code: int = data["code"] + self.reason: str = data["code"] + self.by_remote: bool = data["byRemote"] + + def __repr__(self) -> str: + return f"" + class WebSocketClosedEvent(PomiceEvent): """Fired when a websocket connection to a node has been closed. Returns the reason and the error code. """ + name = "websocket_closed" - def __init__(self, guild, reason, code): - super().__init__() + def __init__(self, data: dict): + self.payload = WebSocketClosedPayload(data) - self.name = "websocket_closed" - self.guild = guild - self.reason = reason - self.code = code + # on_pomice_websocket_closed(payload) + self.handler_args = self.payload, def __repr__(self) -> str: - return f"" + return f"" class WebSocketOpenEvent(PomiceEvent): """Fired when a websocket connection to a node has been initiated. Returns the target and the session SSRC. """ + name = "websocket_open" - def __init__(self, target, ssrc): - super().__init__() + def __init__(self, data: dict): + self.target: str = data["target"] + self.ssrc: int = data["ssrc"] - self.name = "websocket_open" - - self.target: str = target - self.ssrc: int = ssrc + # on_pomice_websocket_open(target, ssrc) + self.handler_args = self.target, self.ssrc def __repr__(self) -> str: - return f"" + return f"" diff --git a/pomice/player.py b/pomice/player.py index 566f2b2..b5c40fc 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -1,27 +1,25 @@ import time from typing import ( - Any, - Dict, - Optional, - Type, - Union + Any, + Dict, + Optional ) from discord import ( - Client, - Guild, - VoiceChannel, + Guild, + VoiceChannel, VoiceProtocol ) from discord.ext import commands -from pomice.enums import SearchType - from . import events +from .enums import SearchType +from .events import PomiceEvent, TrackStartEvent from .exceptions import TrackInvalidPosition from .filters import Filter -from .pool import Node, NodePool from .objects import Track +from .pool import Node, NodePool +from .utils import ClientType class Player(VoiceProtocol): @@ -32,7 +30,7 @@ class Player(VoiceProtocol): ``` """ - def __init__(self, client: Type[Client], channel: VoiceChannel): + def __init__(self, client: ClientType, channel: VoiceChannel): super().__init__(client=client, channel=channel) self.client = client @@ -50,6 +48,7 @@ class Player(VoiceProtocol): self._position = 0 self._last_position = 0 self._last_update = 0 + self._ending_track: Optional[Track] = None self._voice_state = {} @@ -118,7 +117,7 @@ class Player(VoiceProtocol): return self._filter @property - def bot(self) -> Type[Client]: + def bot(self) -> ClientType: """Property which returns the bot associated with this player instance""" return self._bot @@ -159,32 +158,11 @@ class Player(VoiceProtocol): async def _dispatch_event(self, data: dict): event_type = data.get("type") + event: PomiceEvent = getattr(events, event_type)(data) + event.dispatch(self._bot) - if event_type == "TrackStartEvent": - track = await self._node.build_track(data["track"]) - event = events.TrackStartEvent(self, track) - self.dispatch(event, self, track) - elif event_type == "TrackEndEvent": - track = await self._node.build_track(data["track"]) - event = events.TrackEndEvent(self, track, data["reason"]) - self.dispatch(event, self, track, data["reason"]) - elif event_type == "TrackExceptionEvent": - track = await self._node.build_track(data["track"]) - event = events.TrackExceptionEvent(self, track, data["error"]) - self.dispatch(event, self, track, data["error"]) - elif event_type == "TrackStuckEvent": - track = await self._node.build_track(data["track"]) - event = events.TrackStuckEvent(self, track, data["thresholdMs"]) - self.dispatch(event, self, track, data["thresholdMs"]) - elif event_type == "WebSocketOpenEvent": - event = events.WebSocketOpenEvent(data["target"], data["ssrc"]) - self.dispatch(event, data["target"], data["ssrc"]) - elif event_type == "WebSocketClosedEvent": - event = events.WebSocketClosedEvent(self._guild, data["reason"], data["code"]) - self.dispatch(event, self._guild, data["reason"], data["code"]) - - def dispatch(self, event, *args, **kwargs): - self.bot.dispatch(f"pomice_{event.name}", event, *args, **kwargs) + if isinstance(event, TrackStartEvent): + self._ending_track = self._current async def get_tracks( self,