finally done translating all methods over to REST
This commit is contained in:
parent
6b513d1e67
commit
fc2b12af4e
|
|
@ -0,0 +1 @@
|
|||
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
|
||||
|
|
@ -18,6 +18,10 @@ class NodeConnectionClosed(NodeException):
|
|||
"""The node's connection is closed."""
|
||||
pass
|
||||
|
||||
class NodeRestException(NodeException):
|
||||
"""A request made using the node's REST uri failed"""
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotAvailable(PomiceException):
|
||||
"""The node is currently unavailable."""
|
||||
|
|
|
|||
112
pomice/player.py
112
pomice/player.py
|
|
@ -129,6 +129,8 @@ class Player(VoiceProtocol):
|
|||
|
||||
self._voice_state = {}
|
||||
|
||||
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<Pomice.player bot={self.bot} guildId={self.guild.id} "
|
||||
|
|
@ -217,9 +219,10 @@ class Player(VoiceProtocol):
|
|||
return
|
||||
|
||||
await self._node.send(
|
||||
op="voiceUpdate",
|
||||
guildId=str(self.guild.id),
|
||||
**voice_data
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"voice": voice_data}
|
||||
)
|
||||
|
||||
async def on_voice_server_update(self, data: dict):
|
||||
|
|
@ -283,7 +286,12 @@ class Player(VoiceProtocol):
|
|||
async def stop(self):
|
||||
"""Stops the currently playing track."""
|
||||
self._current = None
|
||||
await self._node.send(op="stop", guildId=str(self.guild.id))
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={'encodedTrack': None}
|
||||
)
|
||||
|
||||
async def disconnect(self, *, force: bool = False):
|
||||
"""Disconnects the player from voice."""
|
||||
|
|
@ -304,7 +312,7 @@ class Player(VoiceProtocol):
|
|||
assert self.channel is None and not self.is_connected
|
||||
|
||||
self._node._players.pop(self.guild.id)
|
||||
await self._node.send(op="destroy", guildId=str(self.guild.id))
|
||||
await self._node.send(method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id)
|
||||
|
||||
async def play(
|
||||
self,
|
||||
|
|
@ -336,22 +344,18 @@ class Player(VoiceProtocol):
|
|||
"No equivalent track was able to be found."
|
||||
)
|
||||
data = {
|
||||
"op": "play",
|
||||
"guildId": str(self.guild.id),
|
||||
"track": search.track_id,
|
||||
"startTime": str(start),
|
||||
"noReplace": ignore_if_playing
|
||||
"encodedTrack": search.track_id,
|
||||
"position": str(start),
|
||||
"endTime": str(end)
|
||||
}
|
||||
track.original = search
|
||||
track.track_id = search.track_id
|
||||
# Set track_id for later lavalink searches
|
||||
else:
|
||||
data = {
|
||||
"op": "play",
|
||||
"guildId": str(self.guild.id),
|
||||
"track": track.track_id,
|
||||
"startTime": str(start),
|
||||
"noReplace": ignore_if_playing
|
||||
"encodedTrack": track.track_id,
|
||||
"position": str(start),
|
||||
"endTime": str(end)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -374,7 +378,13 @@ class Player(VoiceProtocol):
|
|||
if end > 0:
|
||||
data["endTime"] = str(end)
|
||||
|
||||
await self._node.send(**data)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data=data,
|
||||
query=f"noReplace={ignore_if_playing}"
|
||||
)
|
||||
|
||||
self._current = track
|
||||
return self._current
|
||||
|
|
@ -386,18 +396,33 @@ class Player(VoiceProtocol):
|
|||
"Seek position must be between 0 and the track length"
|
||||
)
|
||||
|
||||
await self._node.send(op="seek", guildId=str(self.guild.id), position=position)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"position": position}
|
||||
)
|
||||
return self._position
|
||||
|
||||
async def set_pause(self, pause: bool) -> bool:
|
||||
"""Sets the pause state of the currently playing track."""
|
||||
await self._node.send(op="pause", guildId=str(self.guild.id), pause=pause)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"paused": pause}
|
||||
)
|
||||
self._paused = pause
|
||||
return self._paused
|
||||
|
||||
async def set_volume(self, volume: int) -> int:
|
||||
"""Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500."""
|
||||
await self._node.send(op="volume", guildId=str(self.guild.id), volume=volume)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"volume": volume}
|
||||
)
|
||||
self._volume = volume
|
||||
return self._volume
|
||||
|
||||
|
|
@ -411,7 +436,12 @@ class Player(VoiceProtocol):
|
|||
|
||||
self._filters.add_filter(filter=filter)
|
||||
payload = self._filters.get_all_payloads()
|
||||
await self._node.send(op="filters", guildId=str(self.guild.id), **payload)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"filters": payload}
|
||||
)
|
||||
if fast_apply:
|
||||
await self.seek(self.position)
|
||||
|
||||
|
|
@ -427,7 +457,12 @@ class Player(VoiceProtocol):
|
|||
|
||||
self._filters.remove_filter(filter_tag=filter_tag)
|
||||
payload = self._filters.get_all_payloads()
|
||||
await self._node.send(op="filters", guildId=str(self.guild.id), **payload)
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"filters": payload}
|
||||
)
|
||||
if fast_apply:
|
||||
await self.seek(self.position)
|
||||
|
||||
|
|
@ -446,14 +481,45 @@ class Player(VoiceProtocol):
|
|||
"You must have filters applied first in order to use this method."
|
||||
)
|
||||
self._filters.reset_filters()
|
||||
await self._node.send(op="filters", guildId=str(self.guild.id))
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
path=self._player_endpoint_uri,
|
||||
guild_id=self._guild.id,
|
||||
data={"filters": {}}
|
||||
)
|
||||
|
||||
if fast_apply:
|
||||
await self.seek(self.position)
|
||||
|
||||
|
||||
|
||||
class QueuePlayer(Player):
|
||||
"""Player class, but with pomice.Queue included"""
|
||||
|
||||
def __init__(self, client: Optional[Client] = None, channel: Optional[VoiceChannel] = None, *, node: Node = None):
|
||||
super().__init__(client, channel, node=node)
|
||||
self._queue = Queue
|
||||
|
||||
|
||||
async def get_tracks(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None
|
||||
):
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
||||
If you passed in Spotify API credentials when you created the node,
|
||||
you can also pass in a Spotify URL of a playlist, album or track and it will be parsed
|
||||
accordingly.
|
||||
|
||||
You can pass in a discord.py Context object to get a
|
||||
Context object on any track you search.
|
||||
|
||||
You may also pass in a List of filters
|
||||
to be applied to your track once it plays.
|
||||
"""
|
||||
super()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import re
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -25,6 +27,7 @@ from .exceptions import (
|
|||
NodeException,
|
||||
NodeNotAvailable,
|
||||
NoNodesAvailable,
|
||||
NodeRestException,
|
||||
TrackLoadError
|
||||
)
|
||||
from .filters import Filter
|
||||
|
|
@ -81,14 +84,14 @@ class Node:
|
|||
self._secure = secure
|
||||
|
||||
|
||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
|
||||
self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
||||
|
||||
self._session = session or aiohttp.ClientSession()
|
||||
self._websocket: aiohttp.ClientWebSocketResponse = None
|
||||
self._task: asyncio.Task = None
|
||||
|
||||
self._connection_id = None
|
||||
self._session_id = None
|
||||
self._metadata = None
|
||||
self._available = None
|
||||
|
||||
|
|
@ -153,6 +156,7 @@ class Node:
|
|||
"""Property which returns the latency of the node"""
|
||||
return Ping(self._host, port=self._port).get_ping()
|
||||
|
||||
|
||||
async def _update_handler(self, data: dict):
|
||||
await self._bot.wait_until_ready()
|
||||
|
||||
|
|
@ -208,13 +212,32 @@ class Node:
|
|||
elif op == "playerUpdate":
|
||||
await player._update_state(data)
|
||||
|
||||
async def send(self, **data):
|
||||
async def send(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
guild_id: Optional[Union[int, str]],
|
||||
query: Optional[str],
|
||||
data: Optional[Union[dict, str]]
|
||||
):
|
||||
if not self._available:
|
||||
raise NodeNotAvailable(
|
||||
f"The node '{self._identifier}' is unavailable."
|
||||
)
|
||||
|
||||
await self._websocket.send_str(json.dumps(data))
|
||||
uri: str = f'{self._rest_uri}/' \
|
||||
f'v3/' \
|
||||
f'{path}' \
|
||||
f'{f"/{guild_id}" if guild_id else ""}' \
|
||||
f'{f"?{query}" if query else ""}'
|
||||
|
||||
async with self._session.request(method=method, url=uri, json=data or {}) as resp:
|
||||
if resp.status >= 300:
|
||||
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
|
||||
|
||||
return await resp.json()
|
||||
|
||||
|
||||
|
||||
def get_player(self, guild_id: int):
|
||||
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
|
||||
|
|
@ -230,6 +253,14 @@ class Node:
|
|||
)
|
||||
self._task = self._bot.loop.create_task(self._listen())
|
||||
self._available = True
|
||||
self._session_id = f"pomice_{secrets.token_hex(20)}"
|
||||
async with self._session.get(f'{self._host}/v3/version') as resp:
|
||||
version: str = await resp.text()
|
||||
# To make version comparasion easier, lets remove the periods
|
||||
# from the version numbers and compare them like whole numbers
|
||||
version = int(version.translate(str.maketrans('', '', string.punctuation)).replace(" ", ""))
|
||||
print(version)
|
||||
|
||||
return self
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
|
|
@ -270,7 +301,7 @@ class Node:
|
|||
"""
|
||||
|
||||
async with self._session.get(
|
||||
f"{self._rest_uri}/decodetrack?",
|
||||
f"{self._rest_uri}/v3/decodetrack?",
|
||||
headers={"Authorization": self._password},
|
||||
params={"track": identifier}
|
||||
) as resp:
|
||||
|
|
@ -376,7 +407,7 @@ class Node:
|
|||
|
||||
elif discord_url := DISCORD_MP3_URL_REGEX.match(query):
|
||||
async with self._session.get(
|
||||
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
|
||||
url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
|
||||
headers={"Authorization": self._password}
|
||||
) as response:
|
||||
data: dict = await response.json()
|
||||
|
|
@ -402,7 +433,7 @@ class Node:
|
|||
|
||||
else:
|
||||
async with self._session.get(
|
||||
url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}",
|
||||
url=f"{self._rest_uri}/v3/loadtracks?identifier={quote(query)}",
|
||||
headers={"Authorization": self._password}
|
||||
) as response:
|
||||
data = await response.json()
|
||||
|
|
|
|||
Loading…
Reference in New Issue