fix a couple of outstanding bugs
This commit is contained in:
parent
9d831d3ecd
commit
de7385d8ff
|
|
@ -3,8 +3,9 @@ Pomice
|
||||||
~~~~~~
|
~~~~~~
|
||||||
The modern Lavalink wrapper designed for discord.py.
|
The modern Lavalink wrapper designed for discord.py.
|
||||||
|
|
||||||
:copyright: 2023, cloudwithax
|
Copyright (c) 2023, cloudwithax
|
||||||
:license: GPL-3.0
|
|
||||||
|
Licensed under GPL-3.0
|
||||||
"""
|
"""
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
|
|
@ -18,9 +19,11 @@ if not discord.version_info.major >= 2:
|
||||||
"using 'pip install discord.py'"
|
"using 'pip install discord.py'"
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "2.1.1"
|
__version__ = "2.2a"
|
||||||
__title__ = "pomice"
|
__title__ = "pomice"
|
||||||
__author__ = "cloudwithax"
|
__author__ = "cloudwithax"
|
||||||
|
__license__ = "GPL-3.0"
|
||||||
|
__copyright__ = "Copyright (c) 2023, cloudwithax"
|
||||||
|
|
||||||
from .enums import *
|
from .enums import *
|
||||||
from .events import *
|
from .events import *
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import orjson as json
|
import orjson as json
|
||||||
|
|
@ -6,6 +8,10 @@ import base64
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .objects import *
|
from .objects import *
|
||||||
from .exceptions import *
|
from .exceptions import *
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..pool import Node
|
||||||
|
|
||||||
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
|
AM_URL_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)")
|
||||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
|
AM_SINGLE_IN_ALBUM_REGEX = re.compile(r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)")
|
||||||
|
|
@ -18,10 +24,11 @@ class Client:
|
||||||
and translating it to a valid Lavalink track. No client auth is required here.
|
and translating it to a valid Lavalink track. No client auth is required here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, node: Node) -> None:
|
||||||
self.token: str = None
|
self.token: str = None
|
||||||
self.expiry: datetime = None
|
self.expiry: datetime = None
|
||||||
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
|
self.node = node
|
||||||
|
self.session = self.node._session
|
||||||
self.headers = None
|
self.headers = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,11 +42,11 @@ class TrackType(Enum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
YOUTUBE = "youtube_track"
|
YOUTUBE = "youtube"
|
||||||
SOUNDCLOUD = "soundcloud_track"
|
SOUNDCLOUD = "soundcloud"
|
||||||
SPOTIFY = "spotify_track"
|
SPOTIFY = "spotify"
|
||||||
APPLE_MUSIC = "apple_music_track"
|
APPLE_MUSIC = "apple_music"
|
||||||
HTTP = "http_source"
|
HTTP = "http"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -65,10 +65,10 @@ class PlaylistType(Enum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
YOUTUBE = "youtube_playlist"
|
YOUTUBE = "youtube"
|
||||||
SOUNDCLOUD = "soundcloud_playlist"
|
SOUNDCLOUD = "soundcloud"
|
||||||
SPOTIFY = "spotify_playlist"
|
SPOTIFY = "spotify"
|
||||||
APPLE_MUSIC = "apple_music_list"
|
APPLE_MUSIC = "apple_music"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -114,27 +114,6 @@ class LoopMode(Enum):
|
||||||
QUEUE = "queue"
|
QUEUE = "queue"
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
class PlatformRecommendation(Enum):
|
|
||||||
|
|
||||||
"""
|
|
||||||
The enum for choosing what platform you want for recommendations.
|
|
||||||
This feature is exclusively for the recommendations function.
|
|
||||||
If you are not using this feature, this class is not necessary.
|
|
||||||
|
|
||||||
PlatformRecommendation.SPOTIFY sets the recommendations to come from Spotify
|
|
||||||
|
|
||||||
PlatformRecommendation.YOUTUBE sets the recommendations to come from YouTube
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
|
||||||
SPOTIFY = "spotify"
|
|
||||||
YOUTUBE = "youtube"
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,6 @@ from discord.ext import commands
|
||||||
from .enums import SearchType, TrackType, PlaylistType
|
from .enums import SearchType, TrackType, PlaylistType
|
||||||
from .filters import Filter
|
from .filters import Filter
|
||||||
|
|
||||||
from . import (
|
|
||||||
spotify,
|
|
||||||
applemusic
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Track:
|
class Track:
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,11 @@ from discord import (
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from . import events
|
from . import events
|
||||||
from .enums import SearchType, PlatformRecommendation
|
from .enums import SearchType
|
||||||
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
|
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
|
||||||
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError
|
from .exceptions import FilterInvalidArgument, FilterTagAlreadyInUse, FilterTagInvalid, TrackInvalidPosition, TrackLoadError
|
||||||
from .filters import Filter
|
from .filters import Filter
|
||||||
from .objects import Track, Playlist
|
from .objects import Track
|
||||||
from .pool import Node, NodePool
|
from .pool import Node, NodePool
|
||||||
|
|
||||||
class Filters:
|
class Filters:
|
||||||
|
|
@ -111,24 +111,24 @@ class Player(VoiceProtocol):
|
||||||
*,
|
*,
|
||||||
node: Node = None
|
node: Node = None
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client: Optional[Client] = client
|
||||||
self._bot: Union[Client, commands.Bot] = client
|
self._bot: Union[Client, commands.Bot] = client
|
||||||
self.channel = channel
|
self.channel: Optional[VoiceChannel] = channel
|
||||||
self._guild = channel.guild if channel else None
|
self._guild: Guild = channel.guild if channel else None
|
||||||
|
|
||||||
self._node = node if node else NodePool.get_node()
|
self._node: Node = node if node else NodePool.get_node()
|
||||||
self._current: Track = None
|
self._current: Track = None
|
||||||
self._filters: Filters = Filters()
|
self._filters: Filters = Filters()
|
||||||
self._volume = 100
|
self._volume: int = 100
|
||||||
self._paused = False
|
self._paused: bool = False
|
||||||
self._is_connected = False
|
self._is_connected: bool = False
|
||||||
|
|
||||||
self._position = 0
|
self._position: int = 0
|
||||||
self._last_position = 0
|
self._last_position: int = 0
|
||||||
self._last_update = 0
|
self._last_update: int = 0
|
||||||
self._ending_track: Optional[Track] = None
|
self._ending_track: Optional[Track] = None
|
||||||
|
|
||||||
self._voice_state = {}
|
self._voice_state: dict = {}
|
||||||
|
|
||||||
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
|
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
|
||||||
|
|
||||||
|
|
@ -211,7 +211,7 @@ class Player(VoiceProtocol):
|
||||||
|
|
||||||
async def _update_state(self, data: dict):
|
async def _update_state(self, data: dict):
|
||||||
state: dict = data.get("state")
|
state: dict = data.get("state")
|
||||||
self._last_update = time.time() * 1000
|
self._last_update = int(state.get("time")) * 1000
|
||||||
self._is_connected = state.get("connected")
|
self._is_connected = state.get("connected")
|
||||||
self._last_position = state.get("position")
|
self._last_position = state.get("position")
|
||||||
|
|
||||||
|
|
@ -366,7 +366,7 @@ class Player(VoiceProtocol):
|
||||||
data = {
|
data = {
|
||||||
"encodedTrack": search.track_id,
|
"encodedTrack": search.track_id,
|
||||||
"position": str(start),
|
"position": str(start),
|
||||||
"endTime": str(end)
|
"endTime": str(track.length)
|
||||||
}
|
}
|
||||||
track.original = search
|
track.original = search
|
||||||
track.track_id = search.track_id
|
track.track_id = search.track_id
|
||||||
|
|
@ -375,7 +375,7 @@ class Player(VoiceProtocol):
|
||||||
data = {
|
data = {
|
||||||
"encodedTrack": track.track_id,
|
"encodedTrack": track.track_id,
|
||||||
"position": str(start),
|
"position": str(start),
|
||||||
"endTime": str(end)
|
"endTime": str(track.length)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -401,6 +401,11 @@ class Player(VoiceProtocol):
|
||||||
for filter in track.filters:
|
for filter in track.filters:
|
||||||
await self.add_filter(filter=filter)
|
await self.add_filter(filter=filter)
|
||||||
|
|
||||||
|
# Lavalink v4 changed the way the end time parameter works
|
||||||
|
# so now the end time cannot be zero.
|
||||||
|
# If it isnt zero, it'll match the length of the track,
|
||||||
|
# otherwise itll be set here:
|
||||||
|
|
||||||
if end > 0:
|
if end > 0:
|
||||||
data["endTime"] = str(end)
|
data["endTime"] = str(end)
|
||||||
|
|
||||||
|
|
|
||||||
123
pomice/pool.py
123
pomice/pool.py
|
|
@ -3,13 +3,13 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
import logging
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from discord import Client
|
from discord import Client
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
__version__,
|
__version__,
|
||||||
|
|
@ -37,6 +37,8 @@ from .routeplanner import RoutePlanner
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .player import Player
|
from .player import Player
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
"""The base class for a node.
|
"""The base class for a node.
|
||||||
|
|
@ -48,7 +50,7 @@ class Node:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
pool,
|
pool: NodePool,
|
||||||
bot: Union[Client, commands.Bot],
|
bot: Union[Client, commands.Bot],
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
|
|
@ -59,7 +61,8 @@ class Node:
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
spotify_client_id: Optional[str] = None,
|
spotify_client_id: Optional[str] = None,
|
||||||
spotify_client_secret: Optional[str] = None,
|
spotify_client_secret: Optional[str] = None,
|
||||||
apple_music: bool = False
|
apple_music: bool = False,
|
||||||
|
fallback: bool = False
|
||||||
|
|
||||||
):
|
):
|
||||||
self._bot = bot
|
self._bot = bot
|
||||||
|
|
@ -70,6 +73,7 @@ class Node:
|
||||||
self._identifier = identifier
|
self._identifier = identifier
|
||||||
self._heartbeat = heartbeat
|
self._heartbeat = heartbeat
|
||||||
self._secure = secure
|
self._secure = secure
|
||||||
|
self.fallback = fallback
|
||||||
|
|
||||||
|
|
||||||
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
|
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
|
||||||
|
|
@ -100,11 +104,11 @@ class Node:
|
||||||
|
|
||||||
if self._spotify_client_id and self._spotify_client_secret:
|
if self._spotify_client_id and self._spotify_client_secret:
|
||||||
self._spotify_client = spotify.Client(
|
self._spotify_client = spotify.Client(
|
||||||
self._spotify_client_id, self._spotify_client_secret
|
self, self._spotify_client_id, self._spotify_client_secret
|
||||||
)
|
)
|
||||||
|
|
||||||
if apple_music:
|
if apple_music:
|
||||||
self._apple_music_client = applemusic.Client()
|
self._apple_music_client = applemusic.Client(self)
|
||||||
|
|
||||||
self._bot.add_listener(self._update_handler, "on_socket_response")
|
self._bot.add_listener(self._update_handler, "on_socket_response")
|
||||||
|
|
||||||
|
|
@ -165,6 +169,7 @@ class Node:
|
||||||
|
|
||||||
if data["t"] == "VOICE_SERVER_UPDATE":
|
if data["t"] == "VOICE_SERVER_UPDATE":
|
||||||
guild_id = int(data["d"]["guild_id"])
|
guild_id = int(data["d"]["guild_id"])
|
||||||
|
_log.debug(f"Recieved voice server update message from guild ID: {guild_id}")
|
||||||
try:
|
try:
|
||||||
player = self._players[guild_id]
|
player = self._players[guild_id]
|
||||||
await player.on_voice_server_update(data["d"])
|
await player.on_voice_server_update(data["d"])
|
||||||
|
|
@ -176,6 +181,7 @@ class Node:
|
||||||
return
|
return
|
||||||
|
|
||||||
guild_id = int(data["d"]["guild_id"])
|
guild_id = int(data["d"]["guild_id"])
|
||||||
|
_log.debug(f"Recieved voice state update message from guild ID: {guild_id}")
|
||||||
try:
|
try:
|
||||||
player = self._players[guild_id]
|
player = self._players[guild_id]
|
||||||
await player.on_voice_state_update(data["d"])
|
await player.on_voice_state_update(data["d"])
|
||||||
|
|
@ -187,13 +193,13 @@ class Node:
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
msg = await self._websocket.receive()
|
msg = await self._websocket.receive()
|
||||||
if msg.type == aiohttp.WSMsgType.CLOSED:
|
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
||||||
retry = backoff.delay()
|
retry = backoff.delay()
|
||||||
await asyncio.sleep(retry)
|
await asyncio.sleep(retry)
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
self._bot.loop.create_task(self.connect())
|
asyncio.create_task(self.connect())
|
||||||
else:
|
else:
|
||||||
self._bot.loop.create_task(self._handle_payload(msg.json()))
|
asyncio.create_task(self._handle_payload(msg.json()))
|
||||||
|
|
||||||
async def _handle_payload(self, data: dict):
|
async def _handle_payload(self, data: dict):
|
||||||
op = data.get("op", None)
|
op = data.get("op", None)
|
||||||
|
|
@ -216,32 +222,6 @@ class Node:
|
||||||
elif op == "playerUpdate":
|
elif op == "playerUpdate":
|
||||||
await player._update_state(data)
|
await player._update_state(data)
|
||||||
|
|
||||||
def _get_type(self, query: str):
|
|
||||||
if match := URLRegex.LAVALINK_SEARCH.match(query):
|
|
||||||
type = match.group("type")
|
|
||||||
if type == "sc":
|
|
||||||
return TrackType.SOUNDCLOUD
|
|
||||||
|
|
||||||
return TrackType.YOUTUBE
|
|
||||||
|
|
||||||
|
|
||||||
elif URLRegex.YOUTUBE_URL.match(query):
|
|
||||||
if URLRegex.YOUTUBE_PLAYLIST_URL.match(query):
|
|
||||||
return PlaylistType.YOUTUBE
|
|
||||||
|
|
||||||
return TrackType.YOUTUBE
|
|
||||||
|
|
||||||
elif URLRegex.SOUNDCLOUD_URL.match(query):
|
|
||||||
if URLRegex.SOUNDCLOUD_TRACK_IN_SET_URL.match(query):
|
|
||||||
return TrackType.SOUNDCLOUD
|
|
||||||
if URLRegex.SOUNDCLOUD_PLAYLIST_URL.match(query):
|
|
||||||
return PlaylistType.SOUNDCLOUD
|
|
||||||
|
|
||||||
return TrackType.SOUNDCLOUD
|
|
||||||
|
|
||||||
else:
|
|
||||||
return TrackType.HTTP
|
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
|
@ -249,9 +229,10 @@ class Node:
|
||||||
include_version: bool = True,
|
include_version: bool = True,
|
||||||
guild_id: Optional[Union[int, str]] = None,
|
guild_id: Optional[Union[int, str]] = None,
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
data: Optional[Union[dict, str]] = None
|
data: Optional[Union[dict, str]] = None,
|
||||||
|
ignore_if_available: bool = False,
|
||||||
):
|
):
|
||||||
if not self._available:
|
if not ignore_if_available and not self._available:
|
||||||
raise NodeNotAvailable(
|
raise NodeNotAvailable(
|
||||||
f"The node '{self._identifier}' is unavailable."
|
f"The node '{self._identifier}' is unavailable."
|
||||||
)
|
)
|
||||||
|
|
@ -264,7 +245,8 @@ class Node:
|
||||||
|
|
||||||
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp:
|
async with self._session.request(method=method, url=uri, headers=self._headers, json=data or {}) as resp:
|
||||||
if resp.status >= 300:
|
if resp.status >= 300:
|
||||||
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}')
|
data: dict = await resp.json()
|
||||||
|
raise NodeRestException(f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}')
|
||||||
|
|
||||||
if method == "DELETE" or resp.status == 204:
|
if method == "DELETE" or resp.status == 204:
|
||||||
return await resp.json(content_type=None)
|
return await resp.json(content_type=None)
|
||||||
|
|
@ -285,34 +267,42 @@ class Node:
|
||||||
await self._bot.wait_until_ready()
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._websocket = await self._session.ws_connect(
|
version = await self.send(method="GET", path="version", ignore_if_available=True, include_version=False)
|
||||||
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
|
|
||||||
)
|
|
||||||
self._task = self._bot.loop.create_task(self._listen())
|
|
||||||
self._available = True
|
|
||||||
version = await self.send(method="GET", path="version", include_version=False)
|
|
||||||
version = version.replace(".", "")
|
version = version.replace(".", "")
|
||||||
if int(version) < 370:
|
if not version.endswith('-SNAPSHOT') and int(version) < 370:
|
||||||
|
self._available = False
|
||||||
raise LavalinkVersionIncompatible(
|
raise LavalinkVersionIncompatible(
|
||||||
"The Lavalink version you're using is incompatible."
|
"The Lavalink version you're using is incompatible. "
|
||||||
"Lavalink version 3.7.0 or above is required to use this library."
|
"Lavalink version 3.7.0 or above is required to use this library."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._websocket = await self._session.ws_connect(
|
||||||
|
self._websocket_uri, headers=self._headers, heartbeat=self._heartbeat
|
||||||
|
)
|
||||||
|
if not self._task:
|
||||||
|
self._task = asyncio.create_task(self._listen())
|
||||||
|
|
||||||
|
self._available = True
|
||||||
|
if version.endswith('-SNAPSHOT'):
|
||||||
|
# we're just gonna assume all snapshot versions correlate with v4
|
||||||
|
self._version = 4
|
||||||
|
else:
|
||||||
self._version = version[:1]
|
self._version = version[:1]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
except aiohttp.ClientConnectorError:
|
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The connection to node '{self._identifier}' failed."
|
f"The connection to node '{self._identifier}' failed."
|
||||||
)
|
) from None
|
||||||
except aiohttp.WSServerHandshakeError:
|
except aiohttp.WSServerHandshakeError:
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The password for node '{self._identifier}' is invalid."
|
f"The password for node '{self._identifier}' is invalid."
|
||||||
)
|
) from None
|
||||||
except aiohttp.InvalidURL:
|
except aiohttp.InvalidURL:
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The URI for node '{self._identifier}' is invalid."
|
f"The URI for node '{self._identifier}' is invalid."
|
||||||
)
|
) from None
|
||||||
|
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Disconnects a connected Lavalink node and removes it from the node pool.
|
"""Disconnects a connected Lavalink node and removes it from the node pool.
|
||||||
|
|
@ -322,7 +312,8 @@ class Node:
|
||||||
await player.destroy()
|
await player.destroy()
|
||||||
|
|
||||||
await self._websocket.close()
|
await self._websocket.close()
|
||||||
del self._pool.nodes[self._identifier]
|
await self._session.close()
|
||||||
|
del self._pool._nodes[self._identifier]
|
||||||
self.available = False
|
self.available = False
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
|
|
@ -338,8 +329,6 @@ class Node:
|
||||||
Context object on the track it builds.
|
Context object on the track it builds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}")
|
data: dict = await self.send(method="GET", path="decodetrack", query=f"encodedTrack={identifier}")
|
||||||
return Track(track_id=identifier, ctx=ctx, info=data)
|
return Track(track_id=identifier, ctx=ctx, info=data)
|
||||||
|
|
||||||
|
|
@ -537,8 +526,6 @@ class Node:
|
||||||
|
|
||||||
load_type = data.get("loadType")
|
load_type = data.get("loadType")
|
||||||
|
|
||||||
query_type = self._get_type(query)
|
|
||||||
|
|
||||||
if not load_type:
|
if not load_type:
|
||||||
raise TrackLoadError("There was an error while trying to load this track.")
|
raise TrackLoadError("There was an error while trying to load this track.")
|
||||||
|
|
||||||
|
|
@ -550,19 +537,14 @@ class Node:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif load_type == "PLAYLIST_LOADED":
|
elif load_type == "PLAYLIST_LOADED":
|
||||||
if query_type == PlaylistType.SOUNDCLOUD:
|
|
||||||
track_type = TrackType.SOUNDCLOUD
|
|
||||||
else:
|
|
||||||
track_type = TrackType.YOUTUBE
|
|
||||||
|
|
||||||
tracks = [
|
tracks = [
|
||||||
Track(track_id=track["track"], info=track["info"], ctx=ctx, track_type=track_type)
|
Track(track_id=track["encoded"], info=track["info"], ctx=ctx, track_type=TrackType(track["info"]["sourceName"]))
|
||||||
for track in data["tracks"]
|
for track in data["tracks"]
|
||||||
]
|
]
|
||||||
return Playlist(
|
return Playlist(
|
||||||
playlist_info=data["playlistInfo"],
|
playlist_info=data["playlistInfo"],
|
||||||
tracks=tracks,
|
tracks=tracks,
|
||||||
playlist_type=query_type,
|
playlist_type=PlaylistType(tracks[0].track_type.value),
|
||||||
thumbnail=tracks[0].thumbnail,
|
thumbnail=tracks[0].thumbnail,
|
||||||
uri=query
|
uri=query
|
||||||
)
|
)
|
||||||
|
|
@ -570,10 +552,10 @@ class Node:
|
||||||
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
|
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
|
||||||
return [
|
return [
|
||||||
Track(
|
Track(
|
||||||
track_id=track["track"],
|
track_id=track["encoded"],
|
||||||
info=track["info"],
|
info=track["info"],
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=query_type,
|
track_type=TrackType(track["info"]["sourceName"]),
|
||||||
filters=filters,
|
filters=filters,
|
||||||
timestamp=timestamp
|
timestamp=timestamp
|
||||||
)
|
)
|
||||||
|
|
@ -629,7 +611,7 @@ class NodePool:
|
||||||
This holds all the nodes that are to be used by the bot.
|
This holds all the nodes that are to be used by the bot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_nodes: dict = {}
|
_nodes: Dict[str, Node] = {}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Pomice.NodePool node_count={self.node_count}>"
|
return f"<Pomice.NodePool node_count={self.node_count}>"
|
||||||
|
|
@ -658,7 +640,7 @@ class NodePool:
|
||||||
based on how players it has. This method will return a node with
|
based on how players it has. This method will return a node with
|
||||||
the least amount of players
|
the least amount of players
|
||||||
"""
|
"""
|
||||||
available_nodes = [node for node in cls._nodes.values() if node._available]
|
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
||||||
|
|
||||||
if not available_nodes:
|
if not available_nodes:
|
||||||
raise NoNodesAvailable("There are no nodes available.")
|
raise NoNodesAvailable("There are no nodes available.")
|
||||||
|
|
@ -704,7 +686,8 @@ class NodePool:
|
||||||
spotify_client_id: Optional[str] = None,
|
spotify_client_id: Optional[str] = None,
|
||||||
spotify_client_secret: Optional[str] = None,
|
spotify_client_secret: Optional[str] = None,
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
apple_music: bool = False
|
apple_music: bool = False,
|
||||||
|
fallback: bool = False
|
||||||
|
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Creates a Node object to be then added into the node pool.
|
"""Creates a Node object to be then added into the node pool.
|
||||||
|
|
@ -718,7 +701,7 @@ class NodePool:
|
||||||
identifier=identifier, secure=secure, heartbeat=heartbeat,
|
identifier=identifier, secure=secure, heartbeat=heartbeat,
|
||||||
spotify_client_id=spotify_client_id,
|
spotify_client_id=spotify_client_id,
|
||||||
session=session, spotify_client_secret=spotify_client_secret,
|
session=session, spotify_client_secret=spotify_client_secret,
|
||||||
apple_music=apple_music
|
apple_music=apple_music, fallback=fallback
|
||||||
)
|
)
|
||||||
|
|
||||||
await node.connect()
|
await node.connect()
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,18 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import orjson as json
|
import orjson as json
|
||||||
|
|
||||||
|
from base64 import b64encode
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
|
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
|
||||||
from .objects import *
|
from .objects import *
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..pool import Node
|
||||||
|
|
||||||
|
|
||||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||||
SPOTIFY_URL_REGEX = re.compile(
|
SPOTIFY_URL_REGEX = re.compile(
|
||||||
|
|
@ -22,11 +26,12 @@ class Client:
|
||||||
for any Spotify URL you throw at it.
|
for any Spotify URL you throw at it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
def __init__(self, node: Node, client_id: str, client_secret: str) -> None:
|
||||||
self._client_id = client_id
|
self._client_id = client_id
|
||||||
self._client_secret = client_secret
|
self._client_secret = client_secret
|
||||||
|
self.node = node
|
||||||
|
|
||||||
self.session = aiohttp.ClientSession()
|
self.session = self.node._session
|
||||||
|
|
||||||
self._bearer_token: str = None
|
self._bearer_token: str = None
|
||||||
self._expiry = 0
|
self._expiry = 0
|
||||||
|
|
|
||||||
|
|
@ -120,8 +120,6 @@ class RouteStats:
|
||||||
self.block_index = details.get("blockIndex")
|
self.block_index = details.get("blockIndex")
|
||||||
self.address_index = details.get("currentAddressIndex")
|
self.address_index = details.get("currentAddressIndex")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>"
|
return f"<Pomice.RouteStats route_strategy={self.strategy!r} failing_addresses={len(self.failing_addresses)}>"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue