Compare commits

..

30 Commits
2.8.1 ... main

Author SHA1 Message Date
cloudwithax 9bffdebe25 2.10.0 2025-10-04 00:00:57 -04:00
cloudwithax 720ba187ab 2.10.0 2025-10-04 00:00:01 -04:00
cloudwithax 855bf4e0d7 2.9.2 2024-11-21 21:11:24 -05:00
cloudwithax cd579becad fixed file playing and recursion issue in queue looping 2024-11-21 21:06:32 -05:00
cloudwithax 3a1ecf9eec 2.9.1 2024-08-23 21:18:25 -04:00
clxud 5227962228
Merge pull request #69 from ZandercraftGames/fix/other-sources
Fix Support for Other Source and Playlist Types
2024-08-23 21:04:04 -04:00
Zander be7106616b
Typing fix from NiceAesth
Co-authored-by: Andrei Baciu <8437201+NiceAesth@users.noreply.github.com>
2024-08-18 00:34:43 -04:00
Zander ba9534bc27
Typing fix from NiceAesth
Co-authored-by: Andrei Baciu <8437201+NiceAesth@users.noreply.github.com>
2024-08-18 00:34:32 -04:00
pre-commit-ci[bot] 2e0f5b365a [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-08-17 05:44:34 +00:00
Zander M. 851f00aa97
Add support for other unsupported playlist types 2024-08-17 01:23:59 -04:00
Zander M. 817295d321
Add support for other unsupported source types 2024-08-17 00:46:32 -04:00
Zander M. 8ab3ae9ccd
Add websockets dependency to Pipenv 2024-08-17 00:44:51 -04:00
cloudwithax 094f2be181 refactor: set original track if search type is not defined in Player's play_track method 2024-06-10 21:54:21 -04:00
cloudwithax b60a6aec18 refactor: guard check for search type to prevent nulled search types getting searched 2024-06-10 21:30:57 -04:00
cloudwithax 80f7b77cd3 refactor: update search_type handling in Player and Node classes to be nullish to support lavasrc 2024-06-10 21:20:59 -04:00
cloudwithax 8679d6d125 Merge branch 'main' of https://github.com/cloudwithax/pomice 2024-06-10 21:17:57 -04:00
cloudwithax ad01407fff refactor: update query handling in Node class 2024-06-10 21:17:53 -04:00
Clxud f1609f7049
Merge pull request #67 from ZandercraftGames/fix/load-exceptions
Fix KeyError in exception handling on error when loading a track.
2024-03-27 22:10:37 -04:00
pre-commit-ci[bot] 5fcfc73901 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-03-28 02:08:22 +00:00
Zander M. 86b35106b2
Fix KeyError in exception handling on error when loading a track. 2024-03-27 22:04:39 -04:00
Clxud 519a14fbde
Merge pull request #66 from ZandercraftGames/main
Fix build_track failure with Lavalink v4 decodetrack format
2024-03-13 10:16:31 -04:00
Zander M. 9a42093f64
Merge remote-tracking branch 'origin/main' 2024-03-11 13:50:22 -04:00
Zander M. 347a6e0b96
Refactor sourceName to use track_info object. 2024-03-11 13:50:07 -04:00
pre-commit-ci[bot] 83d5add134 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-03-11 17:29:47 +00:00
Zander M. bb12e33584
Python Black Format 2024-03-11 13:24:54 -04:00
Zander M. 6817cd8e07
Fix build_track failure with Lavalink v4 decodetrack format. 2024-03-11 13:23:42 -04:00
Clxud 179472bd6e
Merge pull request #63 from NiceAesth/fix-assert
fix: remove unnecessary assert
2024-02-22 11:00:05 -05:00
NiceAesth ba761743b9 fix: remove unnecessary assert
Hit this in production (https://sunnycord.sentry.io/share/issue/e39efaaa16d64b4fbf4e3ec409406971/)
The assert is unnecessary since track is typed as optional either way.
2024-02-19 20:43:58 +02:00
cloudwithax b3795102b8 2.9.0 2024-02-06 17:32:17 -05:00
cloudwithax 2a492c793f fix issues related to loading files/http links, added spotify recommendation querying, changed loglevel enum behavior 2024-02-06 17:31:51 -05:00
11 changed files with 519 additions and 132 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ build/
Pipfile.lock
.mypy_cache/
.vscode/
.idea/
.venv/
*.code-workspace
*.ini

View File

@ -14,7 +14,7 @@ repos:
rev: 23.10.1
hooks:
- id: black
language_version: python3.11
language_version: python3.13
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
@ -28,10 +28,6 @@ repos:
rev: v3.1.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/hadialqattan/pycln
rev: v2.3.0
hooks:
- id: pycln
default_language_version:
python: python3.11
python: python3.13

View File

@ -6,6 +6,7 @@ name = "pypi"
[packages]
orjson = "*"
"discord.py" = {extras = ["voice"], version = "*"}
websockets = "*"
[dev-packages]
mypy = "*"

View File

@ -3,7 +3,7 @@ Pomice
~~~~~~
The modern Lavalink wrapper designed for discord.py.
Copyright (c) 2023, cloudwithax
Copyright (c) 2024, cloudwithax
Licensed under GPL-3.0
"""
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'",
)
__version__ = "2.8.1"
__version__ = "2.10.0"
__title__ = "pomice"
__author__ = "cloudwithax"
__license__ = "GPL-3.0"

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
import base64
import logging
import re
from datetime import datetime
from typing import AsyncGenerator
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import aiohttp
@ -17,10 +20,10 @@ from .objects import *
__all__ = ("Client",)
AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
)
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
@ -35,12 +38,14 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here.
"""
def __init__(self) -> None:
def __init__(self, *, playlist_concurrency: int = 6) -> None:
self.expiry: datetime = datetime(1970, 1, 1)
self.token: str = ""
self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__)
# Concurrency knob for parallel playlist page retrieval
self._playlist_concurrency = max(1, playlist_concurrency)
async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session
@ -167,25 +172,127 @@ class Client:
"This playlist is empty and therefore cannot be queued.",
)
_next = track_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
# concurrently by first collecting cursors in rolling waves.
next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
while next_page_url is not None:
resp = await self.session.get(next_page_url, headers=self.headers)
async def fetch_page(url: str) -> List[Song]:
async with semaphore:
resp = await self.session.get(url, headers=self.headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
)
return []
pj: dict = await resp.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
# Return songs; we will look for pj.get('next') in streaming iterator variant
return songs, pj.get("next") # type: ignore
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
waves: List[List[Song]] = []
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
# Limit total waves to avoid infinite loops in malformed responses
max_waves = 50
wave_size = self._playlist_concurrency * 2
wave_counter = 0
while cursors and wave_counter < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
tasks = [
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for res in results:
if isinstance(res, tuple): # (songs, next)
songs, nxt = res
if songs:
waves.append(songs)
if nxt:
cursors.append(nxt)
wave_counter += 1
for w in waves:
album_tracks.extend(w)
return Playlist(data, album_tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Song], None]:
"""Stream Apple Music playlist tracks in batches.
Parameters
----------
query: str
Apple Music playlist URL.
batch_size: int
Logical grouping size for yielded batches.
"""
if not self.token or datetime.utcnow() > self.expiry:
await self.request_token()
result = AM_URL_REGEX.match(query)
if not result or result.group("type") != "playlist":
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
country = result.group("country")
playlist_id = result.group("id")
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
resp = await self.session.get(request_url, headers=self.headers)
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
playlist_data = data["data"][0]
track_data: dict = playlist_data["relationships"]["tracks"]
next_data: dict = await resp.json(loads=json.loads)
album_tracks.extend(Song(track) for track in next_data["data"])
first_page_tracks = [Song(track) for track in track_data["data"]]
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
_next = next_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
else:
next_page_url = None
next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
return Playlist(data, album_tracks)
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
url = AM_BASE_URL + cursor
async with semaphore:
r = await self.session.get(url, headers=self.headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping Apple Music page due to {r.status} {r.reason}",
)
return [], None
pj: dict = await r.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
return songs, pj.get("next")
# Rolling waves of fetches following cursor chain
max_waves = 50
wave_size = self._playlist_concurrency * 2
waves = 0
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
while cursors and waves < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
results = await asyncio.gather(*[fetch(c) for c in current])
for songs, nxt in results:
if songs:
for j in range(0, len(songs), batch_size):
yield songs[j : j + batch_size]
if nxt:
cursors.append(nxt)
waves += 1

View File

@ -34,6 +34,11 @@ class SearchType(Enum):
ytsearch = "ytsearch"
ytmsearch = "ytmsearch"
scsearch = "scsearch"
other = "other"
@classmethod
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
return cls.other
def __str__(self) -> str:
return self.value
@ -54,6 +59,8 @@ class TrackType(Enum):
TrackType.HTTP defines that the track is from an HTTP source.
TrackType.LOCAL defines that the track is from a local source.
TrackType.OTHER defines that the track is from an unknown source (possible from 3rd-party plugins).
"""
# We don't have to define anything special for these, since these just serve as flags
@ -63,6 +70,11 @@ class TrackType(Enum):
APPLE_MUSIC = "apple_music"
HTTP = "http"
LOCAL = "local"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str:
return self.value
@ -79,6 +91,8 @@ class PlaylistType(Enum):
PlaylistType.SPOTIFY defines that the playlist is from Spotify
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
PlaylistType.OTHER defines that the playlist is from an unknown source (possible from 3rd-party plugins).
"""
# We don't have to define anything special for these, since these just serve as flags
@ -86,6 +100,11 @@ class PlaylistType(Enum):
SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str:
return self.value
@ -199,8 +218,12 @@ class URLRegex:
"""
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
# the type and id while ignoring extra parameters. This prevents the URL from being
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
SPOTIFY_URL = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
)
DISCORD_MP3_URL = re.compile(
@ -221,14 +244,17 @@ class URLRegex:
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
)
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
# Allow arbitrary query parameters so valid links are captured and parsed.
AM_URL = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
)
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
)
SOUNDCLOUD_URL = re.compile(
@ -273,3 +299,10 @@ class LogLevel(IntEnum):
WARN = 30
ERROR = 40
CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")

View File

@ -128,7 +128,6 @@ class TrackExceptionEvent(PomiceEvent):
def __init__(self, data: dict, player: Player):
self.player: Player = player
assert self.player._ending_track is not None
self.track: Optional[Track] = self.player._ending_track
# Error is for Lavalink <= 3.3
self.exception: str = data.get(

View File

@ -396,7 +396,7 @@ class Player(VoiceProtocol):
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
search_type: SearchType | None = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
@ -504,8 +504,11 @@ class Player(VoiceProtocol):
) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
if not track._search_type:
track.original = track
# Make sure we've never searched the track before
if track.original is None:
if track._search_type and track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully)
try:
if not track.isrc:

View File

@ -21,7 +21,12 @@ import orjson as json
from discord import Client
from discord.ext import commands
from discord.utils import MISSING
from websockets import client
try:
from websockets.legacy import client # websockets >= 10.0
except ImportError:
import websockets.client as client # websockets < 10.0 # type: ignore
from websockets import exceptions
from websockets import typing as wstype
@ -30,6 +35,7 @@ from . import applemusic
from . import spotify
from .enums import *
from .enums import LogLevel
from .exceptions import InvalidSpotifyClientAuthorization
from .exceptions import LavalinkVersionIncompatible
from .exceptions import NodeConnectionFailure
from .exceptions import NodeCreationError
@ -302,7 +308,7 @@ class Node:
if not self._resume_key:
return
data = {"timeout": self._resume_timeout}
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
if self._version.major == 3:
data["resumingKey"] = self._resume_key
@ -443,7 +449,17 @@ class Node:
start = time.perf_counter()
if not self._session:
self._session = aiohttp.ClientSession()
# Configure connection pooling for optimal concurrent request performance
connector = aiohttp.TCPConnector(
limit=100, # Total connection limit
limit_per_host=30, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL in seconds
)
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
try:
if not reconnect:
@ -462,7 +478,7 @@ class Node:
f"Version check from Node {self._identifier} successful. Returned version {version}",
)
self._websocket = await client.connect(
self._websocket = await client.connect( # type: ignore
f"{self._websocket_uri}/v{self._version.major}/websocket",
extra_headers=self._headers,
ping_interval=self._heartbeat,
@ -544,11 +560,14 @@ class Node:
path="decodetrack",
query=f"encodedTrack={quote(identifier)}",
)
track_info = data["info"] if self._version.major >= 4 else data
return Track(
track_id=identifier,
ctx=ctx,
info=data,
track_type=TrackType(data["sourceName"]),
info=track_info,
track_type=TrackType(track_info["sourceName"]),
)
async def get_tracks(
@ -556,7 +575,7 @@ class Node:
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
search_type: Optional[SearchType] = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
@ -591,7 +610,7 @@ class Node:
track_id=apple_music_results.id,
ctx=ctx,
track_type=TrackType.APPLE_MUSIC,
search_type=search_type,
search_type=search_type or SearchType.ytsearch,
filters=filters,
info={
"title": apple_music_results.name,
@ -613,7 +632,7 @@ class Node:
track_id=track.id,
ctx=ctx,
track_type=TrackType.APPLE_MUSIC,
search_type=search_type,
search_type=search_type or SearchType.ytsearch,
filters=filters,
info={
"title": track.name,
@ -651,7 +670,7 @@ class Node:
track_id=spotify_results.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
search_type=search_type,
search_type=search_type or SearchType.ytsearch,
filters=filters,
info={
"title": spotify_results.name,
@ -673,7 +692,7 @@ class Node:
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
search_type=search_type,
search_type=search_type or SearchType.ytsearch,
filters=filters,
info={
"title": track.name,
@ -702,63 +721,14 @@ class Node:
uri=spotify_results.uri,
)
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0]
info: dict = track["info"]
return [
Track(
track_id=track["track"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": info["length"],
"uri": info["uri"],
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
),
]
elif path.exists(path.dirname(query)):
local_file = Path(query)
data: dict = await self.send( # type: ignore
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0] # type: ignore
info: dict = track["info"] # type: ignore
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": info["length"],
"uri": quote(local_file.as_uri()),
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
),
]
else:
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
if (
search_type
and not URLRegex.BASE_URL.match(query)
and not re.match(r"(?:[a-z]+?)search:.", query)
and not URLRegex.DISCORD_MP3_URL.match(query)
and not path.exists(path.dirname(query))
):
query = f"{search_type}:{query}"
# If YouTube url contains a timestamp, capture it for use later.
@ -784,7 +754,7 @@ class Node:
)
elif load_type in ("LOAD_FAILED", "error"):
exception = data["exception"]
exception = data["data"] if self._version.major >= 4 else data["exception"]
raise TrackLoadError(
f"{exception['message']} [{exception['severity']}]",
)
@ -819,6 +789,47 @@ class Node:
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
if self._version.major >= 4 and isinstance(data[data_type], dict):
data[data_type] = [data[data_type]]
if path.exists(path.dirname(query)):
local_file = Path(query)
return [
Track(
track_id=track["encoded"],
info={
"title": local_file.name,
"author": "Unknown",
"length": track["info"]["length"],
"uri": quote(local_file.as_uri()),
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
)
for track in data[data_type]
]
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
return [
Track(
track_id=track["encoded"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": track["info"]["length"],
"uri": track["info"]["uri"],
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
)
for track in data[data_type]
]
return [
Track(
track_id=track["encoded"],
@ -885,6 +896,57 @@ class Node:
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
)
async def search_spotify_recommendations(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""
Searches for recommendations on Spotify and returns a list of tracks based on the query.
You must have Spotify enabled for this to work.
You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended.
"""
if not self._spotify_client:
raise InvalidSpotifyClientAuthorization(
"You must have Spotify enabled to use this feature.",
)
results = await self._spotify_client.track_search(query=query) # type: ignore
if not results:
raise TrackLoadError(
"Unable to find any tracks based on the query.",
)
tracks = [
Track(
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
info={
"title": track.name,
"author": track.artists,
"length": track.length,
"identifier": track.id,
"uri": track.uri,
"isStream": False,
"isSeekable": True,
"position": 0,
"thumbnail": track.image,
"isrc": track.isrc,
},
requester=self.bot.user,
)
for track in results
]
track = tracks[0]
return await self.get_recommendations(track=track, ctx=ctx)
class NodePool:
"""The base class for the node pool.

View File

@ -203,9 +203,13 @@ class Queue(Iterable[Track]):
raise QueueEmpty("No items in the queue.")
if self._loop_mode == LoopMode.QUEUE:
# recurse if the item isnt in the queue
if self._current_item not in self._queue:
self.get()
# set current item to first track in queue if not set already
# otherwise exception will be raised
if not self._current_item or self._current_item not in self._queue:
if self._queue:
item = self._queue[0]
else:
raise QueueEmpty("No items in the queue.")
# set current item to first track in queue if not set already
if not self._current_item:

View File

@ -1,13 +1,16 @@
from __future__ import annotations
import asyncio
import logging
import re
import time
from base64 import b64encode
from typing import AsyncGenerator
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from urllib.parse import quote
import aiohttp
import orjson as json
@ -21,8 +24,10 @@ __all__ = ("Client",)
GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
# optional trailing slash, and query parameters.
SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
)
@ -32,29 +37,39 @@ class Client:
for any Spotify URL you throw at it.
"""
def __init__(self, client_id: str, client_secret: str) -> None:
self._client_id: str = client_id
self._client_secret: str = client_secret
def __init__(
self,
client_id: str,
client_secret: str,
*,
playlist_concurrency: int = 10,
playlist_page_limit: Optional[int] = None,
) -> None:
self._client_id = client_id
self._client_secret = client_secret
self.session: aiohttp.ClientSession = None # type: ignore
# HTTP session will be injected by Node
self.session: Optional[aiohttp.ClientSession] = None
self._bearer_token: Optional[str] = None
self._expiry: float = 0.0
self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode(),
)
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__)
# Performance tuning knobs
self._playlist_concurrency = max(1, playlist_concurrency)
self._playlist_page_limit = playlist_page_limit
async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session
async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"}
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
if resp.status != 200:
@ -85,6 +100,8 @@ class Client:
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
@ -102,6 +119,8 @@ class Client:
elif spotify_type == "album":
return Album(data)
elif spotify_type == "artist":
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(
f"{request_url}/top-tracks?market=US",
headers=self._bearer_headers,
@ -115,36 +134,177 @@ class Client:
tracks = track_data["tracks"]
return Artist(data, tracks)
else:
# For playlists we optionally use a reduced fields payload to shrink response sizes.
# NB: We cannot apply fields filter to initial request because original metadata is needed.
tracks = [
Track(track["track"])
for track in data["tracks"]["items"]
if track["track"] is not None
]
if not len(tracks):
if not tracks:
raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued.",
)
next_page_url = data["tracks"]["next"]
total_tracks = data["tracks"]["total"]
limit = data["tracks"]["limit"]
while next_page_url is not None:
resp = await self.session.get(next_page_url, headers=self._bearer_headers)
# Shortcircuit small playlists (single page)
if total_tracks <= limit:
return Playlist(data, tracks)
# Build remaining page URLs; Spotify supports offset-based pagination.
remaining_offsets = range(limit, total_tracks, limit)
page_urls: List[str] = []
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
for idx, offset in enumerate(remaining_offsets):
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
break
page_urls.append(
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
)
if page_urls:
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch_page(url: str) -> Optional[List[Track]]:
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
resp = await self.session.get(url, headers=self._bearer_headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Page fetch failed {resp.status} {resp.reason} for {url}",
)
return None
page_json: dict = await resp.json(loads=json.loads)
return [
Track(item["track"])
for item in page_json.get("items", [])
if item.get("track") is not None
]
# Chunk gather in waves to avoid creating thousands of tasks at once
aggregated: List[Track] = []
wave_size = self._playlist_concurrency * 2
for i in range(0, len(page_urls), wave_size):
wave = page_urls[i : i + wave_size]
results = await asyncio.gather(
*[fetch_page(url) for url in wave],
return_exceptions=False,
)
for result in results:
if result:
aggregated.extend(result)
tracks.extend(aggregated)
return Playlist(data, tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Track], None]:
"""Stream playlist tracks in batches without waiting for full materialization.
Parameters
----------
query: str
Spotify playlist URL.
batch_size: int
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
"""
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
match = SPOTIFY_URL_REGEX.match(query)
if not match or match.group("type") != "playlist":
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
playlist_id = match.group("id")
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
next_data: dict = await resp.json(loads=json.loads)
tracks += [
Track(track["track"])
for track in next_data["items"]
if track["track"] is not None
# Yield first page immediately
first_page_tracks = [
Track(item["track"])
for item in data["tracks"]["items"]
if item.get("track") is not None
]
next_page_url = next_data["next"]
# Batch yield
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
return Playlist(data, tracks)
total = data["tracks"]["total"]
limit = data["tracks"]["limit"]
remaining_offsets = range(limit, total, limit)
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch(offset: int) -> List[Track]:
url = (
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
)
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
r = await self.session.get(url, headers=self._bearer_headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping page offset={offset} due to {r.status} {r.reason}",
)
return []
pj: dict = await r.json(loads=json.loads)
return [
Track(item["track"])
for item in pj.get("items", [])
if item.get("track") is not None
]
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
wave_size = self._playlist_concurrency * 2
for i, offset in enumerate(remaining_offsets):
# Build wave
if i % wave_size == 0:
wave_offsets = list(
o for o in remaining_offsets if o >= offset and o < offset + wave_size
)
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
for page_tracks in results:
if not page_tracks:
continue
for j in range(0, len(page_tracks), batch_size):
yield page_tracks[j : j + batch_size]
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
# Fast-forward the generator manually
for _ in range(len(wave_offsets) - 1):
try:
next(remaining_offsets) # type: ignore
except StopIteration:
break
async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
@ -167,6 +327,8 @@ class Client:
id=f"?seed_tracks={spotify_id}",
)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
@ -177,3 +339,22 @@ class Client:
tracks = [Track(track) for track in data["tracks"]]
return tracks
async def track_search(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
tracks = [Track(track) for track in data["tracks"]["items"]]
return tracks