1067 lines
37 KiB
Python
1067 lines
37 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
import re
|
|
import time
|
|
from os import path
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Type
|
|
from typing import TYPE_CHECKING
|
|
from typing import Union
|
|
from urllib.parse import quote
|
|
|
|
import aiohttp
|
|
import orjson as json
|
|
from discord import Client
|
|
from discord.ext import commands
|
|
from discord.utils import MISSING
|
|
from websockets import client
|
|
from websockets import exceptions
|
|
from websockets import typing as wstype
|
|
|
|
from . import __version__
|
|
from . import applemusic
|
|
from . import spotify
|
|
from .enums import *
|
|
from .enums import LogLevel
|
|
from .exceptions import InvalidSpotifyClientAuthorization
|
|
from .exceptions import LavalinkVersionIncompatible
|
|
from .exceptions import NodeConnectionFailure
|
|
from .exceptions import NodeCreationError
|
|
from .exceptions import NodeNotAvailable
|
|
from .exceptions import NodeRestException
|
|
from .exceptions import NoNodesAvailable
|
|
from .exceptions import TrackLoadError
|
|
from .filters import Filter
|
|
from .objects import Playlist
|
|
from .objects import Track
|
|
from .routeplanner import RoutePlanner
|
|
from .utils import ExponentialBackoff
|
|
from .utils import LavalinkVersion
|
|
from .utils import NodeStats
|
|
from .utils import Ping
|
|
|
|
if TYPE_CHECKING:
|
|
from .player import Player
|
|
|
|
__all__ = (
|
|
"Node",
|
|
"NodePool",
|
|
)
|
|
|
|
VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?")
|
|
|
|
|
|
class Node:
|
|
"""The base class for a node.
|
|
This node object represents a Lavalink node.
|
|
To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret
|
|
To enable Apple music, set the "apple_music" parameter to "True"
|
|
"""
|
|
|
|
__slots__ = (
|
|
"_bot",
|
|
"_bot_user",
|
|
"_host",
|
|
"_port",
|
|
"_pool",
|
|
"_password",
|
|
"_identifier",
|
|
"_heartbeat",
|
|
"_resume_key",
|
|
"_resume_timeout",
|
|
"_secure",
|
|
"_fallback",
|
|
"_log_level",
|
|
"_websocket_uri",
|
|
"_rest_uri",
|
|
"_session",
|
|
"_websocket",
|
|
"_task",
|
|
"_loop",
|
|
"_session_id",
|
|
"_available",
|
|
"_version",
|
|
"_headers",
|
|
"_players",
|
|
"_spotify_client_id",
|
|
"_spotify_client_secret",
|
|
"_spotify_client",
|
|
"_apple_music_client",
|
|
"_route_planner",
|
|
"_log",
|
|
"_stats",
|
|
"available",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
pool: Type[NodePool],
|
|
bot: commands.Bot,
|
|
host: str,
|
|
port: int,
|
|
password: str,
|
|
identifier: str,
|
|
secure: bool = False,
|
|
heartbeat: int = 120,
|
|
resume_key: Optional[str] = None,
|
|
resume_timeout: int = 60,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
session: Optional[aiohttp.ClientSession] = None,
|
|
spotify_client_id: Optional[str] = None,
|
|
spotify_client_secret: Optional[str] = None,
|
|
apple_music: bool = False,
|
|
fallback: bool = False,
|
|
logger: Optional[logging.Logger] = None,
|
|
):
|
|
if not isinstance(port, int):
|
|
raise TypeError("Port must be an integer")
|
|
|
|
self._bot: commands.Bot = bot
|
|
self._host: str = host
|
|
self._port: int = port
|
|
self._pool: Type[NodePool] = pool
|
|
self._password: str = password
|
|
self._identifier: str = identifier
|
|
self._heartbeat: int = heartbeat
|
|
self._resume_key: Optional[str] = resume_key
|
|
self._resume_timeout: int = resume_timeout
|
|
self._secure: bool = secure
|
|
self._fallback: bool = fallback
|
|
|
|
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
|
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
|
|
|
self._session: aiohttp.ClientSession = session # type: ignore
|
|
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
|
self._websocket: client.WebSocketClientProtocol
|
|
self._task: asyncio.Task = None # type: ignore
|
|
|
|
self._session_id: Optional[str] = None
|
|
self._available: bool = False
|
|
self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
|
|
|
|
self._route_planner = RoutePlanner(self)
|
|
self._log = logger
|
|
|
|
if not self._bot.user:
|
|
raise NodeCreationError("Bot user is not ready yet.")
|
|
|
|
self._bot_user = self._bot.user
|
|
|
|
self._headers = {
|
|
"Authorization": self._password,
|
|
"User-Id": str(self._bot_user.id),
|
|
"Client-Name": f"Pomice/{__version__}",
|
|
}
|
|
|
|
self._players: Dict[int, Player] = {}
|
|
|
|
self._spotify_client: Optional[spotify.Client] = None
|
|
self._apple_music_client: Optional[applemusic.Client] = None
|
|
|
|
self._spotify_client_id: Optional[str] = spotify_client_id
|
|
self._spotify_client_secret: Optional[str] = spotify_client_secret
|
|
|
|
if self._spotify_client_id and self._spotify_client_secret:
|
|
self._spotify_client = spotify.Client(
|
|
self._spotify_client_id,
|
|
self._spotify_client_secret,
|
|
)
|
|
|
|
if apple_music:
|
|
self._apple_music_client = applemusic.Client()
|
|
|
|
self._bot.add_listener(self._update_handler, "on_socket_response")
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
|
|
f"player_count={len(self._players)}>"
|
|
)
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Property which returns whether this node is connected or not"""
|
|
return self._websocket is not None and not self._websocket.closed
|
|
|
|
@property
|
|
def stats(self) -> NodeStats:
|
|
"""Property which returns the node stats."""
|
|
return self._stats
|
|
|
|
@property
|
|
def players(self) -> Dict[int, Player]:
|
|
"""Property which returns a dict containing the guild ID and the player object."""
|
|
return self._players
|
|
|
|
@property
|
|
def bot(self) -> Client:
|
|
"""Property which returns the discord.py client linked to this node"""
|
|
return self._bot
|
|
|
|
@property
|
|
def player_count(self) -> int:
|
|
"""Property which returns how many players are connected to this node"""
|
|
return len(self.players.values())
|
|
|
|
@property
|
|
def pool(self) -> Type[NodePool]:
|
|
"""Property which returns the pool this node is apart of"""
|
|
return self._pool
|
|
|
|
@property
|
|
def latency(self) -> float:
|
|
"""Property which returns the latency of the node"""
|
|
return Ping(self._host, port=self._port).get_ping()
|
|
|
|
@property
|
|
def ping(self) -> float:
|
|
"""Alias for `Node.latency`, returns the latency of the node"""
|
|
return self.latency
|
|
|
|
async def _handle_version_check(self, version: str) -> None:
|
|
if version.endswith("-SNAPSHOT"):
|
|
# we're just gonna assume all snapshot versions correlate with v4
|
|
self._version = LavalinkVersion(major=4, minor=0, fix=0)
|
|
return
|
|
|
|
_version_rx = VERSION_REGEX.match(version)
|
|
if not _version_rx:
|
|
self._available = False
|
|
raise LavalinkVersionIncompatible(
|
|
"The Lavalink version you're using is incompatible. "
|
|
"Lavalink version 3.7.0 or above is required to use this library.",
|
|
)
|
|
|
|
_version_groups = _version_rx.groups()
|
|
major, minor, fix = (
|
|
int(_version_groups[0] or 0),
|
|
int(_version_groups[1] or 0),
|
|
int(_version_groups[2] or 0),
|
|
)
|
|
|
|
if self._log:
|
|
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
|
|
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
|
|
if self._version < LavalinkVersion(3, 7, 0):
|
|
self._available = False
|
|
raise LavalinkVersionIncompatible(
|
|
"The Lavalink version you're using is incompatible. "
|
|
"Lavalink version 3.7.0 or above is required to use this library.",
|
|
)
|
|
|
|
async def _set_ext_client_session(self, session: aiohttp.ClientSession) -> None:
|
|
if self._spotify_client:
|
|
await self._spotify_client._set_session(session=session)
|
|
|
|
if self._apple_music_client:
|
|
await self._apple_music_client._set_session(session=session)
|
|
|
|
async def _update_handler(self, data: dict) -> None:
|
|
await self._bot.wait_until_ready()
|
|
|
|
if not data:
|
|
return
|
|
|
|
if data["t"] == "VOICE_SERVER_UPDATE":
|
|
guild_id = int(data["d"]["guild_id"])
|
|
try:
|
|
player = self._players[guild_id]
|
|
await player.on_voice_server_update(data["d"])
|
|
except KeyError:
|
|
return
|
|
|
|
elif data["t"] == "VOICE_STATE_UPDATE":
|
|
if int(data["d"]["user_id"]) != self._bot_user.id:
|
|
return
|
|
|
|
guild_id = int(data["d"]["guild_id"])
|
|
try:
|
|
player = self._players[guild_id]
|
|
await player.on_voice_state_update(data["d"])
|
|
except KeyError:
|
|
return
|
|
|
|
async def _handle_node_switch(self) -> None:
|
|
nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected]
|
|
new_node = random.choice(nodes)
|
|
|
|
for player in self.players.copy().values():
|
|
await player._swap_node(new_node=new_node)
|
|
|
|
await self.disconnect()
|
|
|
|
async def _configure_resuming(self) -> None:
|
|
if not self._resume_key:
|
|
return
|
|
|
|
data = {"timeout": self._resume_timeout}
|
|
|
|
if self._version.major == 3:
|
|
data["resumingKey"] = self._resume_key
|
|
elif self._version.major == 4:
|
|
if self._log:
|
|
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
|
|
data["resuming"] = True
|
|
|
|
await self.send(
|
|
method="PATCH",
|
|
path=f"sessions/{self._session_id}",
|
|
include_version=True,
|
|
data=data,
|
|
)
|
|
|
|
async def _listen(self) -> None:
|
|
while True:
|
|
try:
|
|
msg = await self._websocket.recv()
|
|
data = json.loads(msg)
|
|
if self._log:
|
|
self._log.debug(f"Recieved raw websocket message {msg}")
|
|
self._loop.create_task(self._handle_ws_msg(data=data))
|
|
except exceptions.ConnectionClosed:
|
|
if self.player_count > 0:
|
|
for _player in self.players.values():
|
|
self._loop.create_task(_player.destroy())
|
|
|
|
if self._fallback:
|
|
self._loop.create_task(self._handle_node_switch())
|
|
|
|
self._loop.create_task(self._websocket.close())
|
|
|
|
backoff = ExponentialBackoff(base=7)
|
|
retry = backoff.delay()
|
|
if self._log:
|
|
self._log.debug(
|
|
f"Retrying connection to Node {self._identifier} in {retry} secs",
|
|
)
|
|
await asyncio.sleep(retry)
|
|
|
|
if not self.is_connected:
|
|
self._loop.create_task(self.connect(reconnect=True))
|
|
|
|
async def _handle_ws_msg(self, data: dict) -> None:
|
|
if self._log:
|
|
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
|
|
op = data.get("op", None)
|
|
|
|
if op == "stats":
|
|
self._stats = NodeStats(data)
|
|
return
|
|
|
|
if op == "ready":
|
|
self._session_id = data["sessionId"]
|
|
await self._configure_resuming()
|
|
|
|
if not "guildId" in data:
|
|
return
|
|
|
|
player: Optional[Player] = self._players.get(int(data["guildId"]))
|
|
if not player:
|
|
return
|
|
|
|
if op == "event":
|
|
return await player._dispatch_event(data)
|
|
|
|
if op == "playerUpdate":
|
|
return await player._update_state(data)
|
|
|
|
async def send(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
include_version: bool = True,
|
|
guild_id: Optional[Union[int, str]] = None,
|
|
query: Optional[str] = None,
|
|
data: Optional[Union[Dict, str]] = None,
|
|
ignore_if_available: bool = False,
|
|
) -> Any:
|
|
if not ignore_if_available and not self._available:
|
|
raise NodeNotAvailable(
|
|
f"The node '{self._identifier}' is unavailable.",
|
|
)
|
|
|
|
uri: str = (
|
|
f"{self._rest_uri}/"
|
|
f'{f"v{self._version.major}/" if include_version else ""}'
|
|
f"{path}"
|
|
f'{f"/{guild_id}" if guild_id else ""}'
|
|
f'{f"?{query}" if query else ""}'
|
|
)
|
|
|
|
resp = await self._session.request(
|
|
method=method,
|
|
url=uri,
|
|
headers=self._headers,
|
|
json=data or {},
|
|
)
|
|
if self._log:
|
|
self._log.debug(
|
|
f"Making REST request to Node {self._identifier} with method {method} to {uri}",
|
|
)
|
|
if resp.status >= 300:
|
|
resp_data: dict = await resp.json()
|
|
raise NodeRestException(
|
|
f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
|
|
)
|
|
|
|
if method == "DELETE" or resp.status == 204:
|
|
if self._log:
|
|
self._log.debug(
|
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
|
|
)
|
|
return await resp.json(content_type=None)
|
|
|
|
if resp.content_type == "text/plain":
|
|
if self._log:
|
|
self._log.debug(
|
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
|
|
)
|
|
return await resp.text()
|
|
|
|
if self._log:
|
|
self._log.debug(
|
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
|
|
)
|
|
return await resp.json()
|
|
|
|
def get_player(self, guild_id: int) -> Optional[Player]:
|
|
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
|
|
return self._players.get(guild_id, None)
|
|
|
|
async def connect(self, *, reconnect: bool = False) -> Node:
|
|
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
|
|
await self._bot.wait_until_ready()
|
|
|
|
start = time.perf_counter()
|
|
|
|
if not self._session:
|
|
self._session = aiohttp.ClientSession()
|
|
|
|
try:
|
|
if not reconnect:
|
|
version: str = await self.send(
|
|
method="GET",
|
|
path="version",
|
|
ignore_if_available=True,
|
|
include_version=False,
|
|
)
|
|
|
|
await self._handle_version_check(version=version)
|
|
await self._set_ext_client_session(session=self._session)
|
|
|
|
if self._log:
|
|
self._log.debug(
|
|
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
|
)
|
|
|
|
self._websocket = await client.connect(
|
|
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
|
extra_headers=self._headers,
|
|
ping_interval=self._heartbeat,
|
|
)
|
|
|
|
if reconnect:
|
|
if self._log:
|
|
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
|
|
if self.player_count:
|
|
for player in self.players.values():
|
|
await player._refresh_endpoint_uri(self._session_id)
|
|
|
|
if self._log:
|
|
self._log.debug(
|
|
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
|
|
)
|
|
|
|
if not self._task:
|
|
self._task = self._loop.create_task(self._listen())
|
|
|
|
self._available = True
|
|
|
|
end = time.perf_counter()
|
|
|
|
if self._log:
|
|
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
|
return self
|
|
|
|
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
|
|
raise NodeConnectionFailure(
|
|
f"The connection to node '{self._identifier}' failed.",
|
|
) from None
|
|
except exceptions.InvalidHandshake:
|
|
raise NodeConnectionFailure(
|
|
f"The password for node '{self._identifier}' is invalid.",
|
|
) from None
|
|
except exceptions.InvalidURI:
|
|
raise NodeConnectionFailure(
|
|
f"The URI for node '{self._identifier}' is invalid.",
|
|
) from None
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnects a connected Lavalink node and removes it from the node pool.
|
|
This also destroys any players connected to the node.
|
|
"""
|
|
|
|
start = time.perf_counter()
|
|
|
|
for player in self.players.copy().values():
|
|
await player.destroy()
|
|
if self._log:
|
|
self._log.debug("All players disconnected from node.")
|
|
|
|
await self._websocket.close()
|
|
await self._session.close()
|
|
if self._log:
|
|
self._log.debug("Websocket and http session closed.")
|
|
|
|
del self._pool._nodes[self._identifier]
|
|
self.available = False
|
|
self._task.cancel()
|
|
|
|
end = time.perf_counter()
|
|
if self._log:
|
|
self._log.info(
|
|
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
|
|
)
|
|
|
|
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
|
|
"""
|
|
Builds a track using a valid track identifier
|
|
|
|
You can also pass in a discord.py Context object to get a
|
|
Context object on the track it builds.
|
|
"""
|
|
|
|
data: dict = await self.send(
|
|
method="GET",
|
|
path="decodetrack",
|
|
query=f"encodedTrack={quote(identifier)}",
|
|
)
|
|
|
|
track_info = data["info"] if self._version.major >= 4 else data
|
|
|
|
return Track(
|
|
track_id=identifier,
|
|
ctx=ctx,
|
|
info=track_info,
|
|
track_type=TrackType(track_info["sourceName"]),
|
|
)
|
|
|
|
async def get_tracks(
|
|
self,
|
|
query: str,
|
|
*,
|
|
ctx: Optional[commands.Context] = None,
|
|
search_type: SearchType | None = SearchType.ytsearch,
|
|
filters: Optional[List[Filter]] = None,
|
|
) -> Optional[Union[Playlist, List[Track]]]:
|
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
|
|
|
If you passed in Spotify API credentials, you can also pass in a
|
|
Spotify URL of a playlist, album or track and it will be parsed accordingly.
|
|
|
|
You can pass in a discord.py Context object to get a
|
|
Context object on any track you search.
|
|
|
|
You may also pass in a List of filters
|
|
to be applied to your track once it plays.
|
|
"""
|
|
|
|
timestamp = None
|
|
|
|
if filters:
|
|
for filter in filters:
|
|
filter.set_preload()
|
|
|
|
# Due to the inclusion of plugins in the v4 update
|
|
# we are doing away with raising an error if pomice detects
|
|
# either a Spotify or Apple Music URL and the respective client
|
|
# is not enabled. Instead, we will just only parse the URL
|
|
# if the client is enabled and the URL is valid.
|
|
|
|
if self._apple_music_client and URLRegex.AM_URL.match(query):
|
|
apple_music_results = await self._apple_music_client.search(query=query)
|
|
if isinstance(apple_music_results, applemusic.Song):
|
|
return [
|
|
Track(
|
|
track_id=apple_music_results.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.APPLE_MUSIC,
|
|
search_type=search_type,
|
|
filters=filters,
|
|
info={
|
|
"title": apple_music_results.name,
|
|
"author": apple_music_results.artists,
|
|
"length": apple_music_results.length,
|
|
"identifier": apple_music_results.id,
|
|
"uri": apple_music_results.url,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": apple_music_results.image,
|
|
"isrc": apple_music_results.isrc,
|
|
},
|
|
),
|
|
]
|
|
|
|
tracks = [
|
|
Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.APPLE_MUSIC,
|
|
search_type=search_type,
|
|
filters=filters,
|
|
info={
|
|
"title": track.name,
|
|
"author": track.artists,
|
|
"length": track.length,
|
|
"identifier": track.id,
|
|
"uri": track.url,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": track.image,
|
|
"isrc": track.isrc,
|
|
},
|
|
)
|
|
for track in apple_music_results.tracks
|
|
]
|
|
|
|
return Playlist(
|
|
playlist_info={
|
|
"name": apple_music_results.name,
|
|
"selectedTrack": 0,
|
|
},
|
|
tracks=tracks,
|
|
playlist_type=PlaylistType.APPLE_MUSIC,
|
|
thumbnail=apple_music_results.image,
|
|
uri=apple_music_results.url,
|
|
)
|
|
|
|
elif self._spotify_client and URLRegex.SPOTIFY_URL.match(query):
|
|
spotify_results = await self._spotify_client.search(query=query) # type: ignore
|
|
|
|
if isinstance(spotify_results, spotify.Track):
|
|
return [
|
|
Track(
|
|
track_id=spotify_results.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.SPOTIFY,
|
|
search_type=search_type,
|
|
filters=filters,
|
|
info={
|
|
"title": spotify_results.name,
|
|
"author": spotify_results.artists,
|
|
"length": spotify_results.length,
|
|
"identifier": spotify_results.id,
|
|
"uri": spotify_results.uri,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": spotify_results.image,
|
|
"isrc": spotify_results.isrc,
|
|
},
|
|
),
|
|
]
|
|
|
|
tracks = [
|
|
Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.SPOTIFY,
|
|
search_type=search_type,
|
|
filters=filters,
|
|
info={
|
|
"title": track.name,
|
|
"author": track.artists,
|
|
"length": track.length,
|
|
"identifier": track.id,
|
|
"uri": track.uri,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": track.image,
|
|
"isrc": track.isrc,
|
|
},
|
|
)
|
|
for track in spotify_results.tracks
|
|
]
|
|
|
|
return Playlist(
|
|
playlist_info={
|
|
"name": spotify_results.name,
|
|
"selectedTrack": 0,
|
|
},
|
|
tracks=tracks,
|
|
playlist_type=PlaylistType.SPOTIFY,
|
|
thumbnail=spotify_results.image,
|
|
uri=spotify_results.uri,
|
|
)
|
|
|
|
else:
|
|
if (
|
|
search_type
|
|
and not URLRegex.BASE_URL.match(query)
|
|
and not re.match(r"(?:[a-z]+?)search:.", query)
|
|
and not URLRegex.DISCORD_MP3_URL.match(query)
|
|
and not path.exists(path.dirname(query))
|
|
):
|
|
query = f"{search_type}:{query}"
|
|
|
|
# If YouTube url contains a timestamp, capture it for use later.
|
|
|
|
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):
|
|
timestamp = float(match.group("time"))
|
|
|
|
data = await self.send(
|
|
method="GET",
|
|
path="loadtracks",
|
|
query=f"identifier={quote(query)}",
|
|
)
|
|
|
|
load_type = data.get("loadType")
|
|
|
|
# Lavalink v4 changed the name of the key from "tracks" to "data"
|
|
# so lets account for that
|
|
data_type = "data" if self._version.major >= 4 else "tracks"
|
|
|
|
if not load_type:
|
|
raise TrackLoadError(
|
|
"There was an error while trying to load this track.",
|
|
)
|
|
|
|
elif load_type in ("LOAD_FAILED", "error"):
|
|
exception = data["data"] if self._version.major >= 4 else data["exception"]
|
|
raise TrackLoadError(
|
|
f"{exception['message']} [{exception['severity']}]",
|
|
)
|
|
|
|
elif load_type in ("NO_MATCHES", "empty"):
|
|
return None
|
|
|
|
elif load_type in ("PLAYLIST_LOADED", "playlist"):
|
|
if self._version.major >= 4:
|
|
track_list = data[data_type]["tracks"]
|
|
playlist_info = data[data_type]["info"]
|
|
else:
|
|
track_list = data[data_type]
|
|
playlist_info = data["playlistInfo"]
|
|
tracks = [
|
|
Track(
|
|
track_id=track["encoded"],
|
|
info=track["info"],
|
|
ctx=ctx,
|
|
track_type=TrackType(track["info"]["sourceName"]),
|
|
)
|
|
for track in track_list
|
|
]
|
|
return Playlist(
|
|
playlist_info=playlist_info,
|
|
tracks=tracks,
|
|
playlist_type=PlaylistType(tracks[0].track_type.value),
|
|
thumbnail=tracks[0].thumbnail,
|
|
uri=query,
|
|
)
|
|
|
|
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
|
|
if self._version.major >= 4 and isinstance(data[data_type], dict):
|
|
data[data_type] = [data[data_type]]
|
|
|
|
if path.exists(path.dirname(query)):
|
|
local_file = Path(query)
|
|
|
|
return [
|
|
Track(
|
|
track_id=track["encoded"],
|
|
info={
|
|
"title": local_file.name,
|
|
"author": "Unknown",
|
|
"length": track["info"]["length"],
|
|
"uri": quote(local_file.as_uri()),
|
|
"position": track["info"]["position"],
|
|
"identifier": track["info"]["identifier"],
|
|
},
|
|
ctx=ctx,
|
|
track_type=TrackType.LOCAL,
|
|
filters=filters,
|
|
)
|
|
for track in data[data_type]
|
|
]
|
|
|
|
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
|
return [
|
|
Track(
|
|
track_id=track["encoded"],
|
|
info={
|
|
"title": discord_url.group("file"),
|
|
"author": "Unknown",
|
|
"length": track["info"]["length"],
|
|
"uri": track["info"]["uri"],
|
|
"position": track["info"]["position"],
|
|
"identifier": track["info"]["identifier"],
|
|
},
|
|
ctx=ctx,
|
|
track_type=TrackType.HTTP,
|
|
filters=filters,
|
|
)
|
|
for track in data[data_type]
|
|
]
|
|
|
|
return [
|
|
Track(
|
|
track_id=track["encoded"],
|
|
info=track["info"],
|
|
ctx=ctx,
|
|
track_type=TrackType(track["info"]["sourceName"]),
|
|
filters=filters,
|
|
timestamp=timestamp,
|
|
)
|
|
for track in data[data_type]
|
|
]
|
|
|
|
else:
|
|
raise TrackLoadError(
|
|
"There was an error while trying to load this track.",
|
|
)
|
|
|
|
async def get_recommendations(
|
|
self,
|
|
*,
|
|
track: Track,
|
|
ctx: Optional[commands.Context] = None,
|
|
) -> Optional[Union[List[Track], Playlist]]:
|
|
"""
|
|
Gets recommendations from either YouTube or Spotify.
|
|
The track that is passed in must be either from
|
|
YouTube or Spotify or else this will not work.
|
|
You can pass in a discord.py Context object to get a
|
|
Context object on all tracks that get recommended.
|
|
"""
|
|
if track.track_type == TrackType.SPOTIFY:
|
|
results = await self._spotify_client.get_recommendations(query=track.uri) # type: ignore
|
|
tracks = [
|
|
Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.SPOTIFY,
|
|
info={
|
|
"title": track.name,
|
|
"author": track.artists,
|
|
"length": track.length,
|
|
"identifier": track.id,
|
|
"uri": track.uri,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": track.image,
|
|
"isrc": track.isrc,
|
|
},
|
|
requester=self.bot.user,
|
|
)
|
|
for track in results
|
|
]
|
|
return tracks
|
|
|
|
elif track.track_type == TrackType.YOUTUBE:
|
|
return await self.get_tracks(
|
|
query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}",
|
|
ctx=ctx,
|
|
)
|
|
|
|
else:
|
|
raise TrackLoadError(
|
|
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
|
)
|
|
|
|
async def search_spotify_recommendations(
|
|
self,
|
|
query: str,
|
|
*,
|
|
ctx: Optional[commands.Context] = None,
|
|
filters: Optional[List[Filter]] = None,
|
|
) -> Optional[Union[List[Track], Playlist]]:
|
|
"""
|
|
Searches for recommendations on Spotify and returns a list of tracks based on the query.
|
|
You must have Spotify enabled for this to work.
|
|
You can pass in a discord.py Context object to get a
|
|
Context object on all tracks that get recommended.
|
|
"""
|
|
|
|
if not self._spotify_client:
|
|
raise InvalidSpotifyClientAuthorization(
|
|
"You must have Spotify enabled to use this feature.",
|
|
)
|
|
|
|
results = await self._spotify_client.track_search(query=query) # type: ignore
|
|
if not results:
|
|
raise TrackLoadError(
|
|
"Unable to find any tracks based on the query.",
|
|
)
|
|
|
|
tracks = [
|
|
Track(
|
|
track_id=track.id,
|
|
ctx=ctx,
|
|
track_type=TrackType.SPOTIFY,
|
|
info={
|
|
"title": track.name,
|
|
"author": track.artists,
|
|
"length": track.length,
|
|
"identifier": track.id,
|
|
"uri": track.uri,
|
|
"isStream": False,
|
|
"isSeekable": True,
|
|
"position": 0,
|
|
"thumbnail": track.image,
|
|
"isrc": track.isrc,
|
|
},
|
|
requester=self.bot.user,
|
|
)
|
|
for track in results
|
|
]
|
|
|
|
track = tracks[0]
|
|
|
|
return await self.get_recommendations(track=track, ctx=ctx)
|
|
|
|
|
|
class NodePool:
|
|
"""The base class for the node pool.
|
|
This holds all the nodes that are to be used by the bot.
|
|
"""
|
|
|
|
__slots__ = ()
|
|
_nodes: Dict[str, Node] = {}
|
|
|
|
def __repr__(self) -> str:
|
|
return f"<Pomice.NodePool node_count={self.node_count}>"
|
|
|
|
@property
|
|
def nodes(self) -> Dict[str, Node]:
|
|
"""Property which returns a dict with the node identifier and the Node object."""
|
|
return self._nodes
|
|
|
|
@property
|
|
def node_count(self) -> int:
|
|
return len(self._nodes.values())
|
|
|
|
@classmethod
|
|
def get_best_node(cls, *, algorithm: NodeAlgorithm) -> Node:
|
|
"""Fetches the best node based on an NodeAlgorithm.
|
|
This option is preferred if you want to choose the best node
|
|
from a multi-node setup using either the node's latency
|
|
or the node's voice region.
|
|
|
|
Use NodeAlgorithm.by_ping if you want to get the best node
|
|
based on the node's latency.
|
|
|
|
|
|
Use NodeAlgorithm.by_players if you want to get the best node
|
|
based on how players it has. This method will return a node with
|
|
the least amount of players
|
|
"""
|
|
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
|
|
|
if not available_nodes:
|
|
raise NoNodesAvailable("There are no nodes available.")
|
|
|
|
if algorithm == NodeAlgorithm.by_ping:
|
|
tested_nodes = {node: node.latency for node in available_nodes}
|
|
return min(tested_nodes, key=tested_nodes.get) # type: ignore
|
|
|
|
elif algorithm == NodeAlgorithm.by_players:
|
|
tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
|
|
return min(tested_nodes, key=tested_nodes.get) # type: ignore
|
|
|
|
else:
|
|
raise ValueError(
|
|
"The algorithm provided is not a valid NodeAlgorithm.",
|
|
)
|
|
|
|
@classmethod
|
|
def get_node(cls, *, identifier: Optional[str] = None) -> Node:
|
|
"""Fetches a node from the node pool using it's identifier.
|
|
If no identifier is provided, it will choose a node at random.
|
|
"""
|
|
available_nodes = {
|
|
identifier: node for identifier, node in cls._nodes.items() if node._available
|
|
}
|
|
|
|
if not available_nodes:
|
|
raise NoNodesAvailable("There are no nodes available.")
|
|
|
|
if identifier is None:
|
|
return random.choice(list(available_nodes.values()))
|
|
|
|
return available_nodes[identifier]
|
|
|
|
@classmethod
|
|
async def create_node(
|
|
cls,
|
|
*,
|
|
bot: commands.Bot,
|
|
host: str,
|
|
port: int,
|
|
password: str,
|
|
identifier: str,
|
|
secure: bool = False,
|
|
heartbeat: int = 120,
|
|
resume_key: Optional[str] = None,
|
|
resume_timeout: int = 60,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
spotify_client_id: Optional[str] = None,
|
|
spotify_client_secret: Optional[str] = None,
|
|
session: Optional[aiohttp.ClientSession] = None,
|
|
apple_music: bool = False,
|
|
fallback: bool = False,
|
|
logger: Optional[logging.Logger] = None,
|
|
) -> Node:
|
|
"""Creates a Node object to be then added into the node pool.
|
|
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
|
"""
|
|
if identifier in cls._nodes.keys():
|
|
raise NodeCreationError(
|
|
f"A node with identifier '{identifier}' already exists.",
|
|
)
|
|
|
|
node = Node(
|
|
pool=cls,
|
|
bot=bot,
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
identifier=identifier,
|
|
secure=secure,
|
|
heartbeat=heartbeat,
|
|
resume_key=resume_key,
|
|
resume_timeout=resume_timeout,
|
|
loop=loop,
|
|
spotify_client_id=spotify_client_id,
|
|
session=session,
|
|
spotify_client_secret=spotify_client_secret,
|
|
apple_music=apple_music,
|
|
fallback=fallback,
|
|
logger=logger,
|
|
)
|
|
|
|
await node.connect()
|
|
cls._nodes[node._identifier] = node
|
|
return node
|
|
|
|
@classmethod
|
|
async def disconnect(cls) -> None:
|
|
"""Disconnects all available nodes from the node pool."""
|
|
|
|
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
|
|
|
for node in available_nodes:
|
|
await node.disconnect()
|