Compare commits
No commits in common. "main" and "2.8.0" have entirely different histories.
|
|
@ -10,7 +10,6 @@ build/
|
||||||
Pipfile.lock
|
Pipfile.lock
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
|
||||||
.venv/
|
.venv/
|
||||||
*.code-workspace
|
*.code-workspace
|
||||||
*.ini
|
*.ini
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ repos:
|
||||||
rev: 23.10.1
|
rev: 23.10.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
language_version: python3.13
|
language_version: python3.11
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.15.0
|
rev: v3.15.0
|
||||||
hooks:
|
hooks:
|
||||||
|
|
@ -28,6 +28,10 @@ repos:
|
||||||
rev: v3.1.0
|
rev: v3.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: add-trailing-comma
|
- id: add-trailing-comma
|
||||||
|
- repo: https://github.com/hadialqattan/pycln
|
||||||
|
rev: v2.3.0
|
||||||
|
hooks:
|
||||||
|
- id: pycln
|
||||||
|
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.13
|
python: python3.11
|
||||||
|
|
|
||||||
1
Pipfile
1
Pipfile
|
|
@ -6,7 +6,6 @@ name = "pypi"
|
||||||
[packages]
|
[packages]
|
||||||
orjson = "*"
|
orjson = "*"
|
||||||
"discord.py" = {extras = ["voice"], version = "*"}
|
"discord.py" = {extras = ["voice"], version = "*"}
|
||||||
websockets = "*"
|
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
mypy = "*"
|
mypy = "*"
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ Pomice
|
||||||
~~~~~~
|
~~~~~~
|
||||||
The modern Lavalink wrapper designed for discord.py.
|
The modern Lavalink wrapper designed for discord.py.
|
||||||
|
|
||||||
Copyright (c) 2024, cloudwithax
|
Copyright (c) 2023, cloudwithax
|
||||||
|
|
||||||
Licensed under GPL-3.0
|
Licensed under GPL-3.0
|
||||||
"""
|
"""
|
||||||
|
|
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
|
||||||
"using 'pip install discord.py'",
|
"using 'pip install discord.py'",
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "2.10.0"
|
__version__ = "2.8.0"
|
||||||
__title__ = "pomice"
|
__title__ = "pomice"
|
||||||
__author__ = "cloudwithax"
|
__author__ = "cloudwithax"
|
||||||
__license__ = "GPL-3.0"
|
__license__ = "GPL-3.0"
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,11 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -20,10 +17,10 @@ from .objects import *
|
||||||
__all__ = ("Client",)
|
__all__ = ("Client",)
|
||||||
|
|
||||||
AM_URL_REGEX = re.compile(
|
AM_URL_REGEX = re.compile(
|
||||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||||
)
|
)
|
||||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||||
)
|
)
|
||||||
|
|
||||||
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
|
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
|
||||||
|
|
@ -38,14 +35,12 @@ class Client:
|
||||||
and translating it to a valid Lavalink track. No client auth is required here.
|
and translating it to a valid Lavalink track. No client auth is required here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, playlist_concurrency: int = 6) -> None:
|
def __init__(self) -> None:
|
||||||
self.expiry: datetime = datetime(1970, 1, 1)
|
self.expiry: datetime = datetime(1970, 1, 1)
|
||||||
self.token: str = ""
|
self.token: str = ""
|
||||||
self.headers: Dict[str, str] = {}
|
self.headers: Dict[str, str] = {}
|
||||||
self.session: aiohttp.ClientSession = None # type: ignore
|
self.session: aiohttp.ClientSession = None # type: ignore
|
||||||
self._log = logging.getLogger(__name__)
|
self._log = logging.getLogger(__name__)
|
||||||
# Concurrency knob for parallel playlist page retrieval
|
|
||||||
self._playlist_concurrency = max(1, playlist_concurrency)
|
|
||||||
|
|
||||||
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
@ -101,8 +96,7 @@ class Client:
|
||||||
).decode()
|
).decode()
|
||||||
token_data = json.loads(token_json)
|
token_data = json.loads(token_json)
|
||||||
self.expiry = datetime.fromtimestamp(token_data["exp"])
|
self.expiry = datetime.fromtimestamp(token_data["exp"])
|
||||||
if self._log:
|
self._log.debug(f"Fetched Apple Music bearer token successfully")
|
||||||
self._log.debug(f"Fetched Apple Music bearer token successfully")
|
|
||||||
|
|
||||||
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
|
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
|
||||||
if not self.token or datetime.utcnow() > self.expiry:
|
if not self.token or datetime.utcnow() > self.expiry:
|
||||||
|
|
@ -136,10 +130,9 @@ class Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Made request to Apple Music API with status {resp.status} and response {data}",
|
||||||
f"Made request to Apple Music API with status {resp.status} and response {data}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
data = data["data"][0]
|
data = data["data"][0]
|
||||||
|
|
||||||
|
|
@ -172,127 +165,25 @@ class Client:
|
||||||
"This playlist is empty and therefore cannot be queued.",
|
"This playlist is empty and therefore cannot be queued.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
|
_next = track_data.get("next")
|
||||||
# concurrently by first collecting cursors in rolling waves.
|
if _next:
|
||||||
next_cursor = track_data.get("next")
|
next_page_url = AM_BASE_URL + _next
|
||||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
|
||||||
|
while next_page_url is not None:
|
||||||
|
resp = await self.session.get(next_page_url, headers=self.headers)
|
||||||
|
|
||||||
async def fetch_page(url: str) -> List[Song]:
|
|
||||||
async with semaphore:
|
|
||||||
resp = await self.session.get(url, headers=self.headers)
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
if self._log:
|
raise AppleMusicRequestException(
|
||||||
self._log.warning(
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
|
)
|
||||||
)
|
|
||||||
return []
|
|
||||||
pj: dict = await resp.json(loads=json.loads)
|
|
||||||
songs = [Song(track) for track in pj.get("data", [])]
|
|
||||||
# Return songs; we will look for pj.get('next') in streaming iterator variant
|
|
||||||
return songs, pj.get("next") # type: ignore
|
|
||||||
|
|
||||||
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
|
next_data: dict = await resp.json(loads=json.loads)
|
||||||
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
|
album_tracks.extend(Song(track) for track in next_data["data"])
|
||||||
waves: List[List[Song]] = []
|
|
||||||
cursors: List[str] = []
|
|
||||||
if next_cursor:
|
|
||||||
cursors.append(next_cursor)
|
|
||||||
|
|
||||||
# Limit total waves to avoid infinite loops in malformed responses
|
_next = next_data.get("next")
|
||||||
max_waves = 50
|
if _next:
|
||||||
wave_size = self._playlist_concurrency * 2
|
next_page_url = AM_BASE_URL + _next
|
||||||
wave_counter = 0
|
else:
|
||||||
while cursors and wave_counter < max_waves:
|
next_page_url = None
|
||||||
current = cursors[:wave_size]
|
|
||||||
cursors = cursors[wave_size:]
|
|
||||||
tasks = [
|
|
||||||
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
|
|
||||||
]
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
for res in results:
|
|
||||||
if isinstance(res, tuple): # (songs, next)
|
|
||||||
songs, nxt = res
|
|
||||||
if songs:
|
|
||||||
waves.append(songs)
|
|
||||||
if nxt:
|
|
||||||
cursors.append(nxt)
|
|
||||||
wave_counter += 1
|
|
||||||
|
|
||||||
for w in waves:
|
|
||||||
album_tracks.extend(w)
|
|
||||||
|
|
||||||
return Playlist(data, album_tracks)
|
return Playlist(data, album_tracks)
|
||||||
|
|
||||||
async def iter_playlist_tracks(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
query: str,
|
|
||||||
batch_size: int = 100,
|
|
||||||
) -> AsyncGenerator[List[Song], None]:
|
|
||||||
"""Stream Apple Music playlist tracks in batches.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query: str
|
|
||||||
Apple Music playlist URL.
|
|
||||||
batch_size: int
|
|
||||||
Logical grouping size for yielded batches.
|
|
||||||
"""
|
|
||||||
if not self.token or datetime.utcnow() > self.expiry:
|
|
||||||
await self.request_token()
|
|
||||||
|
|
||||||
result = AM_URL_REGEX.match(query)
|
|
||||||
if not result or result.group("type") != "playlist":
|
|
||||||
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
|
|
||||||
|
|
||||||
country = result.group("country")
|
|
||||||
playlist_id = result.group("id")
|
|
||||||
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
|
|
||||||
resp = await self.session.get(request_url, headers=self.headers)
|
|
||||||
if resp.status != 200:
|
|
||||||
raise AppleMusicRequestException(
|
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
|
||||||
)
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
|
||||||
playlist_data = data["data"][0]
|
|
||||||
track_data: dict = playlist_data["relationships"]["tracks"]
|
|
||||||
|
|
||||||
first_page_tracks = [Song(track) for track in track_data["data"]]
|
|
||||||
for i in range(0, len(first_page_tracks), batch_size):
|
|
||||||
yield first_page_tracks[i : i + batch_size]
|
|
||||||
|
|
||||||
next_cursor = track_data.get("next")
|
|
||||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
|
||||||
|
|
||||||
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
|
|
||||||
url = AM_BASE_URL + cursor
|
|
||||||
async with semaphore:
|
|
||||||
r = await self.session.get(url, headers=self.headers)
|
|
||||||
if r.status != 200:
|
|
||||||
if self._log:
|
|
||||||
self._log.warning(
|
|
||||||
f"Skipping Apple Music page due to {r.status} {r.reason}",
|
|
||||||
)
|
|
||||||
return [], None
|
|
||||||
pj: dict = await r.json(loads=json.loads)
|
|
||||||
songs = [Song(track) for track in pj.get("data", [])]
|
|
||||||
return songs, pj.get("next")
|
|
||||||
|
|
||||||
# Rolling waves of fetches following cursor chain
|
|
||||||
max_waves = 50
|
|
||||||
wave_size = self._playlist_concurrency * 2
|
|
||||||
waves = 0
|
|
||||||
cursors: List[str] = []
|
|
||||||
if next_cursor:
|
|
||||||
cursors.append(next_cursor)
|
|
||||||
while cursors and waves < max_waves:
|
|
||||||
current = cursors[:wave_size]
|
|
||||||
cursors = cursors[wave_size:]
|
|
||||||
results = await asyncio.gather(*[fetch(c) for c in current])
|
|
||||||
for songs, nxt in results:
|
|
||||||
if songs:
|
|
||||||
for j in range(0, len(songs), batch_size):
|
|
||||||
yield songs[j : j + batch_size]
|
|
||||||
if nxt:
|
|
||||||
cursors.append(nxt)
|
|
||||||
waves += 1
|
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,6 @@ class SearchType(Enum):
|
||||||
ytsearch = "ytsearch"
|
ytsearch = "ytsearch"
|
||||||
ytmsearch = "ytmsearch"
|
ytmsearch = "ytmsearch"
|
||||||
scsearch = "scsearch"
|
scsearch = "scsearch"
|
||||||
other = "other"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
|
|
||||||
return cls.other
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -59,8 +54,6 @@ class TrackType(Enum):
|
||||||
TrackType.HTTP defines that the track is from an HTTP source.
|
TrackType.HTTP defines that the track is from an HTTP source.
|
||||||
|
|
||||||
TrackType.LOCAL defines that the track is from a local source.
|
TrackType.LOCAL defines that the track is from a local source.
|
||||||
|
|
||||||
TrackType.OTHER defines that the track is from an unknown source (possible from 3rd-party plugins).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
|
|
@ -70,11 +63,6 @@ class TrackType(Enum):
|
||||||
APPLE_MUSIC = "apple_music"
|
APPLE_MUSIC = "apple_music"
|
||||||
HTTP = "http"
|
HTTP = "http"
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
|
|
||||||
return cls.OTHER
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -91,8 +79,6 @@ class PlaylistType(Enum):
|
||||||
PlaylistType.SPOTIFY defines that the playlist is from Spotify
|
PlaylistType.SPOTIFY defines that the playlist is from Spotify
|
||||||
|
|
||||||
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
|
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
|
||||||
|
|
||||||
PlaylistType.OTHER defines that the playlist is from an unknown source (possible from 3rd-party plugins).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
|
|
@ -100,11 +86,6 @@ class PlaylistType(Enum):
|
||||||
SOUNDCLOUD = "soundcloud"
|
SOUNDCLOUD = "soundcloud"
|
||||||
SPOTIFY = "spotify"
|
SPOTIFY = "spotify"
|
||||||
APPLE_MUSIC = "apple_music"
|
APPLE_MUSIC = "apple_music"
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
|
|
||||||
return cls.OTHER
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -218,12 +199,8 @@ class URLRegex:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
|
|
||||||
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
|
|
||||||
# the type and id while ignoring extra parameters. This prevents the URL from being
|
|
||||||
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
|
|
||||||
SPOTIFY_URL = re.compile(
|
SPOTIFY_URL = re.compile(
|
||||||
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||||
)
|
)
|
||||||
|
|
||||||
DISCORD_MP3_URL = re.compile(
|
DISCORD_MP3_URL = re.compile(
|
||||||
|
|
@ -244,17 +221,14 @@ class URLRegex:
|
||||||
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
|
|
||||||
# Allow arbitrary query parameters so valid links are captured and parsed.
|
|
||||||
AM_URL = re.compile(
|
AM_URL = re.compile(
|
||||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
|
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
|
||||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
|
|
||||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
||||||
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||||
)
|
)
|
||||||
|
|
||||||
SOUNDCLOUD_URL = re.compile(
|
SOUNDCLOUD_URL = re.compile(
|
||||||
|
|
@ -299,10 +273,3 @@ class LogLevel(IntEnum):
|
||||||
WARN = 30
|
WARN = 30
|
||||||
ERROR = 40
|
ERROR = 40
|
||||||
CRITICAL = 50
|
CRITICAL = 50
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_str(cls, level_str):
|
|
||||||
try:
|
|
||||||
return cls[level_str.upper()]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"No such log level: {level_str}")
|
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ class TrackExceptionEvent(PomiceEvent):
|
||||||
|
|
||||||
def __init__(self, data: dict, player: Player):
|
def __init__(self, data: dict, player: Player):
|
||||||
self.player: Player = player
|
self.player: Player = player
|
||||||
|
assert self.player._ending_track is not None
|
||||||
self.track: Optional[Track] = self.player._ending_track
|
self.track: Optional[Track] = self.player._ending_track
|
||||||
# Error is for Lavalink <= 3.3
|
# Error is for Lavalink <= 3.3
|
||||||
self.exception: str = data.get(
|
self.exception: str = data.get(
|
||||||
|
|
|
||||||
|
|
@ -298,8 +298,7 @@ class Player(VoiceProtocol):
|
||||||
self._last_update = int(state.get("time", 0))
|
self._last_update = int(state.get("time", 0))
|
||||||
self._is_connected = bool(state.get("connected"))
|
self._is_connected = bool(state.get("connected"))
|
||||||
self._last_position = int(state.get("position", 0))
|
self._last_position = int(state.get("position", 0))
|
||||||
if self._log:
|
self._log.debug(f"Got player update state with data {state}")
|
||||||
self._log.debug(f"Got player update state with data {state}")
|
|
||||||
|
|
||||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
||||||
if {"sessionId", "event"} != self._voice_state.keys():
|
if {"sessionId", "event"} != self._voice_state.keys():
|
||||||
|
|
@ -320,10 +319,7 @@ class Player(VoiceProtocol):
|
||||||
data={"voice": data},
|
data={"voice": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Dispatched voice update to {state['event']['endpoint']} with data {data}")
|
||||||
self._log.debug(
|
|
||||||
f"Dispatched voice update to {state['event']['endpoint']} with data {data}",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||||
self._voice_state.update({"event": data})
|
self._voice_state.update({"event": data})
|
||||||
|
|
@ -365,8 +361,7 @@ class Player(VoiceProtocol):
|
||||||
if isinstance(event, TrackStartEvent):
|
if isinstance(event, TrackStartEvent):
|
||||||
self._ending_track = self._current
|
self._ending_track = self._current
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Dispatched event {data['type']} to player.")
|
||||||
self._log.debug(f"Dispatched event {data['type']} to player.")
|
|
||||||
|
|
||||||
async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None:
|
async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None:
|
||||||
self._player_endpoint_uri = f"sessions/{session_id}/players"
|
self._player_endpoint_uri = f"sessions/{session_id}/players"
|
||||||
|
|
@ -388,15 +383,14 @@ class Player(VoiceProtocol):
|
||||||
data=data or None,
|
data=data or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
|
||||||
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
|
|
||||||
|
|
||||||
async def get_tracks(
|
async def get_tracks(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
ctx: Optional[commands.Context] = None,
|
ctx: Optional[commands.Context] = None,
|
||||||
search_type: SearchType | None = SearchType.ytsearch,
|
search_type: SearchType = SearchType.ytsearch,
|
||||||
filters: Optional[List[Filter]] = None,
|
filters: Optional[List[Filter]] = None,
|
||||||
) -> Optional[Union[List[Track], Playlist]]:
|
) -> Optional[Union[List[Track], Playlist]]:
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
@ -462,8 +456,7 @@ class Player(VoiceProtocol):
|
||||||
data={"encodedTrack": None},
|
data={"encodedTrack": None},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Player has been stopped.")
|
||||||
self._log.debug(f"Player has been stopped.")
|
|
||||||
|
|
||||||
async def disconnect(self, *, force: bool = False) -> None:
|
async def disconnect(self, *, force: bool = False) -> None:
|
||||||
"""Disconnects the player from voice."""
|
"""Disconnects the player from voice."""
|
||||||
|
|
@ -491,8 +484,7 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug("Player has been destroyed.")
|
||||||
self._log.debug("Player has been destroyed.")
|
|
||||||
|
|
||||||
async def play(
|
async def play(
|
||||||
self,
|
self,
|
||||||
|
|
@ -504,11 +496,8 @@ class Player(VoiceProtocol):
|
||||||
) -> Track:
|
) -> Track:
|
||||||
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
||||||
|
|
||||||
if not track._search_type:
|
|
||||||
track.original = track
|
|
||||||
|
|
||||||
# Make sure we've never searched the track before
|
# Make sure we've never searched the track before
|
||||||
if track._search_type and track.original is None:
|
if track.original is None:
|
||||||
# First lets try using the tracks ISRC, every track has one (hopefully)
|
# First lets try using the tracks ISRC, every track has one (hopefully)
|
||||||
try:
|
try:
|
||||||
if not track.isrc:
|
if not track.isrc:
|
||||||
|
|
@ -588,10 +577,9 @@ class Player(VoiceProtocol):
|
||||||
query=f"noReplace={ignore_if_playing}",
|
query=f"noReplace={ignore_if_playing}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
|
||||||
f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return self._current
|
return self._current
|
||||||
|
|
||||||
|
|
@ -612,8 +600,7 @@ class Player(VoiceProtocol):
|
||||||
data={"position": position},
|
data={"position": position},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Seeking to {position}.")
|
||||||
self._log.debug(f"Seeking to {position}.")
|
|
||||||
return self.position
|
return self.position
|
||||||
|
|
||||||
async def set_pause(self, pause: bool) -> bool:
|
async def set_pause(self, pause: bool) -> bool:
|
||||||
|
|
@ -626,8 +613,7 @@ class Player(VoiceProtocol):
|
||||||
)
|
)
|
||||||
self._paused = pause
|
self._paused = pause
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
|
||||||
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
|
|
||||||
return self._paused
|
return self._paused
|
||||||
|
|
||||||
async def set_volume(self, volume: int) -> int:
|
async def set_volume(self, volume: int) -> int:
|
||||||
|
|
@ -640,8 +626,7 @@ class Player(VoiceProtocol):
|
||||||
)
|
)
|
||||||
self._volume = volume
|
self._volume = volume
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Player volume has been adjusted to {volume}")
|
||||||
self._log.debug(f"Player volume has been adjusted to {volume}")
|
|
||||||
return self._volume
|
return self._volume
|
||||||
|
|
||||||
async def move_to(self, channel: VoiceChannel) -> None:
|
async def move_to(self, channel: VoiceChannel) -> None:
|
||||||
|
|
@ -670,11 +655,9 @@ class Player(VoiceProtocol):
|
||||||
data={"filters": payload},
|
data={"filters": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
|
||||||
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
|
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
if self._log:
|
self._log.debug(f"Fast apply passed, now applying filter instantly.")
|
||||||
self._log.debug(f"Fast apply passed, now applying filter instantly.")
|
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
||||||
return self._filters
|
return self._filters
|
||||||
|
|
@ -695,11 +678,9 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"filters": payload},
|
data={"filters": payload},
|
||||||
)
|
)
|
||||||
if self._log:
|
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
|
||||||
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
|
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
if self._log:
|
self._log.debug(f"Fast apply passed, now removing filter instantly.")
|
||||||
self._log.debug(f"Fast apply passed, now removing filter instantly.")
|
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
||||||
return self._filters
|
return self._filters
|
||||||
|
|
@ -728,11 +709,9 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"filters": payload},
|
data={"filters": payload},
|
||||||
)
|
)
|
||||||
if self._log:
|
self._log.debug(f"Filter with tag {filter_tag} has been edited to {edited_filter!r}")
|
||||||
self._log.debug(f"Filter with tag {filter_tag} has been edited to {edited_filter!r}")
|
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
if self._log:
|
self._log.debug(f"Fast apply passed, now editing filter instantly.")
|
||||||
self._log.debug(f"Fast apply passed, now editing filter instantly.")
|
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
||||||
return self._filters
|
return self._filters
|
||||||
|
|
@ -756,10 +735,8 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"filters": {}},
|
data={"filters": {}},
|
||||||
)
|
)
|
||||||
if self._log:
|
self._log.debug(f"All filters have been removed from player.")
|
||||||
self._log.debug(f"All filters have been removed from player.")
|
|
||||||
|
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
if self._log:
|
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
|
||||||
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
|
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
|
||||||
276
pomice/pool.py
276
pomice/pool.py
|
|
@ -21,12 +21,7 @@ import orjson as json
|
||||||
from discord import Client
|
from discord import Client
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.utils import MISSING
|
from discord.utils import MISSING
|
||||||
|
from websockets import client
|
||||||
try:
|
|
||||||
from websockets.legacy import client # websockets >= 10.0
|
|
||||||
except ImportError:
|
|
||||||
import websockets.client as client # websockets < 10.0 # type: ignore
|
|
||||||
|
|
||||||
from websockets import exceptions
|
from websockets import exceptions
|
||||||
from websockets import typing as wstype
|
from websockets import typing as wstype
|
||||||
|
|
||||||
|
|
@ -35,7 +30,6 @@ from . import applemusic
|
||||||
from . import spotify
|
from . import spotify
|
||||||
from .enums import *
|
from .enums import *
|
||||||
from .enums import LogLevel
|
from .enums import LogLevel
|
||||||
from .exceptions import InvalidSpotifyClientAuthorization
|
|
||||||
from .exceptions import LavalinkVersionIncompatible
|
from .exceptions import LavalinkVersionIncompatible
|
||||||
from .exceptions import NodeConnectionFailure
|
from .exceptions import NodeConnectionFailure
|
||||||
from .exceptions import NodeCreationError
|
from .exceptions import NodeCreationError
|
||||||
|
|
@ -253,8 +247,7 @@ class Node:
|
||||||
int(_version_groups[2] or 0),
|
int(_version_groups[2] or 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
|
||||||
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
|
|
||||||
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
|
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
|
||||||
if self._version < LavalinkVersion(3, 7, 0):
|
if self._version < LavalinkVersion(3, 7, 0):
|
||||||
self._available = False
|
self._available = False
|
||||||
|
|
@ -308,13 +301,12 @@ class Node:
|
||||||
if not self._resume_key:
|
if not self._resume_key:
|
||||||
return
|
return
|
||||||
|
|
||||||
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
|
data = {"timeout": self._resume_timeout}
|
||||||
|
|
||||||
if self._version.major == 3:
|
if self._version.major == 3:
|
||||||
data["resumingKey"] = self._resume_key
|
data["resumingKey"] = self._resume_key
|
||||||
elif self._version.major == 4:
|
elif self._version.major == 4:
|
||||||
if self._log:
|
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
|
||||||
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
|
|
||||||
data["resuming"] = True
|
data["resuming"] = True
|
||||||
|
|
||||||
await self.send(
|
await self.send(
|
||||||
|
|
@ -329,8 +321,7 @@ class Node:
|
||||||
try:
|
try:
|
||||||
msg = await self._websocket.recv()
|
msg = await self._websocket.recv()
|
||||||
data = json.loads(msg)
|
data = json.loads(msg)
|
||||||
if self._log:
|
self._log.debug(f"Recieved raw websocket message {msg}")
|
||||||
self._log.debug(f"Recieved raw websocket message {msg}")
|
|
||||||
self._loop.create_task(self._handle_ws_msg(data=data))
|
self._loop.create_task(self._handle_ws_msg(data=data))
|
||||||
except exceptions.ConnectionClosed:
|
except exceptions.ConnectionClosed:
|
||||||
if self.player_count > 0:
|
if self.player_count > 0:
|
||||||
|
|
@ -344,18 +335,14 @@ class Node:
|
||||||
|
|
||||||
backoff = ExponentialBackoff(base=7)
|
backoff = ExponentialBackoff(base=7)
|
||||||
retry = backoff.delay()
|
retry = backoff.delay()
|
||||||
if self._log:
|
self._log.debug(f"Retrying connection to Node {self._identifier} in {retry} secs")
|
||||||
self._log.debug(
|
|
||||||
f"Retrying connection to Node {self._identifier} in {retry} secs",
|
|
||||||
)
|
|
||||||
await asyncio.sleep(retry)
|
await asyncio.sleep(retry)
|
||||||
|
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
self._loop.create_task(self.connect(reconnect=True))
|
self._loop.create_task(self.connect(reconnect=True))
|
||||||
|
|
||||||
async def _handle_ws_msg(self, data: dict) -> None:
|
async def _handle_ws_msg(self, data: dict) -> None:
|
||||||
if self._log:
|
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
|
||||||
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
|
|
||||||
op = data.get("op", None)
|
op = data.get("op", None)
|
||||||
|
|
||||||
if op == "stats":
|
if op == "stats":
|
||||||
|
|
@ -408,10 +395,9 @@ class Node:
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
json=data or {},
|
json=data or {},
|
||||||
)
|
)
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Making REST request to Node {self._identifier} with method {method} to {uri}",
|
||||||
f"Making REST request to Node {self._identifier} with method {method} to {uri}",
|
)
|
||||||
)
|
|
||||||
if resp.status >= 300:
|
if resp.status >= 300:
|
||||||
resp_data: dict = await resp.json()
|
resp_data: dict = await resp.json()
|
||||||
raise NodeRestException(
|
raise NodeRestException(
|
||||||
|
|
@ -419,23 +405,20 @@ class Node:
|
||||||
)
|
)
|
||||||
|
|
||||||
if method == "DELETE" or resp.status == 204:
|
if method == "DELETE" or resp.status == 204:
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
|
||||||
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
|
)
|
||||||
)
|
|
||||||
return await resp.json(content_type=None)
|
return await resp.json(content_type=None)
|
||||||
|
|
||||||
if resp.content_type == "text/plain":
|
if resp.content_type == "text/plain":
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
|
||||||
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
|
)
|
||||||
)
|
|
||||||
return await resp.text()
|
return await resp.text()
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
|
||||||
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
|
)
|
||||||
)
|
|
||||||
return await resp.json()
|
return await resp.json()
|
||||||
|
|
||||||
def get_player(self, guild_id: int) -> Optional[Player]:
|
def get_player(self, guild_id: int) -> Optional[Player]:
|
||||||
|
|
@ -449,17 +432,7 @@ class Node:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
if not self._session:
|
if not self._session:
|
||||||
# Configure connection pooling for optimal concurrent request performance
|
self._session = aiohttp.ClientSession()
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=100, # Total connection limit
|
|
||||||
limit_per_host=30, # Per-host connection limit
|
|
||||||
ttl_dns_cache=300, # DNS cache TTL in seconds
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
|
||||||
self._session = aiohttp.ClientSession(
|
|
||||||
connector=connector,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not reconnect:
|
if not reconnect:
|
||||||
|
|
@ -473,28 +446,25 @@ class Node:
|
||||||
await self._handle_version_check(version=version)
|
await self._handle_version_check(version=version)
|
||||||
await self._set_ext_client_session(session=self._session)
|
await self._set_ext_client_session(session=self._session)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
||||||
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self._websocket = await client.connect( # type: ignore
|
self._websocket = await client.connect(
|
||||||
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
||||||
extra_headers=self._headers,
|
extra_headers=self._headers,
|
||||||
ping_interval=self._heartbeat,
|
ping_interval=self._heartbeat,
|
||||||
)
|
)
|
||||||
|
|
||||||
if reconnect:
|
if reconnect:
|
||||||
if self._log:
|
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
|
||||||
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
|
|
||||||
if self.player_count:
|
if self.player_count:
|
||||||
for player in self.players.values():
|
for player in self.players.values():
|
||||||
await player._refresh_endpoint_uri(self._session_id)
|
await player._refresh_endpoint_uri(self._session_id)
|
||||||
|
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
|
||||||
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if not self._task:
|
if not self._task:
|
||||||
self._task = self._loop.create_task(self._listen())
|
self._task = self._loop.create_task(self._listen())
|
||||||
|
|
@ -503,8 +473,7 @@ class Node:
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
if self._log:
|
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
||||||
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
|
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
|
||||||
|
|
@ -529,23 +498,20 @@ class Node:
|
||||||
|
|
||||||
for player in self.players.copy().values():
|
for player in self.players.copy().values():
|
||||||
await player.destroy()
|
await player.destroy()
|
||||||
if self._log:
|
self._log.debug("All players disconnected from node.")
|
||||||
self._log.debug("All players disconnected from node.")
|
|
||||||
|
|
||||||
await self._websocket.close()
|
await self._websocket.close()
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
if self._log:
|
self._log.debug("Websocket and http session closed.")
|
||||||
self._log.debug("Websocket and http session closed.")
|
|
||||||
|
|
||||||
del self._pool._nodes[self._identifier]
|
del self._pool._nodes[self._identifier]
|
||||||
self.available = False
|
self.available = False
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
if self._log:
|
self._log.info(
|
||||||
self._log.info(
|
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
|
||||||
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
|
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
|
||||||
"""
|
"""
|
||||||
|
|
@ -560,14 +526,11 @@ class Node:
|
||||||
path="decodetrack",
|
path="decodetrack",
|
||||||
query=f"encodedTrack={quote(identifier)}",
|
query=f"encodedTrack={quote(identifier)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
track_info = data["info"] if self._version.major >= 4 else data
|
|
||||||
|
|
||||||
return Track(
|
return Track(
|
||||||
track_id=identifier,
|
track_id=identifier,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
info=track_info,
|
info=data,
|
||||||
track_type=TrackType(track_info["sourceName"]),
|
track_type=TrackType(data["sourceName"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_tracks(
|
async def get_tracks(
|
||||||
|
|
@ -575,7 +538,7 @@ class Node:
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
ctx: Optional[commands.Context] = None,
|
ctx: Optional[commands.Context] = None,
|
||||||
search_type: Optional[SearchType] = SearchType.ytsearch,
|
search_type: SearchType = SearchType.ytsearch,
|
||||||
filters: Optional[List[Filter]] = None,
|
filters: Optional[List[Filter]] = None,
|
||||||
) -> Optional[Union[Playlist, List[Track]]]:
|
) -> Optional[Union[Playlist, List[Track]]]:
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
@ -610,7 +573,7 @@ class Node:
|
||||||
track_id=apple_music_results.id,
|
track_id=apple_music_results.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.APPLE_MUSIC,
|
track_type=TrackType.APPLE_MUSIC,
|
||||||
search_type=search_type or SearchType.ytsearch,
|
search_type=search_type,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": apple_music_results.name,
|
"title": apple_music_results.name,
|
||||||
|
|
@ -632,7 +595,7 @@ class Node:
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.APPLE_MUSIC,
|
track_type=TrackType.APPLE_MUSIC,
|
||||||
search_type=search_type or SearchType.ytsearch,
|
search_type=search_type,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": track.name,
|
"title": track.name,
|
||||||
|
|
@ -670,7 +633,7 @@ class Node:
|
||||||
track_id=spotify_results.id,
|
track_id=spotify_results.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.SPOTIFY,
|
track_type=TrackType.SPOTIFY,
|
||||||
search_type=search_type or SearchType.ytsearch,
|
search_type=search_type,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": spotify_results.name,
|
"title": spotify_results.name,
|
||||||
|
|
@ -692,7 +655,7 @@ class Node:
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.SPOTIFY,
|
track_type=TrackType.SPOTIFY,
|
||||||
search_type=search_type or SearchType.ytsearch,
|
search_type=search_type,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": track.name,
|
"title": track.name,
|
||||||
|
|
@ -721,14 +684,63 @@ class Node:
|
||||||
uri=spotify_results.uri,
|
uri=spotify_results.uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
||||||
|
data: dict = await self.send(
|
||||||
|
method="GET",
|
||||||
|
path="loadtracks",
|
||||||
|
query=f"identifier={quote(query)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
track: dict = data["tracks"][0]
|
||||||
|
info: dict = track["info"]
|
||||||
|
|
||||||
|
return [
|
||||||
|
Track(
|
||||||
|
track_id=track["track"],
|
||||||
|
info={
|
||||||
|
"title": discord_url.group("file"),
|
||||||
|
"author": "Unknown",
|
||||||
|
"length": info["length"],
|
||||||
|
"uri": info["uri"],
|
||||||
|
"position": info["position"],
|
||||||
|
"identifier": info["identifier"],
|
||||||
|
},
|
||||||
|
ctx=ctx,
|
||||||
|
track_type=TrackType.HTTP,
|
||||||
|
filters=filters,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
elif path.exists(path.dirname(query)):
|
||||||
|
local_file = Path(query)
|
||||||
|
data: dict = await self.send( # type: ignore
|
||||||
|
method="GET",
|
||||||
|
path="loadtracks",
|
||||||
|
query=f"identifier={quote(query)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
track: dict = data["tracks"][0] # type: ignore
|
||||||
|
info: dict = track["info"] # type: ignore
|
||||||
|
|
||||||
|
return [
|
||||||
|
Track(
|
||||||
|
track_id=track["track"],
|
||||||
|
info={
|
||||||
|
"title": local_file.name,
|
||||||
|
"author": "Unknown",
|
||||||
|
"length": info["length"],
|
||||||
|
"uri": quote(local_file.as_uri()),
|
||||||
|
"position": info["position"],
|
||||||
|
"identifier": info["identifier"],
|
||||||
|
},
|
||||||
|
ctx=ctx,
|
||||||
|
track_type=TrackType.LOCAL,
|
||||||
|
filters=filters,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if (
|
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
|
||||||
search_type
|
|
||||||
and not URLRegex.BASE_URL.match(query)
|
|
||||||
and not re.match(r"(?:[a-z]+?)search:.", query)
|
|
||||||
and not URLRegex.DISCORD_MP3_URL.match(query)
|
|
||||||
and not path.exists(path.dirname(query))
|
|
||||||
):
|
|
||||||
query = f"{search_type}:{query}"
|
query = f"{search_type}:{query}"
|
||||||
|
|
||||||
# If YouTube url contains a timestamp, capture it for use later.
|
# If YouTube url contains a timestamp, capture it for use later.
|
||||||
|
|
@ -754,7 +766,7 @@ class Node:
|
||||||
)
|
)
|
||||||
|
|
||||||
elif load_type in ("LOAD_FAILED", "error"):
|
elif load_type in ("LOAD_FAILED", "error"):
|
||||||
exception = data["data"] if self._version.major >= 4 else data["exception"]
|
exception = data["exception"]
|
||||||
raise TrackLoadError(
|
raise TrackLoadError(
|
||||||
f"{exception['message']} [{exception['severity']}]",
|
f"{exception['message']} [{exception['severity']}]",
|
||||||
)
|
)
|
||||||
|
|
@ -789,47 +801,6 @@ class Node:
|
||||||
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
|
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
|
||||||
if self._version.major >= 4 and isinstance(data[data_type], dict):
|
if self._version.major >= 4 and isinstance(data[data_type], dict):
|
||||||
data[data_type] = [data[data_type]]
|
data[data_type] = [data[data_type]]
|
||||||
|
|
||||||
if path.exists(path.dirname(query)):
|
|
||||||
local_file = Path(query)
|
|
||||||
|
|
||||||
return [
|
|
||||||
Track(
|
|
||||||
track_id=track["encoded"],
|
|
||||||
info={
|
|
||||||
"title": local_file.name,
|
|
||||||
"author": "Unknown",
|
|
||||||
"length": track["info"]["length"],
|
|
||||||
"uri": quote(local_file.as_uri()),
|
|
||||||
"position": track["info"]["position"],
|
|
||||||
"identifier": track["info"]["identifier"],
|
|
||||||
},
|
|
||||||
ctx=ctx,
|
|
||||||
track_type=TrackType.LOCAL,
|
|
||||||
filters=filters,
|
|
||||||
)
|
|
||||||
for track in data[data_type]
|
|
||||||
]
|
|
||||||
|
|
||||||
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
|
||||||
return [
|
|
||||||
Track(
|
|
||||||
track_id=track["encoded"],
|
|
||||||
info={
|
|
||||||
"title": discord_url.group("file"),
|
|
||||||
"author": "Unknown",
|
|
||||||
"length": track["info"]["length"],
|
|
||||||
"uri": track["info"]["uri"],
|
|
||||||
"position": track["info"]["position"],
|
|
||||||
"identifier": track["info"]["identifier"],
|
|
||||||
},
|
|
||||||
ctx=ctx,
|
|
||||||
track_type=TrackType.HTTP,
|
|
||||||
filters=filters,
|
|
||||||
)
|
|
||||||
for track in data[data_type]
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Track(
|
Track(
|
||||||
track_id=track["encoded"],
|
track_id=track["encoded"],
|
||||||
|
|
@ -896,57 +867,6 @@ class Node:
|
||||||
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def search_spotify_recommendations(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
ctx: Optional[commands.Context] = None,
|
|
||||||
filters: Optional[List[Filter]] = None,
|
|
||||||
) -> Optional[Union[List[Track], Playlist]]:
|
|
||||||
"""
|
|
||||||
Searches for recommendations on Spotify and returns a list of tracks based on the query.
|
|
||||||
You must have Spotify enabled for this to work.
|
|
||||||
You can pass in a discord.py Context object to get a
|
|
||||||
Context object on all tracks that get recommended.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self._spotify_client:
|
|
||||||
raise InvalidSpotifyClientAuthorization(
|
|
||||||
"You must have Spotify enabled to use this feature.",
|
|
||||||
)
|
|
||||||
|
|
||||||
results = await self._spotify_client.track_search(query=query) # type: ignore
|
|
||||||
if not results:
|
|
||||||
raise TrackLoadError(
|
|
||||||
"Unable to find any tracks based on the query.",
|
|
||||||
)
|
|
||||||
|
|
||||||
tracks = [
|
|
||||||
Track(
|
|
||||||
track_id=track.id,
|
|
||||||
ctx=ctx,
|
|
||||||
track_type=TrackType.SPOTIFY,
|
|
||||||
info={
|
|
||||||
"title": track.name,
|
|
||||||
"author": track.artists,
|
|
||||||
"length": track.length,
|
|
||||||
"identifier": track.id,
|
|
||||||
"uri": track.uri,
|
|
||||||
"isStream": False,
|
|
||||||
"isSeekable": True,
|
|
||||||
"position": 0,
|
|
||||||
"thumbnail": track.image,
|
|
||||||
"isrc": track.isrc,
|
|
||||||
},
|
|
||||||
requester=self.bot.user,
|
|
||||||
)
|
|
||||||
for track in results
|
|
||||||
]
|
|
||||||
|
|
||||||
track = tracks[0]
|
|
||||||
|
|
||||||
return await self.get_recommendations(track=track, ctx=ctx)
|
|
||||||
|
|
||||||
|
|
||||||
class NodePool:
|
class NodePool:
|
||||||
"""The base class for the node pool.
|
"""The base class for the node pool.
|
||||||
|
|
|
||||||
|
|
@ -203,13 +203,9 @@ class Queue(Iterable[Track]):
|
||||||
raise QueueEmpty("No items in the queue.")
|
raise QueueEmpty("No items in the queue.")
|
||||||
|
|
||||||
if self._loop_mode == LoopMode.QUEUE:
|
if self._loop_mode == LoopMode.QUEUE:
|
||||||
# set current item to first track in queue if not set already
|
# recurse if the item isnt in the queue
|
||||||
# otherwise exception will be raised
|
if self._current_item not in self._queue:
|
||||||
if not self._current_item or self._current_item not in self._queue:
|
self.get()
|
||||||
if self._queue:
|
|
||||||
item = self._queue[0]
|
|
||||||
else:
|
|
||||||
raise QueueEmpty("No items in the queue.")
|
|
||||||
|
|
||||||
# set current item to first track in queue if not set already
|
# set current item to first track in queue if not set already
|
||||||
if not self._current_item:
|
if not self._current_item:
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from typing import AsyncGenerator
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import orjson as json
|
import orjson as json
|
||||||
|
|
@ -24,10 +21,8 @@ __all__ = ("Client",)
|
||||||
|
|
||||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||||
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
|
|
||||||
# optional trailing slash, and query parameters.
|
|
||||||
SPOTIFY_URL_REGEX = re.compile(
|
SPOTIFY_URL_REGEX = re.compile(
|
||||||
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,39 +32,29 @@ class Client:
|
||||||
for any Spotify URL you throw at it.
|
for any Spotify URL you throw at it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||||
self,
|
self._client_id: str = client_id
|
||||||
client_id: str,
|
self._client_secret: str = client_secret
|
||||||
client_secret: str,
|
|
||||||
*,
|
|
||||||
playlist_concurrency: int = 10,
|
|
||||||
playlist_page_limit: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
self._client_id = client_id
|
|
||||||
self._client_secret = client_secret
|
|
||||||
|
|
||||||
# HTTP session will be injected by Node
|
self.session: aiohttp.ClientSession = None # type: ignore
|
||||||
self.session: Optional[aiohttp.ClientSession] = None
|
|
||||||
|
|
||||||
self._bearer_token: Optional[str] = None
|
self._bearer_token: Optional[str] = None
|
||||||
self._expiry: float = 0.0
|
self._expiry: float = 0.0
|
||||||
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
|
self._auth_token = b64encode(
|
||||||
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
|
f"{self._client_id}:{self._client_secret}".encode(),
|
||||||
|
)
|
||||||
|
self._grant_headers = {
|
||||||
|
"Authorization": f"Basic {self._auth_token.decode()}",
|
||||||
|
}
|
||||||
self._bearer_headers: Optional[Dict] = None
|
self._bearer_headers: Optional[Dict] = None
|
||||||
self._log = logging.getLogger(__name__)
|
self._log = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Performance tuning knobs
|
|
||||||
self._playlist_concurrency = max(1, playlist_concurrency)
|
|
||||||
self._playlist_page_limit = playlist_page_limit
|
|
||||||
|
|
||||||
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
async def _fetch_bearer_token(self) -> None:
|
async def _fetch_bearer_token(self) -> None:
|
||||||
_data = {"grant_type": "client_credentials"}
|
_data = {"grant_type": "client_credentials"}
|
||||||
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
|
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
|
|
@ -78,8 +63,7 @@ class Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
if self._log:
|
self._log.debug(f"Fetched Spotify bearer token successfully")
|
||||||
self._log.debug(f"Fetched Spotify bearer token successfully")
|
|
||||||
|
|
||||||
self._bearer_token = data["access_token"]
|
self._bearer_token = data["access_token"]
|
||||||
self._expiry = time.time() + (int(data["expires_in"]) - 10)
|
self._expiry = time.time() + (int(data["expires_in"]) - 10)
|
||||||
|
|
@ -100,8 +84,6 @@ class Client:
|
||||||
|
|
||||||
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
||||||
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
|
|
@ -109,18 +91,15 @@ class Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
if self._log:
|
self._log.debug(
|
||||||
self._log.debug(
|
f"Made request to Spotify API with status {resp.status} and response {data}",
|
||||||
f"Made request to Spotify API with status {resp.status} and response {data}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if spotify_type == "track":
|
if spotify_type == "track":
|
||||||
return Track(data)
|
return Track(data)
|
||||||
elif spotify_type == "album":
|
elif spotify_type == "album":
|
||||||
return Album(data)
|
return Album(data)
|
||||||
elif spotify_type == "artist":
|
elif spotify_type == "artist":
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.get(
|
resp = await self.session.get(
|
||||||
f"{request_url}/top-tracks?market=US",
|
f"{request_url}/top-tracks?market=US",
|
||||||
headers=self._bearer_headers,
|
headers=self._bearer_headers,
|
||||||
|
|
@ -134,178 +113,37 @@ class Client:
|
||||||
tracks = track_data["tracks"]
|
tracks = track_data["tracks"]
|
||||||
return Artist(data, tracks)
|
return Artist(data, tracks)
|
||||||
else:
|
else:
|
||||||
# For playlists we optionally use a reduced fields payload to shrink response sizes.
|
|
||||||
# NB: We cannot apply fields filter to initial request because original metadata is needed.
|
|
||||||
tracks = [
|
tracks = [
|
||||||
Track(track["track"])
|
Track(track["track"])
|
||||||
for track in data["tracks"]["items"]
|
for track in data["tracks"]["items"]
|
||||||
if track["track"] is not None
|
if track["track"] is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
if not tracks:
|
if not len(tracks):
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
"This playlist is empty and therefore cannot be queued.",
|
"This playlist is empty and therefore cannot be queued.",
|
||||||
)
|
)
|
||||||
|
|
||||||
total_tracks = data["tracks"]["total"]
|
next_page_url = data["tracks"]["next"]
|
||||||
limit = data["tracks"]["limit"]
|
|
||||||
|
|
||||||
# Short‑circuit small playlists (single page)
|
while next_page_url is not None:
|
||||||
if total_tracks <= limit:
|
resp = await self.session.get(next_page_url, headers=self._bearer_headers)
|
||||||
return Playlist(data, tracks)
|
if resp.status != 200:
|
||||||
|
raise SpotifyRequestException(
|
||||||
# Build remaining page URLs; Spotify supports offset-based pagination.
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
remaining_offsets = range(limit, total_tracks, limit)
|
|
||||||
page_urls: List[str] = []
|
|
||||||
fields_filter = (
|
|
||||||
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
|
||||||
",next"
|
|
||||||
)
|
|
||||||
for idx, offset in enumerate(remaining_offsets):
|
|
||||||
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
|
|
||||||
break
|
|
||||||
page_urls.append(
|
|
||||||
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if page_urls:
|
|
||||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
|
||||||
|
|
||||||
async def fetch_page(url: str) -> Optional[List[Track]]:
|
|
||||||
async with semaphore:
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException(
|
|
||||||
"HTTP session not initialized for Spotify client.",
|
|
||||||
)
|
|
||||||
resp = await self.session.get(url, headers=self._bearer_headers)
|
|
||||||
if resp.status != 200:
|
|
||||||
if self._log:
|
|
||||||
self._log.warning(
|
|
||||||
f"Page fetch failed {resp.status} {resp.reason} for {url}",
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
page_json: dict = await resp.json(loads=json.loads)
|
|
||||||
return [
|
|
||||||
Track(item["track"])
|
|
||||||
for item in page_json.get("items", [])
|
|
||||||
if item.get("track") is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Chunk gather in waves to avoid creating thousands of tasks at once
|
|
||||||
aggregated: List[Track] = []
|
|
||||||
wave_size = self._playlist_concurrency * 2
|
|
||||||
for i in range(0, len(page_urls), wave_size):
|
|
||||||
wave = page_urls[i : i + wave_size]
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[fetch_page(url) for url in wave],
|
|
||||||
return_exceptions=False,
|
|
||||||
)
|
)
|
||||||
for result in results:
|
|
||||||
if result:
|
|
||||||
aggregated.extend(result)
|
|
||||||
|
|
||||||
tracks.extend(aggregated)
|
next_data: dict = await resp.json(loads=json.loads)
|
||||||
|
|
||||||
|
tracks += [
|
||||||
|
Track(track["track"])
|
||||||
|
for track in next_data["items"]
|
||||||
|
if track["track"] is not None
|
||||||
|
]
|
||||||
|
next_page_url = next_data["next"]
|
||||||
|
|
||||||
return Playlist(data, tracks)
|
return Playlist(data, tracks)
|
||||||
|
|
||||||
async def iter_playlist_tracks(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
query: str,
|
|
||||||
batch_size: int = 100,
|
|
||||||
) -> AsyncGenerator[List[Track], None]:
|
|
||||||
"""Stream playlist tracks in batches without waiting for full materialization.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query: str
|
|
||||||
Spotify playlist URL.
|
|
||||||
batch_size: int
|
|
||||||
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
|
|
||||||
"""
|
|
||||||
if not self._bearer_token or time.time() >= self._expiry:
|
|
||||||
await self._fetch_bearer_token()
|
|
||||||
|
|
||||||
match = SPOTIFY_URL_REGEX.match(query)
|
|
||||||
if not match or match.group("type") != "playlist":
|
|
||||||
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
|
|
||||||
|
|
||||||
playlist_id = match.group("id")
|
|
||||||
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
|
||||||
if resp.status != 200:
|
|
||||||
raise SpotifyRequestException(
|
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
|
||||||
)
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
|
||||||
|
|
||||||
# Yield first page immediately
|
|
||||||
first_page_tracks = [
|
|
||||||
Track(item["track"])
|
|
||||||
for item in data["tracks"]["items"]
|
|
||||||
if item.get("track") is not None
|
|
||||||
]
|
|
||||||
# Batch yield
|
|
||||||
for i in range(0, len(first_page_tracks), batch_size):
|
|
||||||
yield first_page_tracks[i : i + batch_size]
|
|
||||||
|
|
||||||
total = data["tracks"]["total"]
|
|
||||||
limit = data["tracks"]["limit"]
|
|
||||||
remaining_offsets = range(limit, total, limit)
|
|
||||||
fields_filter = (
|
|
||||||
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
|
||||||
",next"
|
|
||||||
)
|
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
|
||||||
|
|
||||||
async def fetch(offset: int) -> List[Track]:
|
|
||||||
url = (
|
|
||||||
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
|
|
||||||
)
|
|
||||||
async with semaphore:
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException(
|
|
||||||
"HTTP session not initialized for Spotify client.",
|
|
||||||
)
|
|
||||||
r = await self.session.get(url, headers=self._bearer_headers)
|
|
||||||
if r.status != 200:
|
|
||||||
if self._log:
|
|
||||||
self._log.warning(
|
|
||||||
f"Skipping page offset={offset} due to {r.status} {r.reason}",
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
pj: dict = await r.json(loads=json.loads)
|
|
||||||
return [
|
|
||||||
Track(item["track"])
|
|
||||||
for item in pj.get("items", [])
|
|
||||||
if item.get("track") is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
|
|
||||||
wave_size = self._playlist_concurrency * 2
|
|
||||||
for i, offset in enumerate(remaining_offsets):
|
|
||||||
# Build wave
|
|
||||||
if i % wave_size == 0:
|
|
||||||
wave_offsets = list(
|
|
||||||
o for o in remaining_offsets if o >= offset and o < offset + wave_size
|
|
||||||
)
|
|
||||||
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
|
|
||||||
for page_tracks in results:
|
|
||||||
if not page_tracks:
|
|
||||||
continue
|
|
||||||
for j in range(0, len(page_tracks), batch_size):
|
|
||||||
yield page_tracks[j : j + batch_size]
|
|
||||||
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
|
|
||||||
# Fast-forward the generator manually
|
|
||||||
for _ in range(len(wave_offsets) - 1):
|
|
||||||
try:
|
|
||||||
next(remaining_offsets) # type: ignore
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def get_recommendations(self, *, query: str) -> List[Track]:
|
async def get_recommendations(self, *, query: str) -> List[Track]:
|
||||||
if not self._bearer_token or time.time() >= self._expiry:
|
if not self._bearer_token or time.time() >= self._expiry:
|
||||||
await self._fetch_bearer_token()
|
await self._fetch_bearer_token()
|
||||||
|
|
@ -327,8 +165,6 @@ class Client:
|
||||||
id=f"?seed_tracks={spotify_id}",
|
id=f"?seed_tracks={spotify_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
|
|
@ -339,22 +175,3 @@ class Client:
|
||||||
tracks = [Track(track) for track in data["tracks"]]
|
tracks = [Track(track) for track in data["tracks"]]
|
||||||
|
|
||||||
return tracks
|
return tracks
|
||||||
|
|
||||||
async def track_search(self, *, query: str) -> List[Track]:
|
|
||||||
if not self._bearer_token or time.time() >= self._expiry:
|
|
||||||
await self._fetch_bearer_token()
|
|
||||||
|
|
||||||
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
|
|
||||||
|
|
||||||
if not self.session:
|
|
||||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
|
||||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
|
||||||
if resp.status != 200:
|
|
||||||
raise SpotifyRequestException(
|
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
|
||||||
)
|
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
|
||||||
tracks = [Track(track) for track in data["tracks"]["items"]]
|
|
||||||
|
|
||||||
return tracks
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue