event rewrite 2

This commit is contained in:
VP 2021-10-20 14:36:01 +03:00
parent 7d53934697
commit 584f6e5286
2 changed files with 75 additions and 83 deletions

View File

@ -1,8 +1,5 @@
import asyncio
from pomice import exceptions
from .pool import NodePool
from .utils import ClientType
class PomiceEvent:
"""The base class for all events dispatched by a node.
@ -14,20 +11,24 @@ class PomiceEvent:
```
"""
name = "event"
handler_args = ()
def dispatch(self, bot: ClientType):
bot.dispatch(f"pomice_{self.name}", *self.handler_args)
class TrackStartEvent(PomiceEvent):
"""Fired when a track has successfully started.
Returns the player associated with the event and the pomice.Track object.
"""
name = "track_start"
def __init__(self, player, track):
super().__init__()
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
self.track = self.player._current
self.name = "track_start"
self.player = player
self.track = track
# on_pomice_track_start(player, track)
self.handler_args = self.player, self.track
def __repr__(self) -> str:
return f"<Pomice.TrackStartEvent player={self.player} track_id={self.track.track_id}>"
@ -37,14 +38,15 @@ class TrackEndEvent(PomiceEvent):
"""Fired when a track has successfully ended.
Returns the player associated with the event along with the pomice.Track object and reason.
"""
name = "track_end"
def __init__(self, player, track, reason):
super().__init__()
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
self.track = self.player._ending_track
self.reason: str = data["reason"]
self.name = "track_end"
self.player = player
self.track = track
self.reason = reason
# on_pomice_track_end(player, track, reason)
self.handler_args = self.player, self.track, self.reason
def __repr__(self) -> str:
return f"<Pomice.TrackEndEvent player={self.player} track_id={self.track.track_id} reason={self.reason}>"
@ -54,66 +56,78 @@ class TrackStuckEvent(PomiceEvent):
"""Fired when a track is stuck and cannot be played. Returns the player
associated with the event along with the pomice.Track object to be further parsed by the end user.
"""
name = "track_stuck"
def __init__(self, player, track, threshold):
super().__init__()
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
self.track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
self.name = "track_stuck"
self.player = player
self.track = track
self.threshold = threshold
# on_pomice_track_stuck(player, track, threshold)
self.handler_args = self.player, self.track, self.threshold
def __repr__(self) -> str:
return f"<Pomice.TrackStuckEvent player={self.player} track_id={self.track.track_id} threshold={self.threshold}>"
return f"<Pomice.TrackStuckEvent player={self.player!r} track={self.track!r} " \
f"threshold={self.threshold!r}>"
class TrackExceptionEvent(PomiceEvent):
"""Fired when a track error has occured.
Returns the player associated with the event along with the error code and exception.
"""
name = "track_exception"
def __init__(self, player, track, error):
super().__init__()
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
self.track = self.player._ending_track
self.error: str = data["error"]
self.name = "track_exception"
self.player = player
self.track = track
self.error = error
# on_pomice_track_exception(player, track, error)
self.handler_args = self.player, self.track, self.error
def __repr__(self) -> str:
return f"<Pomice.TrackExceptionEvent player={self.player} error={self.error} exeception={self.exception}>"
return f"<Pomice.TrackExceptionEvent player={self.player!r} error={self.error!r}>"
class WebSocketClosedPayload:
def __init__(self, data: dict):
self.guild = NodePool.get_node().get_player(int(data["guildId"]))._guild
self.code: int = data["code"]
self.reason: str = data["code"]
self.by_remote: bool = data["byRemote"]
def __repr__(self) -> str:
return f"<Pomice.WebSocketClosedPayload guild={self.guild!r} code={self.code!r} " \
f"reason={self.reason!r} by_remote={self.by_remote!r}>"
class WebSocketClosedEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been closed.
Returns the reason and the error code.
"""
name = "websocket_closed"
def __init__(self, guild, reason, code):
super().__init__()
def __init__(self, data: dict):
self.payload = WebSocketClosedPayload(data)
self.name = "websocket_closed"
self.guild = guild
self.reason = reason
self.code = code
# on_pomice_websocket_closed(payload)
self.handler_args = self.payload,
def __repr__(self) -> str:
return f"<Pomice.WebsocketClosedEvent guild_id={self.guild.id} reason={self.reason} code={self.code}>"
return f"<Pomice.WebsocketClosedEvent payload={self.payload!r}>"
class WebSocketOpenEvent(PomiceEvent):
"""Fired when a websocket connection to a node has been initiated.
Returns the target and the session SSRC.
"""
name = "websocket_open"
def __init__(self, target, ssrc):
super().__init__()
def __init__(self, data: dict):
self.target: str = data["target"]
self.ssrc: int = data["ssrc"]
self.name = "websocket_open"
self.target: str = target
self.ssrc: int = ssrc
# on_pomice_websocket_open(target, ssrc)
self.handler_args = self.target, self.ssrc
def __repr__(self) -> str:
return f"<Pomice.WebsocketOpenEvent target={self.target} ssrc={self.ssrc}>"
return f"<Pomice.WebsocketOpenEvent target={self.target!r} ssrc={self.ssrc!r}>"

View File

@ -2,26 +2,24 @@ import time
from typing import (
Any,
Dict,
Optional,
Type,
Union
Optional
)
from discord import (
Client,
Guild,
VoiceChannel,
VoiceProtocol
)
from discord.ext import commands
from pomice.enums import SearchType
from . import events
from .enums import SearchType
from .events import PomiceEvent, TrackStartEvent
from .exceptions import TrackInvalidPosition
from .filters import Filter
from .pool import Node, NodePool
from .objects import Track
from .pool import Node, NodePool
from .utils import ClientType
class Player(VoiceProtocol):
@ -32,7 +30,7 @@ class Player(VoiceProtocol):
```
"""
def __init__(self, client: Type[Client], channel: VoiceChannel):
def __init__(self, client: ClientType, channel: VoiceChannel):
super().__init__(client=client, channel=channel)
self.client = client
@ -50,6 +48,7 @@ class Player(VoiceProtocol):
self._position = 0
self._last_position = 0
self._last_update = 0
self._ending_track: Optional[Track] = None
self._voice_state = {}
@ -118,7 +117,7 @@ class Player(VoiceProtocol):
return self._filter
@property
def bot(self) -> Type[Client]:
def bot(self) -> ClientType:
"""Property which returns the bot associated with this player instance"""
return self._bot
@ -159,32 +158,11 @@ class Player(VoiceProtocol):
async def _dispatch_event(self, data: dict):
event_type = data.get("type")
event: PomiceEvent = getattr(events, event_type)(data)
event.dispatch(self._bot)
if event_type == "TrackStartEvent":
track = await self._node.build_track(data["track"])
event = events.TrackStartEvent(self, track)
self.dispatch(event, self, track)
elif event_type == "TrackEndEvent":
track = await self._node.build_track(data["track"])
event = events.TrackEndEvent(self, track, data["reason"])
self.dispatch(event, self, track, data["reason"])
elif event_type == "TrackExceptionEvent":
track = await self._node.build_track(data["track"])
event = events.TrackExceptionEvent(self, track, data["error"])
self.dispatch(event, self, track, data["error"])
elif event_type == "TrackStuckEvent":
track = await self._node.build_track(data["track"])
event = events.TrackStuckEvent(self, track, data["thresholdMs"])
self.dispatch(event, self, track, data["thresholdMs"])
elif event_type == "WebSocketOpenEvent":
event = events.WebSocketOpenEvent(data["target"], data["ssrc"])
self.dispatch(event, data["target"], data["ssrc"])
elif event_type == "WebSocketClosedEvent":
event = events.WebSocketClosedEvent(self._guild, data["reason"], data["code"])
self.dispatch(event, self._guild, data["reason"], data["code"])
def dispatch(self, event, *args, **kwargs):
self.bot.dispatch(f"pomice_{event.name}", event, *args, **kwargs)
if isinstance(event, TrackStartEvent):
self._ending_track = self._current
async def get_tracks(
self,