fix a couple of outstanding bugs

This commit is contained in:
cloudwithax 2023-03-09 08:24:26 -05:00
parent 9d831d3ecd
commit de7385d8ff
8 changed files with 114 additions and 138 deletions

View File

@ -3,8 +3,9 @@ Pomice
~~~~~~ ~~~~~~
The modern Lavalink wrapper designed for discord.py. The modern Lavalink wrapper designed for discord.py.
:copyright: 2023, cloudwithax Copyright (c) 2023, cloudwithax
:license: GPL-3.0
Licensed under GPL-3.0
""" """
import discord import discord
@ -18,9 +19,11 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'" "using 'pip install discord.py'"
) )
__version__ = "2.1.1" __version__ = "2.2a"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0"
__copyright__ = "Copyright (c) 2023, cloudwithax"
from .enums import * from .enums import *
from .events import * from .events import *

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re import re
import aiohttp import aiohttp
import orjson as json import orjson as json
@ -6,6 +8,10 @@ import base64
from datetime import datetime from datetime import datetime
from .objects import * from .objects import *
from .exceptions import * from .exceptions import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..pool import Node
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)") AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)") AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
@ -18,10 +24,11 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here. and translating it to a valid Lavalink track. No client auth is required here.
""" """
def __init__(self) -> None: def __init__(self, node: Node) -> None:
self.token: str = None self.token: str = None
self.expiry: datetime = None self.expiry: datetime = None
self.session: aiohttp.ClientSession = aiohttp.ClientSession() self.node = node
self.session = self.node._session
self.headers = None self.headers = None

View File

@ -42,11 +42,11 @@ class TrackType(Enum):
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
YOUTUBE = "youtube_track" YOUTUBE = "youtube"
SOUNDCLOUD = "soundcloud_track" SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify_track" SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music_track" APPLE_MUSIC = "apple_music"
HTTP = "http_source" HTTP = "http"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -65,10 +65,10 @@ class PlaylistType(Enum):
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
YOUTUBE = "youtube_playlist" YOUTUBE = "youtube"
SOUNDCLOUD = "soundcloud_playlist" SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify_playlist" SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music_list" APPLE_MUSIC = "apple_music"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -114,27 +114,6 @@ class LoopMode(Enum):
QUEUE = "queue" QUEUE = "queue"
def __str__(self) -> str:
return self.value
class PlatformRecommendation(Enum):
"""
The enum for choosing what platform you want for recommendations.
This feature is exclusively for the recommendations function.
If you are not using this feature, this class is not necessary.
PlatformRecommendation.SPOTIFY sets the recommendations to come from Spotify
PlatformRecommendation.YOUTUBE sets the recommendations to come from YouTube
"""
# We don't have to define anything special for these, since these just serve as flags
SPOTIFY = "spotify"
YOUTUBE = "youtube"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value

View File

@ -7,10 +7,6 @@ from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType from .enums import SearchType, TrackType, PlaylistType
from .filters import Filter from .filters import Filter
from . import (
spotify,
applemusic
)
class Track: class Track:

View File

@ -16,11 +16,11 @@ from discord import (
from discord.ext import commands from discord.ext import commands
from . import events from . import events
from .enums import SearchType, PlatformRecommendation from .enums import SearchType
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError
from .filters import Filter from .filters import Filter
from .objects import Track, Playlist from .objects import Track
from .pool import Node, NodePool from .pool import Node, NodePool
class Filters: class Filters:
@ -111,24 +111,24 @@ class Player(VoiceProtocol):
*, *,
node: Node = None node: Node = None
): ):
self.client = client self.client: Optional[Client] = client
self._bot: Union[Client, commands.Bot] = client self._bot: Union[Client, commands.Bot] = client
self.channel = channel self.channel: Optional[VoiceChannel] = channel
self._guild = channel.guild if channel else None self._guild: Guild = channel.guild if channel else None
self._node = node if node else NodePool.get_node() self._node: Node = node if node else NodePool.get_node()
self._current: Track = None self._current: Track = None
self._filters: Filters = Filters() self._filters: Filters = Filters()
self._volume = 100 self._volume: int = 100
self._paused = False self._paused: bool = False
self._is_connected = False self._is_connected: bool = False
self._position = 0 self._position: int = 0
self._last_position = 0 self._last_position: int = 0
self._last_update = 0 self._last_update: int = 0
self._ending_track: Optional[Track] = None self._ending_track: Optional[Track] = None
self._voice_state = {} self._voice_state: dict = {}
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players' self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
@ -211,7 +211,7 @@ class Player(VoiceProtocol):
async def _update_state(self, data: dict): async def _update_state(self, data: dict):
state: dict = data.get("state") state: dict = data.get("state")
self._last_update = time.time() * 1000 self._last_update = int(state.get("time")) * 1000
self._is_connected = state.get("connected") self._is_connected = state.get("connected")
self._last_position = state.get("position") self._last_position = state.get("position")
@ -366,7 +366,7 @@ class Player(VoiceProtocol):
data = { data = {
"encodedTrack": search.track_id, "encodedTrack": search.track_id,
"position": str(start), "position": str(start),
"endTime": str(end) "endTime": str(track.length)
} }
track.original = search track.original = search
track.track_id = search.track_id track.track_id = search.track_id
@ -375,7 +375,7 @@ class Player(VoiceProtocol):
data = { data = {
"encodedTrack": track.track_id, "encodedTrack": track.track_id,
"position": str(start), "position": str(start),
"endTime": str(end) "endTime": str(track.length)
} }
@ -401,6 +401,11 @@ class Player(VoiceProtocol):
for filter in track.filters: for filter in track.filters:
await self.add_filter(filter=filter) await self.add_filter(filter=filter)
# Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track,
# otherwise itll be set here:
if end > 0: if end > 0:
data["endTime"] = str(end) data["endTime"] = str(end)

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
import random import random
import re import re
from typing import Dict, List, Optional, TYPE_CHECKING, Union import logging
from urllib.parse import quote
import aiohttp import aiohttp
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote
from . import ( from . import (
__version__, __version__,
@ -37,6 +37,8 @@ from .routeplanner import RoutePlanner
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
_log = logging.getLogger(__name__)
class Node: class Node:
"""The base class for a node. """The base class for a node.
@ -48,7 +50,7 @@ class Node:
def __init__( def __init__(
self, self,
*, *,
pool, pool: NodePool,
bot: Union[Client, commands.Bot], bot: Union[Client, commands.Bot],
host: str, host: str,
port: int, port: int,
@ -59,7 +61,8 @@ class Node:
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
apple_music: bool = False apple_music: bool = False,
fallback: bool = False
): ):
self._bot = bot self._bot = bot
@ -70,6 +73,7 @@ class Node:
self._identifier = identifier self._identifier = identifier
self._heartbeat = heartbeat self._heartbeat = heartbeat
self._secure = secure self._secure = secure
self.fallback = fallback
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket" self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
@ -100,11 +104,11 @@ class Node:
if self._spotify_client_id and self._spotify_client_secret: if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client( self._spotify_client = spotify.Client(
self._spotify_client_id, self._spotify_client_secret self, self._spotify_client_id, self._spotify_client_secret
) )
if apple_music: if apple_music:
self._apple_music_client = applemusic.Client() self._apple_music_client = applemusic.Client(self)
self._bot.add_listener(self._update_handler, "on_socket_response") self._bot.add_listener(self._update_handler, "on_socket_response")
@ -165,6 +169,7 @@ class Node:
if data["t"] == "VOICE_SERVER_UPDATE": if data["t"] == "VOICE_SERVER_UPDATE":
guild_id = int(data["d"]["guild_id"]) guild_id = int(data["d"]["guild_id"])
_log.debug(f"Recieved voice server update message from guild ID: {guild_id}")
try: try:
player = self._players[guild_id] player = self._players[guild_id]
await player.on_voice_server_update(data["d"]) await player.on_voice_server_update(data["d"])
@ -176,6 +181,7 @@ class Node:
return return
guild_id = int(data["d"]["guild_id"]) guild_id = int(data["d"]["guild_id"])
_log.debug(f"Recieved voice state update message from guild ID: {guild_id}")
try: try:
player = self._players[guild_id] player = self._players[guild_id]
await player.on_voice_state_update(data["d"]) await player.on_voice_state_update(data["d"])
@ -187,13 +193,13 @@ class Node:
while True: while True:
msg = await self._websocket.receive() msg = await self._websocket.receive()
if msg.type == aiohttp.WSMsgType.CLOSED: if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
retry = backoff.delay() retry = backoff.delay()
await asyncio.sleep(retry) await asyncio.sleep(retry)
if not self.is_connected: if not self.is_connected:
self._bot.loop.create_task(self.connect()) asyncio.create_task(self.connect())
else: else:
self._bot.loop.create_task(self._handle_payload(msg.json())) asyncio.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict): async def _handle_payload(self, data: dict):
op = data.get("op", None) op = data.get("op", None)
@ -216,32 +222,6 @@ class Node:
elif op == "playerUpdate": elif op == "playerUpdate":
await player._update_state(data) await player._update_state(data)
def _get_type(self, query: str):
if match := URLRegex.LAVALINK_SEARCH.match(query):
type = match.group("type")
if type == "sc":
return TrackType.SOUNDCLOUD
return TrackType.YOUTUBE
elif URLRegex.YOUTUBE_URL.match(query):
if URLRegex.YOUTUBE_PLAYLIST_URL.match(query):
return PlaylistType.YOUTUBE
return TrackType.YOUTUBE
elif URLRegex.SOUNDCLOUD_URL.match(query):
if URLRegex.SOUNDCLOUD_TRACK_IN_SET_URL.match(query):
return TrackType.SOUNDCLOUD
if URLRegex.SOUNDCLOUD_PLAYLIST_URL.match(query):
return PlaylistType.SOUNDCLOUD
return TrackType.SOUNDCLOUD
else:
return TrackType.HTTP
async def send( async def send(
self, self,
method: str, method: str,
@ -249,9 +229,10 @@ class Node:
include_version: bool = True, include_version: bool = True,
guild_id: Optional[Union[int, str]] = None, guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None, query: Optional[str] = None,
data: Optional[Union[dict, str]] = None data: Optional[Union[dict, str]] = None,
ignore_if_available: bool = False,
): ):
if not self._available: if not ignore_if_available and not self._available:
raise NodeNotAvailable( raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable." f"The node '{self._identifier}' is unavailable."
) )
@ -264,7 +245,8 @@ class Node:
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp: async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp:
if resp.status >= 300: if resp.status >= 300:
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}') data: dict = await resp.json()
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}')
if method == "DELETE" or resp.status == 204: if method == "DELETE" or resp.status == 204:
return await resp.json(content_type=None) return await resp.json(content_type=None)
@ -285,34 +267,42 @@ class Node:
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
try: try:
self._websocket = await self._session.ws_connect( version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False)
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
)
self._task = self._bot.loop.create_task(self._listen())
self._available = True
version = await self.send(method="GET", path="version", include_version=False)
version = version.replace(".", "") version = version.replace(".", "")
if int(version) < 370: if not version.endswith('-SNAPSHOT') and int(version) < 370:
self._available = False
raise LavalinkVersionIncompatible( raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible." "The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library." "Lavalink version 3.7.0 or above is required to use this library."
) )
self._version = version[:1] self._websocket = await self._session.ws_connect(
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
)
if not self._task:
self._task = asyncio.create_task(self._listen())
self._available = True
if version.endswith('-SNAPSHOT'):
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = version[:1]
return self return self
except aiohttp.ClientConnectorError: except (aiohttp.ClientConnectorError, ConnectionRefusedError):
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed." f"The connection to node '{self._identifier}' failed."
) ) from None
except aiohttp.WSServerHandshakeError: except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid." f"The password for node '{self._identifier}' is invalid."
) ) from None
except aiohttp.InvalidURL: except aiohttp.InvalidURL:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid." f"The URI for node '{self._identifier}' is invalid."
) ) from None
async def disconnect(self): async def disconnect(self):
"""Disconnects a connected Lavalink node and removes it from the node pool. """Disconnects a connected Lavalink node and removes it from the node pool.
@ -322,7 +312,8 @@ class Node:
await player.destroy() await player.destroy()
await self._websocket.close() await self._websocket.close()
del self._pool.nodes[self._identifier] await self._session.close()
del self._pool._nodes[self._identifier]
self.available = False self.available = False
self._task.cancel() self._task.cancel()
@ -338,8 +329,6 @@ class Node:
Context object on the track it builds. Context object on the track it builds.
""" """
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}") data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}")
return Track(track_id=identifier, ctx=ctx, info=data) return Track(track_id=identifier, ctx=ctx, info=data)
@ -537,8 +526,6 @@ class Node:
load_type = data.get("loadType") load_type = data.get("loadType")
query_type = self._get_type(query)
if not load_type: if not load_type:
raise TrackLoadError("There was an error while trying to load this track.") raise TrackLoadError("There was an error while trying to load this track.")
@ -550,19 +537,14 @@ class Node:
return None return None
elif load_type == "PLAYLIST_LOADED": elif load_type == "PLAYLIST_LOADED":
if query_type == PlaylistType.SOUNDCLOUD:
track_type = TrackType.SOUNDCLOUD
else:
track_type = TrackType.YOUTUBE
tracks = [ tracks = [
Track(track_id=track["track"], info=track["info"], ctx=ctx, track_type=track_type) Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]))
for track in data["tracks"] for track in data["tracks"]
] ]
return Playlist( return Playlist(
playlist_info=data["playlistInfo"], playlist_info=data["playlistInfo"],
tracks=tracks, tracks=tracks,
playlist_type=query_type, playlist_type=PlaylistType(tracks[0].track_type.value),
thumbnail=tracks[0].thumbnail, thumbnail=tracks[0].thumbnail,
uri=query uri=query
) )
@ -570,10 +552,10 @@ class Node:
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED": elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
return [ return [
Track( Track(
track_id=track["track"], track_id=track["encoded"],
info=track["info"], info=track["info"],
ctx=ctx, ctx=ctx,
track_type=query_type, track_type=TrackType(track["info"]["sourceName"]),
filters=filters, filters=filters,
timestamp=timestamp timestamp=timestamp
) )
@ -629,7 +611,7 @@ class NodePool:
This holds all the nodes that are to be used by the bot. This holds all the nodes that are to be used by the bot.
""" """
_nodes: dict = {} _nodes: Dict[str, Node] = {}
def __repr__(self): def __repr__(self):
return f"<Pomice.NodePool node_count={self.node_count}>" return f"<Pomice.NodePool node_count={self.node_count}>"
@ -658,7 +640,7 @@ class NodePool:
based on how players it has. This method will return a node with based on how players it has. This method will return a node with
the least amount of players the least amount of players
""" """
available_nodes = [node for node in cls._nodes.values() if node._available] available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
if not available_nodes: if not available_nodes:
raise NoNodesAvailable("There are no nodes available.") raise NoNodesAvailable("There are no nodes available.")
@ -704,7 +686,8 @@ class NodePool:
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False apple_music: bool = False,
fallback: bool = False
) -> Node: ) -> Node:
"""Creates a Node object to be then added into the node pool. """Creates a Node object to be then added into the node pool.
@ -718,7 +701,7 @@ class NodePool:
identifier=identifier, secure=secure, heartbeat=heartbeat, identifier=identifier, secure=secure, heartbeat=heartbeat,
spotify_client_id=spotify_client_id, spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret, session=session, spotify_client_secret=spotify_client_secret,
apple_music=apple_music apple_music=apple_music, fallback=fallback
) )
await node.connect() await node.connect()

