added node algos, cleaned up events.py, add node arg to Player()

This commit is contained in:
cloudwithax 2021-11-23 13:59:20 -05:00
parent c3135b7798
commit d249d84f9f
7 changed files with 175 additions and 21 deletions

View File

@ -11,7 +11,7 @@ if discord.__version__ != "2.0.0a":
"using 'pip install git+https://github.com/Rapptz/discord.py@master'"
)
__version__ = "1.1.4"
__version__ = "1.1.5"
__title__ = "pomice"
__author__ = "cloudwithax"

View File

@ -1,4 +1,4 @@
from enum import Enum
from enum import Enum, auto
class SearchType(Enum):
@ -21,3 +21,25 @@ class SearchType(Enum):
def __str__(self) -> str:
return self.value
class NodeAlgorithm(Enum):
"""The enum for the different node algorithms in Pomice.
The enums in this class are to only differentiate different
methods, since the actual method is handled in the
get_best_node() method.
NodeAlgorithm.by_ping returns a node based on it's latency,
preferring a node with the lowest response time
NodeAlgorithm.by_region returns a node based on its voice region,
which the region is specified by the user in the method as an arg.
This method will only work if you set a voice region when you create a node.
"""
# We don't have to define anything special for these, since these just serve as flags
by_ping = auto()
by_region = auto()
def __str__(self) -> str:
return self.value

View File

@ -24,8 +24,8 @@ class TrackStartEvent(PomiceEvent):
"""
name = "track_start"
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
def __init__(self, data: dict, player):
self.player = player
self.track = self.player._current
# on_pomice_track_start(player, track)
@ -41,8 +41,8 @@ class TrackEndEvent(PomiceEvent):
"""
name = "track_end"
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
def __init__(self, data: dict, player):
self.player = player
self.track = self.player._ending_track
self.reason: str = data["reason"]
@ -63,8 +63,8 @@ class TrackStuckEvent(PomiceEvent):
"""
name = "track_stuck"
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
def __init__(self, data: dict, player):
self.player = player
self.track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
@ -82,8 +82,8 @@ class TrackExceptionEvent(PomiceEvent):
"""
name = "track_exception"
def __init__(self, data: dict):
self.player = NodePool.get_node().get_player(int(data["guildId"]))
def __init__(self, data: dict, player):
self.player = player
self.track = self.player._ending_track
if data.get('error'):
# User is running Lavalink <= 3.3
@ -117,7 +117,7 @@ class WebSocketClosedEvent(PomiceEvent):
"""
name = "websocket_closed"
def __init__(self, data: dict):
def __init__(self, data: dict, _):
self.payload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload)
@ -133,7 +133,7 @@ class WebSocketOpenEvent(PomiceEvent):
"""
name = "websocket_open"
def __init__(self, data: dict):
def __init__(self, data: dict, _):
self.target: str = data["target"]
self.ssrc: int = data["ssrc"]

View File

@ -36,13 +36,19 @@ class Player(VoiceProtocol):
return self
def __init__(self, client: ClientType = None, channel: VoiceChannel = None, **kwargs):
def __init__(
self,
client: ClientType = None,
channel: VoiceChannel = None,
*,
node: Node = None
):
self.client = client
self._bot = client
self.channel = channel
self._guild: Guild = self.channel.guild
self._node = NodePool.get_node()
self._node = node if node else NodePool.get_node()
self._current: Track = None
self._filter: Filter = None
self._volume = 100
@ -55,7 +61,6 @@ class Player(VoiceProtocol):
self._ending_track: Optional[Track] = None
self._voice_state = {}
self._extra = kwargs or {}
def __repr__(self):
return (
@ -171,7 +176,7 @@ class Player(VoiceProtocol):
async def _dispatch_event(self, data: dict):
event_type = data.get("type")
event: PomiceEvent = getattr(events, event_type)(data)
event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
self._current = None

View File

@ -4,12 +4,15 @@ import asyncio
import json
import random
import re
import time
import socket
from typing import Dict, Optional, TYPE_CHECKING
from urllib.parse import quote
from enum import Enum
import aiohttp
from discord.ext import commands
from discord import VoiceRegion
from . import (
@ -17,17 +20,18 @@ from . import (
spotify,
)
from .enums import SearchType
from .enums import SearchType, NodeAlgorithm
from .exceptions import (
InvalidSpotifyClientAuthorization,
NodeConnectionFailure,
NodeCreationError,
NodeException,
NodeNotAvailable,
NoNodesAvailable,
TrackLoadError
)
from .objects import Playlist, Track
from .utils import ClientType, ExponentialBackoff, NodeStats
from .utils import ClientType, ExponentialBackoff, NodeStats, Ping
if TYPE_CHECKING:
from .player import Player
@ -46,6 +50,7 @@ URL_REGEX = re.compile(
)
class Node:
"""The base class for a node.
This node object represents a Lavalink node.
@ -62,6 +67,8 @@ class Node:
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
region: Optional[VoiceRegion],
session: Optional[aiohttp.ClientSession],
spotify_client_id: Optional[str],
spotify_client_secret: Optional[str],
@ -73,7 +80,9 @@ class Node:
self._pool = pool
self._password = password
self._identifier = identifier
self._heartbeat = heartbeat
self._secure = secure
self._region: Optional[VoiceRegion] = region
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
@ -127,6 +136,11 @@ class Node:
"""Property which returns a dict containing the guild ID and the player object."""
return self._players
@property
def region(self) -> Optional[VoiceRegion]:
"""Property which returns the VoiceRegion of the node, if one is set"""
return self._region
@property
def bot(self) -> ClientType:
"""Property which returns the discord.py client linked to this node"""
@ -142,6 +156,11 @@ class Node:
"""Property which returns the pool this node is apart of"""
return self._pool
@property
def latency(self):
"""Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping()
async def _update_handler(self, data: dict):
await self._bot.wait_until_ready()
@ -216,7 +235,7 @@ class Node:
try:
self._websocket = await self._session.ws_connect(
self._websocket_uri, headers=self._headers, heartbeat=30
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
)
self._task = self._bot.loop.create_task(self._listen())
self._available = True
@ -414,6 +433,8 @@ class Node:
]
class NodePool:
"""The base class for the node pool.
This holds all the nodes that are to be used by the bot.
@ -433,6 +454,41 @@ class NodePool:
def node_count(self):
return len(self._nodes.values())
@classmethod
def get_best_node(cls, *, algorithm: NodeAlgorithm, voice_region: VoiceRegion = None) -> Node:
"""Fetches the best node based on an NodeAlgorithm.
This option is preferred if you want to choose the best node
from a multi-node setup using either the node's latency
or the node's voice region.
Use NodeAlgorithm.by_ping if you want to get the best node
based on the node's latency.
Use NodeAlgorithm.by_region if you want to get the best node
based on the node's voice region. This method will only work
if you set a voice region when you create a node.
"""
available_nodes = (node for node in cls._nodes.values() if node._available)
if not available_nodes:
raise NoNodesAvailable("There are no nodes available.")
if algorithm == NodeAlgorithm.by_ping:
tested_nodes = {node: node.latency for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get)
else:
if voice_region == None:
raise NodeException("You must specify a VoiceRegion in order to use this functionality.")
nodes = [node for node in available_nodes if node._region is voice_region]
if not nodes:
raise NoNodesAvailable(
f"No nodes for region {voice_region} exist in this pool."
)
return nodes[0]
@classmethod
def get_node(cls, *, identifier: str = None) -> Node:
"""Fetches a node from the node pool using it's identifier.
@ -461,6 +517,8 @@ class NodePool:
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
region: Optional[VoiceRegion] = None,
spotify_client_id: Optional[str],
spotify_client_secret: Optional[str],
session: Optional[aiohttp.ClientSession] = None,
@ -474,7 +532,8 @@ class NodePool:
node = Node(
pool=cls, bot=bot, host=host, port=port, password=password,
identifier=identifier, secure=secure, spotify_client_id=spotify_client_id,
identifier=identifier, secure=secure, heartbeat=heartbeat,
region=region, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret
)

View File

@ -1,6 +1,9 @@
import random
import time
import socket
from typing import Union
from timeit import default_timer as timer
from itertools import zip_longest
from discord import AutoShardedClient, Client
from discord.ext.commands import AutoShardedBot, Bot
@ -86,3 +89,68 @@ class NodeStats:
def __repr__(self) -> str:
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
class Ping:
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
def __init__(self, host, port, timeout=5):
self.timer = self.Timer()
self._successed = 0
self._failed = 0
self._conn_time = None
self._host = host
self._port = port
self._timeout = timeout
class Socket(object):
def __init__(self, family, type_, timeout):
s = socket.socket(family, type_)
s.settimeout(timeout)
self._s = s
def connect(self, host, port):
self._s.connect((host, int(port)))
def shutdown(self):
self._s.shutdown(socket.SHUT_RD)
def close(self):
self._s.close()
class Timer(object):
def __init__(self):
self._start = 0
self._stop = 0
def start(self):
self._start = timer()
def stop(self):
self._stop = timer()
def cost(self, funcs, args):
self.start()
for func, arg in zip_longest(funcs, args):
if arg:
func(*arg)
else:
func()
self.stop()
return self._stop - self._start
def _create_socket(self, family, type_):
return self.Socket(family, type_, self._timeout)
def get_ping(self):
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost(
(s.connect, s.shutdown),
((self._host, self._port), None))
s_runtime = 1000 * (cost_time)
return s_runtime

View File

@ -6,7 +6,7 @@ with open("README.md") as f:
setuptools.setup(
name="pomice",
author="cloudwithax",
version="1.1.4",
version="1.1.5",
url="https://github.com/cloudwithax/pomice",
packages=setuptools.find_packages(),
license="GPL",