diff --git a/pomice/__init__.py b/pomice/__init__.py index c916d6b..2dc1826 100644 --- a/pomice/__init__.py +++ b/pomice/__init__.py @@ -20,7 +20,7 @@ if not discord.version_info.major >= 2: "using 'pip install discord.py'", ) -__version__ = "2.5.1" +__version__ = "2.6.0a" __title__ = "pomice" __author__ = "cloudwithax" __license__ = "GPL-3.0" diff --git a/pomice/player.py b/pomice/player.py index ef754d2..276a9df 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -434,11 +434,12 @@ class Player(VoiceProtocol): assert not self.is_connected and not self.channel self._node._players.pop(self.guild.id) - await self._node.send( - method="DELETE", - path=self._player_endpoint_uri, - guild_id=self._guild.id, - ) + if self.node.is_connected: + await self._node.send( + method="DELETE", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + ) self._log.debug("Player has been destroyed.") diff --git a/pomice/pool.py b/pomice/pool.py index 29d0031..515feda 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -17,9 +17,13 @@ from typing import Union from urllib.parse import quote import aiohttp +import orjson as json from discord import Client from discord.ext import commands from discord.utils import MISSING +from websockets import client +from websockets import exceptions +from websockets import typing as wstype from . import __version__ from . import applemusic @@ -71,6 +75,8 @@ class Node: "_password", "_identifier", "_heartbeat", + "_resume_key", + "_resume_timeout", "_secure", "_fallback", "_log_level", @@ -106,7 +112,9 @@ class Node: password: str, identifier: str, secure: bool = False, - heartbeat: int = 30, + heartbeat: int = 60, + resume_key: Optional[str] = None, + resume_timeout: int = 60, loop: Optional[asyncio.AbstractEventLoop] = None, session: Optional[aiohttp.ClientSession] = None, spotify_client_id: Optional[str] = None, @@ -126,6 +134,8 @@ class Node: self._password: str = password self._identifier: str = identifier self._heartbeat: int = heartbeat + self._resume_key: Optional[str] = resume_key + self._resume_timeout: int = resume_timeout self._secure: bool = secure self._fallback: bool = fallback self._log_level: LogLevel = log_level @@ -136,7 +146,7 @@ class Node: self._session: aiohttp.ClientSession = session # type: ignore self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() - self._websocket: aiohttp.ClientWebSocketResponse + self._websocket: client.WebSocketClientProtocol self._task: asyncio.Task = None # type: ignore self._session_id: Optional[str] = None @@ -205,7 +215,7 @@ class Node: @property def player_count(self) -> int: """Property which returns how many players are connected to this node""" - return len(self.players) + return len(self.players.values()) @property def pool(self) -> Type[NodePool]: @@ -316,25 +326,44 @@ class Node: await self.disconnect() - async def _listen(self) -> None: - backoff = ExponentialBackoff(base=7) + async def _configure_resuming(self) -> None: + if self._resume_key: + data = {"resumingKey": self._resume_key, "timeout": self._resume_timeout} + await self.send( + method="PATCH", + path=f"sessions/{self._session_id}", + include_version=True, + data=data, + ) + async def _listen(self) -> None: while True: - msg = await self._websocket.receive() - if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): + try: + msg = await self._websocket.recv() + data = json.loads(msg) + self._log.debug(f"Recieved raw websocket message {msg}") + self._loop.create_task(self._handle_ws_msg(data=data)) + except exceptions.ConnectionClosed: + if self.player_count > 0: + for _player in self.players.values(): + self._loop.create_task(_player.destroy()) + if self._fallback: - await self._handle_node_switch() + self._loop.create_task(self._handle_node_switch()) + + self._loop.create_task(self._websocket.close()) + + backoff = ExponentialBackoff(base=7) retry = backoff.delay() + self._log.debug(f"Retrying connection to Node {self._identifier} in {retry} secs") await asyncio.sleep(retry) + if not self.is_connected: self._loop.create_task(self.connect(reconnect=True)) - else: - self._loop.create_task(self._handle_payload(msg.json())) - async def _handle_payload(self, data: dict) -> None: + async def _handle_ws_msg(self, data: dict) -> None: + self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}") op = data.get("op", None) - if not op: - return if op == "stats": self._stats = NodeStats(data) @@ -342,21 +371,20 @@ class Node: if op == "ready": self._session_id = data["sessionId"] + await self._configure_resuming() if not "guildId" in data: return - player = self._players.get(int(data["guildId"])) + player: Optional[Player] = self._players.get(int(data["guildId"])) if not player: return if op == "event": - await player._dispatch_event(data) - return + return await player._dispatch_event(data) if op == "playerUpdate": - await player._update_state(data) - return + return await player._update_state(data) async def send( self, @@ -442,15 +470,17 @@ class Node: f"Version check from Node {self._identifier} successful. Returned version {version}", ) - self._websocket = await self._session.ws_connect( + self._websocket = await client.connect( f"{self._websocket_uri}/v{self._version.major}/websocket", - headers=self._headers, - heartbeat=self._heartbeat, + extra_headers=self._headers, + ping_interval=self._heartbeat, ) if reconnect: - for player in self.players.values(): - await player._refresh_endpoint_uri(self._session_id) + self._log.debug(f"Trying to reconnect to Node {self._identifier}...") + if self.player_count: + for player in self.players.values(): + await player._refresh_endpoint_uri(self._session_id) self._log.debug( f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket", @@ -466,15 +496,15 @@ class Node: self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s") return self - except (aiohttp.ClientConnectorError, ConnectionRefusedError): + except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError): raise NodeConnectionFailure( f"The connection to node '{self._identifier}' failed.", ) from None - except aiohttp.WSServerHandshakeError: + except exceptions.InvalidHandshake: raise NodeConnectionFailure( f"The password for node '{self._identifier}' is invalid.", ) from None - except aiohttp.InvalidURL: + except exceptions.InvalidURI: raise NodeConnectionFailure( f"The URI for node '{self._identifier}' is invalid.", ) from None @@ -931,6 +961,8 @@ class NodePool: identifier: str, secure: bool = False, heartbeat: int = 30, + resume_key: Optional[str] = None, + resume_timeout: int = 60, loop: Optional[asyncio.AbstractEventLoop] = None, spotify_client_id: Optional[str] = None, spotify_client_secret: Optional[str] = None, @@ -957,6 +989,8 @@ class NodePool: identifier=identifier, secure=secure, heartbeat=heartbeat, + resume_key=resume_key, + resume_timeout=resume_timeout, loop=loop, spotify_client_id=spotify_client_id, session=session, diff --git a/setup.py b/setup.py index c7a9aac..81fa4f5 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import re import setuptools version = "" -requirements = ["aiohttp>=3.7.4,<4", "orjson"] +requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"] with open("pomice/__init__.py") as f: version = re.search( r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',