from __future__ import annotations import asyncio import logging import random import re import time from os import path from pathlib import Path from typing import TYPE_CHECKING from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Type from typing import Union from urllib.parse import quote import aiohttp import orjson as json from discord import Client from discord.ext import commands from discord.utils import MISSING from websockets import client from websockets import exceptions from websockets import typing as wstype from . import __version__ from . import applemusic from . import spotify from .enums import LogLevel from .enums import * from .exceptions import InvalidSpotifyClientAuthorization from .exceptions import LavalinkVersionIncompatible from .exceptions import NodeConnectionFailure from .exceptions import NodeCreationError from .exceptions import NodeNotAvailable from .exceptions import NodeRestException from .exceptions import NoNodesAvailable from .exceptions import TrackLoadError from .filters import Filter from .objects import Playlist from .objects import Track from .routeplanner import RoutePlanner from .utils import ExponentialBackoff from .utils import LavalinkVersion from .utils import NodeStats from .utils import Ping if TYPE_CHECKING: from .player import Player __all__ = ( "Node", "NodePool", ) VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?") class Node: """The base class for a node. This node object represents a Lavalink node. To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret To enable Apple music, set the "apple_music" parameter to "True" """ __slots__ = ( "_bot", "_bot_user", "_host", "_port", "_pool", "_password", "_identifier", "_heartbeat", "_resume_key", "_resume_timeout", "_secure", "_fallback", "_log_level", "_websocket_uri", "_rest_uri", "_session", "_websocket", "_task", "_loop", "_session_id", "_available", "_version", "_headers", "_players", "_spotify_client_id", "_spotify_client_secret", "_spotify_client", "_apple_music_client", "_route_planner", "_log", "_stats", "available", ) def __init__( self, *, pool: Type[NodePool], bot: commands.Bot, host: str, port: int, password: str, identifier: str, secure: bool = False, heartbeat: int = 120, resume_key: Optional[str] = None, resume_timeout: int = 60, loop: Optional[asyncio.AbstractEventLoop] = None, session: Optional[aiohttp.ClientSession] = None, spotify_client_id: Optional[str] = None, spotify_client_secret: Optional[str] = None, apple_music: bool = False, fallback: bool = False, logger: Optional[logging.Logger] = None, ): if not isinstance(port, int): raise TypeError("Port must be an integer") self._bot: commands.Bot = bot self._host: str = host self._port: int = port self._pool: Type[NodePool] = pool self._password: str = password self._identifier: str = identifier self._heartbeat: int = heartbeat self._resume_key: Optional[str] = resume_key self._resume_timeout: int = resume_timeout self._secure: bool = secure self._fallback: bool = fallback self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._session: aiohttp.ClientSession = session # type: ignore self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self._websocket: client.WebSocketClientProtocol self._task: asyncio.Task = None # type: ignore self._session_id: Optional[str] = None self._available: bool = False self._version: LavalinkVersion = LavalinkVersion(0, 0, 0) self._route_planner = RoutePlanner(self) self._log = logger if not self._bot.user: raise NodeCreationError("Bot user is not ready yet.") self._bot_user = self._bot.user self._headers = { "Authorization": self._password, "User-Id": str(self._bot_user.id), "Client-Name": f"Pomice/{__version__}", } self._players: Dict[int, Player] = {} self._spotify_client: Optional[spotify.Client] = None self._apple_music_client: Optional[applemusic.Client] = None self._spotify_client_id: Optional[str] = spotify_client_id self._spotify_client_secret: Optional[str] = spotify_client_secret if self._spotify_client_id and self._spotify_client_secret: self._spotify_client = spotify.Client( self._spotify_client_id, self._spotify_client_secret, ) if apple_music: self._apple_music_client = applemusic.Client() self._bot.add_listener(self._update_handler, "on_socket_response") def __repr__(self) -> str: return ( f"" ) @property def is_connected(self) -> bool: """Property which returns whether this node is connected or not""" return self._websocket is not None and not self._websocket.closed @property def stats(self) -> NodeStats: """Property which returns the node stats.""" return self._stats @property def players(self) -> Dict[int, Player]: """Property which returns a dict containing the guild ID and the player object.""" return self._players @property def bot(self) -> Client: """Property which returns the discord.py client linked to this node""" return self._bot @property def player_count(self) -> int: """Property which returns how many players are connected to this node""" return len(self.players.values()) @property def pool(self) -> Type[NodePool]: """Property which returns the pool this node is apart of""" return self._pool @property def latency(self) -> float: """Property which returns the latency of the node""" return Ping(self._host, port=self._port).get_ping() @property def ping(self) -> float: """Alias for `Node.latency`, returns the latency of the node""" return self.latency async def _handle_version_check(self, version: str) -> None: if version.endswith("-SNAPSHOT"): # we're just gonna assume all snapshot versions correlate with v4 self._version = LavalinkVersion(major=4, minor=0, fix=0) return _version_rx = VERSION_REGEX.match(version) if not _version_rx: self._available = False raise LavalinkVersionIncompatible( "The Lavalink version you're using is incompatible. " "Lavalink version 3.7.0 or above is required to use this library.", ) _version_groups = _version_rx.groups() major, minor, fix = ( int(_version_groups[0] or 0), int(_version_groups[1] or 0), int(_version_groups[2] or 0), ) if self._log: self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}") self._version = LavalinkVersion(major=major, minor=minor, fix=fix) if self._version < LavalinkVersion(3, 7, 0): self._available = False raise LavalinkVersionIncompatible( "The Lavalink version you're using is incompatible. " "Lavalink version 3.7.0 or above is required to use this library.", ) async def _set_ext_client_session(self, session: aiohttp.ClientSession) -> None: if self._spotify_client: await self._spotify_client._set_session(session=session) if self._apple_music_client: await self._apple_music_client._set_session(session=session) async def _update_handler(self, data: dict) -> None: await self._bot.wait_until_ready() if not data: return if data["t"] == "VOICE_SERVER_UPDATE": guild_id = int(data["d"]["guild_id"]) try: player = self._players[guild_id] await player.on_voice_server_update(data["d"]) except KeyError: return elif data["t"] == "VOICE_STATE_UPDATE": if int(data["d"]["user_id"]) != self._bot_user.id: return guild_id = int(data["d"]["guild_id"]) try: player = self._players[guild_id] await player.on_voice_state_update(data["d"]) except KeyError: return async def _handle_node_switch(self) -> None: nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected] new_node = random.choice(nodes) for player in self.players.copy().values(): await player._swap_node(new_node=new_node) await self.disconnect() async def _configure_resuming(self) -> None: if not self._resume_key: return data = {"timeout": self._resume_timeout} if self._version.major == 3: data["resumingKey"] = self._resume_key elif self._version.major == 4: if self._log: self._log.warning("Using a resume key with Lavalink v4 is deprecated.") data["resuming"] = True await self.send( method="PATCH", path=f"sessions/{self._session_id}", include_version=True, data=data, ) async def _listen(self) -> None: while True: try: msg = await self._websocket.recv() data = json.loads(msg) if self._log: self._log.debug(f"Recieved raw websocket message {msg}") self._loop.create_task(self._handle_ws_msg(data=data)) except exceptions.ConnectionClosed: if self.player_count > 0: for _player in self.players.values(): self._loop.create_task(_player.destroy()) if self._fallback: self._loop.create_task(self._handle_node_switch()) self._loop.create_task(self._websocket.close()) backoff = ExponentialBackoff(base=7) retry = backoff.delay() if self._log: self._log.debug( f"Retrying connection to Node {self._identifier} in {retry} secs", ) await asyncio.sleep(retry) if not self.is_connected: self._loop.create_task(self.connect(reconnect=True)) async def _handle_ws_msg(self, data: dict) -> None: if self._log: self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}") op = data.get("op", None) if op == "stats": self._stats = NodeStats(data) return if op == "ready": self._session_id = data["sessionId"] await self._configure_resuming() if not "guildId" in data: return player: Optional[Player] = self._players.get(int(data["guildId"])) if not player: return if op == "event": return await player._dispatch_event(data) if op == "playerUpdate": return await player._update_state(data) async def send( self, method: str, path: str, include_version: bool = True, guild_id: Optional[Union[int, str]] = None, query: Optional[str] = None, data: Optional[Union[Dict, str]] = None, ignore_if_available: bool = False, ) -> Any: if not ignore_if_available and not self._available: raise NodeNotAvailable( f"The node '{self._identifier}' is unavailable.", ) uri: str = ( f"{self._rest_uri}/" f'{f"v{self._version.major}/" if include_version else ""}' f"{path}" f'{f"/{guild_id}" if guild_id else ""}' f'{f"?{query}" if query else ""}' ) resp = await self._session.request( method=method, url=uri, headers=self._headers, json=data or {}, ) if self._log: self._log.debug( f"Making REST request to Node {self._identifier} with method {method} to {uri}", ) if resp.status >= 300: resp_data: dict = await resp.json() raise NodeRestException( f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}', ) if method == "DELETE" or resp.status == 204: if self._log: self._log.debug( f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.", ) return await resp.json(content_type=None) if resp.content_type == "text/plain": if self._log: self._log.debug( f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", ) return await resp.text() if self._log: self._log.debug( f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", ) return await resp.json() def get_player(self, guild_id: int) -> Optional[Player]: """Takes a guild ID as a parameter. Returns a pomice Player object or None.""" return self._players.get(guild_id, None) async def connect(self, *, reconnect: bool = False) -> Node: """Initiates a connection with a Lavalink node and adds it to the node pool.""" await self._bot.wait_until_ready() start = time.perf_counter() if not self._session: self._session = aiohttp.ClientSession() try: if not reconnect: version: str = await self.send( method="GET", path="version", ignore_if_available=True, include_version=False, ) await self._handle_version_check(version=version) await self._set_ext_client_session(session=self._session) if self._log: self._log.debug( f"Version check from Node {self._identifier} successful. Returned version {version}", ) self._websocket = await client.connect( f"{self._websocket_uri}/v{self._version.major}/websocket", extra_headers=self._headers, ping_interval=self._heartbeat, ) if reconnect: if self._log: self._log.debug(f"Trying to reconnect to Node {self._identifier}...") if self.player_count: for player in self.players.values(): await player._refresh_endpoint_uri(self._session_id) if self._log: self._log.debug( f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket", ) if not self._task: self._task = self._loop.create_task(self._listen()) self._available = True end = time.perf_counter() if self._log: self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s") return self except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError): raise NodeConnectionFailure( f"The connection to node '{self._identifier}' failed.", ) from None except exceptions.InvalidHandshake: raise NodeConnectionFailure( f"The password for node '{self._identifier}' is invalid.", ) from None except exceptions.InvalidURI: raise NodeConnectionFailure( f"The URI for node '{self._identifier}' is invalid.", ) from None async def disconnect(self) -> None: """Disconnects a connected Lavalink node and removes it from the node pool. This also destroys any players connected to the node. """ start = time.perf_counter() for player in self.players.copy().values(): await player.destroy() if self._log: self._log.debug("All players disconnected from node.") await self._websocket.close() await self._session.close() if self._log: self._log.debug("Websocket and http session closed.") del self._pool._nodes[self._identifier] self.available = False self._task.cancel() end = time.perf_counter() if self._log: self._log.info( f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s", ) async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track: """ Builds a track using a valid track identifier You can also pass in a discord.py Context object to get a Context object on the track it builds. """ data: dict = await self.send( method="GET", path="decodetrack", query=f"encodedTrack={quote(identifier)}", ) track_info = data["info"] if self._version.major >= 4 else data return Track( track_id=identifier, ctx=ctx, info=track_info, track_type=TrackType(track_info["sourceName"]), ) async def get_tracks( self, query: str, *, ctx: Optional[commands.Context] = None, search_type: SearchType = SearchType.ytsearch, filters: Optional[List[Filter]] = None, ) -> Optional[Union[Playlist, List[Track]]]: """Fetches tracks from the node's REST api to parse into Lavalink. If you passed in Spotify API credentials, you can also pass in a Spotify URL of a playlist, album or track and it will be parsed accordingly. You can pass in a discord.py Context object to get a Context object on any track you search. You may also pass in a List of filters to be applied to your track once it plays. """ timestamp = None if filters: for filter in filters: filter.set_preload() # Due to the inclusion of plugins in the v4 update # we are doing away with raising an error if pomice detects # either a Spotify or Apple Music URL and the respective client # is not enabled. Instead, we will just only parse the URL # if the client is enabled and the URL is valid. if self._apple_music_client and URLRegex.AM_URL.match(query): apple_music_results = await self._apple_music_client.search(query=query) if isinstance(apple_music_results, applemusic.Song): return [ Track( track_id=apple_music_results.id, ctx=ctx, track_type=TrackType.APPLE_MUSIC, search_type=search_type, filters=filters, info={ "title": apple_music_results.name, "author": apple_music_results.artists, "length": apple_music_results.length, "identifier": apple_music_results.id, "uri": apple_music_results.url, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": apple_music_results.image, "isrc": apple_music_results.isrc, }, ), ] tracks = [ Track( track_id=track.id, ctx=ctx, track_type=TrackType.APPLE_MUSIC, search_type=search_type, filters=filters, info={ "title": track.name, "author": track.artists, "length": track.length, "identifier": track.id, "uri": track.url, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": track.image, "isrc": track.isrc, }, ) for track in apple_music_results.tracks ] return Playlist( playlist_info={ "name": apple_music_results.name, "selectedTrack": 0, }, tracks=tracks, playlist_type=PlaylistType.APPLE_MUSIC, thumbnail=apple_music_results.image, uri=apple_music_results.url, ) elif self._spotify_client and URLRegex.SPOTIFY_URL.match(query): spotify_results = await self._spotify_client.search(query=query) # type: ignore if isinstance(spotify_results, spotify.Track): return [ Track( track_id=spotify_results.id, ctx=ctx, track_type=TrackType.SPOTIFY, search_type=search_type, filters=filters, info={ "title": spotify_results.name, "author": spotify_results.artists, "length": spotify_results.length, "identifier": spotify_results.id, "uri": spotify_results.uri, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": spotify_results.image, "isrc": spotify_results.isrc, }, ), ] tracks = [ Track( track_id=track.id, ctx=ctx, track_type=TrackType.SPOTIFY, search_type=search_type, filters=filters, info={ "title": track.name, "author": track.artists, "length": track.length, "identifier": track.id, "uri": track.uri, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": track.image, "isrc": track.isrc, }, ) for track in spotify_results.tracks ] return Playlist( playlist_info={ "name": spotify_results.name, "selectedTrack": 0, }, tracks=tracks, playlist_type=PlaylistType.SPOTIFY, thumbnail=spotify_results.image, uri=spotify_results.uri, ) else: if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): query = f"{search_type}:{query}" # If YouTube url contains a timestamp, capture it for use later. if match := URLRegex.YOUTUBE_TIMESTAMP.match(query): timestamp = float(match.group("time")) data = await self.send( method="GET", path="loadtracks", query=f"identifier={quote(query)}", ) load_type = data.get("loadType") # Lavalink v4 changed the name of the key from "tracks" to "data" # so lets account for that data_type = "data" if self._version.major >= 4 else "tracks" if not load_type: raise TrackLoadError( "There was an error while trying to load this track.", ) elif load_type in ("LOAD_FAILED", "error"): exception = data["exception"] raise TrackLoadError( f"{exception['message']} [{exception['severity']}]", ) elif load_type in ("NO_MATCHES", "empty"): return None elif load_type in ("PLAYLIST_LOADED", "playlist"): if self._version.major >= 4: track_list = data[data_type]["tracks"] playlist_info = data[data_type]["info"] else: track_list = data[data_type] playlist_info = data["playlistInfo"] tracks = [ Track( track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]), ) for track in track_list ] return Playlist( playlist_info=playlist_info, tracks=tracks, playlist_type=PlaylistType(tracks[0].track_type.value), thumbnail=tracks[0].thumbnail, uri=query, ) elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"): if self._version.major >= 4 and isinstance(data[data_type], dict): data[data_type] = [data[data_type]] if path.exists(path.dirname(query)): local_file = Path(query) return [ Track( track_id=track["track"], info={ "title": local_file.name, "author": "Unknown", "length": track["info"]["length"], "uri": quote(local_file.as_uri()), "position": track["info"]["position"], "identifier": track["info"]["identifier"], }, ctx=ctx, track_type=TrackType.LOCAL, filters=filters, ) for track in data[data_type] ] elif discord_url := URLRegex.DISCORD_MP3_URL.match(query): return [ Track( track_id=track["encoded"], info={ "title": discord_url.group("file"), "author": "Unknown", "length": track["info"]["length"], "uri": track["info"]["uri"], "position": track["info"]["position"], "identifier": track["info"]["identifier"], }, ctx=ctx, track_type=TrackType.HTTP, filters=filters, ) for track in data[data_type] ] return [ Track( track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]), filters=filters, timestamp=timestamp, ) for track in data[data_type] ] else: raise TrackLoadError( "There was an error while trying to load this track.", ) async def get_recommendations( self, *, track: Track, ctx: Optional[commands.Context] = None, ) -> Optional[Union[List[Track], Playlist]]: """ Gets recommendations from either YouTube or Spotify. The track that is passed in must be either from YouTube or Spotify or else this will not work. You can pass in a discord.py Context object to get a Context object on all tracks that get recommended. """ if track.track_type == TrackType.SPOTIFY: results = await self._spotify_client.get_recommendations(query=track.uri) # type: ignore tracks = [ Track( track_id=track.id, ctx=ctx, track_type=TrackType.SPOTIFY, info={ "title": track.name, "author": track.artists, "length": track.length, "identifier": track.id, "uri": track.uri, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": track.image, "isrc": track.isrc, }, requester=self.bot.user, ) for track in results ] return tracks elif track.track_type == TrackType.YOUTUBE: return await self.get_tracks( query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", ctx=ctx, ) else: raise TrackLoadError( "The specfied track must be either a YouTube or Spotify track to recieve recommendations.", ) async def search_spotify_recommendations( self, query: str, *, ctx: Optional[commands.Context] = None, filters: Optional[List[Filter]] = None, ) -> Optional[Union[List[Track], Playlist]]: """ Searches for recommendations on Spotify and returns a list of tracks based on the query. You must have Spotify enabled for this to work. You can pass in a discord.py Context object to get a Context object on all tracks that get recommended. """ if not self._spotify_client: raise InvalidSpotifyClientAuthorization( "You must have Spotify enabled to use this feature.", ) results = await self._spotify_client.track_search(query=query) # type: ignore if not results: raise TrackLoadError( "Unable to find any tracks based on the query.", ) tracks = [ Track( track_id=track.id, ctx=ctx, track_type=TrackType.SPOTIFY, info={ "title": track.name, "author": track.artists, "length": track.length, "identifier": track.id, "uri": track.uri, "isStream": False, "isSeekable": True, "position": 0, "thumbnail": track.image, "isrc": track.isrc, }, requester=self.bot.user, ) for track in results ] track = tracks[0] return await self.get_recommendations(track=track, ctx=ctx) class NodePool: """The base class for the node pool. This holds all the nodes that are to be used by the bot. """ __slots__ = () _nodes: Dict[str, Node] = {} def __repr__(self) -> str: return f"" @property def nodes(self) -> Dict[str, Node]: """Property which returns a dict with the node identifier and the Node object.""" return self._nodes @property def node_count(self) -> int: return len(self._nodes.values()) @classmethod def get_best_node(cls, *, algorithm: NodeAlgorithm) -> Node: """Fetches the best node based on an NodeAlgorithm. This option is preferred if you want to choose the best node from a multi-node setup using either the node's latency or the node's voice region. Use NodeAlgorithm.by_ping if you want to get the best node based on the node's latency. Use NodeAlgorithm.by_players if you want to get the best node based on how players it has. This method will return a node with the least amount of players """ available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] if not available_nodes: raise NoNodesAvailable("There are no nodes available.") if algorithm == NodeAlgorithm.by_ping: tested_nodes = {node: node.latency for node in available_nodes} return min(tested_nodes, key=tested_nodes.get) # type: ignore elif algorithm == NodeAlgorithm.by_players: tested_nodes = {node: len(node.players.keys()) for node in available_nodes} return min(tested_nodes, key=tested_nodes.get) # type: ignore else: raise ValueError( "The algorithm provided is not a valid NodeAlgorithm.", ) @classmethod def get_node(cls, *, identifier: Optional[str] = None) -> Node: """Fetches a node from the node pool using it's identifier. If no identifier is provided, it will choose a node at random. """ available_nodes = { identifier: node for identifier, node in cls._nodes.items() if node._available } if not available_nodes: raise NoNodesAvailable("There are no nodes available.") if identifier is None: return random.choice(list(available_nodes.values())) return available_nodes[identifier] @classmethod async def create_node( cls, *, bot: commands.Bot, host: str, port: int, password: str, identifier: str, secure: bool = False, heartbeat: int = 120, resume_key: Optional[str] = None, resume_timeout: int = 60, loop: Optional[asyncio.AbstractEventLoop] = None, spotify_client_id: Optional[str] = None, spotify_client_secret: Optional[str] = None, session: Optional[aiohttp.ClientSession] = None, apple_music: bool = False, fallback: bool = False, logger: Optional[logging.Logger] = None, ) -> Node: """Creates a Node object to be then added into the node pool. For Spotify searching capabilites, pass in valid Spotify API credentials. """ if identifier in cls._nodes.keys(): raise NodeCreationError( f"A node with identifier '{identifier}' already exists.", ) node = Node( pool=cls, bot=bot, host=host, port=port, password=password, identifier=identifier, secure=secure, heartbeat=heartbeat, resume_key=resume_key, resume_timeout=resume_timeout, loop=loop, spotify_client_id=spotify_client_id, session=session, spotify_client_secret=spotify_client_secret, apple_music=apple_music, fallback=fallback, logger=logger, ) await node.connect() cls._nodes[node._identifier] = node return node @classmethod async def disconnect(cls) -> None: """Disconnects all available nodes from the node pool.""" available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] for node in available_nodes: await node.disconnect()