diff --git a/pomice/__init__.py b/pomice/__init__.py index 9112ff2..02de6c5 100644 --- a/pomice/__init__.py +++ b/pomice/__init__.py @@ -11,7 +11,7 @@ if discord.__version__ != "2.0.0a": "using 'pip install git+https://github.com/Rapptz/discord.py@master'" ) -__version__ = "1.1.4" +__version__ = "1.1.5" __title__ = "pomice" __author__ = "cloudwithax" diff --git a/pomice/enums.py b/pomice/enums.py index f85232f..b7e3d22 100644 --- a/pomice/enums.py +++ b/pomice/enums.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, auto class SearchType(Enum): @@ -21,3 +21,25 @@ class SearchType(Enum): def __str__(self) -> str: return self.value + +class NodeAlgorithm(Enum): + """The enum for the different node algorithms in Pomice. + + The enums in this class are to only differentiate different + methods, since the actual method is handled in the + get_best_node() method. + + NodeAlgorithm.by_ping returns a node based on it's latency, + preferring a node with the lowest response time + + NodeAlgorithm.by_region returns a node based on its voice region, + which the region is specified by the user in the method as an arg. + This method will only work if you set a voice region when you create a node. + """ + + # We don't have to define anything special for these, since these just serve as flags + by_ping = auto() + by_region = auto() + + def __str__(self) -> str: + return self.value \ No newline at end of file diff --git a/pomice/events.py b/pomice/events.py index 6aeb45d..1cd7641 100644 --- a/pomice/events.py +++ b/pomice/events.py @@ -24,8 +24,8 @@ class TrackStartEvent(PomiceEvent): """ name = "track_start" - def __init__(self, data: dict): - self.player = NodePool.get_node().get_player(int(data["guildId"])) + def __init__(self, data: dict, player): + self.player = player self.track = self.player._current # on_pomice_track_start(player, track) @@ -41,8 +41,8 @@ class TrackEndEvent(PomiceEvent): """ name = "track_end" - def __init__(self, data: dict): - self.player = NodePool.get_node().get_player(int(data["guildId"])) + def __init__(self, data: dict, player): + self.player = player self.track = self.player._ending_track self.reason: str = data["reason"] @@ -63,8 +63,8 @@ class TrackStuckEvent(PomiceEvent): """ name = "track_stuck" - def __init__(self, data: dict): - self.player = NodePool.get_node().get_player(int(data["guildId"])) + def __init__(self, data: dict, player): + self.player = player self.track = self.player._ending_track self.threshold: float = data["thresholdMs"] @@ -82,8 +82,8 @@ class TrackExceptionEvent(PomiceEvent): """ name = "track_exception" - def __init__(self, data: dict): - self.player = NodePool.get_node().get_player(int(data["guildId"])) + def __init__(self, data: dict, player): + self.player = player self.track = self.player._ending_track if data.get('error'): # User is running Lavalink <= 3.3 @@ -117,7 +117,7 @@ class WebSocketClosedEvent(PomiceEvent): """ name = "websocket_closed" - def __init__(self, data: dict): + def __init__(self, data: dict, _): self.payload = WebSocketClosedPayload(data) # on_pomice_websocket_closed(payload) @@ -133,7 +133,7 @@ class WebSocketOpenEvent(PomiceEvent): """ name = "websocket_open" - def __init__(self, data: dict): + def __init__(self, data: dict, _): self.target: str = data["target"] self.ssrc: int = data["ssrc"] diff --git a/pomice/player.py b/pomice/player.py index 00d385b..850cc4b 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -36,13 +36,19 @@ class Player(VoiceProtocol): return self - def __init__(self, client: ClientType = None, channel: VoiceChannel = None, **kwargs): + def __init__( + self, + client: ClientType = None, + channel: VoiceChannel = None, + *, + node: Node = None + ): self.client = client self._bot = client self.channel = channel self._guild: Guild = self.channel.guild - self._node = NodePool.get_node() + self._node = node if node else NodePool.get_node() self._current: Track = None self._filter: Filter = None self._volume = 100 @@ -55,7 +61,6 @@ class Player(VoiceProtocol): self._ending_track: Optional[Track] = None self._voice_state = {} - self._extra = kwargs or {} def __repr__(self): return ( @@ -171,7 +176,7 @@ class Player(VoiceProtocol): async def _dispatch_event(self, data: dict): event_type = data.get("type") - event: PomiceEvent = getattr(events, event_type)(data) + event: PomiceEvent = getattr(events, event_type)(data, self) if isinstance(event, TrackEndEvent) and event.reason != "REPLACED": self._current = None diff --git a/pomice/pool.py b/pomice/pool.py index c91cd83..9459422 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -4,12 +4,15 @@ import asyncio import json import random import re +import time import socket from typing import Dict, Optional, TYPE_CHECKING from urllib.parse import quote +from enum import Enum import aiohttp from discord.ext import commands +from discord import VoiceRegion from . import ( @@ -17,17 +20,18 @@ from . import ( spotify, ) -from .enums import SearchType +from .enums import SearchType, NodeAlgorithm from .exceptions import ( InvalidSpotifyClientAuthorization, NodeConnectionFailure, NodeCreationError, + NodeException, NodeNotAvailable, NoNodesAvailable, TrackLoadError ) from .objects import Playlist, Track -from .utils import ClientType, ExponentialBackoff, NodeStats +from .utils import ClientType, ExponentialBackoff, NodeStats, Ping if TYPE_CHECKING: from .player import Player @@ -46,6 +50,7 @@ URL_REGEX = re.compile( ) + class Node: """The base class for a node. This node object represents a Lavalink node. @@ -62,6 +67,8 @@ class Node: password: str, identifier: str, secure: bool = False, + heartbeat: int = 30, + region: Optional[VoiceRegion], session: Optional[aiohttp.ClientSession], spotify_client_id: Optional[str], spotify_client_secret: Optional[str], @@ -73,7 +80,9 @@ class Node: self._pool = pool self._password = password self._identifier = identifier + self._heartbeat = heartbeat self._secure = secure + self._region: Optional[VoiceRegion] = region self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" @@ -127,6 +136,11 @@ class Node: """Property which returns a dict containing the guild ID and the player object.""" return self._players + @property + def region(self) -> Optional[VoiceRegion]: + """Property which returns the VoiceRegion of the node, if one is set""" + return self._region + @property def bot(self) -> ClientType: """Property which returns the discord.py client linked to this node""" @@ -142,6 +156,11 @@ class Node: """Property which returns the pool this node is apart of""" return self._pool + @property + def latency(self): + """Property which returns the latency of the node""" + return Ping(self._host, port=self._port).get_ping() + async def _update_handler(self, data: dict): await self._bot.wait_until_ready() @@ -216,7 +235,7 @@ class Node: try: self._websocket = await self._session.ws_connect( - self._websocket_uri, headers=self._headers, heartbeat=30 + self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat ) self._task = self._bot.loop.create_task(self._listen()) self._available = True @@ -414,6 +433,8 @@ class Node: ] + + class NodePool: """The base class for the node pool. This holds all the nodes that are to be used by the bot. @@ -433,6 +454,41 @@ class NodePool: def node_count(self): return len(self._nodes.values()) + @classmethod + def get_best_node(cls, *, algorithm: NodeAlgorithm, voice_region: VoiceRegion = None) -> Node: + """Fetches the best node based on an NodeAlgorithm. + This option is preferred if you want to choose the best node + from a multi-node setup using either the node's latency + or the node's voice region. + + Use NodeAlgorithm.by_ping if you want to get the best node + based on the node's latency. + + Use NodeAlgorithm.by_region if you want to get the best node + based on the node's voice region. This method will only work + if you set a voice region when you create a node. + """ + available_nodes = (node for node in cls._nodes.values() if node._available) + + if not available_nodes: + raise NoNodesAvailable("There are no nodes available.") + + if algorithm == NodeAlgorithm.by_ping: + tested_nodes = {node: node.latency for node in available_nodes} + return min(tested_nodes, key=tested_nodes.get) + + else: + if voice_region == None: + raise NodeException("You must specify a VoiceRegion in order to use this functionality.") + + nodes = [node for node in available_nodes if node._region is voice_region] + if not nodes: + raise NoNodesAvailable( + f"No nodes for region {voice_region} exist in this pool." + ) + + return nodes[0] + @classmethod def get_node(cls, *, identifier: str = None) -> Node: """Fetches a node from the node pool using it's identifier. @@ -461,6 +517,8 @@ class NodePool: password: str, identifier: str, secure: bool = False, + heartbeat: int = 30, + region: Optional[VoiceRegion] = None, spotify_client_id: Optional[str], spotify_client_secret: Optional[str], session: Optional[aiohttp.ClientSession] = None, @@ -474,7 +532,8 @@ class NodePool: node = Node( pool=cls, bot=bot, host=host, port=port, password=password, - identifier=identifier, secure=secure, spotify_client_id=spotify_client_id, + identifier=identifier, secure=secure, heartbeat=heartbeat, + region=region, spotify_client_id=spotify_client_id, session=session, spotify_client_secret=spotify_client_secret ) diff --git a/pomice/utils.py b/pomice/utils.py index 096844a..a26b721 100644 --- a/pomice/utils.py +++ b/pomice/utils.py @@ -1,6 +1,9 @@ import random import time +import socket from typing import Union +from timeit import default_timer as timer +from itertools import zip_longest from discord import AutoShardedClient, Client from discord.ext.commands import AutoShardedBot, Bot @@ -86,3 +89,68 @@ class NodeStats: def __repr__(self) -> str: return f"" + + +class Ping: + # Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl + def __init__(self, host, port, timeout=5): + self.timer = self.Timer() + + self._successed = 0 + self._failed = 0 + self._conn_time = None + self._host = host + self._port = port + self._timeout = timeout + + class Socket(object): + def __init__(self, family, type_, timeout): + s = socket.socket(family, type_) + s.settimeout(timeout) + self._s = s + + def connect(self, host, port): + self._s.connect((host, int(port))) + + def shutdown(self): + self._s.shutdown(socket.SHUT_RD) + + def close(self): + self._s.close() + + + class Timer(object): + def __init__(self): + self._start = 0 + self._stop = 0 + + def start(self): + self._start = timer() + + def stop(self): + self._stop = timer() + + def cost(self, funcs, args): + self.start() + for func, arg in zip_longest(funcs, args): + if arg: + func(*arg) + else: + func() + + self.stop() + return self._stop - self._start + + def _create_socket(self, family, type_): + return self.Socket(family, type_, self._timeout) + + def get_ping(self): + s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM) + + cost_time = self.timer.cost( + (s.connect, s.shutdown), + ((self._host, self._port), None)) + s_runtime = 1000 * (cost_time) + + return s_runtime + diff --git a/setup.py b/setup.py index 66a9e80..0b32f5a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open("README.md") as f: setuptools.setup( name="pomice", author="cloudwithax", - version="1.1.4", + version="1.1.5", url="https://github.com/cloudwithax/pomice", packages=setuptools.find_packages(), license="GPL",