diff --git a/pomice/player.py b/pomice/player.py index fd24a6f..2891a13 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -100,24 +100,6 @@ class Player(VoiceProtocol): ``` """ - __slots__ = ( - 'client', - '_bot', - 'channel', - '_guild', - '_node', - '_current', - '_filters', - '_volume', - '_paused', - '_is_connected', - '_position', - '_last_position', - '_last_update', - '_ending_track', - '_player_endpoint_uri' - ) - def __call__(self, client: Client, channel: VoiceChannel): self.client: Client = client self.channel: VoiceChannel = channel @@ -132,6 +114,26 @@ class Player(VoiceProtocol): *, node: Node = None ): + __slots__ = ( + 'client', + 'channel', + '_bot', + '_guild', + '_node', + '_current', + '_filters', + '_volume', + '_paused', + '_is_connected', + '_position', + '_last_position', + '_last_update', + '_ending_track', + '_voice_state', + '_player_endpoint_uri', + '__dict__' + ) + self.client: Optional[Client] = client self.channel: Optional[VoiceChannel] = channel @@ -236,14 +238,16 @@ class Player(VoiceProtocol): self._is_connected = state.get("connected") self._last_position = state.get("position") - async def _dispatch_voice_update(self, voice_data: Dict[str, Any]): + async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None): if {"sessionId", "event"} != self._voice_state.keys(): return + state = voice_data or self._voice_state + data = { - "token": voice_data['event']['token'], - "endpoint": voice_data['event']['endpoint'], - "sessionId": voice_data['sessionId'], + "token": state['event']['token'], + "endpoint": state['event']['endpoint'], + "sessionId": state['sessionId'], } await self._node.send( @@ -284,6 +288,26 @@ class Player(VoiceProtocol): if isinstance(event, TrackStartEvent): self._ending_track = self._current + async def _swap_node(self, *, new_node: Node): + data: dict = { + 'encodedTrack': self.current.track_id, + 'position': self.position, + } + + del self._node._players[self._guild.id] + self._node = new_node + self._node._players[self._guild.id] = self + # reassign uri to update session id + self._player_endpoint_uri = f'sessions/{self._node._session_id}/players' + + await self._dispatch_voice_update() + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data=data, + ) + async def get_tracks( self, query: str, diff --git a/pomice/pool.py b/pomice/pool.py index 16d094a..9438fdf 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import random import re -import logging import aiohttp from discord import Client @@ -89,7 +88,7 @@ class Node: "_spotify_client", "_apple_music_client" ) - + self._bot: Union[Client, commands.Bot] = bot self._host: str = host self._port: int = port @@ -145,7 +144,7 @@ class Node: @property def is_connected(self) -> bool: - """"Property which returns whether this node is connected or not""" + """Property which returns whether this node is connected or not""" return self._websocket is not None and not self._websocket.closed @@ -210,6 +209,15 @@ class Node: await player.on_voice_state_update(data["d"]) except KeyError: return + + async def _handle_node_switch(self): + nodes = [node for node in self._pool._nodes.values() if node.is_connected] + new_node = random.choice(nodes) + + for player in self._players.values(): + await player._swap_node(new_node=new_node) + + await self.disconnect() async def _listen(self): backoff = ExponentialBackoff(base=7) @@ -217,6 +225,8 @@ class Node: while True: msg = await self._websocket.receive() if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): + if self._fallback: + await self._handle_node_switch() retry = backoff.delay() await asyncio.sleep(retry) if not self.is_connected: