add node fallback, wrap up 2.2

This commit is contained in:
cloudwithax 2023-03-10 23:44:51 -05:00
parent 9e0a5e0ad0
commit 2ded9d6205
2 changed files with 59 additions and 25 deletions

View File

@ -100,24 +100,6 @@ class Player(VoiceProtocol):
``` ```
""" """
__slots__ = (
'client',
'_bot',
'channel',
'_guild',
'_node',
'_current',
'_filters',
'_volume',
'_paused',
'_is_connected',
'_position',
'_last_position',
'_last_update',
'_ending_track',
'_player_endpoint_uri'
)
def __call__(self, client: Client, channel: VoiceChannel): def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client self.client: Client = client
self.channel: VoiceChannel = channel self.channel: VoiceChannel = channel
@ -132,6 +114,26 @@ class Player(VoiceProtocol):
*, *,
node: Node = None node: Node = None
): ):
__slots__ = (
'client',
'channel',
'_bot',
'_guild',
'_node',
'_current',
'_filters',
'_volume',
'_paused',
'_is_connected',
'_position',
'_last_position',
'_last_update',
'_ending_track',
'_voice_state',
'_player_endpoint_uri',
'__dict__'
)
self.client: Optional[Client] = client self.client: Optional[Client] = client
self.channel: Optional[VoiceChannel] = channel self.channel: Optional[VoiceChannel] = channel
@ -236,14 +238,16 @@ class Player(VoiceProtocol):
self._is_connected = state.get("connected") self._is_connected = state.get("connected")
self._last_position = state.get("position") self._last_position = state.get("position")
async def _dispatch_voice_update(self, voice_data: Dict[str, Any]): async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None):
if {"sessionId", "event"} != self._voice_state.keys(): if {"sessionId", "event"} != self._voice_state.keys():
return return
state = voice_data or self._voice_state
data = { data = {
"token": voice_data['event']['token'], "token": state['event']['token'],
"endpoint": voice_data['event']['endpoint'], "endpoint": state['event']['endpoint'],
"sessionId": voice_data['sessionId'], "sessionId": state['sessionId'],
} }
await self._node.send( await self._node.send(
@ -284,6 +288,26 @@ class Player(VoiceProtocol):
if isinstance(event, TrackStartEvent): if isinstance(event, TrackStartEvent):
self._ending_track = self._current self._ending_track = self._current
async def _swap_node(self, *, new_node: Node):
data: dict = {
'encodedTrack': self.current.track_id,
'position': self.position,
}
del self._node._players[self._guild.id]
self._node = new_node
self._node._players[self._guild.id] = self
# reassign uri to update session id
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
await self._dispatch_voice_update()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data,
)
async def get_tracks( async def get_tracks(
self, self,
query: str, query: str,

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
import random import random
import re import re
import logging
import aiohttp import aiohttp
from discord import Client from discord import Client
@ -89,7 +88,7 @@ class Node:
"_spotify_client", "_spotify_client",
"_apple_music_client" "_apple_music_client"
) )
self._bot: Union[Client, commands.Bot] = bot self._bot: Union[Client, commands.Bot] = bot
self._host: str = host self._host: str = host
self._port: int = port self._port: int = port
@ -145,7 +144,7 @@ class Node:
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
""""Property which returns whether this node is connected or not""" """Property which returns whether this node is connected or not"""
return self._websocket is not None and not self._websocket.closed return self._websocket is not None and not self._websocket.closed
@ -210,6 +209,15 @@ class Node:
await player.on_voice_state_update(data["d"]) await player.on_voice_state_update(data["d"])
except KeyError: except KeyError:
return return
async def _handle_node_switch(self):
nodes = [node for node in self._pool._nodes.values() if node.is_connected]
new_node = random.choice(nodes)
for player in self._players.values():
await player._swap_node(new_node=new_node)
await self.disconnect()
async def _listen(self): async def _listen(self):
backoff = ExponentialBackoff(base=7) backoff = ExponentialBackoff(base=7)
@ -217,6 +225,8 @@ class Node:
while True: while True:
msg = await self._websocket.receive() msg = await self._websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
if self._fallback:
await self._handle_node_switch()
retry = backoff.delay() retry = backoff.delay()
await asyncio.sleep(retry) await asyncio.sleep(retry)
if not self.is_connected: if not self.is_connected: