added node algos, cleaned up events.py, add node arg to Player()

This commit is contained in:
cloudwithax 2021-11-23 13:59:20 -05:00
parent c3135b7798
commit d249d84f9f
7 changed files with 175 additions and 21 deletions

View File

@ -11,7 +11,7 @@ if discord.__version__ != "2.0.0a":
"using 'pip install git+https://github.com/Rapptz/discord.py@master'" "using 'pip install git+https://github.com/Rapptz/discord.py@master'"
) )
__version__ = "1.1.4" __version__ = "1.1.5"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"

View File

@ -1,4 +1,4 @@
from enum import Enum from enum import Enum, auto
class SearchType(Enum): class SearchType(Enum):
@ -21,3 +21,25 @@ class SearchType(Enum):
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
class NodeAlgorithm(Enum):
"""The enum for the different node algorithms in Pomice.
The enums in this class are to only differentiate different
methods, since the actual method is handled in the
get_best_node() method.
NodeAlgorithm.by_ping returns a node based on it's latency,
preferring a node with the lowest response time
NodeAlgorithm.by_region returns a node based on its voice region,
which the region is specified by the user in the method as an arg.
This method will only work if you set a voice region when you create a node.
"""
# We don't have to define anything special for these, since these just serve as flags
by_ping = auto()
by_region = auto()
def __str__(self) -> str:
return self.value

View File

@ -24,8 +24,8 @@ class TrackStartEvent(PomiceEvent):
""" """
name = "track_start" name = "track_start"
def __init__(self, data: dict): def __init__(self, data: dict, player):
self.player = NodePool.get_node().get_player(int(data["guildId"])) self.player = player
self.track = self.player._current self.track = self.player._current
# on_pomice_track_start(player, track) # on_pomice_track_start(player, track)
@ -41,8 +41,8 @@ class TrackEndEvent(PomiceEvent):
""" """
name = "track_end" name = "track_end"
def __init__(self, data: dict): def __init__(self, data: dict, player):
self.player = NodePool.get_node().get_player(int(data["guildId"])) self.player = player
self.track = self.player._ending_track self.track = self.player._ending_track
self.reason: str = data["reason"] self.reason: str = data["reason"]
@ -63,8 +63,8 @@ class TrackStuckEvent(PomiceEvent):
""" """
name = "track_stuck" name = "track_stuck"
def __init__(self, data: dict): def __init__(self, data: dict, player):
self.player = NodePool.get_node().get_player(int(data["guildId"])) self.player = player
self.track = self.player._ending_track self.track = self.player._ending_track
self.threshold: float = data["thresholdMs"] self.threshold: float = data["thresholdMs"]
@ -82,8 +82,8 @@ class TrackExceptionEvent(PomiceEvent):
""" """
name = "track_exception" name = "track_exception"
def __init__(self, data: dict): def __init__(self, data: dict, player):
self.player = NodePool.get_node().get_player(int(data["guildId"])) self.player = player
self.track = self.player._ending_track self.track = self.player._ending_track
if data.get('error'): if data.get('error'):
# User is running Lavalink <= 3.3 # User is running Lavalink <= 3.3
@ -117,7 +117,7 @@ class WebSocketClosedEvent(PomiceEvent):
""" """
name = "websocket_closed" name = "websocket_closed"
def __init__(self, data: dict): def __init__(self, data: dict, _):
self.payload = WebSocketClosedPayload(data) self.payload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload) # on_pomice_websocket_closed(payload)
@ -133,7 +133,7 @@ class WebSocketOpenEvent(PomiceEvent):
""" """
name = "websocket_open" name = "websocket_open"
def __init__(self, data: dict): def __init__(self, data: dict, _):
self.target: str = data["target"] self.target: str = data["target"]
self.ssrc: int = data["ssrc"] self.ssrc: int = data["ssrc"]

View File

@ -36,13 +36,19 @@ class Player(VoiceProtocol):
return self return self
def __init__(self, client: ClientType = None, channel: VoiceChannel = None, **kwargs): def __init__(
self,
client: ClientType = None,
channel: VoiceChannel = None,
*,
node: Node = None
):
self.client = client self.client = client
self._bot = client self._bot = client
self.channel = channel self.channel = channel
self._guild: Guild = self.channel.guild self._guild: Guild = self.channel.guild
self._node = NodePool.get_node() self._node = node if node else NodePool.get_node()
self._current: Track = None self._current: Track = None
self._filter: Filter = None self._filter: Filter = None
self._volume = 100 self._volume = 100
@ -55,7 +61,6 @@ class Player(VoiceProtocol):
self._ending_track: Optional[Track] = None self._ending_track: Optional[Track] = None
self._voice_state = {} self._voice_state = {}
self._extra = kwargs or {}
def __repr__(self): def __repr__(self):
return ( return (
@ -171,7 +176,7 @@ class Player(VoiceProtocol):
async def _dispatch_event(self, data: dict): async def _dispatch_event(self, data: dict):
event_type = data.get("type") event_type = data.get("type")
event: PomiceEvent = getattr(events, event_type)(data) event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED": if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
self._current = None self._current = None

View File

@ -4,12 +4,15 @@ import asyncio
import json import json
import random import random
import re import re
import time
import socket import socket
from typing import Dict, Optional, TYPE_CHECKING from typing import Dict, Optional, TYPE_CHECKING
from urllib.parse import quote from urllib.parse import quote
from enum import Enum
import aiohttp import aiohttp
from discord.ext import commands from discord.ext import commands
from discord import VoiceRegion
from . import ( from . import (
@ -17,17 +20,18 @@ from . import (
spotify, spotify,
) )
from .enums import SearchType from .enums import SearchType, NodeAlgorithm
from .exceptions import ( from .exceptions import (
InvalidSpotifyClientAuthorization, InvalidSpotifyClientAuthorization,
NodeConnectionFailure, NodeConnectionFailure,
NodeCreationError, NodeCreationError,
NodeException,
NodeNotAvailable, NodeNotAvailable,
NoNodesAvailable, NoNodesAvailable,
TrackLoadError TrackLoadError
) )
from .objects import Playlist, Track from .objects import Playlist, Track
from .utils import ClientType, ExponentialBackoff, NodeStats from .utils import ClientType, ExponentialBackoff, NodeStats, Ping
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
@ -46,6 +50,7 @@ URL_REGEX = re.compile(
) )
class Node: class Node:
"""The base class for a node. """The base class for a node.
This node object represents a Lavalink node. This node object represents a Lavalink node.
@ -62,6 +67,8 @@ class Node:
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30,
region: Optional[VoiceRegion],
session: Optional[aiohttp.ClientSession], session: Optional[aiohttp.ClientSession],
spotify_client_id: Optional[str], spotify_client_id: Optional[str],
spotify_client_secret: Optional[str], spotify_client_secret: Optional[str],
@ -73,7 +80,9 @@ class Node:
self._pool = pool self._pool = pool
self._password = password self._password = password
self._identifier = identifier self._identifier = identifier
self._heartbeat = heartbeat
self._secure = secure self._secure = secure
self._region: Optional[VoiceRegion] = region
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
@ -127,6 +136,11 @@ class Node:
"""Property which returns a dict containing the guild ID and the player object.""" """Property which returns a dict containing the guild ID and the player object."""
return self._players return self._players
@property
def region(self) -> Optional[VoiceRegion]:
"""Property which returns the VoiceRegion of the node, if one is set"""
return self._region
@property @property
def bot(self) -> ClientType: def bot(self) -> ClientType:
"""Property which returns the discord.py client linked to this node""" """Property which returns the discord.py client linked to this node"""
@ -142,6 +156,11 @@ class Node:
"""Property which returns the pool this node is apart of""" """Property which returns the pool this node is apart of"""
return self._pool return self._pool
@property
def latency(self):
"""Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping()
async def _update_handler(self, data: dict): async def _update_handler(self, data: dict):
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
@ -216,7 +235,7 @@ class Node:
try: try:
self._websocket = await self._session.ws_connect( self._websocket = await self._session.ws_connect(
self._websocket_uri, headers=self._headers, heartbeat=30 self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
) )
self._task = self._bot.loop.create_task(self._listen()) self._task = self._bot.loop.create_task(self._listen())
self._available = True self._available = True
@ -414,6 +433,8 @@ class Node:
] ]
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.
This holds all the nodes that are to be used by the bot. This holds all the nodes that are to be used by the bot.
@ -433,6 +454,41 @@ class NodePool:
def node_count(self): def node_count(self):
return len(self._nodes.values()) return len(self._nodes.values())
@classmethod
def get_best_node(cls, *, algorithm: NodeAlgorithm, voice_region: VoiceRegion = None) -> Node:
"""Fetches the best node based on an NodeAlgorithm.
This option is preferred if you want to choose the best node
from a multi-node setup using either the node's latency
or the node's voice region.
Use NodeAlgorithm.by_ping if you want to get the best node
based on the node's latency.
Use NodeAlgorithm.by_region if you want to get the best node
based on the node's voice region. This method will only work
if you set a voice region when you create a node.
"""
available_nodes = (node for node in cls._nodes.values() if node._available)
if not available_nodes:
raise NoNodesAvailable("There are no nodes available.")
if algorithm == NodeAlgorithm.by_ping:
tested_nodes = {node: node.latency for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get)
else:
if voice_region == None:
raise NodeException("You must specify a VoiceRegion in order to use this functionality.")
nodes = [node for node in available_nodes if node._region is voice_region]
if not nodes:
raise NoNodesAvailable(
f"No nodes for region {voice_region} exist in this pool."
)
return nodes[0]
@classmethod @classmethod
def get_node(cls, *, identifier: str = None) -> Node: def get_node(cls, *, identifier: str = None) -> Node:
"""Fetches a node from the node pool using it's identifier. """Fetches a node from the node pool using it's identifier.
@ -461,6 +517,8 @@ class NodePool:
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30,
region: Optional[VoiceRegion] = None,
spotify_client_id: Optional[str], spotify_client_id: Optional[str],
spotify_client_secret: Optional[str], spotify_client_secret: Optional[str],
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
@ -474,7 +532,8 @@ class NodePool:
node = Node( node = Node(
pool=cls, bot=bot, host=host, port=port, password=password, pool=cls, bot=bot, host=host, port=port, password=password,
identifier=identifier, secure=secure, spotify_client_id=spotify_client_id, identifier=identifier, secure=secure, heartbeat=heartbeat,
region=region, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret session=session, spotify_client_secret=spotify_client_secret
) )

View File

@ -1,6 +1,9 @@
import random import random
import time import time
import socket
from typing import Union from typing import Union
from timeit import default_timer as timer
from itertools import zip_longest
from discord import AutoShardedClient, Client from discord import AutoShardedClient, Client
from discord.ext.commands import AutoShardedBot, Bot from discord.ext.commands import AutoShardedBot, Bot
@ -86,3 +89,68 @@ class NodeStats:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>" return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
class Ping:
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
def __init__(self, host, port, timeout=5):
self.timer = self.Timer()
self._successed = 0
self._failed = 0
self._conn_time = None
self._host = host
self._port = port
self._timeout = timeout
class Socket(object):
def __init__(self, family, type_, timeout):
s = socket.socket(family, type_)
s.settimeout(timeout)
self._s = s
def connect(self, host, port):
self._s.connect((host, int(port)))
def shutdown(self):
self._s.shutdown(socket.SHUT_RD)
def close(self):
self._s.close()
class Timer(object):
def __init__(self):
self._start = 0
self._stop = 0
def start(self):
self._start = timer()
def stop(self):
self._stop = timer()
def cost(self, funcs, args):
self.start()
for func, arg in zip_longest(funcs, args):
if arg:
func(*arg)
else:
func()
self.stop()
return self._stop - self._start
def _create_socket(self, family, type_):
return self.Socket(family, type_, self._timeout)
def get_ping(self):
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost(
(s.connect, s.shutdown),
((self._host, self._port), None))
s_runtime = 1000 * (cost_time)
return s_runtime

View File

@ -6,7 +6,7 @@ with open("README.md") as f:
setuptools.setup( setuptools.setup(
name="pomice", name="pomice",
author="cloudwithax", author="cloudwithax",
version="1.1.4", version="1.1.5",
url="https://github.com/cloudwithax/pomice", url="https://github.com/cloudwithax/pomice",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
license="GPL", license="GPL",