fix websocket issue and add node resuming

This commit is contained in:
cloudwithax 2023-05-07 19:27:11 -04:00
parent dd3d43e702
commit 00ac166371
No known key found for this signature in database
GPG Key ID: 5DBE54E45794983E
4 changed files with 68 additions and 33 deletions

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'",
)
__version__ = "2.5.1"
__version__ = "2.6.0a"
__title__ = "pomice"
__author__ = "cloudwithax"
__license__ = "GPL-3.0"

View File

@ -434,6 +434,7 @@ class Player(VoiceProtocol):
assert not self.is_connected and not self.channel
self._node._players.pop(self.guild.id)
if self.node.is_connected:
await self._node.send(
method="DELETE",
path=self._player_endpoint_uri,

View File

@ -17,9 +17,13 @@ from typing import Union
from urllib.parse import quote
import aiohttp
import orjson as json
from discord import Client
from discord.ext import commands
from discord.utils import MISSING
from websockets import client
from websockets import exceptions
from websockets import typing as wstype
from . import __version__
from . import applemusic
@ -71,6 +75,8 @@ class Node:
"_password",
"_identifier",
"_heartbeat",
"_resume_key",
"_resume_timeout",
"_secure",
"_fallback",
"_log_level",
@ -106,7 +112,9 @@ class Node:
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
heartbeat: int = 60,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None,
@ -126,6 +134,8 @@ class Node:
self._password: str = password
self._identifier: str = identifier
self._heartbeat: int = heartbeat
self._resume_key: Optional[str] = resume_key
self._resume_timeout: int = resume_timeout
self._secure: bool = secure
self._fallback: bool = fallback
self._log_level: LogLevel = log_level
@ -136,7 +146,7 @@ class Node:
self._session: aiohttp.ClientSession = session # type: ignore
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: aiohttp.ClientWebSocketResponse
self._websocket: client.WebSocketClientProtocol
self._task: asyncio.Task = None # type: ignore
self._session_id: Optional[str] = None
@ -205,7 +215,7 @@ class Node:
@property
def player_count(self) -> int:
"""Property which returns how many players are connected to this node"""
return len(self.players)
return len(self.players.values())
@property
def pool(self) -> Type[NodePool]:
@ -316,25 +326,44 @@ class Node:
await self.disconnect()
async def _listen(self) -> None:
backoff = ExponentialBackoff(base=7)
async def _configure_resuming(self) -> None:
if self._resume_key:
data = {"resumingKey": self._resume_key, "timeout": self._resume_timeout}
await self.send(
method="PATCH",
path=f"sessions/{self._session_id}",
include_version=True,
data=data,
)
async def _listen(self) -> None:
while True:
msg = await self._websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
try:
msg = await self._websocket.recv()
data = json.loads(msg)
self._log.debug(f"Recieved raw websocket message {msg}")
self._loop.create_task(self._handle_ws_msg(data=data))
except exceptions.ConnectionClosed:
if self.player_count > 0:
for _player in self.players.values():
self._loop.create_task(_player.destroy())
if self._fallback:
await self._handle_node_switch()
self._loop.create_task(self._handle_node_switch())
self._loop.create_task(self._websocket.close())
backoff = ExponentialBackoff(base=7)
retry = backoff.delay()
self._log.debug(f"Retrying connection to Node {self._identifier} in {retry} secs")
await asyncio.sleep(retry)
if not self.is_connected:
self._loop.create_task(self.connect(reconnect=True))
else:
self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict) -> None:
async def _handle_ws_msg(self, data: dict) -> None:
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
op = data.get("op", None)
if not op:
return
if op == "stats":
self._stats = NodeStats(data)
@ -342,21 +371,20 @@ class Node:
if op == "ready":
self._session_id = data["sessionId"]
await self._configure_resuming()
if not "guildId" in data:
return
player = self._players.get(int(data["guildId"]))
player: Optional[Player] = self._players.get(int(data["guildId"]))
if not player:
return
if op == "event":
await player._dispatch_event(data)
return
return await player._dispatch_event(data)
if op == "playerUpdate":
await player._update_state(data)
return
return await player._update_state(data)
async def send(
self,
@ -442,13 +470,15 @@ class Node:
f"Version check from Node {self._identifier} successful. Returned version {version}",
)
self._websocket = await self._session.ws_connect(
self._websocket = await client.connect(
f"{self._websocket_uri}/v{self._version.major}/websocket",
headers=self._headers,
heartbeat=self._heartbeat,
extra_headers=self._headers,
ping_interval=self._heartbeat,
)
if reconnect:
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
if self.player_count:
for player in self.players.values():
await player._refresh_endpoint_uri(self._session_id)
@ -466,15 +496,15 @@ class Node:
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
return self
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed.",
) from None
except aiohttp.WSServerHandshakeError:
except exceptions.InvalidHandshake:
raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid.",
) from None
except aiohttp.InvalidURL:
except exceptions.InvalidURI:
raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid.",
) from None
@ -931,6 +961,8 @@ class NodePool:
identifier: str,
secure: bool = False,
heartbeat: int = 30,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None,
@ -957,6 +989,8 @@ class NodePool:
identifier=identifier,
secure=secure,
heartbeat=heartbeat,
resume_key=resume_key,
resume_timeout=resume_timeout,
loop=loop,
spotify_client_id=spotify_client_id,
session=session,

View File

@ -4,7 +4,7 @@ import re
import setuptools
version = ""
requirements = ["aiohttp>=3.7.4,<4", "orjson"]
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
with open("pomice/__init__.py") as f:
version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',