pomice/pomice/pool.py

507 lines
18 KiB
Python

import asyncio
import json
import random
import re
import socket
import time
from typing import Dict, Optional, Type
from urllib.parse import quote
import aiohttp
import discord
from discord.ext import commands
from . import __version__, spotify
from .enums import SearchType
from .exceptions import (
InvalidSpotifyClientAuthorization,
NodeConnectionFailure,
NodeCreationError,
NodeNotAvailable,
NoNodesAvailable,
SpotifyAlbumLoadFailed,
SpotifyPlaylistLoadFailed,
SpotifyTrackLoadFailed,
TrackLoadError
)
from .objects import Playlist, Track
from .player import Player
from .spotify import SpotifyException
from .utils import ExponentialBackoff, NodeStats
SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track)/(?P<id>[a-zA-Z0-9]+)"
)
DISCORD_MP3_URL_REGEX = re.compile(
r"https?://cdn.discordapp.com/attachments/(?P<channel_id>[0-9]+)/(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+"
)
class Node:
"""The base class for a node.
This node object represents a Lavalink node.
To enable Spotify searching, pass in a proper Spotify Client ID and Spotify Client Secret
"""
def __init__(
self,
pool,
bot: Type[discord.Client],
host: str,
port: int,
password: str,
identifier: str,
session: Optional[aiohttp.ClientSession],
spotify_client_id: Optional[str],
spotify_client_secret: Optional[str],
):
self._bot = bot
self._host = host
self._port = port
self._pool = pool
self._password = password
self._identifier = identifier
self._websocket_uri = f"ws://{self._host}:{self._port}"
self._rest_uri = f"http://{self._host}:{self._port}"
self._session = session or aiohttp.ClientSession()
self._websocket: aiohttp.ClientWebSocketResponse = None
self._task: asyncio.Task = None
self._connection_id = None
self._metadata = None
self._available = None
self._headers = {
"Authorization": self._password,
"User-Id": str(self._bot.user.id),
"Client-Name": f"Pomice/{__version__}"
}
self._players = {}
self._spotify_client_id = spotify_client_id
self._spotify_client_secret = spotify_client_secret
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client(
self._spotify_client_id, self._spotify_client_secret
)
self._spotify_http_client = spotify.HTTPClient(
self._spotify_client_id, self._spotify_client_secret
)
self._bot.add_listener(self._update_handler, "on_socket_response")
def __repr__(self):
return (
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
f"player_count={len(self._players)}>"
)
@property
def is_connected(self) -> bool:
""""Property which returns whether this node is connected or not"""
return self._websocket is not None and not self._websocket.closed
@property
async def latency(self) -> int:
"""Property which returns the latency of the node in milliseconds"""
start_time = time.time()
await self.send(op="ping")
end_time = await self._bot.wait_for("node_ping")
return (end_time - start_time) * 1000
@property
async def stats(self) -> NodeStats:
"""Property which returns the node stats."""
await self.send(op="get-stats")
node_stats = await self._bot.wait_for("node_stats")
return node_stats
@property
def players(self) -> Dict[int, Player]:
"""Property which returns a dict containing the guild ID and the player object."""
return self._players
@property
def bot(self) -> Type[discord.Client]:
"""Property which returns the discord.py client linked to this node"""
return self._bot
@property
def player_count(self) -> int:
"""Property which returns how many players are connected to this node"""
return len(self.players)
@property
def pool(self):
"""Property which returns the pool this node is apart of"""
return self._pool
async def _update_handler(self, data: dict):
await self._bot.wait_until_ready()
if not data:
return
if data["t"] == "VOICE_SERVER_UPDATE":
guild_id = int(data["d"]["guild_id"])
try:
player = self._players[guild_id]
await player.on_voice_server_update(data["d"])
except KeyError:
return
elif data["t"] == "VOICE_STATE_UPDATE":
if int(data["d"]["user_id"]) != self._bot.user.id:
return
guild_id = int(data["d"]["guild_id"])
try:
player = self._players[guild_id]
await player.on_voice_state_update(data["d"])
except KeyError:
return
async def _listen(self):
backoff = ExponentialBackoff(base=7)
while True:
msg = await self._websocket.receive()
if msg.type == aiohttp.WSMsgType.CLOSED:
retry = backoff.delay()
await asyncio.sleep(retry)
if not self.is_connected:
self._bot.loop.create_task(self.connect())
else:
self._bot.loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict):
op = data.get("op", None)
if not op:
return
if op == "stats":
self._stats = NodeStats(data)
return
if not (player := self._players.get(int(data["guildId"]))):
return
if op == "event":
await player._dispatch_event(data)
elif op == "playerUpdate":
await player._update_state(data)
async def send(self, **data):
if not self._available:
raise NodeNotAvailable(
f"The node '{self.identifier}' is not currently available.")
await self._websocket.send_str(json.dumps(data))
def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
return self._players.get(guild_id, None)
async def connect(self):
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
await self._bot.wait_until_ready()
try:
self._websocket = await self._session.ws_connect(
self._websocket_uri, headers=self._headers, heartbeat=60
)
self._task = self._bot.loop.create_task(self._listen())
self._available = True
return self
except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure(
f"The password for node '{self.identifier}' is invalid."
)
except aiohttp.InvalidURL:
raise NodeConnectionFailure(
f"The URI for node '{self.identifier}' is invalid."
)
except socket.gaierror:
raise NodeConnectionFailure(
f"The node '{self.identifier}' failed to connect."
)
async def disconnect(self):
"""Disconnects a connected Lavalink node and removes it from the node pool.
This also destroys any players connected to the node.
"""
for player in self.players.copy().values():
await player.destroy()
await self._websocket.close()
del self._pool.nodes[self._identifier]
self.available = False
self._task.cancel()
async def get_tracks(
self,
query: str,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch
):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a
Spotify URL of a playlist, album or track and it will be parsed accordingly.
You can also pass in a discord.py Context object to get a
Context object on any track you search.
"""
if spotify_url_check := SPOTIFY_URL_REGEX.match(query):
if not self._spotify_client_id and not self._spotify_client_secret:
raise InvalidSpotifyClientAuthorization(
"You did not provide proper Spotify client authorization credentials. "
"If you would like to use the Spotify searching feature, "
"please obtain Spotify API credentials here: https://developer.spotify.com/"
)
spotify_search_type = spotify_url_check.group("type")
spotify_id = spotify_url_check.group("id")
if spotify_search_type == "playlist":
results = spotify.Playlist(
client=self._spotify_client,
data=await self._spotify_http_client.get_playlist(spotify_id)
)
try:
search_tracks = await results.get_all_tracks()
tracks = [
Track(
track_id=track.id,
ctx=ctx,
search_type=f"{search_type}{track.artists[0].name} - {track.name}" if search_type else f"ytmsearch:{track.artists[0].name} - {track.name}",
spotify=True,
info={
"title": track.name or "Unknown",
"author": ", ".join(
artist.name for artist in track.artists
) or "Unknown",
"length": track.duration or 0,
"identifier": track.id or "Unknown",
"uri": track.url or "spotify",
"isStream": False,
"isSeekable": False,
"position": 0,
"thumbnail": track.images[0].url if track.images else None
},
) for track in search_tracks
]
return Playlist(
playlist_info={"name": results.name, "selectedTrack": tracks[0]},
tracks=tracks,
ctx=ctx,
spotify=True,
thumbnail=results.images[0].url,
uri=results.url,
)
except SpotifyException:
raise SpotifyPlaylistLoadFailed(
f"Unable to find results for {query}"
)
elif spotify_search_type == "album":
results = await self._spotify_client.get_album(spotify_id=spotify_id)
try:
search_tracks = await results.get_all_tracks()
tracks = [
Track(
track_id=track.id,
ctx=ctx,
search_type=f"{search_type}{track.artists[0].name} - {track.name}" if search_type else f"ytmsearch:{track.artists[0].name} - {track.name}",
spotify=True,
info={
"title": track.name or "Unknown",
"author": ", ".join(
artist.name for artist in track.artists
) or "Unknown",
"length": track.duration or 0,
"identifier": track.id or "Unknown",
"uri": track.url or "spotify",
"isStream": False,
"isSeekable": False,
"position": 0,
"thumbnail": track.images[0].url if track.images else None
},
) for track in search_tracks
]
return Playlist(
playlist_info={"name": results.name, "selectedTrack": tracks[0]},
tracks=tracks,
ctx=ctx,
spotify=True,
thumbnail=results.images[0].url,
uri=results.url,
)
except SpotifyException:
raise SpotifyAlbumLoadFailed(f"Unable to find results for {query}")
elif spotify_search_type == 'track':
try:
results = await self._spotify_client.get_track(spotify_id=spotify_id)
return [
Track(
track_id=results.id,
ctx=ctx,
search_type=f"{search_type}{results.artists[0].name} - {results.name}" if search_type else f"ytmsearch:{results.artists[0].name} - {results.name}",
spotify=True,
info={
"title": results.name or "Unknown",
"author": ", ".join(
artist.name for artist in results.artists
) or "Unknown",
"length": results.duration or 0,
"identifier": results.id or "Unknown",
"uri": results.url or "spotify",
"isStream": False,
"isSeekable": False,
"position": 0,
"thumbnail": results.images[0].url if results.images else None
},
)
]
except SpotifyException:
raise SpotifyTrackLoadFailed(f"Unable to find results for {query}")
elif discord_url := DISCORD_MP3_URL_REGEX.match(query):
async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password}
) as response:
data: dict = await response.json()
track: dict = data["tracks"][0]
info: dict = track.get('info')
return [Track(
track_id=track['track'],
info={
"title": discord_url.group('file'),
"author": "Unknown",
"length": info.get('length'),
"uri": info.get('uri'),
"position": info.get('position'),
"identifier": info.get('identifier')
},
ctx=ctx)]
else:
async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password}
) as response:
data = await response.json()
load_type = data.get("loadType")
if not load_type:
raise TrackLoadError("There was an error while trying to load this track.")
elif load_type == "LOAD_FAILED":
raise TrackLoadError(
f"There was an error of severity '{data['severity']}' "
f"while loading tracks.\n\n{data['cause']}"
)
elif load_type == "NO_MATCHES":
return None
elif load_type == "PLAYLIST_LOADED":
return Playlist(
playlist_info=data["playlistInfo"],
tracks=data["tracks"],
ctx=ctx
)
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
return [
Track(
track_id=track["track"],
info=track["info"],
ctx=ctx
)
for track in data["tracks"]
]
class NodePool:
"""The base class for the node pool.
This holds all the nodes that are to be used by the bot.
"""
_nodes = {}
def __repr__(self):
return f"<Pomice.NodePool node_count={self.node_count}>"
@property
def nodes(self) -> Dict[str, Node]:
"""Property which returns a dict with the node identifier and the Node object."""
return self._nodes
@property
def node_count(self):
return len(self._nodes.values())
@classmethod
def get_node(cls, *, identifier: str = None) -> Node:
"""Fetches a node from the node pool using it's identifier.
If no identifier is provided, it will choose a node at random.
"""
available_nodes = {identifier: node for identifier, node in cls._nodes.items()}
if not available_nodes:
raise NoNodesAvailable('There are no nodes available.')
if identifier is None:
return random.choice(list(available_nodes.values()))
return available_nodes.get(identifier, None)
@classmethod
async def create_node(
cls,
bot: Type[discord.Client],
host: str,
port: str,
password: str,
identifier: str,
spotify_client_id: Optional[str],
spotify_client_secret: Optional[str],
session: Optional[aiohttp.ClientSession] = None,
) -> Node:
"""Creates a Node object to be then added into the node pool.
For Spotify searching capabilites, pass in valid Spotify API credentials.
"""
if identifier in cls._nodes.keys():
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
node = Node(
pool=cls, bot=bot, host=host, port=port, password=password,
identifier=identifier, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret
)
await node.connect()
cls._nodes[node._identifier] = node
return node