fix a couple of outstanding bugs

This commit is contained in:
cloudwithax 2023-03-09 08:24:26 -05:00
parent 9d831d3ecd
commit de7385d8ff
8 changed files with 114 additions and 138 deletions

View File

@ -3,8 +3,9 @@ Pomice
~~~~~~
The modern Lavalink wrapper designed for discord.py.
:copyright: 2023, cloudwithax
:license: GPL-3.0
Copyright (c) 2023, cloudwithax
Licensed under GPL-3.0
"""
import discord
@ -18,9 +19,11 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'"
)
__version__ = "2.1.1"
__version__ = "2.2a"
__title__ = "pomice"
__author__ = "cloudwithax"
__license__ = "GPL-3.0"
__copyright__ = "Copyright (c) 2023, cloudwithax"
from .enums import *
from .events import *

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re
import aiohttp
import orjson as json
@ -6,6 +8,10 @@ import base64
from datetime import datetime
from .objects import *
from .exceptions import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..pool import Node
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
@ -18,10 +24,11 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here.
"""
def __init__(self) -> None:
def __init__(self, node: Node) -> None:
self.token: str = None
self.expiry: datetime = None
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
self.node = node
self.session = self.node._session
self.headers = None

View File

@ -42,11 +42,11 @@ class TrackType(Enum):
"""
# We don't have to define anything special for these, since these just serve as flags
YOUTUBE = "youtube_track"
SOUNDCLOUD = "soundcloud_track"
SPOTIFY = "spotify_track"
APPLE_MUSIC = "apple_music_track"
HTTP = "http_source"
YOUTUBE = "youtube"
SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
HTTP = "http"
def __str__(self) -> str:
return self.value
@ -65,10 +65,10 @@ class PlaylistType(Enum):
"""
# We don't have to define anything special for these, since these just serve as flags
YOUTUBE = "youtube_playlist"
SOUNDCLOUD = "soundcloud_playlist"
SPOTIFY = "spotify_playlist"
APPLE_MUSIC = "apple_music_list"
YOUTUBE = "youtube"
SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
def __str__(self) -> str:
return self.value
@ -114,27 +114,6 @@ class LoopMode(Enum):
QUEUE = "queue"
def __str__(self) -> str:
return self.value
class PlatformRecommendation(Enum):
"""
The enum for choosing what platform you want for recommendations.
This feature is exclusively for the recommendations function.
If you are not using this feature, this class is not necessary.
PlatformRecommendation.SPOTIFY sets the recommendations to come from Spotify
PlatformRecommendation.YOUTUBE sets the recommendations to come from YouTube
"""
# We don't have to define anything special for these, since these just serve as flags
SPOTIFY = "spotify"
YOUTUBE = "youtube"
def __str__(self) -> str:
return self.value

View File

@ -7,10 +7,6 @@ from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType
from .filters import Filter
from . import (
spotify,
applemusic
)
class Track:

View File

