Compare commits

..

2 Commits

Author SHA1 Message Date
pre-commit-ci[bot] ff97597cf2
Merge 594b79151e into 855bf4e0d7 2025-09-22 17:50:47 +00:00
pre-commit-ci[bot] 594b79151e
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.5.0 → v6.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.5.0...v6.0.0)
- [github.com/psf/black: 23.10.1 → 25.9.0](https://github.com/psf/black/compare/23.10.1...25.9.0)
- [github.com/asottile/pyupgrade: v3.15.0 → v3.20.0](https://github.com/asottile/pyupgrade/compare/v3.15.0...v3.20.0)
- [github.com/asottile/reorder-python-imports: v3.12.0 → v3.15.0](https://github.com/asottile/reorder-python-imports/compare/v3.12.0...v3.15.0)
- [github.com/asottile/add-trailing-comma: v3.1.0 → v3.2.0](https://github.com/asottile/add-trailing-comma/compare/v3.1.0...v3.2.0)
2025-09-22 17:50:44 +00:00
6 changed files with 65 additions and 362 deletions

View File

@ -10,11 +10,11 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: requirements-txt-fixer - id: requirements-txt-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black
rev: 25.9.0 rev: 25.9.0
hooks: hooks:
- id: black - id: black
language_version: python3.13 language_version: python3.12
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.20.0 rev: v3.20.0
hooks: hooks:
@ -30,4 +30,4 @@ repos:
- id: add-trailing-comma - id: add-trailing-comma
default_language_version: default_language_version:
python: python3.13 python: python3.12

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'", "using 'pip install discord.py'",
) )
__version__ = "2.10.0" __version__ = "2.9.2"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0" __license__ = "GPL-3.0"

View File

@ -1,14 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
import logging import logging
import re import re
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Union from typing import Union
import aiohttp import aiohttp
@ -20,10 +17,10 @@ from .objects import *
__all__ = ("Client",) __all__ = ("Client",)
AM_URL_REGEX = re.compile( AM_URL_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$", r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
) )
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$", r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
) )
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"') AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
@ -38,14 +35,12 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here. and translating it to a valid Lavalink track. No client auth is required here.
""" """
def __init__(self, *, playlist_concurrency: int = 6) -> None: def __init__(self) -> None:
self.expiry: datetime = datetime(1970, 1, 1) self.expiry: datetime = datetime(1970, 1, 1)
self.token: str = "" self.token: str = ""
self.headers: Dict[str, str] = {} self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__) self._log = logging.getLogger(__name__)
# Concurrency knob for parallel playlist page retrieval
self._playlist_concurrency = max(1, playlist_concurrency)
async def _set_session(self, session: aiohttp.ClientSession) -> None: async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session self.session = session
@ -172,127 +167,25 @@ class Client:
"This playlist is empty and therefore cannot be queued.", "This playlist is empty and therefore cannot be queued.",
) )
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages _next = track_data.get("next")
# concurrently by first collecting cursors in rolling waves. if _next:
next_cursor = track_data.get("next") next_page_url = AM_BASE_URL + _next
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch_page(url: str) -> List[Song]: while next_page_url is not None:
async with semaphore: resp = await self.session.get(next_page_url, headers=self.headers)
resp = await self.session.get(url, headers=self.headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
)
return []
pj: dict = await resp.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
# Return songs; we will look for pj.get('next') in streaming iterator variant
return songs, pj.get("next") # type: ignore
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
waves: List[List[Song]] = []
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
# Limit total waves to avoid infinite loops in malformed responses
max_waves = 50
wave_size = self._playlist_concurrency * 2
wave_counter = 0
while cursors and wave_counter < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
tasks = [
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for res in results:
if isinstance(res, tuple): # (songs, next)
songs, nxt = res
if songs:
waves.append(songs)
if nxt:
cursors.append(nxt)
wave_counter += 1
for w in waves:
album_tracks.extend(w)
return Playlist(data, album_tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Song], None]:
"""Stream Apple Music playlist tracks in batches.
Parameters
----------
query: str
Apple Music playlist URL.
batch_size: int
Logical grouping size for yielded batches.
"""
if not self.token or datetime.utcnow() > self.expiry:
await self.request_token()
result = AM_URL_REGEX.match(query)
if not result or result.group("type") != "playlist":
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
country = result.group("country")
playlist_id = result.group("id")
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
resp = await self.session.get(request_url, headers=self.headers)
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}", f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads)
playlist_data = data["data"][0]
track_data: dict = playlist_data["relationships"]["tracks"]
first_page_tracks = [Song(track) for track in track_data["data"]] next_data: dict = await resp.json(loads=json.loads)
for i in range(0, len(first_page_tracks), batch_size): album_tracks.extend(Song(track) for track in next_data["data"])
yield first_page_tracks[i : i + batch_size]
next_cursor = track_data.get("next") _next = next_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency) if _next:
next_page_url = AM_BASE_URL + _next
else:
next_page_url = None
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]: return Playlist(data, album_tracks)
url = AM_BASE_URL + cursor
async with semaphore:
r = await self.session.get(url, headers=self.headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping Apple Music page due to {r.status} {r.reason}",
)
return [], None
pj: dict = await r.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
return songs, pj.get("next")
# Rolling waves of fetches following cursor chain
max_waves = 50
wave_size = self._playlist_concurrency * 2
waves = 0
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
while cursors and waves < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
results = await asyncio.gather(*[fetch(c) for c in current])
for songs, nxt in results:
if songs:
for j in range(0, len(songs), batch_size):
yield songs[j : j + batch_size]
if nxt:
cursors.append(nxt)
waves += 1

