fix websocket issue and add node resuming

This commit is contained in:
cloudwithax 2023-05-07 19:27:11 -04:00
parent dd3d43e702
commit 00ac166371
No known key found for this signature in database
GPG Key ID: 5DBE54E45794983E
4 changed files with 68 additions and 33 deletions

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'", "using 'pip install discord.py'",
) )
__version__ = "2.5.1" __version__ = "2.6.0a"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0" __license__ = "GPL-3.0"

View File

@ -434,6 +434,7 @@ class Player(VoiceProtocol):
assert not self.is_connected and not self.channel assert not self.is_connected and not self.channel
self._node._players.pop(self.guild.id) self._node._players.pop(self.guild.id)
if self.node.is_connected:
await self._node.send( await self._node.send(
method="DELETE", method="DELETE",
path=self._player_endpoint_uri, path=self._player_endpoint_uri,

View File

@ -17,9 +17,13 @@ from typing import Union
from urllib.parse import quote from urllib.parse import quote
import aiohttp import aiohttp
import orjson as json
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from discord.utils import MISSING from discord.utils import MISSING
from websockets import client
from websockets import exceptions
from websockets import typing as wstype
from . import __version__ from . import __version__
from . import applemusic from . import applemusic
@ -71,6 +75,8 @@ class Node:
"_password", "_password",
"_identifier", "_identifier",
"_heartbeat", "_heartbeat",
"_resume_key",
"_resume_timeout",
"_secure", "_secure",
"_fallback", "_fallback",
"_log_level", "_log_level",
@ -106,7 +112,9 @@ class Node:
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30, heartbeat: int = 60,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
@ -126,6 +134,8 @@ class Node:
self._password: str = password self._password: str = password
self._identifier: str = identifier self._identifier: str = identifier
self._heartbeat: int = heartbeat self._heartbeat: int = heartbeat
self._resume_key: Optional[str] = resume_key
self._resume_timeout: int = resume_timeout
self._secure: bool = secure self._secure: bool = secure
self._fallback: bool = fallback self._fallback: bool = fallback
self._log_level: LogLevel = log_level self._log_level: LogLevel = log_level
@ -136,7 +146,7 @@ class Node:
self._session: aiohttp.ClientSession = session # type: ignore self._session: aiohttp.ClientSession = session # type: ignore
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: aiohttp.ClientWebSocketResponse self._websocket: client.WebSocketClientProtocol
self._task: asyncio.Task = None # type: ignore self._task: asyncio.Task = None # type: ignore
self._session_id: Optional[str] = None self._session_id: Optional[str] = None
@ -205,7 +215,7 @@ class Node:
@property @property
def player_count(self) -> int: def player_count(self) -> int:
"""Property which returns how many players are connected to this node""" """Property which returns how many players are connected to this node"""
return len(self.players) return len(self.players.values())
@property @property
def pool(self) -> Type[NodePool]: def pool(self) -> Type[NodePool]:
@ -316,25 +326,44 @@ class Node:
await self.disconnect() await self.disconnect()
async def _listen(self) -> None: async def _configure_resuming(self) -> None:
backoff = ExponentialBackoff(base=7) if self._resume_key:
data = {"resumingKey": self._resume_key, "timeout": self._resume_timeout}
await self.send(
method="PATCH",
path=f"sessions/{self._session_id}",
include_version=True,
data=data,
)
async def _listen(self) -> None:
while True: while True:
msg = await self._websocket.receive() try:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): msg = await self._websocket.recv()
data = json.loads(msg)
self._log.debug(f"Recieved raw websocket message {msg}")
self._loop.create_task(self._handle_ws_msg(data=data))
except exceptions.ConnectionClosed:
if self.player_count > 0:
for _player in self.players.values():
self._loop.create_task(_player.destroy())
if self._fallback: if self._fallback:
await self._handle_node_switch() self._loop.create_task(self._handle_node_switch())
self._loop.create_task(self._websocket.close())
backoff = ExponentialBackoff(base=7)
retry = backoff.delay() retry = backoff.delay()
self._log.debug(f"Retrying connection to Node {self._identifier} in {retry} secs")
await asyncio.sleep(retry) await asyncio.sleep(retry)
if not self.is_connected: if not self.is_connected:
self._loop.create_task(self.connect(reconnect=True)) self._loop.create_task(self.connect(reconnect=True))
else:
self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict) -> None: async def _handle_ws_msg(self, data: dict) -> None:
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
op = data.get("op", None) op = data.get("op", None)
if not op:
return
if op == "stats": if op == "stats":
self._stats = NodeStats(data) self._stats = NodeStats(data)
@ -342,21 +371,20 @@ class Node:
if op == "ready": if op == "ready":
self._session_id = data["sessionId"] self._session_id = data["sessionId"]
await self._configure_resuming()
if not "guildId" in data: if not "guildId" in data:
return return
player = self._players.get(int(data["guildId"])) player: Optional[Player] = self._players.get(int(data["guildId"]))
if not player: if not player:
return return
if op == "event": if op == "event":
await player._dispatch_event(data) return await player._dispatch_event(data)
return
if op == "playerUpdate": if op == "playerUpdate":
await player._update_state(data) return await player._update_state(data)
return
async def send( async def send(
self, self,
@ -442,13 +470,15 @@ class Node:
f"Version check from Node {self._identifier} successful. Returned version {version}", f"Version check from Node {self._identifier} successful. Returned version {version}",
) )
self._websocket = await self._session.ws_connect( self._websocket = await client.connect(
f"{self._websocket_uri}/v{self._version.major}/websocket", f"{self._websocket_uri}/v{self._version.major}/websocket",
headers=self._headers, extra_headers=self._headers,
heartbeat=self._heartbeat, ping_interval=self._heartbeat,
) )
if reconnect: if reconnect:
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
if self.player_count:
for player in self.players.values(): for player in self.players.values():
await player._refresh_endpoint_uri(self._session_id) await player._refresh_endpoint_uri(self._session_id)
@ -466,15 +496,15 @@ class Node:
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s") self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
return self return self
except (aiohttp.ClientConnectorError, ConnectionRefusedError): except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed.", f"The connection to node '{self._identifier}' failed.",
) from None ) from None
except aiohttp.WSServerHandshakeError: except exceptions.InvalidHandshake:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid.", f"The password for node '{self._identifier}' is invalid.",
) from None ) from None
except aiohttp.InvalidURL: except exceptions.InvalidURI:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid.", f"The URI for node '{self._identifier}' is invalid.",
) from None ) from None
@ -931,6 +961,8 @@ class NodePool:
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30, heartbeat: int = 30,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
@ -957,6 +989,8 @@ class NodePool:
identifier=identifier, identifier=identifier,
secure=secure, secure=secure,
heartbeat=heartbeat, heartbeat=heartbeat,
resume_key=resume_key,
resume_timeout=resume_timeout,
loop=loop, loop=loop,
spotify_client_id=spotify_client_id, spotify_client_id=spotify_client_id,
session=session, session=session,

View File

@ -4,7 +4,7 @@ import re
import setuptools import setuptools
version = "" version = ""
requirements = ["aiohttp>=3.7.4,<4", "orjson"] requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
with open("pomice/__init__.py") as f: with open("pomice/__init__.py") as f:
version = re.search( version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',