diff --git a/pomice/player.py b/pomice/player.py index d73ea60..8314d91 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -334,6 +334,9 @@ class Player(VoiceProtocol): self._log.debug(f"Dispatched event {data['type']} to player.") + async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None: + self._player_endpoint_uri = f"sessions/{session_id}/players" + async def _swap_node(self, *, new_node: Node) -> None: if self.current: data: dict = {"position": self.position, "encodedTrack": self.current.track_id} @@ -342,8 +345,7 @@ class Player(VoiceProtocol): self._node = new_node self._node._players[self._guild.id] = self # reassign uri to update session id - self._player_endpoint_uri = f"sessions/{self._node._session_id}/players" - + await self._refresh_endpoint_uri(new_node._session_id) await self._dispatch_voice_update() await self._node.send( method="PATCH", diff --git a/pomice/pool.py b/pomice/pool.py index 225eb65..6a0a9d2 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -319,7 +319,7 @@ class Node: retry = backoff.delay() await asyncio.sleep(retry) if not self.is_connected: - self._loop.create_task(self.connect()) + self._loop.create_task(self.connect(reconnect=True)) else: self._loop.create_task(self._handle_payload(msg.json())) @@ -379,27 +379,29 @@ class Node: headers=self._headers, json=data or {}, ) as resp: - self._log.debug(f"Making REST request with method {method} to {uri}") + self._log.debug( + f"Making REST request to Node {self._identifier} with method {method} to {uri}", + ) if resp.status >= 300: resp_data: dict = await resp.json() raise NodeRestException( - f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}', + f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}', ) if method == "DELETE" or resp.status == 204: self._log.debug( - f"REST request with method {method} to {uri} completed sucessfully and returned no data.", + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.", ) return await resp.json(content_type=None) if resp.content_type == "text/plain": self._log.debug( - f"REST request with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", ) return await resp.text() self._log.debug( - f"REST request with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", ) return await resp.json() @@ -407,7 +409,7 @@ class Node: """Takes a guild ID as a parameter. Returns a pomice Player object or None.""" return self._players.get(guild_id, None) - async def connect(self) -> "Node": + async def connect(self, *, reconnect: bool = False) -> "Node": """Initiates a connection with a Lavalink node and adds it to the node pool.""" await self._bot.wait_until_ready() @@ -417,16 +419,19 @@ class Node: self._session = aiohttp.ClientSession() try: - version: str = await self.send( - method="GET", - path="version", - ignore_if_available=True, - include_version=False, - ) + if not reconnect: + version: str = await self.send( + method="GET", + path="version", + ignore_if_available=True, + include_version=False, + ) - await self._handle_version_check(version=version) + await self._handle_version_check(version=version) - self._log.debug(f"Version check from node successful. Returned version {version}") + self._log.debug( + f"Version check from Node {self._identifier} successful. Returned version {version}", + ) self._websocket = await self._session.ws_connect( f"{self._websocket_uri}/v{self._version.major}/websocket", @@ -434,8 +439,12 @@ class Node: heartbeat=self._heartbeat, ) + if reconnect: + for player in self.players.values(): + await player._refresh_endpoint_uri(self._session_id) + self._log.debug( - f"Connected to node websocket using {self._websocket_uri}/v{self._version.major}/websocket", + f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket", ) if not self._task: