added node algos, cleaned up events.py, add node arg to Player()
This commit is contained in:
parent
c3135b7798
commit
d249d84f9f
|
|
@ -11,7 +11,7 @@ if discord.__version__ != "2.0.0a":
|
||||||
"using 'pip install git+https://github.com/Rapptz/discord.py@master'"
|
"using 'pip install git+https://github.com/Rapptz/discord.py@master'"
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "1.1.4"
|
__version__ = "1.1.5"
|
||||||
__title__ = "pomice"
|
__title__ = "pomice"
|
||||||
__author__ = "cloudwithax"
|
__author__ = "cloudwithax"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from enum import Enum
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
|
||||||
class SearchType(Enum):
|
class SearchType(Enum):
|
||||||
|
|
@ -21,3 +21,25 @@ class SearchType(Enum):
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
class NodeAlgorithm(Enum):
|
||||||
|
"""The enum for the different node algorithms in Pomice.
|
||||||
|
|
||||||
|
The enums in this class are to only differentiate different
|
||||||
|
methods, since the actual method is handled in the
|
||||||
|
get_best_node() method.
|
||||||
|
|
||||||
|
NodeAlgorithm.by_ping returns a node based on it's latency,
|
||||||
|
preferring a node with the lowest response time
|
||||||
|
|
||||||
|
NodeAlgorithm.by_region returns a node based on its voice region,
|
||||||
|
which the region is specified by the user in the method as an arg.
|
||||||
|
This method will only work if you set a voice region when you create a node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
|
by_ping = auto()
|
||||||
|
by_region = auto()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
@ -24,8 +24,8 @@ class TrackStartEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "track_start"
|
name = "track_start"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, player):
|
||||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
self.player = player
|
||||||
self.track = self.player._current
|
self.track = self.player._current
|
||||||
|
|
||||||
# on_pomice_track_start(player, track)
|
# on_pomice_track_start(player, track)
|
||||||
|
|
@ -41,8 +41,8 @@ class TrackEndEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "track_end"
|
name = "track_end"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, player):
|
||||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
self.player = player
|
||||||
self.track = self.player._ending_track
|
self.track = self.player._ending_track
|
||||||
self.reason: str = data["reason"]
|
self.reason: str = data["reason"]
|
||||||
|
|
||||||
|
|
@ -63,8 +63,8 @@ class TrackStuckEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "track_stuck"
|
name = "track_stuck"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, player):
|
||||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
self.player = player
|
||||||
self.track = self.player._ending_track
|
self.track = self.player._ending_track
|
||||||
self.threshold: float = data["thresholdMs"]
|
self.threshold: float = data["thresholdMs"]
|
||||||
|
|
||||||
|
|
@ -82,8 +82,8 @@ class TrackExceptionEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "track_exception"
|
name = "track_exception"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, player):
|
||||||
self.player = NodePool.get_node().get_player(int(data["guildId"]))
|
self.player = player
|
||||||
self.track = self.player._ending_track
|
self.track = self.player._ending_track
|
||||||
if data.get('error'):
|
if data.get('error'):
|
||||||
# User is running Lavalink <= 3.3
|
# User is running Lavalink <= 3.3
|
||||||
|
|
@ -117,7 +117,7 @@ class WebSocketClosedEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "websocket_closed"
|
name = "websocket_closed"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, _):
|
||||||
self.payload = WebSocketClosedPayload(data)
|
self.payload = WebSocketClosedPayload(data)
|
||||||
|
|
||||||
# on_pomice_websocket_closed(payload)
|
# on_pomice_websocket_closed(payload)
|
||||||
|
|
@ -133,7 +133,7 @@ class WebSocketOpenEvent(PomiceEvent):
|
||||||
"""
|
"""
|
||||||
name = "websocket_open"
|
name = "websocket_open"
|
||||||
|
|
||||||
def __init__(self, data: dict):
|
def __init__(self, data: dict, _):
|
||||||
self.target: str = data["target"]
|
self.target: str = data["target"]
|
||||||
self.ssrc: int = data["ssrc"]
|
self.ssrc: int = data["ssrc"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,13 +36,19 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __init__(self, client: ClientType = None, channel: VoiceChannel = None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: ClientType = None,
|
||||||
|
channel: VoiceChannel = None,
|
||||||
|
*,
|
||||||
|
node: Node = None
|
||||||
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self._bot = client
|
self._bot = client
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self._guild: Guild = self.channel.guild
|
self._guild: Guild = self.channel.guild
|
||||||
|
|
||||||
self._node = NodePool.get_node()
|
self._node = node if node else NodePool.get_node()
|
||||||
self._current: Track = None
|
self._current: Track = None
|
||||||
self._filter: Filter = None
|
self._filter: Filter = None
|
||||||
self._volume = 100
|
self._volume = 100
|
||||||
|
|
@ -55,7 +61,6 @@ class Player(VoiceProtocol):
|
||||||
self._ending_track: Optional[Track] = None
|
self._ending_track: Optional[Track] = None
|
||||||
|
|
||||||
self._voice_state = {}
|
self._voice_state = {}
|
||||||
self._extra = kwargs or {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
|
|
@ -171,7 +176,7 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
async def _dispatch_event(self, data: dict):
|
async def _dispatch_event(self, data: dict):
|
||||||
event_type = data.get("type")
|
event_type = data.get("type")
|
||||||
event: PomiceEvent = getattr(events, event_type)(data)
|
event: PomiceEvent = getattr(events, event_type)(data, self)
|
||||||
|
|
||||||
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
|
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
|
||||||
self._current = None
|
self._current = None
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,15 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
import socket
|
import socket
|
||||||
from typing import Dict, Optional, TYPE_CHECKING
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from discord import VoiceRegion
|
||||||
|
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
|
|
@ -17,17 +20,18 @@ from . import (
|
||||||
spotify,
|
spotify,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .enums import SearchType
|
from .enums import SearchType, NodeAlgorithm
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
InvalidSpotifyClientAuthorization,
|
InvalidSpotifyClientAuthorization,
|
||||||
NodeConnectionFailure,
|
NodeConnectionFailure,
|
||||||
NodeCreationError,
|
NodeCreationError,
|
||||||
|
NodeException,
|
||||||
NodeNotAvailable,
|
NodeNotAvailable,
|
||||||
NoNodesAvailable,
|
NoNodesAvailable,
|
||||||
TrackLoadError
|
TrackLoadError
|
||||||
)
|
)
|
||||||
from .objects import Playlist, Track
|
from .objects import Playlist, Track
|
||||||
from .utils import ClientType, ExponentialBackoff, NodeStats
|
from .utils import ClientType, ExponentialBackoff, NodeStats, Ping
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .player import Player
|
from .player import Player
|
||||||
|
|
@ -46,6 +50,7 @@ URL_REGEX = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
"""The base class for a node.
|
"""The base class for a node.
|
||||||
This node object represents a Lavalink node.
|
This node object represents a Lavalink node.
|
||||||
|
|
@ -62,6 +67,8 @@ class Node:
|
||||||
password: str,
|
password: str,
|
||||||
identifier: str,
|
identifier: str,
|
||||||
secure: bool = False,
|
secure: bool = False,
|
||||||
|
heartbeat: int = 30,
|
||||||
|
region: Optional[VoiceRegion],
|
||||||
session: Optional[aiohttp.ClientSession],
|
session: Optional[aiohttp.ClientSession],
|
||||||
spotify_client_id: Optional[str],
|
spotify_client_id: Optional[str],
|
||||||
spotify_client_secret: Optional[str],
|
spotify_client_secret: Optional[str],
|
||||||
|
|
@ -73,7 +80,9 @@ class Node:
|
||||||
self._pool = pool
|
self._pool = pool
|
||||||
self._password = password
|
self._password = password
|
||||||
self._identifier = identifier
|
self._identifier = identifier
|
||||||
|
self._heartbeat = heartbeat
|
||||||
self._secure = secure
|
self._secure = secure
|
||||||
|
self._region: Optional[VoiceRegion] = region
|
||||||
|
|
||||||
|
|
||||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||||
|
|
@ -127,6 +136,11 @@ class Node:
|
||||||
"""Property which returns a dict containing the guild ID and the player object."""
|
"""Property which returns a dict containing the guild ID and the player object."""
|
||||||
return self._players
|
return self._players
|
||||||
|
|
||||||
|
@property
|
||||||
|
def region(self) -> Optional[VoiceRegion]:
|
||||||
|
"""Property which returns the VoiceRegion of the node, if one is set"""
|
||||||
|
return self._region
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> ClientType:
|
def bot(self) -> ClientType:
|
||||||
"""Property which returns the discord.py client linked to this node"""
|
"""Property which returns the discord.py client linked to this node"""
|
||||||
|
|
@ -142,6 +156,11 @@ class Node:
|
||||||
"""Property which returns the pool this node is apart of"""
|
"""Property which returns the pool this node is apart of"""
|
||||||
return self._pool
|
return self._pool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latency(self):
|
||||||
|
"""Property which returns the latency of the node"""
|
||||||
|
return Ping(self._host, port=self._port).get_ping()
|
||||||
|
|
||||||
async def _update_handler(self, data: dict):
|
async def _update_handler(self, data: dict):
|
||||||
await self._bot.wait_until_ready()
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
|
|
@ -216,7 +235,7 @@ class Node:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._websocket = await self._session.ws_connect(
|
self._websocket = await self._session.ws_connect(
|
||||||
self._websocket_uri, headers=self._headers, heartbeat=30
|
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
|
||||||
)
|
)
|
||||||
self._task = self._bot.loop.create_task(self._listen())
|
self._task = self._bot.loop.create_task(self._listen())
|
||||||
self._available = True
|
self._available = True
|
||||||
|
|
@ -414,6 +433,8 @@ class Node:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NodePool:
|
class NodePool:
|
||||||
"""The base class for the node pool.
|
"""The base class for the node pool.
|
||||||
This holds all the nodes that are to be used by the bot.
|
This holds all the nodes that are to be used by the bot.
|
||||||
|
|
@ -433,6 +454,41 @@ class NodePool:
|
||||||
def node_count(self):
|
def node_count(self):
|
||||||
return len(self._nodes.values())
|
return len(self._nodes.values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_best_node(cls, *, algorithm: NodeAlgorithm, voice_region: VoiceRegion = None) -> Node:
|
||||||
|
"""Fetches the best node based on an NodeAlgorithm.
|
||||||
|
This option is preferred if you want to choose the best node
|
||||||
|
from a multi-node setup using either the node's latency
|
||||||
|
or the node's voice region.
|
||||||
|
|
||||||
|
Use NodeAlgorithm.by_ping if you want to get the best node
|
||||||
|
based on the node's latency.
|
||||||
|
|
||||||
|
Use NodeAlgorithm.by_region if you want to get the best node
|
||||||
|
based on the node's voice region. This method will only work
|
||||||
|
if you set a voice region when you create a node.
|
||||||
|
"""
|
||||||
|
available_nodes = (node for node in cls._nodes.values() if node._available)
|
||||||
|
|
||||||
|
if not available_nodes:
|
||||||
|
raise NoNodesAvailable("There are no nodes available.")
|
||||||
|
|
||||||
|
if algorithm == NodeAlgorithm.by_ping:
|
||||||
|
tested_nodes = {node: node.latency for node in available_nodes}
|
||||||
|
return min(tested_nodes, key=tested_nodes.get)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if voice_region == None:
|
||||||
|
raise NodeException("You must specify a VoiceRegion in order to use this functionality.")
|
||||||
|
|
||||||
|
nodes = [node for node in available_nodes if node._region is voice_region]
|
||||||
|
if not nodes:
|
||||||
|
raise NoNodesAvailable(
|
||||||
|
f"No nodes for region {voice_region} exist in this pool."
|
||||||
|
)
|
||||||
|
|
||||||
|
return nodes[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_node(cls, *, identifier: str = None) -> Node:
|
def get_node(cls, *, identifier: str = None) -> Node:
|
||||||
"""Fetches a node from the node pool using it's identifier.
|
"""Fetches a node from the node pool using it's identifier.
|
||||||
|
|
@ -461,6 +517,8 @@ class NodePool:
|
||||||
password: str,
|
password: str,
|
||||||
identifier: str,
|
identifier: str,
|
||||||
secure: bool = False,
|
secure: bool = False,
|
||||||
|
heartbeat: int = 30,
|
||||||
|
region: Optional[VoiceRegion] = None,
|
||||||
spotify_client_id: Optional[str],
|
spotify_client_id: Optional[str],
|
||||||
spotify_client_secret: Optional[str],
|
spotify_client_secret: Optional[str],
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
|
@ -474,7 +532,8 @@ class NodePool:
|
||||||
|
|
||||||
node = Node(
|
node = Node(
|
||||||
pool=cls, bot=bot, host=host, port=port, password=password,
|
pool=cls, bot=bot, host=host, port=port, password=password,
|
||||||
identifier=identifier, secure=secure, spotify_client_id=spotify_client_id,
|
identifier=identifier, secure=secure, heartbeat=heartbeat,
|
||||||
|
region=region, spotify_client_id=spotify_client_id,
|
||||||
session=session, spotify_client_secret=spotify_client_secret
|
session=session, spotify_client_secret=spotify_client_secret
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import socket
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
from itertools import zip_longest
|
||||||
|
|
||||||
from discord import AutoShardedClient, Client
|
from discord import AutoShardedClient, Client
|
||||||
from discord.ext.commands import AutoShardedBot, Bot
|
from discord.ext.commands import AutoShardedBot, Bot
|
||||||
|
|
@ -86,3 +89,68 @@ class NodeStats:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
|
return f"<Pomice.NodeStats total_players={self.players_total!r} playing_active={self.players_active!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
class Ping:
|
||||||
|
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
|
||||||
|
def __init__(self, host, port, timeout=5):
|
||||||
|
self.timer = self.Timer()
|
||||||
|
|
||||||
|
self._successed = 0
|
||||||
|
self._failed = 0
|
||||||
|
self._conn_time = None
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
class Socket(object):
|
||||||
|
def __init__(self, family, type_, timeout):
|
||||||
|
s = socket.socket(family, type_)
|
||||||
|
s.settimeout(timeout)
|
||||||
|
self._s = s
|
||||||
|
|
||||||
|
def connect(self, host, port):
|
||||||
|
self._s.connect((host, int(port)))
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self._s.shutdown(socket.SHUT_RD)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._s.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Timer(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._start = 0
|
||||||
|
self._stop = 0
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._start = timer()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._stop = timer()
|
||||||
|
|
||||||
|
def cost(self, funcs, args):
|
||||||
|
self.start()
|
||||||
|
for func, arg in zip_longest(funcs, args):
|
||||||
|
if arg:
|
||||||
|
func(*arg)
|
||||||
|
else:
|
||||||
|
func()
|
||||||
|
|
||||||
|
self.stop()
|
||||||
|
return self._stop - self._start
|
||||||
|
|
||||||
|
def _create_socket(self, family, type_):
|
||||||
|
return self.Socket(family, type_, self._timeout)
|
||||||
|
|
||||||
|
def get_ping(self):
|
||||||
|
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
|
||||||
|
cost_time = self.timer.cost(
|
||||||
|
(s.connect, s.shutdown),
|
||||||
|
((self._host, self._port), None))
|
||||||
|
s_runtime = 1000 * (cost_time)
|
||||||
|
|
||||||
|
return s_runtime
|
||||||
|
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -6,7 +6,7 @@ with open("README.md") as f:
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="pomice",
|
name="pomice",
|
||||||
author="cloudwithax",
|
author="cloudwithax",
|
||||||
version="1.1.4",
|
version="1.1.5",
|
||||||
url="https://github.com/cloudwithax/pomice",
|
url="https://github.com/cloudwithax/pomice",
|
||||||
packages=setuptools.find_packages(),
|
packages=setuptools.find_packages(),
|
||||||
license="GPL",
|
license="GPL",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue