finally done translating all methods over to REST

This commit is contained in:
cloudwithax 2023-02-01 22:06:54 -05:00
parent 6b513d1e67
commit fc2b12af4e
4 changed files with 134 additions and 32 deletions

View File

@ -0,0 +1 @@
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""

View File

@ -18,6 +18,10 @@ class NodeConnectionClosed(NodeException):
"""The node's connection is closed.""" """The node's connection is closed."""
pass pass
class NodeRestException(NodeException):
"""A request made using the node's REST uri failed"""
pass
class NodeNotAvailable(PomiceException): class NodeNotAvailable(PomiceException):
"""The node is currently unavailable.""" """The node is currently unavailable."""

View File

@ -129,6 +129,8 @@ class Player(VoiceProtocol):
self._voice_state = {} self._voice_state = {}
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
def __repr__(self): def __repr__(self):
return ( return (
f"<Pomice.player bot={self.bot} guildId={self.guild.id} " f"<Pomice.player bot={self.bot} guildId={self.guild.id} "
@ -217,9 +219,10 @@ class Player(VoiceProtocol):
return return
await self._node.send( await self._node.send(
op="voiceUpdate", method="PATCH",
guildId=str(self.guild.id), path=self._player_endpoint_uri,
**voice_data guild_id=self._guild.id,
data={"voice": voice_data}
) )
async def on_voice_server_update(self, data: dict): async def on_voice_server_update(self, data: dict):
@ -283,7 +286,12 @@ class Player(VoiceProtocol):
async def stop(self): async def stop(self):
"""Stops the currently playing track.""" """Stops the currently playing track."""
self._current = None self._current = None
await self._node.send(op="stop", guildId=str(self.guild.id)) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={'encodedTrack': None}
)
async def disconnect(self, *, force: bool = False): async def disconnect(self, *, force: bool = False):
"""Disconnects the player from voice.""" """Disconnects the player from voice."""
@ -304,7 +312,7 @@ class Player(VoiceProtocol):
assert self.channel is None and not self.is_connected assert self.channel is None and not self.is_connected
self._node._players.pop(self.guild.id) self._node._players.pop(self.guild.id)
await self._node.send(op="destroy", guildId=str(self.guild.id)) await self._node.send(method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id)
async def play( async def play(
self, self,
@ -336,22 +344,18 @@ class Player(VoiceProtocol):
"No equivalent track was able to be found." "No equivalent track was able to be found."
) )
data = { data = {
"op": "play", "encodedTrack": search.track_id,
"guildId": str(self.guild.id), "position": str(start),
"track": search.track_id, "endTime": str(end)
"startTime": str(start),
"noReplace": ignore_if_playing
} }
track.original = search track.original = search
track.track_id = search.track_id track.track_id = search.track_id
# Set track_id for later lavalink searches # Set track_id for later lavalink searches
else: else:
data = { data = {
"op": "play", "encodedTrack": track.track_id,
"guildId": str(self.guild.id), "position": str(start),
"track": track.track_id, "endTime": str(end)
"startTime": str(start),
"noReplace": ignore_if_playing
} }
@ -374,7 +378,13 @@ class Player(VoiceProtocol):
if end > 0: if end > 0:
data["endTime"] = str(end) data["endTime"] = str(end)
await self._node.send(**data) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data,
query=f"noReplace={ignore_if_playing}"
)
self._current = track self._current = track
return self._current return self._current
@ -386,18 +396,33 @@ class Player(VoiceProtocol):
"Seek position must be between 0 and the track length" "Seek position must be between 0 and the track length"
) )
await self._node.send(op="seek", guildId=str(self.guild.id), position=position) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"position": position}
)
return self._position return self._position
async def set_pause(self, pause: bool) -> bool: async def set_pause(self, pause: bool) -> bool:
"""Sets the pause state of the currently playing track.""" """Sets the pause state of the currently playing track."""
await self._node.send(op="pause", guildId=str(self.guild.id), pause=pause) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"paused": pause}
)
self._paused = pause self._paused = pause
return self._paused return self._paused
async def set_volume(self, volume: int) -> int: async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
await self._node.send(op="volume", guildId=str(self.guild.id), volume=volume) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"volume": volume}
)
self._volume = volume self._volume = volume
return self._volume return self._volume
@ -411,7 +436,12 @@ class Player(VoiceProtocol):
self._filters.add_filter(filter=filter) self._filters.add_filter(filter=filter)
payload = self._filters.get_all_payloads() payload = self._filters.get_all_payloads()
await self._node.send(op="filters", guildId=str(self.guild.id), **payload) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
@ -427,7 +457,12 @@ class Player(VoiceProtocol):
self._filters.remove_filter(filter_tag=filter_tag) self._filters.remove_filter(filter_tag=filter_tag)
payload = self._filters.get_all_payloads() payload = self._filters.get_all_payloads()
await self._node.send(op="filters", guildId=str(self.guild.id), **payload) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
@ -446,14 +481,45 @@ class Player(VoiceProtocol):
"You must have filters applied first in order to use this method." "You must have filters applied first in order to use this method."
) )
self._filters.reset_filters() self._filters.reset_filters()
await self._node.send(op="filters", guildId=str(self.guild.id)) await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": {}}
)
if fast_apply: if fast_apply:
await self.seek(self.position) await self.seek(self.position)
class QueuePlayer(Player):
"""Player class, but with pomice.Queue included"""
def __init__(self, client: Optional[Client] = None, channel: Optional[VoiceChannel] = None, *, node: Node = None):
super().__init__(client, channel, node=node)
self._queue = Queue
async def get_tracks(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None
):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials when you created the node,
you can also pass in a Spotify URL of a playlist, album or track and it will be parsed
accordingly.
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
to be applied to your track once it plays.
"""
super()

View File

@ -3,8 +3,10 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import random import random
import secrets
import string
import re import re
from typing import Dict, List, Optional, TYPE_CHECKING from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote from urllib.parse import quote
import aiohttp import aiohttp
@ -25,6 +27,7 @@ from .exceptions import (
NodeException, NodeException,
NodeNotAvailable, NodeNotAvailable,
NoNodesAvailable, NoNodesAvailable,
NodeRestException,
TrackLoadError TrackLoadError
) )
from .filters import Filter from .filters import Filter
@ -81,14 +84,14 @@ class Node:
self._secure = secure self._secure = secure
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session = session or aiohttp.ClientSession() self._session = session or aiohttp.ClientSession()
self._websocket: aiohttp.ClientWebSocketResponse = None self._websocket: aiohttp.ClientWebSocketResponse = None
self._task: asyncio.Task = None self._task: asyncio.Task = None
self._connection_id = None self._session_id = None
self._metadata = None self._metadata = None
self._available = None self._available = None
@ -153,6 +156,7 @@ class Node:
"""Property which returns the latency of the node""" """Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping() return Ping(self._host, port=self._port).get_ping()
async def _update_handler(self, data: dict): async def _update_handler(self, data: dict):
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
@ -208,13 +212,32 @@ class Node:
elif op == "playerUpdate": elif op == "playerUpdate":
await player._update_state(data) await player._update_state(data)
async def send(self, **data): async def send(
self,
method: str,
path: str,
guild_id: Optional[Union[int, str]],
query: Optional[str],
data: Optional[Union[dict, str]]
):
if not self._available: if not self._available:
raise NodeNotAvailable( raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable." f"The node '{self._identifier}' is unavailable."
) )
await self._websocket.send_str(json.dumps(data)) uri: str = f'{self._rest_uri}/' \
f'v3/' \
f'{path}' \
f'{f"/{guild_id}" if guild_id else ""}' \
f'{f"?{query}" if query else ""}'
async with self._session.request(method=method, url=uri, json=data or {}) as resp:
if resp.status >= 300:
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
return await resp.json()
def get_player(self, guild_id: int): def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object.""" """Takes a guild ID as a parameter. Returns a pomice Player object."""
@ -230,6 +253,14 @@ class Node:
) )
self._task = self._bot.loop.create_task(self._listen()) self._task = self._bot.loop.create_task(self._listen())
self._available = True self._available = True
self._session_id = f"pomice_{secrets.token_hex(20)}"
async with self._session.get(f'{self._host}/v3/version') as resp:
version: str = await resp.text()
# To make version comparasion easier, lets remove the periods
# from the version numbers and compare them like whole numbers
version = int(version.translate(str.maketrans('', '', string.punctuation)).replace(" ", ""))
print(version)
return self return self
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
@ -270,7 +301,7 @@ class Node:
""" """
async with self._session.get( async with self._session.get(
f"{self._rest_uri}/decodetrack?", f"{self._rest_uri}/v3/decodetrack?",
headers={"Authorization": self._password}, headers={"Authorization": self._password},
params={"track": identifier} params={"track": identifier}
) as resp: ) as resp:
@ -376,7 +407,7 @@ class Node:
elif discord_url := DISCORD_MP3_URL_REGEX.match(query): elif discord_url := DISCORD_MP3_URL_REGEX.match(query):
async with self._session.get( async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password} headers={"Authorization": self._password}
) as response: ) as response:
data: dict = await response.json() data: dict = await response.json()
@ -402,7 +433,7 @@ class Node:
else: else:
async with self._session.get( async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password} headers={"Authorization": self._password}
) as response: ) as response:
data = await response.json() data = await response.json()