View File

@ -1,14 +1,18 @@
from __future__ import annotations
import re import re
import time import time
from base64 import b64encode
import aiohttp
import orjson as json import orjson as json
from base64 import b64encode
from typing import TYPE_CHECKING
from .exceptions import InvalidSpotifyURL, SpotifyRequestException from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .objects import * from .objects import *
if TYPE_CHECKING:
from ..pool import Node
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
SPOTIFY_URL_REGEX = re.compile( SPOTIFY_URL_REGEX = re.compile(
@ -22,11 +26,12 @@ class Client:
for any Spotify URL you throw at it. for any Spotify URL you throw at it.
""" """
def __init__(self, client_id: str, client_secret: str) -> None: def __init__(self, node: Node, client_id: str, client_secret: str) -> None:
self._client_id = client_id self._client_id = client_id
self._client_secret = client_secret self._client_secret = client_secret
self.node = node
self.session = aiohttp.ClientSession() self.session = self.node._session
self._bearer_token: str = None self._bearer_token: str = None
self._expiry = 0 self._expiry = 0

View File

@ -120,8 +120,6 @@ class RouteStats:
self.block_index = details.get("blockIndex") self.block_index = details.get("blockIndex")
self.address_index = details.get("currentAddressIndex") self.address_index = details.get("currentAddressIndex")
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>" return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>"