from __future__ import annotations import asyncio import json import random import re import socket import time from typing import Dict, Optional, TYPE_CHECKING from urllib.parse import quote import aiohttp from discord.ext import commands from . import __version__, spotify from .enums import SearchType from .exceptions import ( InvalidSpotifyClientAuthorization, NodeConnectionFailure, NodeCreationError, NodeNotAvailable, NoNodesAvailable, TrackLoadError ) from .objects import Playlist, Track from .utils import ClientType, ExponentialBackoff, NodeStats if TYPE_CHECKING: from .player import Player SPOTIFY_URL_REGEX = re.compile( r"https?://open.spotify.com/(?Palbum|playlist|track)/(?P[a-zA-Z0-9]+)" ) DISCORD_MP3_URL_REGEX = re.compile( r"https?://cdn.discordapp.com/attachments/(?P[0-9]+)/" r"(?P[0-9]+)/(?P[a-zA-Z0-9_.]+)+" ) URL_REGEX = re.compile( r"https?://(?:www\.)?.+" ) class Node: """The base class for a node. This node object represents a Lavalink node. To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret """ def __init__( self, *, pool, bot: ClientType, host: str, port: int, password: str, identifier: str, session: Optional[aiohttp.ClientSession], spotify_client_id: Optional[str], spotify_client_secret: Optional[str], ): self._bot = bot self._host = host self._port = port self._pool = pool self._password = password self._identifier = identifier self._websocket_uri = f"ws://{self._host}:{self._port}" self._rest_uri = f"http://{self._host}:{self._port}" self._session = session or aiohttp.ClientSession() self._websocket: aiohttp.ClientWebSocketResponse = None self._task: asyncio.Task = None self._connection_id = None self._metadata = None self._available = None self._headers = { "Authorization": self._password, "User-Id": str(self._bot.user.id), "Client-Name": f"Pomice/{__version__}" } self._players: Dict[int, Player] = {} self._spotify_client_id = spotify_client_id self._spotify_client_secret = spotify_client_secret if self._spotify_client_id and self._spotify_client_secret: self._spotify_client = spotify.Client( self._spotify_client_id, self._spotify_client_secret ) self._bot.add_listener(self._update_handler, "on_socket_response") def __repr__(self): return ( f"" ) @property def is_connected(self) -> bool: """"Property which returns whether this node is connected or not""" return self._websocket is not None and not self._websocket.closed @property async def latency(self) -> int: """Property which returns the latency of the node in milliseconds""" start_time = time.time() await self.send(op="ping") end_time = await self._bot.wait_for("node_ping") return (end_time - start_time) * 1000 @property async def stats(self) -> NodeStats: """Property which returns the node stats.""" await self.send(op="get-stats") node_stats = await self._bot.wait_for("node_stats") return node_stats @property def players(self) -> Dict[int, Player]: """Property which returns a dict containing the guild ID and the player object.""" return self._players @property def bot(self) -> ClientType: """Property which returns the discord.py client linked to this node""" return self._bot @property def player_count(self) -> int: """Property which returns how many players are connected to this node""" return len(self.players) @property def pool(self): """Property which returns the pool this node is apart of""" return self._pool async def _update_handler(self, data: dict): await self._bot.wait_until_ready() if not data: return if data["t"] == "VOICE_SERVER_UPDATE": guild_id = int(data["d"]["guild_id"]) try: player = self._players[guild_id] await player.on_voice_server_update(data["d"]) except KeyError: return elif data["t"] == "VOICE_STATE_UPDATE": if int(data["d"]["user_id"]) != self._bot.user.id: return guild_id = int(data["d"]["guild_id"]) try: player = self._players[guild_id] await player.on_voice_state_update(data["d"]) except KeyError: return async def _listen(self): backoff = ExponentialBackoff(base=7) while True: msg = await self._websocket.receive() if msg.type == aiohttp.WSMsgType.CLOSED: retry = backoff.delay() await asyncio.sleep(retry) if not self.is_connected: self._bot.loop.create_task(self.connect()) else: self._bot.loop.create_task(self._handle_payload(msg.json())) async def _handle_payload(self, data: dict): op = data.get("op", None) if not op: return if op == "stats": self._stats = NodeStats(data) return if not (player := self._players.get(int(data["guildId"]))): return if op == "event": await player._dispatch_event(data) elif op == "playerUpdate": await player._update_state(data) async def send(self, **data): if not self._available: raise NodeNotAvailable( f"The node '{self.identifier}' is unavailable." ) await self._websocket.send_str(json.dumps(data)) def get_player(self, guild_id: int): """Takes a guild ID as a parameter. Returns a pomice Player object.""" return self._players.get(guild_id, None) async def connect(self): """Initiates a connection with a Lavalink node and adds it to the node pool.""" await self._bot.wait_until_ready() try: self._websocket = await self._session.ws_connect( self._websocket_uri, headers=self._headers, heartbeat=60 ) self._task = self._bot.loop.create_task(self._listen()) self._available = True return self except aiohttp.WSServerHandshakeError: raise NodeConnectionFailure( f"The password for node '{self.identifier}' is invalid." ) except aiohttp.InvalidURL: raise NodeConnectionFailure( f"The URI for node '{self.identifier}' is invalid." ) except socket.gaierror: raise NodeConnectionFailure( f"The node '{self.identifier}' failed to connect." ) async def disconnect(self): """Disconnects a connected Lavalink node and removes it from the node pool. This also destroys any players connected to the node. """ for player in self.players.copy().values(): await player.destroy() await self._websocket.close() del self._pool.nodes[self._identifier] self.available = False self._task.cancel() async def build_track( self, identifier: str, ctx: Optional[commands.Context] = None ) -> Track: """ Builds a track using a valid track identifier You can also pass in a discord.py Context object to get a Context object on the track it builds. """ async with self._session.get( f"{self._rest_uri}/decodetrack?", headers={"Authorization": self._password}, params={"track": identifier} ) as resp: if not resp.status == 200: raise TrackLoadError( f"Failed to build track. Check the identifier is correct and try again." ) data: dict = await resp.json() return Track(track_id=identifier, ctx=ctx, info=data) async def get_tracks( self, query: str, *, ctx: Optional[commands.Context] = None, search_type: SearchType = SearchType.ytsearch ): """Fetches tracks from the node's REST api to parse into Lavalink. If you passed in Spotify API credentials, you can also pass in a Spotify URL of a playlist, album or track and it will be parsed accordingly. You can also pass in a discord.py Context object to get a Context object on any track you search. """ if not URL_REGEX.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): query = f"{search_type}:{query}" if SPOTIFY_URL_REGEX.match(query): if not self._spotify_client_id and not self._spotify_client_secret: raise InvalidSpotifyClientAuthorization( "You did not provide proper Spotify client authorization credentials. " "If you would like to use the Spotify searching feature, " "please obtain Spotify API credentials here: https://developer.spotify.com/" ) spotify_results = await self._spotify_client.search(query=query) if isinstance(spotify_results, spotify.Track): return [ Track( track_id=spotify_results.id, ctx=ctx, search_type=search_type, spotify=True, info={ "title": spotify_results.name, "author": spotify_results.artists, "length": spotify_results.length, "identifier": spotify_results.id, "uri": spotify_results.uri, "isStream": False, "isSeekable": False, "position": 0, "thumbnail": spotify_results.image } ) ] tracks = [ Track( track_id=track.id, ctx=ctx, search_type=search_type, spotify=True, info={ "title": track.name, "author": track.artists, "length": track.length, "identifier": track.id, "uri": track.uri, "isStream": False, "isSeekable": False, "position": 0, "thumbnail": track.image } ) for track in spotify_results.tracks ] return Playlist( playlist_info={"name": spotify_results.name, "selectedTrack": tracks[0]}, tracks=tracks, ctx=ctx, spotify=True, thumbnail=spotify_results.image, uri=spotify_results.uri ) elif discord_url := DISCORD_MP3_URL_REGEX.match(query): async with self._session.get( url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: data: dict = await response.json() track: dict = data["tracks"][0] info: dict = track.get("info") return [ Track( track_id=track["track"], info={ "title": discord_url.group("file"), "author": "Unknown", "length": info.get("length"), "uri": info.get("uri"), "position": info.get("position"), "identifier": info.get("identifier") }, ctx=ctx ) ] else: async with self._session.get( url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: data = await response.json() load_type = data.get("loadType") if not load_type: raise TrackLoadError("There was an error while trying to load this track.") elif load_type == "LOAD_FAILED": exception = data["exception"] raise TrackLoadError(f"{exception['message']} [{exception['severity']}]") elif load_type == "NO_MATCHES": return None elif load_type == "PLAYLIST_LOADED": return Playlist( playlist_info=data["playlistInfo"], tracks=data["tracks"], ctx=ctx ) elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED": return [ Track( track_id=track["track"], info=track["info"], ctx=ctx ) for track in data["tracks"] ] class NodePool: """The base class for the node pool. This holds all the nodes that are to be used by the bot. """ _nodes = {} def __repr__(self): return f"" @property def nodes(self) -> Dict[str, Node]: """Property which returns a dict with the node identifier and the Node object.""" return self._nodes @property def node_count(self): return len(self._nodes.values()) @classmethod def get_node(cls, *, identifier: str = None) -> Node: """Fetches a node from the node pool using it's identifier. If no identifier is provided, it will choose a node at random. """ available_nodes = { identifier: node for identifier, node in cls._nodes.items() if node._available } if not available_nodes: raise NoNodesAvailable("There are no nodes available.") if identifier is None: return random.choice(list(available_nodes.values())) return available_nodes.get(identifier, None) @classmethod async def create_node( cls, *, bot: ClientType, host: str, port: str, password: str, identifier: str, spotify_client_id: Optional[str], spotify_client_secret: Optional[str], session: Optional[aiohttp.ClientSession] = None, ) -> Node: """Creates a Node object to be then added into the node pool. For Spotify searching capabilites, pass in valid Spotify API credentials. """ if identifier in cls._nodes.keys(): raise NodeCreationError(f"A node with identifier '{identifier}' already exists.") node = Node( pool=cls, bot=bot, host=host, port=port, password=password, identifier=identifier, spotify_client_id=spotify_client_id, session=session, spotify_client_secret=spotify_client_secret ) await node.connect() cls._nodes[node._identifier] = node return node