finally done translating all methods over to REST

This commit is contained in:
cloudwithax 2023-02-01 22:06:54 -05:00
parent 6b513d1e67
commit fc2b12af4e
4 changed files with 134 additions and 32 deletions

View File

@ -0,0 +1 @@
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""

View File

@ -18,6 +18,10 @@ class NodeConnectionClosed(NodeException):
"""The node's connection is closed."""
pass
class NodeRestException(NodeException):
"""A request made using the node's REST uri failed"""
pass
class NodeNotAvailable(PomiceException):
"""The node is currently unavailable."""

View File

@ -129,6 +129,8 @@ class Player(VoiceProtocol):
self._voice_state = {}
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
def __repr__(self):
return (
f"<Pomice.player bot={self.bot} guildId={self.guild.id} "
@ -217,9 +219,10 @@ class Player(VoiceProtocol):
return
await self._node.send(
op="voiceUpdate",
guildId=str(self.guild.id),
**voice_data
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"voice": voice_data}
)
async def on_voice_server_update(self, data: dict):
@ -283,7 +286,12 @@ class Player(VoiceProtocol):
async def stop(self):
"""Stops the currently playing track."""
self._current = None
await self._node.send(op="stop", guildId=str(self.guild.id))
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={'encodedTrack': None}
)
async def disconnect(self, *, force: bool = False):
"""Disconnects the player from voice."""
@ -304,7 +312,7 @@ class Player(VoiceProtocol):
assert self.channel is None and not self.is_connected
self._node._players.pop(self.guild.id)
await self._node.send(op="destroy", guildId=str(self.guild.id))
await self._node.send(method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id)
async def play(
self,
@ -336,22 +344,18 @@ class Player(VoiceProtocol):
"No equivalent track was able to be found."
)
data = {
"op": "play",
"guildId": str(self.guild.id),
"track": search.track_id,
"startTime": str(start),
"noReplace": ignore_if_playing
"encodedTrack": search.track_id,
"position": str(start),
"endTime": str(end)
}
track.original = search
track.track_id = search.track_id
# Set track_id for later lavalink searches
else:
data = {
"op": "play",
"guildId": str(self.guild.id),
"track": track.track_id,
"startTime": str(start),
"noReplace": ignore_if_playing
"encodedTrack": track.track_id,
"position": str(start),
"endTime": str(end)
}
@ -374,7 +378,13 @@ class Player(VoiceProtocol):
if end > 0:
data["endTime"] = str(end)
await self._node.send(**data)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data,
query=f"noReplace={ignore_if_playing}"
)
self._current = track
return self._current
@ -386,18 +396,33 @@ class Player(VoiceProtocol):
"Seek position must be between 0 and the track length"
)
await self._node.send(op="seek", guildId=str(self.guild.id), position=position)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"position": position}
)
return self._position
async def set_pause(self, pause: bool) -> bool:
"""Sets the pause state of the currently playing track."""
await self._node.send(op="pause", guildId=str(self.guild.id), pause=pause)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"paused": pause}
)
self._paused = pause
return self._paused
async def set_volume(self, volume: int) -> int:
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
await self._node.send(op="volume", guildId=str(self.guild.id), volume=volume)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"volume": volume}
)
self._volume = volume
return self._volume
@ -411,7 +436,12 @@ class Player(VoiceProtocol):
self._filters.add_filter(filter=filter)
payload = self._filters.get_all_payloads()
await self._node.send(op="filters", guildId=str(self.guild.id), **payload)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply:
await self.seek(self.position)
@ -427,7 +457,12 @@ class Player(VoiceProtocol):
self._filters.remove_filter(filter_tag=filter_tag)
payload = self._filters.get_all_payloads()
await self._node.send(op="filters", guildId=str(self.guild.id), **payload)
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload}
)
if fast_apply:
await self.seek(self.position)
@ -446,14 +481,45 @@ class Player(VoiceProtocol):
"You must have filters applied first in order to use this method."
)
self._filters.reset_filters()
await self._node.send(op="filters", guildId=str(self.guild.id))
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": {}}
)
if fast_apply:
await self.seek(self.position)
class QueuePlayer(Player):
"""Player class, but with pomice.Queue included"""
def __init__(self, client: Optional[Client] = None, channel: Optional[VoiceChannel] = None, *, node: Node = None):
super().__init__(client, channel, node=node)
self._queue = Queue
async def get_tracks(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None
):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials when you created the node,
you can also pass in a Spotify URL of a playlist, album or track and it will be parsed
accordingly.
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
to be applied to your track once it plays.
"""
super()

View File

@ -3,8 +3,10 @@ from __future__ import annotations
import asyncio
import json
import random
import secrets
import string
import re
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote
import aiohttp
@ -25,6 +27,7 @@ from .exceptions import (
NodeException,
NodeNotAvailable,
NoNodesAvailable,
NodeRestException,
TrackLoadError
)
from .filters import Filter
@ -81,14 +84,14 @@ class Node:
self._secure = secure
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session = session or aiohttp.ClientSession()
self._websocket: aiohttp.ClientWebSocketResponse = None
self._task: asyncio.Task = None
self._connection_id = None
self._session_id = None
self._metadata = None
self._available = None
@ -153,6 +156,7 @@ class Node:
"""Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping()
async def _update_handler(self, data: dict):
await self._bot.wait_until_ready()
@ -208,13 +212,32 @@ class Node:
elif op == "playerUpdate":
await player._update_state(data)
async def send(self, **data):
async def send(
self,
method: str,
path: str,
guild_id: Optional[Union[int, str]],
query: Optional[str],
data: Optional[Union[dict, str]]
):
if not self._available:
raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable."
)
await self._websocket.send_str(json.dumps(data))
uri: str = f'{self._rest_uri}/' \
f'v3/' \
f'{path}' \
f'{f"/{guild_id}" if guild_id else ""}' \
f'{f"?{query}" if query else ""}'
async with self._session.request(method=method, url=uri, json=data or {}) as resp:
if resp.status >= 300:
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
return await resp.json()
def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
@ -230,6 +253,14 @@ class Node:
)
self._task = self._bot.loop.create_task(self._listen())
self._available = True
self._session_id = f"pomice_{secrets.token_hex(20)}"
async with self._session.get(f'{self._host}/v3/version') as resp:
version: str = await resp.text()
# To make version comparasion easier, lets remove the periods
# from the version numbers and compare them like whole numbers
version = int(version.translate(str.maketrans('', '', string.punctuation)).replace(" ", ""))
print(version)
return self
except aiohttp.ClientConnectorError:
@ -270,7 +301,7 @@ class Node:
"""
async with self._session.get(
f"{self._rest_uri}/decodetrack?",
f"{self._rest_uri}/v3/decodetrack?",
headers={"Authorization": self._password},
params={"track": identifier}
) as resp:
@ -376,7 +407,7 @@ class Node:
elif discord_url := DISCORD_MP3_URL_REGEX.match(query):
async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password}
) as response:
data: dict = await response.json()
@ -402,7 +433,7 @@ class Node:
else:
async with self._session.get(
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
headers={"Authorization": self._password}
) as response:
data = await response.json()