From fc2b12af4e660a586a86265e800340dd450ad7ea Mon Sep 17 00:00:00 2001 From: cloudwithax Date: Wed, 1 Feb 2023 22:06:54 -0500 Subject: [PATCH] finally done translating all methods over to REST --- pomice/applemusic/__init__.py | 1 + pomice/exceptions.py | 4 ++ pomice/player.py | 114 +++++++++++++++++++++++++++------- pomice/pool.py | 47 +++++++++++--- 4 files changed, 134 insertions(+), 32 deletions(-) create mode 100644 pomice/applemusic/__init__.py diff --git a/pomice/applemusic/__init__.py b/pomice/applemusic/__init__.py new file mode 100644 index 0000000..27c1284 --- /dev/null +++ b/pomice/applemusic/__init__.py @@ -0,0 +1 @@ +"""Apple Music module for Pomice, made possible by cloudwithax 2023""" diff --git a/pomice/exceptions.py b/pomice/exceptions.py index f244191..4af483f 100644 --- a/pomice/exceptions.py +++ b/pomice/exceptions.py @@ -18,6 +18,10 @@ class NodeConnectionClosed(NodeException): """The node's connection is closed.""" pass +class NodeRestException(NodeException): + """A request made using the node's REST uri failed""" + pass + class NodeNotAvailable(PomiceException): """The node is currently unavailable.""" diff --git a/pomice/player.py b/pomice/player.py index 22db950..1977be7 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -129,6 +129,8 @@ class Player(VoiceProtocol): self._voice_state = {} + self._player_endpoint_uri = f'sessions/{self._node._session_id}/players' + def __repr__(self): return ( f" 0: data["endTime"] = str(end) - await self._node.send(**data) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data=data, + query=f"noReplace={ignore_if_playing}" + ) self._current = track return self._current @@ -386,18 +396,33 @@ class Player(VoiceProtocol): "Seek position must be between 0 and the track length" ) - await self._node.send(op="seek", guildId=str(self.guild.id), position=position) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"position": position} + ) return self._position async def set_pause(self, pause: bool) -> bool: """Sets the pause state of the currently playing track.""" - await self._node.send(op="pause", guildId=str(self.guild.id), pause=pause) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"paused": pause} + ) self._paused = pause return self._paused async def set_volume(self, volume: int) -> int: """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" - await self._node.send(op="volume", guildId=str(self.guild.id), volume=volume) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"volume": volume} + ) self._volume = volume return self._volume @@ -411,7 +436,12 @@ class Player(VoiceProtocol): self._filters.add_filter(filter=filter) payload = self._filters.get_all_payloads() - await self._node.send(op="filters", guildId=str(self.guild.id), **payload) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"filters": payload} + ) if fast_apply: await self.seek(self.position) @@ -427,7 +457,12 @@ class Player(VoiceProtocol): self._filters.remove_filter(filter_tag=filter_tag) payload = self._filters.get_all_payloads() - await self._node.send(op="filters", guildId=str(self.guild.id), **payload) + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"filters": payload} + ) if fast_apply: await self.seek(self.position) @@ -446,14 +481,45 @@ class Player(VoiceProtocol): "You must have filters applied first in order to use this method." ) self._filters.reset_filters() - await self._node.send(op="filters", guildId=str(self.guild.id)) - + await self._node.send( + method="PATCH", + path=self._player_endpoint_uri, + guild_id=self._guild.id, + data={"filters": {}} + ) if fast_apply: await self.seek(self.position) +class QueuePlayer(Player): + """Player class, but with pomice.Queue included""" - + def __init__(self, client: Optional[Client] = None, channel: Optional[VoiceChannel] = None, *, node: Node = None): + super().__init__(client, channel, node=node) + self._queue = Queue + + + async def get_tracks( + self, + query: str, + *, + ctx: Optional[commands.Context] = None, + search_type: SearchType = SearchType.ytsearch, + filters: Optional[List[Filter]] = None + ): + """Fetches tracks from the node's REST api to parse into Lavalink. + + If you passed in Spotify API credentials when you created the node, + you can also pass in a Spotify URL of a playlist, album or track and it will be parsed + accordingly. + + You can pass in a discord.py Context object to get a + Context object on any track you search. + + You may also pass in a List of filters + to be applied to your track once it plays. + """ + super() diff --git a/pomice/pool.py b/pomice/pool.py index 9eee0ed..7776693 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio import json import random +import secrets +import string import re -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, Union from urllib.parse import quote import aiohttp @@ -25,6 +27,7 @@ from .exceptions import ( NodeException, NodeNotAvailable, NoNodesAvailable, + NodeRestException, TrackLoadError ) from .filters import Filter @@ -81,14 +84,14 @@ class Node: self._secure = secure - self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" + self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket" self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._session = session or aiohttp.ClientSession() self._websocket: aiohttp.ClientWebSocketResponse = None self._task: asyncio.Task = None - self._connection_id = None + self._session_id = None self._metadata = None self._available = None @@ -153,6 +156,7 @@ class Node: """Property which returns the latency of the node""" return Ping(self._host, port=self._port).get_ping() + async def _update_handler(self, data: dict): await self._bot.wait_until_ready() @@ -208,13 +212,32 @@ class Node: elif op == "playerUpdate": await player._update_state(data) - async def send(self, **data): + async def send( + self, + method: str, + path: str, + guild_id: Optional[Union[int, str]], + query: Optional[str], + data: Optional[Union[dict, str]] + ): if not self._available: raise NodeNotAvailable( f"The node '{self._identifier}' is unavailable." ) - await self._websocket.send_str(json.dumps(data)) + uri: str = f'{self._rest_uri}/' \ + f'v3/' \ + f'{path}' \ + f'{f"/{guild_id}" if guild_id else ""}' \ + f'{f"?{query}" if query else ""}' + + async with self._session.request(method=method, url=uri, json=data or {}) as resp: + if resp.status >= 300: + raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}') + + return await resp.json() + + def get_player(self, guild_id: int): """Takes a guild ID as a parameter. Returns a pomice Player object.""" @@ -230,6 +253,14 @@ class Node: ) self._task = self._bot.loop.create_task(self._listen()) self._available = True + self._session_id = f"pomice_{secrets.token_hex(20)}" + async with self._session.get(f'{self._host}/v3/version') as resp: + version: str = await resp.text() + # To make version comparasion easier, lets remove the periods + # from the version numbers and compare them like whole numbers + version = int(version.translate(str.maketrans('', '', string.punctuation)).replace(" ", "")) + print(version) + return self except aiohttp.ClientConnectorError: @@ -270,7 +301,7 @@ class Node: """ async with self._session.get( - f"{self._rest_uri}/decodetrack?", + f"{self._rest_uri}/v3/decodetrack?", headers={"Authorization": self._password}, params={"track": identifier} ) as resp: @@ -376,7 +407,7 @@ class Node: elif discord_url := DISCORD_MP3_URL_REGEX.match(query): async with self._session.get( - url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", + url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: data: dict = await response.json() @@ -402,7 +433,7 @@ class Node: else: async with self._session.get( - url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", + url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: data = await response.json()