View File

@ -34,11 +34,6 @@ class SearchType(Enum):
ytsearch = "ytsearch" ytsearch = "ytsearch"
ytmsearch = "ytmsearch" ytmsearch = "ytmsearch"
scsearch = "scsearch" scsearch = "scsearch"
other = "other"
@classmethod
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
return cls.other
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -73,7 +68,7 @@ class TrackType(Enum):
OTHER = "other" OTHER = "other"
@classmethod @classmethod
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override] def _missing_(cls, _: object) -> "TrackType":
return cls.OTHER return cls.OTHER
def __str__(self) -> str: def __str__(self) -> str:
@ -103,7 +98,7 @@ class PlaylistType(Enum):
OTHER = "other" OTHER = "other"
@classmethod @classmethod
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override] def _missing_(cls, _: object) -> "PlaylistType":
return cls.OTHER return cls.OTHER
def __str__(self) -> str: def __str__(self) -> str:
@ -218,12 +213,8 @@ class URLRegex:
""" """
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
# the type and id while ignoring extra parameters. This prevents the URL from being
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
SPOTIFY_URL = re.compile( SPOTIFY_URL = re.compile(
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$", r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
) )
DISCORD_MP3_URL = re.compile( DISCORD_MP3_URL = re.compile(
@ -244,17 +235,14 @@ class URLRegex:
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*", r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
) )
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
# Allow arbitrary query parameters so valid links are captured and parsed.
AM_URL = re.compile( AM_URL = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$", r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
) )
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$", r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
) )
SOUNDCLOUD_URL = re.compile( SOUNDCLOUD_URL = re.compile(

View File

@ -21,12 +21,7 @@ import orjson as json
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from discord.utils import MISSING from discord.utils import MISSING
from websockets import client
try:
from websockets.legacy import client # websockets >= 10.0
except ImportError:
import websockets.client as client # websockets < 10.0 # type: ignore
from websockets import exceptions from websockets import exceptions
from websockets import typing as wstype from websockets import typing as wstype
@ -308,7 +303,7 @@ class Node:
if not self._resume_key: if not self._resume_key:
return return
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout} data = {"timeout": self._resume_timeout}
if self._version.major == 3: if self._version.major == 3:
data["resumingKey"] = self._resume_key data["resumingKey"] = self._resume_key
@ -449,17 +444,7 @@ class Node:
start = time.perf_counter() start = time.perf_counter()
if not self._session: if not self._session:
# Configure connection pooling for optimal concurrent request performance self._session = aiohttp.ClientSession()
connector = aiohttp.TCPConnector(
limit=100, # Total connection limit
limit_per_host=30, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL in seconds
)
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
try: try:
if not reconnect: if not reconnect:
@ -478,7 +463,7 @@ class Node:
f"Version check from Node {self._identifier} successful. Returned version {version}", f"Version check from Node {self._identifier} successful. Returned version {version}",
) )
self._websocket = await client.connect( # type: ignore self._websocket = await client.connect(
f"{self._websocket_uri}/v{self._version.major}/websocket", f"{self._websocket_uri}/v{self._version.major}/websocket",
extra_headers=self._headers, extra_headers=self._headers,
ping_interval=self._heartbeat, ping_interval=self._heartbeat,
@ -575,7 +560,7 @@ class Node:
query: str, query: str,
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: Optional[SearchType] = SearchType.ytsearch, search_type: SearchType | None = SearchType.ytsearch,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
) -> Optional[Union[Playlist, List[Track]]]: ) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
@ -610,7 +595,7 @@ class Node:
track_id=apple_music_results.id, track_id=apple_music_results.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.APPLE_MUSIC, track_type=TrackType.APPLE_MUSIC,
search_type=search_type or SearchType.ytsearch, search_type=search_type,
filters=filters, filters=filters,
info={ info={
"title": apple_music_results.name, "title": apple_music_results.name,
@ -632,7 +617,7 @@ class Node:
track_id=track.id, track_id=track.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.APPLE_MUSIC, track_type=TrackType.APPLE_MUSIC,
search_type=search_type or SearchType.ytsearch, search_type=search_type,
filters=filters, filters=filters,
info={ info={
"title": track.name, "title": track.name,
@ -670,7 +655,7 @@ class Node:
track_id=spotify_results.id, track_id=spotify_results.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.SPOTIFY, track_type=TrackType.SPOTIFY,
search_type=search_type or SearchType.ytsearch, search_type=search_type,
filters=filters, filters=filters,
info={ info={
"title": spotify_results.name, "title": spotify_results.name,
@ -692,7 +677,7 @@ class Node:
track_id=track.id, track_id=track.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.SPOTIFY, track_type=TrackType.SPOTIFY,
search_type=search_type or SearchType.ytsearch, search_type=search_type,
filters=filters, filters=filters,
info={ info={
"title": track.name, "title": track.name,

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import re import re
import time import time
from base64 import b64encode from base64 import b64encode
from typing import AsyncGenerator
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -24,10 +22,8 @@ __all__ = ("Client",)
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
# optional trailing slash, and query parameters.
SPOTIFY_URL_REGEX = re.compile( SPOTIFY_URL_REGEX = re.compile(
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$", r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
) )
@ -37,39 +33,29 @@ class Client:
for any Spotify URL you throw at it. for any Spotify URL you throw at it.
""" """
def __init__( def __init__(self, client_id: str, client_secret: str) -> None:
self, self._client_id: str = client_id
client_id: str, self._client_secret: str = client_secret
client_secret: str,
*,
playlist_concurrency: int = 10,
playlist_page_limit: Optional[int] = None,
) -> None:
self._client_id = client_id
self._client_secret = client_secret
# HTTP session will be injected by Node self.session: aiohttp.ClientSession = None # type: ignore
self.session: Optional[aiohttp.ClientSession] = None
self._bearer_token: Optional[str] = None self._bearer_token: Optional[str] = None
self._expiry: float = 0.0 self._expiry: float = 0.0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode()) self._auth_token = b64encode(
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"} f"{self._client_id}:{self._client_secret}".encode(),
)
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__) self._log = logging.getLogger(__name__)
# Performance tuning knobs
self._playlist_concurrency = max(1, playlist_concurrency)
self._playlist_page_limit = playlist_page_limit
async def _set_session(self, session: aiohttp.ClientSession) -> None: async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session self.session = session
async def _fetch_bearer_token(self) -> None: async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"} _data = {"grant_type": "client_credentials"}
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
if resp.status != 200: if resp.status != 200:
@ -100,8 +86,6 @@ class Client:
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id) request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers) resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
@ -119,8 +103,6 @@ class Client:
elif spotify_type == "album": elif spotify_type == "album":
return Album(data) return Album(data)
elif spotify_type == "artist": elif spotify_type == "artist":
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get( resp = await self.session.get(
f"{request_url}/top-tracks?market=US", f"{request_url}/top-tracks?market=US",
headers=self._bearer_headers, headers=self._bearer_headers,
@ -134,177 +116,36 @@ class Client:
tracks = track_data["tracks"] tracks = track_data["tracks"]
return Artist(data, tracks) return Artist(data, tracks)
else: else:
# For playlists we optionally use a reduced fields payload to shrink response sizes.
# NB: We cannot apply fields filter to initial request because original metadata is needed.
tracks = [ tracks = [
Track(track["track"]) Track(track["track"])
for track in data["tracks"]["items"] for track in data["tracks"]["items"]
if track["track"] is not None if track["track"] is not None
] ]
if not tracks: if not len(tracks):
raise SpotifyRequestException( raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued.", "This playlist is empty and therefore cannot be queued.",
) )
total_tracks = data["tracks"]["total"] next_page_url = data["tracks"]["next"]
limit = data["tracks"]["limit"]
# Shortcircuit small playlists (single page) while next_page_url is not None:
if total_tracks <= limit: resp = await self.session.get(next_page_url, headers=self._bearer_headers)
return Playlist(data, tracks)
# Build remaining page URLs; Spotify supports offset-based pagination.
remaining_offsets = range(limit, total_tracks, limit)
page_urls: List[str] = []
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
for idx, offset in enumerate(remaining_offsets):
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
break
page_urls.append(
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
)
if page_urls:
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch_page(url: str) -> Optional[List[Track]]:
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
resp = await self.session.get(url, headers=self._bearer_headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Page fetch failed {resp.status} {resp.reason} for {url}",
)
return None
page_json: dict = await resp.json(loads=json.loads)
return [
Track(item["track"])
for item in page_json.get("items", [])
if item.get("track") is not None
]
# Chunk gather in waves to avoid creating thousands of tasks at once
aggregated: List[Track] = []
wave_size = self._playlist_concurrency * 2
for i in range(0, len(page_urls), wave_size):
wave = page_urls[i : i + wave_size]
results = await asyncio.gather(
*[fetch_page(url) for url in wave],
return_exceptions=False,
)
for result in results:
if result:
aggregated.extend(result)
tracks.extend(aggregated)
return Playlist(data, tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Track], None]:
"""Stream playlist tracks in batches without waiting for full materialization.
Parameters
----------
query: str
Spotify playlist URL.
batch_size: int
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
"""
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
match = SPOTIFY_URL_REGEX.match(query)
if not match or match.group("type") != "playlist":
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
playlist_id = match.group("id")
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}", f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads)
# Yield first page immediately next_data: dict = await resp.json(loads=json.loads)
first_page_tracks = [
Track(item["track"]) tracks += [
for item in data["tracks"]["items"] Track(track["track"])
if item.get("track") is not None for track in next_data["items"]
if track["track"] is not None
] ]
# Batch yield next_page_url = next_data["next"]
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
total = data["tracks"]["total"] return Playlist(data, tracks)
limit = data["tracks"]["limit"]
remaining_offsets = range(limit, total, limit)
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch(offset: int) -> List[Track]:
url = (
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
)
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
r = await self.session.get(url, headers=self._bearer_headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping page offset={offset} due to {r.status} {r.reason}",
)
return []
pj: dict = await r.json(loads=json.loads)
return [
Track(item["track"])
for item in pj.get("items", [])
if item.get("track") is not None
]
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
wave_size = self._playlist_concurrency * 2
for i, offset in enumerate(remaining_offsets):
# Build wave
if i % wave_size == 0:
wave_offsets = list(
o for o in remaining_offsets if o >= offset and o < offset + wave_size
)
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
for page_tracks in results:
if not page_tracks:
continue
for j in range(0, len(page_tracks), batch_size):
yield page_tracks[j : j + batch_size]
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
# Fast-forward the generator manually
for _ in range(len(wave_offsets) - 1):
try:
next(remaining_offsets) # type: ignore
except StopIteration:
break
async def get_recommendations(self, *, query: str) -> List[Track]: async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
@ -327,8 +168,6 @@ class Client:
id=f"?seed_tracks={spotify_id}", id=f"?seed_tracks={spotify_id}",
) )
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers) resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
@ -346,8 +185,6 @@ class Client:
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track" request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers) resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(