@ -16,11 +16,11 @@ from discord import (
from discord.ext import commands
from . import events
from .enums import SearchType, PlatformRecommendation
from .enums import SearchType
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError
from .filters import Filter
from .objects import Track, Playlist
from .objects import Track
from .pool import Node, NodePool
class Filters:
@ -111,24 +111,24 @@ class Player(VoiceProtocol):
*,
node: Node = None
):
self.client = client
self.client: Optional[Client] = client
self._bot: Union[Client, commands.Bot] = client
self.channel = channel
self._guild = channel.guild if channel else None
self.channel: Optional[VoiceChannel] = channel
self._guild: Guild = channel.guild if channel else None
self._node = node if node else NodePool.get_node()
self._node: Node = node if node else NodePool.get_node()
self._current: Track = None
self._filters: Filters = Filters()
self._volume = 100
self._paused = False
self._is_connected = False
self._volume: int = 100
self._paused: bool = False
self._is_connected: bool = False
self._position = 0
self._last_position = 0
self._last_update = 0
self._position: int = 0
self._last_position: int = 0
self._last_update: int = 0
self._ending_track: Optional[Track] = None
self._voice_state = {}
self._voice_state: dict = {}
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
@ -211,7 +211,7 @@ class Player(VoiceProtocol):
async def _update_state(self, data: dict):
state: dict = data.get("state")
self._last_update = time.time() * 1000
self._last_update = int(state.get("time")) * 1000
self._is_connected = state.get("connected")
self._last_position = state.get("position")
@ -366,7 +366,7 @@ class Player(VoiceProtocol):
data = {
"encodedTrack": search.track_id,
"position": str(start),
"endTime": str(end)
"endTime": str(track.length)
}
track.original = search
track.track_id = search.track_id
@ -375,7 +375,7 @@ class Player(VoiceProtocol):
data = {
"encodedTrack": track.track_id,
"position": str(start),
"endTime": str(end)
"endTime": str(track.length)
}
@ -400,7 +400,12 @@ class Player(VoiceProtocol):
# Now apply all filters
for filter in track.filters:
await self.add_filter(filter=filter)
# Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track,
# otherwise itll be set here:
if end > 0:
data["endTime"] = str(end)

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import asyncio
import random
import re
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote
import logging
import aiohttp
from discord import Client
from discord.ext import commands
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from urllib.parse import quote
from . import (
__version__,
@ -37,6 +37,8 @@ from .routeplanner import RoutePlanner
if TYPE_CHECKING:
from .player import Player
_log = logging.getLogger(__name__)
class Node:
"""The base class for a node.
@ -48,7 +50,7 @@ class Node:
def __init__(
self,
*,
pool,
pool: NodePool,
bot: Union[Client, commands.Bot],
host: str,
port: int,
@ -59,7 +61,8 @@ class Node:
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False
apple_music: bool = False,
fallback: bool = False
):
self._bot = bot
@ -70,6 +73,7 @@ class Node:
self._identifier = identifier
self._heartbeat = heartbeat
self._secure = secure
self.fallback = fallback
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
@ -100,11 +104,11 @@ class Node:
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client(
self._spotify_client_id, self._spotify_client_secret
self, self._spotify_client_id, self._spotify_client_secret
)
if apple_music:
self._apple_music_client = applemusic.Client()
self._apple_music_client = applemusic.Client(self)
self._bot.add_listener(self._update_handler, "on_socket_response")
@ -165,6 +169,7 @@ class Node:
if data["t"] == "VOICE_SERVER_UPDATE":
guild_id = int(data["d"]["guild_id"])
_log.debug(f"Recieved voice server update message from guild ID: {guild_id}")
try:
player = self._players[guild_id]
await player.on_voice_server_update(data["d"])
@ -176,6 +181,7 @@ class Node:
return
guild_id = int(data["d"]["guild_id"])
_log.debug(f"Recieved voice state update message from guild ID: {guild_id}")
try:
player = self._players[guild_id]
await player.on_voice_state_update(data["d"])
@ -187,13 +193,13 @@ class Node:
while True:
msg = await self._websocket.receive()
if msg.type == aiohttp.WSMsgType.CLOSED:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
retry = backoff.delay()
await asyncio.sleep(retry)
if not self.is_connected:
self._bot.loop.create_task(self.connect())
if not self.is_connected:
asyncio.create_task(self.connect())
else:
self._bot.loop.create_task(self._handle_payload(msg.json()))
asyncio.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict):
op = data.get("op", None)
@ -216,32 +222,6 @@ class Node:
elif op == "playerUpdate":
await player._update_state(data)
def _get_type(self, query: str):
if match := URLRegex.LAVALINK_SEARCH.match(query):
type = match.group("type")
if type == "sc":
return TrackType.SOUNDCLOUD
return TrackType.YOUTUBE
elif URLRegex.YOUTUBE_URL.match(query):
if URLRegex.YOUTUBE_PLAYLIST_URL.match(query):
return PlaylistType.YOUTUBE
return TrackType.YOUTUBE
elif URLRegex.SOUNDCLOUD_URL.match(query):
if URLRegex.SOUNDCLOUD_TRACK_IN_SET_URL.match(query):
return TrackType.SOUNDCLOUD
if URLRegex.SOUNDCLOUD_PLAYLIST_URL.match(query):
return PlaylistType.SOUNDCLOUD
return TrackType.SOUNDCLOUD
else:
return TrackType.HTTP
async def send(
self,
method: str,
@ -249,9 +229,10 @@ class Node:
include_version: bool = True,
guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None,
data: Optional[Union[dict, str]] = None
data: Optional[Union[dict, str]] = None,
ignore_if_available: bool = False,
):
if not self._available:
if not ignore_if_available and not self._available:
raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable."
)
@ -264,7 +245,8 @@ class Node:
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp:
if resp.status >= 300:
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
data: dict = await resp.json()
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}')
if method == "DELETE" or resp.status == 204:
return await resp.json(content_type=None)
@ -285,34 +267,42 @@ class Node:
await self._bot.wait_until_ready()
try:
version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False)
version = version.replace(".", "")
if not version.endswith('-SNAPSHOT') and int(version) < 370:
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library."
)
self._websocket = await self._session.ws_connect(
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
)
self._task = self._bot.loop.create_task(self._listen())
self._available = True
version = await self.send(method="GET", path="version", include_version=False)
version = version.replace(".", "")
if int(version) < 370:
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible."
"Lavalink version 3.7.0 or above is required to use this library."
)
self._version = version[:1]
if not self._task:
self._task = asyncio.create_task(self._listen())
self._available = True
if version.endswith('-SNAPSHOT'):
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = version[:1]
return self
except aiohttp.ClientConnectorError:
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed."
)
) from None
except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid."
)
) from None
except aiohttp.InvalidURL:
raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid."
)
) from None
async def disconnect(self):
"""Disconnects a connected Lavalink node and removes it from the node pool.
@ -322,7 +312,8 @@ class Node:
await player.destroy()
await self._websocket.close()
del self._pool.nodes[self._identifier]
await self._session.close()
del self._pool._nodes[self._identifier]
self.available = False
self._task.cancel()
@ -338,8 +329,6 @@ class Node:
Context object on the track it builds.
"""
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}")
return Track(track_id=identifier, ctx=ctx, info=data)
@ -537,8 +526,6 @@ class Node:
load_type = data.get("loadType")
query_type = self._get_type(query)
if not load_type:
raise TrackLoadError("There was an error while trying to load this track.")
@ -550,19 +537,14 @@ class Node:
return None
elif load_type == "PLAYLIST_LOADED":
if query_type == PlaylistType.SOUNDCLOUD:
track_type = TrackType.SOUNDCLOUD
else:
track_type = TrackType.YOUTUBE
tracks = [
Track(track_id=track["track"], info=track["info"], ctx=ctx, track_type=track_type)
Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]))
for track in data["tracks"]
]
return Playlist(
playlist_info=data["playlistInfo"],
tracks=tracks,
playlist_type=query_type,
playlist_type=PlaylistType(tracks[0].track_type.value),
thumbnail=tracks[0].thumbnail,
uri=query
)
@ -570,10 +552,10 @@ class Node:
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
return [
Track(
track_id=track["track"],
track_id=track["encoded"],
info=track["info"],
ctx=ctx,
track_type=query_type,
track_type=TrackType(track["info"]["sourceName"]),
filters=filters,
timestamp=timestamp
)
@ -629,7 +611,7 @@ class NodePool:
This holds all the nodes that are to be used by the bot.
"""
_nodes: dict = {}
_nodes: Dict[str, Node] = {}
def __repr__(self):
return f"<Pomice.NodePool node_count={self.node_count}>"
@ -658,7 +640,7 @@ class NodePool:
based on how players it has. This method will return a node with
the least amount of players
"""
available_nodes = [node for node in cls._nodes.values() if node._available]
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
if not available_nodes:
raise NoNodesAvailable("There are no nodes available.")
@ -704,7 +686,8 @@ class NodePool:
spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False
apple_music: bool = False,
fallback: bool = False
) -> Node:
"""Creates a Node object to be then added into the node pool.
@ -718,7 +701,7 @@ class NodePool:
identifier=identifier, secure=secure, heartbeat=heartbeat,
spotify_client_id=spotify_client_id,
session=session, spotify_client_secret=spotify_client_secret,
apple_music=apple_music
apple_music=apple_music, fallback=fallback
)
await node.connect()

View File

@ -1,14 +1,18 @@
from __future__ import annotations
import re
import time
from base64 import b64encode
import aiohttp
import orjson as json
from base64 import b64encode
from typing import TYPE_CHECKING
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .objects import *
if TYPE_CHECKING:
from ..pool import Node
GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
SPOTIFY_URL_REGEX = re.compile(
@ -22,11 +26,12 @@ class Client:
for any Spotify URL you throw at it.
"""
def __init__(self, client_id: str, client_secret: str) -> None:
def __init__(self, node: Node, client_id: str, client_secret: str) -> None:
self._client_id = client_id
self._client_secret = client_secret
self.node = node
self.session = aiohttp.ClientSession()
self.session = self.node._session
self._bearer_token: str = None
self._expiry = 0

View File

@ -120,8 +120,6 @@ class RouteStats:
self.block_index = details.get("blockIndex")
self.address_index = details.get("currentAddressIndex")
def __repr__(self) -> str:
return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>"