Compare commits
3 Commits
594b79151e
...
9dbcffb113
| Author | SHA1 | Date |
|---|---|---|
|
|
9dbcffb113 | |
|
|
9bffdebe25 | |
|
|
720ba187ab |
|
|
@ -10,11 +10,11 @@ repos:
|
|||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/psf/black
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 25.9.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.12
|
||||
language_version: python3.13
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
|
|
@ -30,4 +30,4 @@ repos:
|
|||
- id: add-trailing-comma
|
||||
|
||||
default_language_version:
|
||||
python: python3.12
|
||||
python: python3.13
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
|
|||
"using 'pip install discord.py'",
|
||||
)
|
||||
|
||||
__version__ = "2.9.2"
|
||||
__version__ = "2.10.0"
|
||||
__title__ = "pomice"
|
||||
__author__ = "cloudwithax"
|
||||
__license__ = "GPL-3.0"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -17,10 +20,10 @@ from .objects import *
|
|||
__all__ = ("Client",)
|
||||
|
||||
AM_URL_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
||||
)
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
||||
)
|
||||
|
||||
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
|
||||
|
|
@ -35,12 +38,14 @@ class Client:
|
|||
and translating it to a valid Lavalink track. No client auth is required here.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, *, playlist_concurrency: int = 6) -> None:
|
||||
self.expiry: datetime = datetime(1970, 1, 1)
|
||||
self.token: str = ""
|
||||
self.headers: Dict[str, str] = {}
|
||||
self.session: aiohttp.ClientSession = None # type: ignore
|
||||
self._log = logging.getLogger(__name__)
|
||||
# Concurrency knob for parallel playlist page retrieval
|
||||
self._playlist_concurrency = max(1, playlist_concurrency)
|
||||
|
||||
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||
self.session = session
|
||||
|
|
@ -167,25 +172,127 @@ class Client:
|
|||
"This playlist is empty and therefore cannot be queued.",
|
||||
)
|
||||
|
||||
_next = track_data.get("next")
|
||||
if _next:
|
||||
next_page_url = AM_BASE_URL + _next
|
||||
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
|
||||
# concurrently by first collecting cursors in rolling waves.
|
||||
next_cursor = track_data.get("next")
|
||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||
|
||||
while next_page_url is not None:
|
||||
resp = await self.session.get(next_page_url, headers=self.headers)
|
||||
async def fetch_page(url: str) -> List[Song]:
|
||||
async with semaphore:
|
||||
resp = await self.session.get(url, headers=self.headers)
|
||||
if resp.status != 200:
|
||||
if self._log:
|
||||
self._log.warning(
|
||||
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
|
||||
)
|
||||
return []
|
||||
pj: dict = await resp.json(loads=json.loads)
|
||||
songs = [Song(track) for track in pj.get("data", [])]
|
||||
# Return songs; we will look for pj.get('next') in streaming iterator variant
|
||||
return songs, pj.get("next") # type: ignore
|
||||
|
||||
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
|
||||
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
|
||||
waves: List[List[Song]] = []
|
||||
cursors: List[str] = []
|
||||
if next_cursor:
|
||||
cursors.append(next_cursor)
|
||||
|
||||
# Limit total waves to avoid infinite loops in malformed responses
|
||||
max_waves = 50
|
||||
wave_size = self._playlist_concurrency * 2
|
||||
wave_counter = 0
|
||||
while cursors and wave_counter < max_waves:
|
||||
current = cursors[:wave_size]
|
||||
cursors = cursors[wave_size:]
|
||||
tasks = [
|
||||
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for res in results:
|
||||
if isinstance(res, tuple): # (songs, next)
|
||||
songs, nxt = res
|
||||
if songs:
|
||||
waves.append(songs)
|
||||
if nxt:
|
||||
cursors.append(nxt)
|
||||
wave_counter += 1
|
||||
|
||||
for w in waves:
|
||||
album_tracks.extend(w)
|
||||
|
||||
return Playlist(data, album_tracks)
|
||||
|
||||
async def iter_playlist_tracks(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
batch_size: int = 100,
|
||||
) -> AsyncGenerator[List[Song], None]:
|
||||
"""Stream Apple Music playlist tracks in batches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str
|
||||
Apple Music playlist URL.
|
||||
batch_size: int
|
||||
Logical grouping size for yielded batches.
|
||||
"""
|
||||
if not self.token or datetime.utcnow() > self.expiry:
|
||||
await self.request_token()
|
||||
|
||||
result = AM_URL_REGEX.match(query)
|
||||
if not result or result.group("type") != "playlist":
|
||||
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
|
||||
|
||||
country = result.group("country")
|
||||
playlist_id = result.group("id")
|
||||
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
|
||||
resp = await self.session.get(request_url, headers=self.headers)
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
playlist_data = data["data"][0]
|
||||
track_data: dict = playlist_data["relationships"]["tracks"]
|
||||
|
||||
next_data: dict = await resp.json(loads=json.loads)
|
||||
album_tracks.extend(Song(track) for track in next_data["data"])
|
||||
first_page_tracks = [Song(track) for track in track_data["data"]]
|
||||
for i in range(0, len(first_page_tracks), batch_size):
|
||||
yield first_page_tracks[i : i + batch_size]
|
||||
|
||||
_next = next_data.get("next")
|
||||
if _next:
|
||||
next_page_url = AM_BASE_URL + _next
|
||||
else:
|
||||
next_page_url = None
|
||||
next_cursor = track_data.get("next")
|
||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||
|
||||
return Playlist(data, album_tracks)
|
||||
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
|
||||
url = AM_BASE_URL + cursor
|
||||
async with semaphore:
|
||||
r = await self.session.get(url, headers=self.headers)
|
||||
if r.status != 200:
|
||||
if self._log:
|
||||
self._log.warning(
|
||||
f"Skipping Apple Music page due to {r.status} {r.reason}",
|
||||
)
|
||||
return [], None
|
||||
pj: dict = await r.json(loads=json.loads)
|
||||
songs = [Song(track) for track in pj.get("data", [])]
|
||||
return songs, pj.get("next")
|
||||
|
||||
# Rolling waves of fetches following cursor chain
|
||||
max_waves = 50
|
||||
wave_size = self._playlist_concurrency * 2
|
||||
waves = 0
|
||||
cursors: List[str] = []
|
||||
if next_cursor:
|
||||
cursors.append(next_cursor)
|
||||
while cursors and waves < max_waves:
|
||||
current = cursors[:wave_size]
|
||||
cursors = cursors[wave_size:]
|
||||
results = await asyncio.gather(*[fetch(c) for c in current])
|
||||
for songs, nxt in results:
|
||||
if songs:
|
||||
for j in range(0, len(songs), batch_size):
|
||||
yield songs[j : j + batch_size]
|
||||
if nxt:
|
||||
cursors.append(nxt)
|
||||
waves += 1
|
||||
|
|
|
|||
|
|
@ -34,6 +34,11 @@ class SearchType(Enum):
|
|||
ytsearch = "ytsearch"
|
||||
ytmsearch = "ytmsearch"
|
||||
scsearch = "scsearch"
|
||||
other = "other"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
|
||||
return cls.other
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
@ -68,7 +73,7 @@ class TrackType(Enum):
|
|||
OTHER = "other"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, _: object) -> "TrackType":
|
||||
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
|
||||
return cls.OTHER
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
@ -98,7 +103,7 @@ class PlaylistType(Enum):
|
|||
OTHER = "other"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, _: object) -> "PlaylistType":
|
||||
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
|
||||
return cls.OTHER
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
@ -213,8 +218,12 @@ class URLRegex:
|
|||
|
||||
"""
|
||||
|
||||
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
|
||||
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
|
||||
# the type and id while ignoring extra parameters. This prevents the URL from being
|
||||
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
|
||||
SPOTIFY_URL = re.compile(
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
||||
)
|
||||
|
||||
DISCORD_MP3_URL = re.compile(
|
||||
|
|
@ -235,14 +244,17 @@ class URLRegex:
|
|||
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
||||
)
|
||||
|
||||
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
|
||||
# Allow arbitrary query parameters so valid links are captured and parsed.
|
||||
AM_URL = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
|
||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
|
||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
||||
)
|
||||
|
||||
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
||||
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
||||
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
||||
)
|
||||
|
||||
SOUNDCLOUD_URL = re.compile(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,12 @@ import orjson as json
|
|||
from discord import Client
|
||||
from discord.ext import commands
|
||||
from discord.utils import MISSING
|
||||
from websockets import client
|
||||
|
||||
try:
|
||||
from websockets.legacy import client # websockets >= 10.0
|
||||
except ImportError:
|
||||
import websockets.client as client # websockets < 10.0 # type: ignore
|
||||
|
||||
from websockets import exceptions
|
||||
from websockets import typing as wstype
|
||||
|
||||
|
|
@ -303,7 +308,7 @@ class Node:
|
|||
if not self._resume_key:
|
||||
return
|
||||
|
||||
data = {"timeout": self._resume_timeout}
|
||||
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
|
||||
|
||||
if self._version.major == 3:
|
||||
data["resumingKey"] = self._resume_key
|
||||
|
|
@ -444,7 +449,17 @@ class Node:
|
|||
start = time.perf_counter()
|
||||
|
||||
if not self._session:
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Configure connection pooling for optimal concurrent request performance
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=100, # Total connection limit
|
||||
limit_per_host=30, # Per-host connection limit
|
||||
ttl_dns_cache=300, # DNS cache TTL in seconds
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
if not reconnect:
|
||||
|
|
@ -463,7 +478,7 @@ class Node:
|
|||
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
||||
)
|
||||
|
||||
self._websocket = await client.connect(
|
||||
self._websocket = await client.connect( # type: ignore
|
||||
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
||||
extra_headers=self._headers,
|
||||
ping_interval=self._heartbeat,
|
||||
|
|
@ -560,7 +575,7 @@ class Node:
|
|||
query: str,
|
||||
*,
|
||||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType | None = SearchType.ytsearch,
|
||||
search_type: Optional[SearchType] = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
) -> Optional[Union[Playlist, List[Track]]]:
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
|
@ -595,7 +610,7 @@ class Node:
|
|||
track_id=apple_music_results.id,
|
||||
ctx=ctx,
|
||||
track_type=TrackType.APPLE_MUSIC,
|
||||
search_type=search_type,
|
||||
search_type=search_type or SearchType.ytsearch,
|
||||
filters=filters,
|
||||
info={
|
||||
"title": apple_music_results.name,
|
||||
|
|
@ -617,7 +632,7 @@ class Node:
|
|||
track_id=track.id,
|
||||
ctx=ctx,
|
||||
track_type=TrackType.APPLE_MUSIC,
|
||||
search_type=search_type,
|
||||
search_type=search_type or SearchType.ytsearch,
|
||||
filters=filters,
|
||||
info={
|
||||
"title": track.name,
|
||||
|
|
@ -655,7 +670,7 @@ class Node:
|
|||
track_id=spotify_results.id,
|
||||
ctx=ctx,
|
||||
track_type=TrackType.SPOTIFY,
|
||||
search_type=search_type,
|
||||
search_type=search_type or SearchType.ytsearch,
|
||||
filters=filters,
|
||||
info={
|
||||
"title": spotify_results.name,
|
||||
|
|
@ -677,7 +692,7 @@ class Node:
|
|||
track_id=track.id,
|
||||
ctx=ctx,
|
||||
track_type=TrackType.SPOTIFY,
|
||||
search_type=search_type,
|
||||
search_type=search_type or SearchType.ytsearch,
|
||||
filters=filters,
|
||||
info={
|
||||
"title": track.name,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from typing import AsyncGenerator
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
|
@ -22,8 +24,10 @@ __all__ = ("Client",)
|
|||
|
||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
|
||||
# optional trailing slash, and query parameters.
|
||||
SPOTIFY_URL_REGEX = re.compile(
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -33,29 +37,39 @@ class Client:
|
|||
for any Spotify URL you throw at it.
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
self._client_id: str = client_id
|
||||
self._client_secret: str = client_secret
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
playlist_concurrency: int = 10,
|
||||
playlist_page_limit: Optional[int] = None,
|
||||
) -> None:
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
|
||||
self.session: aiohttp.ClientSession = None # type: ignore
|
||||
# HTTP session will be injected by Node
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._expiry: float = 0.0
|
||||
self._auth_token = b64encode(
|
||||
f"{self._client_id}:{self._client_secret}".encode(),
|
||||
)
|
||||
self._grant_headers = {
|
||||
"Authorization": f"Basic {self._auth_token.decode()}",
|
||||
}
|
||||
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
|
||||
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
|
||||
self._bearer_headers: Optional[Dict] = None
|
||||
self._log = logging.getLogger(__name__)
|
||||
|
||||
# Performance tuning knobs
|
||||
self._playlist_concurrency = max(1, playlist_concurrency)
|
||||
self._playlist_page_limit = playlist_page_limit
|
||||
|
||||
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def _fetch_bearer_token(self) -> None:
|
||||
_data = {"grant_type": "client_credentials"}
|
||||
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
|
||||
|
||||
if resp.status != 200:
|
||||
|
|
@ -86,6 +100,8 @@ class Client:
|
|||
|
||||
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
||||
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
|
|
@ -103,6 +119,8 @@ class Client:
|
|||
elif spotify_type == "album":
|
||||
return Album(data)
|
||||
elif spotify_type == "artist":
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.get(
|
||||
f"{request_url}/top-tracks?market=US",
|
||||
headers=self._bearer_headers,
|
||||
|
|
@ -116,36 +134,177 @@ class Client:
|
|||
tracks = track_data["tracks"]
|
||||
return Artist(data, tracks)
|
||||
else:
|
||||
# For playlists we optionally use a reduced fields payload to shrink response sizes.
|
||||
# NB: We cannot apply fields filter to initial request because original metadata is needed.
|
||||
tracks = [
|
||||
Track(track["track"])
|
||||
for track in data["tracks"]["items"]
|
||||
if track["track"] is not None
|
||||
]
|
||||
|
||||
if not len(tracks):
|
||||
if not tracks:
|
||||
raise SpotifyRequestException(
|
||||
"This playlist is empty and therefore cannot be queued.",
|
||||
)
|
||||
|
||||
next_page_url = data["tracks"]["next"]
|
||||
total_tracks = data["tracks"]["total"]
|
||||
limit = data["tracks"]["limit"]
|
||||
|
||||
while next_page_url is not None:
|
||||
resp = await self.session.get(next_page_url, headers=self._bearer_headers)
|
||||
# Short‑circuit small playlists (single page)
|
||||
if total_tracks <= limit:
|
||||
return Playlist(data, tracks)
|
||||
|
||||
# Build remaining page URLs; Spotify supports offset-based pagination.
|
||||
remaining_offsets = range(limit, total_tracks, limit)
|
||||
page_urls: List[str] = []
|
||||
fields_filter = (
|
||||
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
||||
",next"
|
||||
)
|
||||
for idx, offset in enumerate(remaining_offsets):
|
||||
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
|
||||
break
|
||||
page_urls.append(
|
||||
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
|
||||
)
|
||||
|
||||
if page_urls:
|
||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||
|
||||
async def fetch_page(url: str) -> Optional[List[Track]]:
|
||||
async with semaphore:
|
||||
if not self.session:
|
||||
raise SpotifyRequestException(
|
||||
"HTTP session not initialized for Spotify client.",
|
||||
)
|
||||
resp = await self.session.get(url, headers=self._bearer_headers)
|
||||
if resp.status != 200:
|
||||
if self._log:
|
||||
self._log.warning(
|
||||
f"Page fetch failed {resp.status} {resp.reason} for {url}",
|
||||
)
|
||||
return None
|
||||
page_json: dict = await resp.json(loads=json.loads)
|
||||
return [
|
||||
Track(item["track"])
|
||||
for item in page_json.get("items", [])
|
||||
if item.get("track") is not None
|
||||
]
|
||||
|
||||
# Chunk gather in waves to avoid creating thousands of tasks at once
|
||||
aggregated: List[Track] = []
|
||||
wave_size = self._playlist_concurrency * 2
|
||||
for i in range(0, len(page_urls), wave_size):
|
||||
wave = page_urls[i : i + wave_size]
|
||||
results = await asyncio.gather(
|
||||
*[fetch_page(url) for url in wave],
|
||||
return_exceptions=False,
|
||||
)
|
||||
for result in results:
|
||||
if result:
|
||||
aggregated.extend(result)
|
||||
|
||||
tracks.extend(aggregated)
|
||||
|
||||
return Playlist(data, tracks)
|
||||
|
||||
async def iter_playlist_tracks(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
batch_size: int = 100,
|
||||
) -> AsyncGenerator[List[Track], None]:
|
||||
"""Stream playlist tracks in batches without waiting for full materialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str
|
||||
Spotify playlist URL.
|
||||
batch_size: int
|
||||
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
|
||||
"""
|
||||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
await self._fetch_bearer_token()
|
||||
|
||||
match = SPOTIFY_URL_REGEX.match(query)
|
||||
if not match or match.group("type") != "playlist":
|
||||
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
|
||||
|
||||
playlist_id = match.group("id")
|
||||
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
|
||||
next_data: dict = await resp.json(loads=json.loads)
|
||||
|
||||
tracks += [
|
||||
Track(track["track"])
|
||||
for track in next_data["items"]
|
||||
if track["track"] is not None
|
||||
# Yield first page immediately
|
||||
first_page_tracks = [
|
||||
Track(item["track"])
|
||||
for item in data["tracks"]["items"]
|
||||
if item.get("track") is not None
|
||||
]
|
||||
next_page_url = next_data["next"]
|
||||
# Batch yield
|
||||
for i in range(0, len(first_page_tracks), batch_size):
|
||||
yield first_page_tracks[i : i + batch_size]
|
||||
|
||||
return Playlist(data, tracks)
|
||||
total = data["tracks"]["total"]
|
||||
limit = data["tracks"]["limit"]
|
||||
remaining_offsets = range(limit, total, limit)
|
||||
fields_filter = (
|
||||
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
||||
",next"
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||
|
||||
async def fetch(offset: int) -> List[Track]:
|
||||
url = (
|
||||
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
|
||||
)
|
||||
async with semaphore:
|
||||
if not self.session:
|
||||
raise SpotifyRequestException(
|
||||
"HTTP session not initialized for Spotify client.",
|
||||
)
|
||||
r = await self.session.get(url, headers=self._bearer_headers)
|
||||
if r.status != 200:
|
||||
if self._log:
|
||||
self._log.warning(
|
||||
f"Skipping page offset={offset} due to {r.status} {r.reason}",
|
||||
)
|
||||
return []
|
||||
pj: dict = await r.json(loads=json.loads)
|
||||
return [
|
||||
Track(item["track"])
|
||||
for item in pj.get("items", [])
|
||||
if item.get("track") is not None
|
||||
]
|
||||
|
||||
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
|
||||
wave_size = self._playlist_concurrency * 2
|
||||
for i, offset in enumerate(remaining_offsets):
|
||||
# Build wave
|
||||
if i % wave_size == 0:
|
||||
wave_offsets = list(
|
||||
o for o in remaining_offsets if o >= offset and o < offset + wave_size
|
||||
)
|
||||
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
|
||||
for page_tracks in results:
|
||||
if not page_tracks:
|
||||
continue
|
||||
for j in range(0, len(page_tracks), batch_size):
|
||||
yield page_tracks[j : j + batch_size]
|
||||
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
|
||||
# Fast-forward the generator manually
|
||||
for _ in range(len(wave_offsets) - 1):
|
||||
try:
|
||||
next(remaining_offsets) # type: ignore
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
async def get_recommendations(self, *, query: str) -> List[Track]:
|
||||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
|
|
@ -168,6 +327,8 @@ class Client:
|
|||
id=f"?seed_tracks={spotify_id}",
|
||||
)
|
||||
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
|
|
@ -185,6 +346,8 @@ class Client:
|
|||
|
||||
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
|
||||
|
||||
if not self.session:
|
||||
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
|
|
|
|||
Loading…
Reference in New Issue