added node algos, cleaned up events.py, add node arg to Player()
This commit is contained in:
parent
c3135b7798
commit
d249d84f9f
|
|
@ -11,7 +11,7 @@ if discord.__version__ != "2.0.0a":
|
|||
"using 'pip install git+https://github.com/Rapptz/discord.py@master'"
|
||||
)
|
||||
|
||||
__version__ = "1.1.4"
|
||||
__version__ = "1.1.5"
|
||||
__title__ = "pomice"
|
||||
__author__ = "cloudwithax"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
|
|
@ -21,3 +21,25 @@ class SearchType(Enum):
|
|||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
class NodeAlgorithm(Enum):
|
||||
"""The enum for the different node algorithms in Pomice.
|
||||
|
||||
The enums in this class are to only differentiate different
|
||||
methods, since the actual method is handled in the
|
||||
get_best_node() method.
|
||||
|
||||
NodeAlgorithm.by_ping returns a node based on it's latency,
|
||||
preferring a node with the lowest response time
|
||||
|
||||
NodeAlgorithm.by_region returns a node based on its voice region,
|
||||
which the region is specified by the user in the method as an arg.
|
||||
This method will only work if you set a voice region when you create a node.
|
||||
"""
|
||||
|
||||
# We don't have to define anything special for these, since these just serve as flags
|
||||
by_ping = auto()
|
||||
by_region = auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -24,8 +24,8 @@ class TrackStartEvent(PomiceEvent):
|
|||
"""
|
||||
name = "track_start"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
||||
def __init__(self, data: dict, player):
|
||||
self.player = player
|
||||
self.track = self.player._current
|
||||
|
||||
# on_pomice_track_start(player, track)
|
||||
|
|
@ -41,8 +41,8 @@ class TrackEndEvent(PomiceEvent):
|
|||
"""
|
||||
name = "track_end"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
||||
def __init__(self, data: dict, player):
|
||||
self.player = player
|
||||
self.track = self.player._ending_track
|
||||
self.reason: str = data["reason"]
|
||||
|
||||
|
|
@ -63,8 +63,8 @@ class TrackStuckEvent(PomiceEvent):
|
|||
"""
|
||||
name = "track_stuck"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
||||
def __init__(self, data: dict, player):
|
||||
self.player = player
|
||||
self.track = self.player._ending_track
|
||||
self.threshold: float = data["thresholdMs"]
|
||||
|
||||
|
|
@ -82,8 +82,8 @@ class TrackExceptionEvent(PomiceEvent):
|
|||
"""
|
||||
name = "track_exception"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
||||
def __init__(self, data: dict, player):
|
||||
self.player = player
|
||||
self.track = self.player._ending_track
|
||||
if data.get('error'):
|
||||
# User is running Lavalink <= 3.3
|
||||
|
|
@ -117,7 +117,7 @@ class WebSocketClosedEvent(PomiceEvent):
|
|||
"""
|
||||
name = "websocket_closed"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
def __init__(self, data: dict, _):
|
||||
self.payload = WebSocketClosedPayload(data)
|
||||
|
||||
# on_pomice_websocket_closed(payload)
|
||||
|
|
@ -133,7 +133,7 @@ class WebSocketOpenEvent(PomiceEvent):
|
|||
"""
|
||||
name = "websocket_open"
|
||||
|
||||
def __init__(self, data: dict):
|
||||
def __init__(self, data: dict, _):
|
||||
self.target: str = data["target"]
|
||||
self.ssrc: int = data["ssrc"]
|
||||
|
||||
|
|
|
|||
|
|
@ -36,13 +36,19 @@ class Player(VoiceProtocol):
|
|||
|
||||
return self
|
||||
|
||||
def __init__(self, client: ClientType = None, channel: VoiceChannel = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
client: ClientType = None,
|
||||
channel: VoiceChannel = None,
|
||||
*,
|
||||
node: Node = None
|
||||
):
|
||||
self.client = client
|
||||
self._bot = client
|
||||
self.channel = channel
|
||||
self._guild: Guild = self.channel.guild
|
||||
|
||||
self._node = NodePool.get_node()
|
||||
self._node = node if node else NodePool.get_node()
|
||||
self._current: Track = None
|
||||
self._filter: Filter = None
|
||||
self._volume = 100
|
||||
|
|
@ -55,7 +61,6 @@ class Player(VoiceProtocol):
|
|||
self._ending_track: Optional[Track] = None
|
||||
|
||||
self._voice_state = {}
|
||||
self._extra = kwargs or {}
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -171,7 +176,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def _dispatch_event(self, data: dict):
|
||||
event_type = data.get("type")
|
||||
event: PomiceEvent = getattr(events, event_type)(data)
|
||||
event: PomiceEvent = getattr(events, event_type)(data, self)
|
||||
|
||||
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
|
||||
self._current = None
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ import asyncio
|
|||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import socket
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from urllib.parse import quote
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
from discord.ext import commands
|
||||
from discord import VoiceRegion
|
||||
|
||||
|
||||
from . import (
|
||||
|
|
@ -17,17 +20,18 @@ from . import (
|
|||
spotify,
|
||||
)
|
||||
|
||||
from .enums import SearchType
|
||||
from .enums import SearchType, NodeAlgorithm
|
||||
from .exceptions import (
|
||||
InvalidSpotifyClientAuthorization,
|
||||
NodeConnectionFailure,
|
||||
NodeCreationError,
|
||||
NodeException,
|
||||
NodeNotAvailable,
|
||||
NoNodesAvailable,
|
||||
TrackLoadError
|
||||
)
|
||||
from .objects import Playlist, Track
|
||||
from .utils import ClientType, ExponentialBackoff, NodeStats
|
||||
from .utils import ClientType, ExponentialBackoff, NodeStats, Ping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
|
|
@ -46,6 +50,7 @@ URL_REGEX = re.compile(
|
|||
)
|
||||
|
||||
|
||||
|
||||
class Node:
|
||||
"""The base class for a node.
|
||||
This node object represents a Lavalink node.
|
||||
|
|
@ -62,6 +67,8 @@ class Node:
|
|||
password: str,
|
||||
identifier: str,
|
||||
secure: bool = False,
|
||||
heartbeat: int = 30,
|
||||
region: Optional[VoiceRegion],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
spotify_client_id: Optional[str],
|
||||
spotify_client_secret: Optional[str],
|
||||
|
|
@ -73,7 +80,9 @@ class Node:
|
|||
self._pool = pool
|
||||
self._password = password
|
||||
self._identifier = identifier
|
||||
self._heartbeat = heartbeat
|
||||
self._secure = secure
|
||||
self._region: Optional[VoiceRegion] = region
|
||||
|
||||
|
||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||
|
|
@ -127,6 +136,11 @@ class Node:
|
|||
"""Property which returns a dict containing the guild ID and the player object."""
|
||||
return self._players
|
||||
|
||||
@property
|
||||
def region(self) -> Optional[VoiceRegion]:
|
||||
"""Property which returns the VoiceRegion of the node, if one is set"""
|
||||
return self._region
|
||||
|
||||
@property
|
||||
def bot(self) -> ClientType:
|
||||
"""Property which returns the discord.py client linked to this node"""
|
||||
|
|
@ -142,6 +156,11 @@ class Node:
|
|||
"""Property which returns the pool this node is apart of"""
|
||||
return self._pool
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
"""Property which returns the latency of the node"""
|
||||
return Ping(self._host, port=self._port).get_ping()
|
||||
|
||||
async def _update_handler(self, data: dict):
|
||||
await self._bot.wait_until_ready()
|
||||
|
||||
|
|
@ -216,7 +235,7 @@ class Node:
|
|||
|
||||
try:
|
||||
self._websocket = await self._session.ws_connect(
|
||||
self._websocket_uri, headers=self._headers, heartbeat=30
|
||||
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
|
||||
)
|
||||
self._task = self._bot.loop.create_task(self._listen())
|
||||
self._available = True
|
||||
|
|
@ -414,6 +433,8 @@ class Node:
|
|||
]
|
||||
|
||||
|
||||
|
||||
|
||||
class NodePool:
|
||||
"""The base class for the node pool.
|
||||
This holds all the nodes that are to be used by the bot.
|
||||
|
|
@ -433,6 +454,41 @@ class NodePool:
|
|||
def node_count(self):
|
||||
return len(self._nodes.values())
|
||||
|
||||
@classmethod
|
||||
def get_best_node(cls, *, algorithm: NodeAlgorithm, voice_region: VoiceRegion = None) -> Node:
|
||||
"""Fetches the best node based on an NodeAlgorithm.
|
||||
This option is preferred if you want to choose the best node
|
||||
from a multi-node setup using either the node's latency
|
||||
or the node's voice region.
|
||||
|
||||
Use NodeAlgorithm.by_ping if you want to get the best node
|
||||
based on the node's latency.
|
||||
|
||||
Use NodeAlgorithm.by_region if you want to get the best node
|
||||
based on the node's voice region. This method will only work
|
||||
if you set a voice region when you create a node.
|
||||
"""
|
||||
available_nodes = (node for node in cls._nodes.values() if node._available)
|
||||
|
||||
if not available_nodes:
|
||||
raise NoNodesAvailable("There are no nodes available.")
|
||||
|
||||
if algorithm == NodeAlgorithm.by_ping:
|
||||
tested_nodes = {node: node.latency for node in available_nodes}
|
||||
return min(tested_nodes, key=tested_nodes.get)
|
||||
|
||||
else:
|
||||
if voice_region == None:
|
||||
raise NodeException("You must specify a VoiceRegion in order to use this functionality.")
|
||||
|
||||
nodes = [node for node in available_nodes if node._region is voice_region]
|
||||
if not nodes:
|
||||
raise NoNodesAvailable(
|
||||
f"No nodes for region {voice_region} exist in this pool."
|
||||
)
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, *, identifier: str = None) -> Node:
|
||||
"""Fetches a node from the node pool using it's identifier.
|
||||
|
|
@ -461,6 +517,8 @@ class NodePool:
|
|||
password: str,
|
||||
identifier: str,
|
||||
secure: bool = False,
|
||||
heartbeat: int = 30,
|
||||
region: Optional[VoiceRegion] = None,
|
||||
spotify_client_id: Optional[str],
|
||||
spotify_client_secret: Optional[str],
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
|
|
@ -474,7 +532,8 @@ class NodePool:
|
|||
|
||||
node = Node(
|
||||
pool=cls, bot=bot, host=host, port=port, password=password,
|
||||
identifier=identifier, secure=secure, spotify_client_id=spotify_client_id,
|
||||
identifier=identifier, secure=secure, heartbeat=heartbeat,
|
||||
region=region, spotify_client_id=spotify_client_id,
|
||||
session=session, spotify_client_secret=spotify_client_secret
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import random
|
||||
import time
|
||||
import socket
|
||||
from typing import Union
|
||||
from timeit import default_timer as timer
|
||||
from itertools import zip_longest
|
||||
|
||||
from discord import AutoShardedClient, Client
|
||||
from discord.ext.commands import AutoShardedBot, Bot
|
||||
|
|
@ -86,3 +89,68 @@ class NodeStats:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
|
||||
|
||||
|
||||
class Ping:
|
||||
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
|
||||
def __init__(self, host, port, timeout=5):
|
||||
self.timer = self.Timer()
|
||||
|
||||
self._successed = 0
|
||||
self._failed = 0
|
||||
self._conn_time = None
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._timeout = timeout
|
||||
|
||||
class Socket(object):
|
||||
def __init__(self, family, type_, timeout):
|
||||
s = socket.socket(family, type_)
|
||||
s.settimeout(timeout)
|
||||
self._s = s
|
||||
|
||||
def connect(self, host, port):
|
||||
self._s.connect((host, int(port)))
|
||||
|
||||
def shutdown(self):
|
||||
self._s.shutdown(socket.SHUT_RD)
|
||||
|
||||
def close(self):
|
||||
self._s.close()
|
||||
|
||||
|
||||
class Timer(object):
|
||||
def __init__(self):
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
|
||||
def start(self):
|
||||
self._start = timer()
|
||||
|
||||
def stop(self):
|
||||
self._stop = timer()
|
||||
|
||||
def cost(self, funcs, args):
|
||||
self.start()
|
||||
for func, arg in zip_longest(funcs, args):
|
||||
if arg:
|
||||
func(*arg)
|
||||
else:
|
||||
func()
|
||||
|
||||
self.stop()
|
||||
return self._stop - self._start
|
||||
|
||||
def _create_socket(self, family, type_):
|
||||
return self.Socket(family, type_, self._timeout)
|
||||
|
||||
def get_ping(self):
|
||||
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
cost_time = self.timer.cost(
|
||||
(s.connect, s.shutdown),
|
||||
((self._host, self._port), None))
|
||||
s_runtime = 1000 * (cost_time)
|
||||
|
||||
return s_runtime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue