fix websocket issue and add node resuming
This commit is contained in:
parent
dd3d43e702
commit
00ac166371
|
|
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
|
|||
"using 'pip install discord.py'",
|
||||
)
|
||||
|
||||
__version__ = "2.5.1"
|
||||
__version__ = "2.6.0a"
|
||||
__title__ = "pomice"
|
||||
__author__ = "cloudwithax"
|
||||
__license__ = "GPL-3.0"
|
||||
|
|
|
|||
|
|
@ -434,6 +434,7 @@ class Player(VoiceProtocol):
|
|||
assert not self.is_connected and not self.channel
|
||||
|
||||
self._node._players.pop(self.guild.id)
|
||||
if self.node.is_connected:
|
||||
await self._node.send(
|
||||
method="DELETE",
|
||||
path=self._player_endpoint_uri,
|
||||
|
|
|
|||
|
|
@ -17,9 +17,13 @@ from typing import Union
|
|||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
import orjson as json
|
||||
from discord import Client
|
||||
from discord.ext import commands
|
||||
from discord.utils import MISSING
|
||||
from websockets import client
|
||||
from websockets import exceptions
|
||||
from websockets import typing as wstype
|
||||
|
||||
from . import __version__
|
||||
from . import applemusic
|
||||
|
|
@ -71,6 +75,8 @@ class Node:
|
|||
"_password",
|
||||
"_identifier",
|
||||
"_heartbeat",
|
||||
"_resume_key",
|
||||
"_resume_timeout",
|
||||
"_secure",
|
||||
"_fallback",
|
||||
"_log_level",
|
||||
|
|
@ -106,7 +112,9 @@ class Node:
|
|||
password: str,
|
||||
identifier: str,
|
||||
secure: bool = False,
|
||||
heartbeat: int = 30,
|
||||
heartbeat: int = 60,
|
||||
resume_key: Optional[str] = None,
|
||||
resume_timeout: int = 60,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
spotify_client_id: Optional[str] = None,
|
||||
|
|
@ -126,6 +134,8 @@ class Node:
|
|||
self._password: str = password
|
||||
self._identifier: str = identifier
|
||||
self._heartbeat: int = heartbeat
|
||||
self._resume_key: Optional[str] = resume_key
|
||||
self._resume_timeout: int = resume_timeout
|
||||
self._secure: bool = secure
|
||||
self._fallback: bool = fallback
|
||||
self._log_level: LogLevel = log_level
|
||||
|
|
@ -136,7 +146,7 @@ class Node:
|
|||
|
||||
self._session: aiohttp.ClientSession = session # type: ignore
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self._websocket: aiohttp.ClientWebSocketResponse
|
||||
self._websocket: client.WebSocketClientProtocol
|
||||
self._task: asyncio.Task = None # type: ignore
|
||||
|
||||
self._session_id: Optional[str] = None
|
||||
|
|
@ -205,7 +215,7 @@ class Node:
|
|||
@property
|
||||
def player_count(self) -> int:
|
||||
"""Property which returns how many players are connected to this node"""
|
||||
return len(self.players)
|
||||
return len(self.players.values())
|
||||
|
||||
@property
|
||||
def pool(self) -> Type[NodePool]:
|
||||
|
|
@ -316,25 +326,44 @@ class Node:
|
|||
|
||||
await self.disconnect()
|
||||
|
||||
async def _listen(self) -> None:
|
||||
backoff = ExponentialBackoff(base=7)
|
||||
async def _configure_resuming(self) -> None:
|
||||
if self._resume_key:
|
||||
data = {"resumingKey": self._resume_key, "timeout": self._resume_timeout}
|
||||
await self.send(
|
||||
method="PATCH",
|
||||
path=f"sessions/{self._session_id}",
|
||||
include_version=True,
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def _listen(self) -> None:
|
||||
while True:
|
||||
msg = await self._websocket.receive()
|
||||
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||
try:
|
||||
msg = await self._websocket.recv()
|
||||
data = json.loads(msg)
|
||||
self._log.debug(f"Recieved raw websocket message {msg}")
|
||||
self._loop.create_task(self._handle_ws_msg(data=data))
|
||||
except exceptions.ConnectionClosed:
|
||||
if self.player_count > 0:
|
||||
for _player in self.players.values():
|
||||
self._loop.create_task(_player.destroy())
|
||||
|
||||
if self._fallback:
|
||||
await self._handle_node_switch()
|
||||
self._loop.create_task(self._handle_node_switch())
|
||||
|
||||
self._loop.create_task(self._websocket.close())
|
||||
|
||||
backoff = ExponentialBackoff(base=7)
|
||||
retry = backoff.delay()
|
||||
self._log.debug(f"Retrying connection to Node {self._identifier} in {retry} secs")
|
||||
await asyncio.sleep(retry)
|
||||
|
||||
if not self.is_connected:
|
||||
self._loop.create_task(self.connect(reconnect=True))
|
||||
else:
|
||||
self._loop.create_task(self._handle_payload(msg.json()))
|
||||
|
||||
async def _handle_payload(self, data: dict) -> None:
|
||||
async def _handle_ws_msg(self, data: dict) -> None:
|
||||
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
|
||||
op = data.get("op", None)
|
||||
if not op:
|
||||
return
|
||||
|
||||
if op == "stats":
|
||||
self._stats = NodeStats(data)
|
||||
|
|
@ -342,21 +371,20 @@ class Node:
|
|||
|
||||
if op == "ready":
|
||||
self._session_id = data["sessionId"]
|
||||
await self._configure_resuming()
|
||||
|
||||
if not "guildId" in data:
|
||||
return
|
||||
|
||||
player = self._players.get(int(data["guildId"]))
|
||||
player: Optional[Player] = self._players.get(int(data["guildId"]))
|
||||
if not player:
|
||||
return
|
||||
|
||||
if op == "event":
|
||||
await player._dispatch_event(data)
|
||||
return
|
||||
return await player._dispatch_event(data)
|
||||
|
||||
if op == "playerUpdate":
|
||||
await player._update_state(data)
|
||||
return
|
||||
return await player._update_state(data)
|
||||
|
||||
async def send(
|
||||
self,
|
||||
|
|
@ -442,13 +470,15 @@ class Node:
|
|||
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
||||
)
|
||||
|
||||
self._websocket = await self._session.ws_connect(
|
||||
self._websocket = await client.connect(
|
||||
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
||||
headers=self._headers,
|
||||
heartbeat=self._heartbeat,
|
||||
extra_headers=self._headers,
|
||||
ping_interval=self._heartbeat,
|
||||
)
|
||||
|
||||
if reconnect:
|
||||
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
|
||||
if self.player_count:
|
||||
for player in self.players.values():
|
||||
await player._refresh_endpoint_uri(self._session_id)
|
||||
|
||||
|
|
@ -466,15 +496,15 @@ class Node:
|
|||
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
||||
return self
|
||||
|
||||
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
|
||||
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
|
||||
raise NodeConnectionFailure(
|
||||
f"The connection to node '{self._identifier}' failed.",
|
||||
) from None
|
||||
except aiohttp.WSServerHandshakeError:
|
||||
except exceptions.InvalidHandshake:
|
||||
raise NodeConnectionFailure(
|
||||
f"The password for node '{self._identifier}' is invalid.",
|
||||
) from None
|
||||
except aiohttp.InvalidURL:
|
||||
except exceptions.InvalidURI:
|
||||
raise NodeConnectionFailure(
|
||||
f"The URI for node '{self._identifier}' is invalid.",
|
||||
) from None
|
||||
|
|
@ -931,6 +961,8 @@ class NodePool:
|
|||
identifier: str,
|
||||
secure: bool = False,
|
||||
heartbeat: int = 30,
|
||||
resume_key: Optional[str] = None,
|
||||
resume_timeout: int = 60,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
spotify_client_id: Optional[str] = None,
|
||||
spotify_client_secret: Optional[str] = None,
|
||||
|
|
@ -957,6 +989,8 @@ class NodePool:
|
|||
identifier=identifier,
|
||||
secure=secure,
|
||||
heartbeat=heartbeat,
|
||||
resume_key=resume_key,
|
||||
resume_timeout=resume_timeout,
|
||||
loop=loop,
|
||||
spotify_client_id=spotify_client_id,
|
||||
session=session,
|
||||
|
|
|
|||
Loading…
Reference in New Issue