Compare commits

..

No commits in common. "main" and "2.3" have entirely different histories.
main ... 2.3

27 changed files with 498 additions and 1224 deletions

39
.github/workflows/python-publish.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

4
.gitignore vendored
View File

@ -10,8 +10,4 @@ build/
Pipfile.lock
.mypy_cache/
.vscode/
.idea/
.venv/
*.code-workspace
*.ini
.pypirc

View File

@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.4.0
hooks:
- id: check-ast
- id: check-builtin-literals
@ -11,23 +11,31 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.10.1
rev: 23.1.0
hooks:
- id: black
language_version: python3.13
language_version: python3.8
- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
hooks:
- id: blacken-docs
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py37-plus, --keep-runtime-typing]
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.12.0
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.9.0
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/add-trailing-comma
rev: v3.1.0
rev: v2.4.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
hooks:
- id: pycln
default_language_version:
python: python3.13
python: python3.8

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.10.9

View File

@ -6,7 +6,6 @@ name = "pypi"
[packages]
orjson = "*"
"discord.py" = {extras = ["voice"], version = "*"}
websockets = "*"
[dev-packages]
mypy = "*"
@ -14,8 +13,6 @@ pre-commit = "*"
furo = "*"
sphinx = "*"
myst-parser = "*"
black = "*"
typing-extensions = "*"
[requires]
python_version = "3.8"

View File

@ -3,16 +3,11 @@
![](https://raw.githubusercontent.com/cloudwithax/pomice/main/banner.jpg)
[![GPL](https://img.shields.io/github/license/cloudwithax/pomice?color=2f2f2f)](https://github.com/cloudwithax/pomice/blob/main/LICENSE) ![](https://img.shields.io/pypi/pyversions/pomice?color=2f2f2f) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Discord](https://img.shields.io/discord/899324069235810315?color=%237289DA&label=Pomice%20Support&logo=discord&logoColor=white)](https://discord.gg/r64qjTSHG8) [![Read the Docs](https://readthedocs.org/projects/pomice/badge/?version=latest)](https://pomice.readthedocs.io/en/latest/)
[![GPL](https://img.shields.io/badge/license-GPL-2f2f2f)](https://github.com/cloudwithax/pomice/blob/main/LICENSE) ![](https://img.shields.io/badge/python-3.8-2f2f2f)
[![Discord](https://img.shields.io/discord/899324069235810315)](https://discord.gg/r64qjTSHG8) [![Read the Docs](https://readthedocs.org/projects/pomice/badge/?version=latest)](https://pomice.readthedocs.io/en/latest/)
Pomice is a fully asynchronous Python library designed for communicating with [Lavalink](https://github.com/freyacodes/Lavalink) seamlessly within the [discord.py](https://github.com/Rapptz/discord.py) library. It features 100% coverage of the [Lavalink](https://github.com/freyacodes/Lavalink) spec that can be accessed with easy-to-understand functions along with Spotify and Apple Music querying capabilities using built-in custom clients, making it easier to develop your next big music bot.
## Quick Links
- [Discord Server](https://discord.gg/r64qjTSHG8)
- [Read the Docs](https://pomice.readthedocs.io/en/latest/)
- [PyPI Homepage](https://pypi.org/project/pomice/)
Pomice is a fully asynchronous Python library designed for communicating with [Lavalink](https://github.com/freyacodes/Lavalink) seamlessly within the [discord.py](https://github.com/Rapptz/discord.py) library. It features 100% API coverage of the entire [Lavalink](https://github.com/freyacodes/Lavalink) spec that can be accessed with easy-to-understand functions. We also include Spotify and Apple Music querying capabilites using built-in custom clients, making it easier to develop your next big music bot.
# Install
@ -28,7 +23,7 @@ pip install pomice
pip install git+https://github.com/cloudwithax/pomice
```
# Support And Documentation
# Support
The official documentation is [here](https://pomice.readthedocs.io/en/latest/)
@ -36,7 +31,7 @@ You can join our support server [here](https://discord.gg/r64qjTSHG8)
# Examples
In-depth examples are located in the [examples folder](https://github.com/cloudwithax/pomice/tree/main/examples)
In-depth examples are located in the examples folder
Here's a quick example:

View File

@ -1,4 +1,3 @@
# type: ignore
import importlib
import inspect
import os

View File

@ -13,9 +13,11 @@ The classes listed here are as they appear in Pomice. When you use them within y
the way you use them will be different. Here's an example on how you would use the `TrackStartEvent` within an event listener in a cog:
```py
@commands.Cog.listener
async def on_pomice_track_start(self, player: Player, track: Track):
...
```
## Event definitions
@ -32,7 +34,7 @@ your application. Here are all the definitions:
All events related to tracks carry a `Player` object so you can access player-specific functions
and properties for further evaluation. They also carry a `Track` object so you can access track-specific functions and properties for further evaluation as well.
and properties for further evaluation. They also carry a `Track` object so you can access track-specific functions and properites for further evaluation as well.
`Event.TrackEndEvent()` carries the reason for the track ending. If the track ends suddenly, you can use the reason provided to determine a solution.

View File

@ -361,27 +361,6 @@ await Player.stop()
```
### Moving the player to another channel
To move the player to another channel, we need to use `Player.move_to()`
```py
await Player.move_to(...)
```
After you have initialized your function, we need to include the `channel` parameter, which is a `VoiceChannel`:
```py
await Player.move_to(channel)
```
After running this function, your player should be in the new voice channel. All voice state updates should also be handled.
## Controlling filters
Pomice has an extensive suite of filter management tools to help you make the most of Lavalink and it's filters.

View File

@ -66,10 +66,9 @@ After you have initialized your function, we need to fill in the proper paramete
- Set this value to `True` if you want Pomice to automatically switch all players to another available node if one disconnects.
You must have two or more nodes to be able to do this.
* - `logger`
- `Optional[logging.Logger]`
- If you would like to receive logging information from Pomice, set this to your logger class
* - `log_level`
- `LogLevel`
- The logging level for the node. The default logging level is `LogLevel.INFO`.
:::
@ -90,13 +89,13 @@ await NodePool.create_node(
spotify_client_secret="<your spotify client secret here>"
apple_music=<True/False>,
fallback=<True/False>,
logger=<your logger here>
log_level=<optiona LogLevel here>
)
```
:::{important}
For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track, it is **up to you** on how you decide to handle it, whether it be through your own methods or a Lavalink plugin.
For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track anyway, they will **not work** because these options are not enabled.
:::

View File

@ -10,17 +10,12 @@ import re
from discord.ext import commands
URL_REG = re.compile(r"https?://(?:www\.)?.+")
URL_REG = re.compile(r'https?://(?:www\.)?.+')
class MyBot(commands.Bot):
def __init__(self) -> None:
super().__init__(
command_prefix="!",
activity=discord.Activity(
type=discord.ActivityType.listening, name="to music!"
),
)
super().__init__(command_prefix='!', activity=discord.Activity(type=discord.ActivityType.listening, name='to music!'))
self.add_cog(Music(self))
@ -30,47 +25,44 @@ class MyBot(commands.Bot):
class Music(commands.Cog):
def __init__(self, bot) -> None:
self.bot = bot
self.pomice = pomice.NodePool()
async def start_nodes(self):
await self.pomice.create_node(
bot=self.bot,
host="127.0.0.1",
port="3030",
password="youshallnotpass",
identifier="MAIN",
)
await self.pomice.create_node(bot=self.bot, host='127.0.0.1', port='3030',
password='youshallnotpass', identifier='MAIN')
print(f"Node is ready!")
@commands.command(name="join", aliases=["connect"])
async def join(
self, ctx: commands.Context, *, channel: discord.TextChannel = None
) -> None:
@commands.command(name='join', aliases=['connect'])
async def join(self, ctx: commands.Context, *, channel: discord.TextChannel = None) -> None:
if not channel:
channel = getattr(ctx.author.voice, "channel", None)
channel = getattr(ctx.author.voice, 'channel', None)
if not channel:
raise commands.CheckFailure(
"You must be in a voice channel to use this command"
"without specifying the channel argument."
)
raise commands.CheckFailure('You must be in a voice channel to use this command'
'without specifying the channel argument.')
await ctx.author.voice.channel.connect(cls=pomice.Player)
await ctx.send(f"Joined the voice channel `{channel}`")
await ctx.send(f'Joined the voice channel `{channel}`')
@commands.command(name="play")
@commands.command(name='play')
async def play(self, ctx, *, search: str) -> None:
if not ctx.voice_client:
await ctx.invoke(self.join)
player = ctx.voice_client
results = await player.get_tracks(query=f"{search}")
results = await player.get_tracks(query=f'{search}')
if not results:
raise commands.CommandError("No results were found for that search term.")
raise commands.CommandError('No results were found for that search term.')
if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0])

View File

@ -3,4 +3,3 @@ discord.py[voice]
furo
myst_parser
orjson
websockets

View File

@ -1,4 +1,3 @@
# type: ignore
"""
This example aims to show the full capabilities of the library.
This is in the form of a drop-in cog you can use and modify to your liking.
@ -101,13 +100,13 @@ class Music(commands.Cog):
await self.pomice.create_node(
bot=self.bot,
host="127.0.0.1",
port=3030,
port="3030",
password="youshallnotpass",
identifier="MAIN",
)
print(f"Node is ready!")
def required(self, ctx: commands.Context):
async def required(self, ctx: commands.Context):
"""Method which returns required votes based on amount of members in a channel."""
player: Player = ctx.voice_client
channel = self.bot.get_channel(int(player.channel.id))
@ -119,7 +118,7 @@ class Music(commands.Cog):
return required
def is_privileged(self, ctx: commands.Context):
async def is_privileged(self, ctx: commands.Context):
"""Check whether the user is an Admin or DJ."""
player: Player = ctx.voice_client

View File

@ -1,4 +1,3 @@
# type: ignore
import discord
from discord.ext import commands
@ -36,7 +35,7 @@ class Music(commands.Cog):
await self.pomice.create_node(
bot=self.bot,
host="127.0.0.1",
port=3030,
port="3030",
password="youshallnotpass",
identifier="MAIN",
)

View File

@ -3,7 +3,7 @@ Pomice
~~~~~~
The modern Lavalink wrapper designed for discord.py.
Copyright (c) 2024, cloudwithax
Copyright (c) 2023, cloudwithax
Licensed under GPL-3.0
"""
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'",
)
__version__ = "2.10.0"
__version__ = "2.3"
__title__ = "pomice"
__author__ = "cloudwithax"
__license__ = "GPL-3.0"

View File

@ -1,4 +1,4 @@
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
from .client import *
from .client import Client
from .exceptions import *
from .objects import *

View File

@ -1,14 +1,10 @@
from __future__ import annotations
import asyncio
import base64
import logging
import re
from datetime import datetime
from typing import AsyncGenerator
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import aiohttp
@ -20,14 +16,11 @@ from .objects import *
__all__ = ("Client",)
AM_URL_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
)
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
AM_BASE_URL = "https://api.music.apple.com"
@ -38,50 +31,21 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here.
"""
def __init__(self, *, playlist_concurrency: int = 6) -> None:
def __init__(self) -> None:
self.expiry: datetime = datetime(1970, 1, 1)
self.token: str = ""
self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__)
# Concurrency knob for parallel playlist page retrieval
self._playlist_concurrency = max(1, playlist_concurrency)
async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session
async def request_token(self) -> None:
# First lets get the raw response from the main page
resp = await self.session.get("https://music.apple.com")
if not self.session:
self.session = aiohttp.ClientSession()
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
# Looking for script tag that fits criteria
text = await resp.text()
match = re.search(AM_SCRIPT_REGEX, text)
if not match:
raise AppleMusicRequestException(
"Could not find valid script URL in response.",
)
# Found the script file, lets grab our token
result = match.group(1)
asset_url = result
resp = await self.session.get("https://music.apple.com" + asset_url)
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
text = await resp.text()
match = re.search('"(eyJ.+?)"', text)
if not match:
@ -101,8 +65,6 @@ class Client:
).decode()
token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"])
if self._log:
self._log.debug(f"Fetched Apple Music bearer token successfully")
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry:
@ -128,42 +90,35 @@ class Client:
else:
request_url = AM_REQ_URL.format(country=country, type=type, id=id)
resp = await self.session.get(request_url, headers=self.headers)
async with self.session.get(request_url, headers=self.headers) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
if self._log:
self._log.debug(
f"Made request to Apple Music API with status {resp.status} and response {data}",
)
data = data["data"][0]
if type == "song":
return Song(data)
elif type == "album":
if type == "album":
return Album(data)
elif type == "artist":
resp = await self.session.get(
if type == "artist":
async with self.session.get(
f"{request_url}/view/top-songs",
headers=self.headers,
)
) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
top_tracks: dict = await resp.json(loads=json.loads)
artist_tracks: dict = top_tracks["data"]
return Artist(data, tracks=artist_tracks)
else:
track_data: dict = data["relationships"]["tracks"]
album_tracks: List[Song] = [Song(track) for track in track_data["data"]]
@ -172,127 +127,30 @@ class Client:
"This playlist is empty and therefore cannot be queued.",
)
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
# concurrently by first collecting cursors in rolling waves.
next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
_next = track_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
async def fetch_page(url: str) -> List[Song]:
async with semaphore:
resp = await self.session.get(url, headers=self.headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
)
return []
pj: dict = await resp.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
# Return songs; we will look for pj.get('next') in streaming iterator variant
return songs, pj.get("next") # type: ignore
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
waves: List[List[Song]] = []
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
# Limit total waves to avoid infinite loops in malformed responses
max_waves = 50
wave_size = self._playlist_concurrency * 2
wave_counter = 0
while cursors and wave_counter < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
tasks = [
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for res in results:
if isinstance(res, tuple): # (songs, next)
songs, nxt = res
if songs:
waves.append(songs)
if nxt:
cursors.append(nxt)
wave_counter += 1
for w in waves:
album_tracks.extend(w)
return Playlist(data, album_tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Song], None]:
"""Stream Apple Music playlist tracks in batches.
Parameters
----------
query: str
Apple Music playlist URL.
batch_size: int
Logical grouping size for yielded batches.
"""
if not self.token or datetime.utcnow() > self.expiry:
await self.request_token()
result = AM_URL_REGEX.match(query)
if not result or result.group("type") != "playlist":
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
country = result.group("country")
playlist_id = result.group("id")
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
resp = await self.session.get(request_url, headers=self.headers)
while next_page_url is not None:
async with self.session.get(next_page_url, headers=self.headers) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
playlist_data = data["data"][0]
track_data: dict = playlist_data["relationships"]["tracks"]
first_page_tracks = [Song(track) for track in track_data["data"]]
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
next_data: dict = await resp.json(loads=json.loads)
next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
album_tracks.extend(Song(track) for track in next_data["data"])
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
url = AM_BASE_URL + cursor
async with semaphore:
r = await self.session.get(url, headers=self.headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping Apple Music page due to {r.status} {r.reason}",
)
return [], None
pj: dict = await r.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
return songs, pj.get("next")
_next = next_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
else:
next_page_url = None
# Rolling waves of fetches following cursor chain
max_waves = 50
wave_size = self._playlist_concurrency * 2
waves = 0
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
while cursors and waves < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
results = await asyncio.gather(*[fetch(c) for c in current])
for songs, nxt in results:
if songs:
for j in range(0, len(songs), batch_size):
yield songs[j : j + batch_size]
if nxt:
cursors.append(nxt)
waves += 1
return Playlist(data, album_tracks)
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None # type: ignore

View File

@ -34,11 +34,6 @@ class SearchType(Enum):
ytsearch = "ytsearch"
ytmsearch = "ytmsearch"
scsearch = "scsearch"
other = "other"
@classmethod
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
return cls.other
def __str__(self) -> str:
return self.value
@ -57,10 +52,6 @@ class TrackType(Enum):
TrackType.APPLE_MUSIC defines that the track is from Apple Music.
TrackType.HTTP defines that the track is from an HTTP source.
TrackType.LOCAL defines that the track is from a local source.
TrackType.OTHER defines that the track is from an unknown source (possible from 3rd-party plugins).
"""
# We don't have to define anything special for these, since these just serve as flags
@ -69,12 +60,6 @@ class TrackType(Enum):
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
HTTP = "http"
LOCAL = "local"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str:
return self.value
@ -91,8 +76,6 @@ class PlaylistType(Enum):
PlaylistType.SPOTIFY defines that the playlist is from Spotify
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
PlaylistType.OTHER defines that the playlist is from an unknown source (possible from 3rd-party plugins).
"""
# We don't have to define anything special for these, since these just serve as flags
@ -100,11 +83,6 @@ class PlaylistType(Enum):
SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str:
return self.value
@ -218,12 +196,8 @@ class URLRegex:
"""
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
# the type and id while ignoring extra parameters. This prevents the URL from being
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
SPOTIFY_URL = re.compile(
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
)
DISCORD_MP3_URL = re.compile(
@ -240,21 +214,22 @@ class URLRegex:
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
)
YOUTUBE_VID_IN_PLAYLIST = re.compile(
r"(?P<video>^.*?v.*?)(?P<list>&list.*)",
)
YOUTUBE_TIMESTAMP = re.compile(
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
)
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
# Allow arbitrary query parameters so valid links are captured and parsed.
AM_URL = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
)
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
)
SOUNDCLOUD_URL = re.compile(
@ -299,10 +274,3 @@ class LogLevel(IntEnum):
WARN = 30
ERROR = 40
CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")

View File

@ -59,13 +59,14 @@ class TrackStartEvent(PomiceEvent):
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._current
assert self.player._current is not None
self.track: Track = self.player._current
# on_pomice_track_start(player, track)
self.handler_args = self.player, self.track
def __repr__(self) -> str:
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
return f"<Pomice.TrackStartEvent player={self.player} track_id={self.track.track_id}>"
class TrackEndEvent(PomiceEvent):
@ -79,7 +80,8 @@ class TrackEndEvent(PomiceEvent):
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
self.reason: str = data["reason"]
# on_pomice_track_end(player, track, reason)
@ -87,8 +89,8 @@ class TrackEndEvent(PomiceEvent):
def __repr__(self) -> str:
return (
f"<Pomice.TrackEndEvent player={self.player!r} track_id={self.track!r} "
f"reason={self.reason!r}>"
f"<Pomice.TrackEndEvent player={self.player} track_id={self.track.track_id} "
f"reason={self.reason}>"
)
@ -104,7 +106,8 @@ class TrackStuckEvent(PomiceEvent):
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
# on_pomice_track_stuck(player, track, threshold)
@ -128,7 +131,8 @@ class TrackExceptionEvent(PomiceEvent):
def __init__(self, data: dict, player: Player):
self.player: Player = player
self.track: Optional[Track] = self.player._ending_track
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
# Error is for Lavalink <= 3.3
self.exception: str = data.get(
"error",

View File

@ -77,12 +77,6 @@ class Equalizer(Filter):
def __repr__(self) -> str:
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Equalizer):
return False
return self.raw == __value.raw
@classmethod
def flat(cls) -> "Equalizer":
"""Equalizer preset which represents a flat EQ board,
@ -237,16 +231,6 @@ class Timescale(Filter):
def __repr__(self) -> str:
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Timescale):
return False
return (
self.speed == __value.speed
and self.pitch == __value.pitch
and self.rate == __value.rate
)
class Karaoke(Filter):
"""Filter which filters the vocal track from any song and leaves the instrumental.
@ -286,17 +270,6 @@ class Karaoke(Filter):
f"filter_band={self.filter_band} filter_width={self.filter_width}>"
)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Karaoke):
return False
return (
self.level == __value.level
and self.mono_level == __value.mono_level
and self.filter_band == __value.filter_band
and self.filter_width == __value.filter_width
)
class Tremolo(Filter):
"""Filter which produces a wavering tone in the music,
@ -332,12 +305,6 @@ class Tremolo(Filter):
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Tremolo):
return False
return self.frequency == __value.frequency and self.depth == __value.depth
class Vibrato(Filter):
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter,
@ -373,12 +340,6 @@ class Vibrato(Filter):
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Vibrato):
return False
return self.frequency == __value.frequency and self.depth == __value.depth
class Rotation(Filter):
"""Filter which produces a stereo-like panning effect, which sounds like
@ -396,12 +357,6 @@ class Rotation(Filter):
def __repr__(self) -> str:
return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Rotation):
return False
return self.rotation_hertz == __value.rotation_hertz
class ChannelMix(Filter):
"""Filter which manually adjusts the panning of the audio, which can make
@ -463,17 +418,6 @@ class ChannelMix(Filter):
f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>"
)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ChannelMix):
return False
return (
self.left_to_left == __value.left_to_left
and self.left_to_right == __value.left_to_right
and self.right_to_left == __value.right_to_left
and self.right_to_right == __value.right_to_right
)
class Distortion(Filter):
"""Filter which generates a distortion effect. Useful for certain filter implementations where
@ -535,21 +479,6 @@ class Distortion(Filter):
f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}"
)
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Distortion):
return False
return (
self.sin_offset == __value.sin_offset
and self.sin_scale == __value.sin_scale
and self.cos_offset == __value.cos_offset
and self.cos_scale == __value.cos_scale
and self.tan_offset == __value.tan_offset
and self.tan_scale == __value.tan_scale
and self.offset == __value.offset
and self.scale == __value.scale
)
class LowPass(Filter):
"""Filter which supresses higher frequencies and allows lower frequencies to pass.
@ -566,9 +495,3 @@ class LowPass(Filter):
def __repr__(self) -> str:
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, LowPass):
return False
return self.smoothing == __value.smoothing

View File

@ -78,7 +78,7 @@ class Track:
self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier", "")
self.isrc: Optional[str] = info.get("isrc", None)
self.isrc: str = info.get("isrc", "")
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri and self.track_type is TrackType.YOUTUBE:
@ -98,6 +98,9 @@ class Track:
if not isinstance(other, Track):
return False
if self.ctx and other.ctx:
return other.track_id == self.track_id and other.ctx.message.id == self.ctx.message.id
return other.track_id == self.track_id
def __str__(self) -> str:

View File

@ -79,35 +79,10 @@ class Filters:
if filter.tag == filter_tag:
del self._filters[index]
def edit_filter(self, *, filter_tag: str, to_apply: Filter) -> None:
"""Edits a filter in the list of filters applied using its filter tag and replaces it with the new filter."""
if not any(f for f in self._filters if f.tag == filter_tag):
raise FilterTagInvalid("A filter with that tag was not found.")
for index, filter in enumerate(self._filters):
if filter.tag == filter_tag:
if not type(filter) == type(to_apply):
raise FilterInvalidArgument(
"Edited filter is not the same type as the current filter.",
)
if self._filters[index] == to_apply:
raise FilterInvalidArgument("Edited filter is the same as the current filter.")
if to_apply.tag != filter_tag:
raise FilterInvalidArgument(
"Edited filter tag is not the same as the current filter tag.",
)
self._filters[index] = to_apply
def has_filter(self, *, filter_tag: str) -> bool:
"""Checks if a filter exists in the list of filters using its filter tag"""
return any(f for f in self._filters if f.tag == filter_tag)
def has_filter_type(self, *, filter_type: Filter) -> bool:
"""Checks if any filters applied match the specified filter type."""
return any(f for f in self._filters if type(f) == type(filter_type))
def reset_filters(self) -> None:
"""Removes all filters from the list"""
self._filters = []
@ -156,10 +131,10 @@ class Player(VoiceProtocol):
"_player_endpoint_uri",
)
def __call__(self, client: Client, channel: VoiceChannel) -> Player:
self.client = client
self.channel = channel
self._guild = channel.guild
def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client
self.channel: VoiceChannel = channel
self._guild: Guild = channel.guild
return self
@ -172,9 +147,9 @@ class Player(VoiceProtocol):
) -> None:
self.client: Client = client
self.channel: VoiceChannel = channel
self._guild = channel.guild
self._bot: Client = client
self._guild: Guild = channel.guild
self._node: Node = node if node else NodePool.get_node()
self._current: Optional[Track] = None
self._filters: Filters = Filters()
@ -213,7 +188,7 @@ class Player(VoiceProtocol):
difference = (time.time() * 1000) - self._last_update
position = self._last_position + difference
return round(min(position, current.length))
return min(position, current.length)
@property
def rate(self) -> float:
@ -287,7 +262,7 @@ class Player(VoiceProtocol):
"""
return self.guild.id not in self._node._players
def _adjust_end_time(self) -> Optional[str]:
def _adjust_end_time(self):
if self._node._version >= LavalinkVersion(3, 7, 5):
return None
@ -298,7 +273,6 @@ class Player(VoiceProtocol):
self._last_update = int(state.get("time", 0))
self._is_connected = bool(state.get("connected"))
self._last_position = int(state.get("position", 0))
if self._log:
self._log.debug(f"Got player update state with data {state}")
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
@ -320,10 +294,7 @@ class Player(VoiceProtocol):
data={"voice": data},
)
if self._log:
self._log.debug(
f"Dispatched voice update to {state['event']['endpoint']} with data {data}",
)
self._log.debug(f"Dispatched voice update to {state['event']['endpoint']} with data {data}")
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data})
@ -339,10 +310,6 @@ class Player(VoiceProtocol):
return
channel = self.guild.get_channel(int(channel_id))
if self.channel != channel:
self.channel = channel
if not channel:
await self.disconnect()
self._voice_state.clear()
@ -357,7 +324,7 @@ class Player(VoiceProtocol):
event_type: str = data["type"]
event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"):
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
self._current = None
event.dispatch(self._bot)
@ -365,12 +332,8 @@ class Player(VoiceProtocol):
if isinstance(event, TrackStartEvent):
self._ending_track = self._current
if self._log:
self._log.debug(f"Dispatched event {data['type']} to player.")
async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None:
self._player_endpoint_uri = f"sessions/{session_id}/players"
async def _swap_node(self, *, new_node: Node) -> None:
if self.current:
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
@ -379,16 +342,16 @@ class Player(VoiceProtocol):
self._node = new_node
self._node._players[self._guild.id] = self
# reassign uri to update session id
await self._refresh_endpoint_uri(new_node._session_id)
self._player_endpoint_uri = f"sessions/{self._node._session_id}/players"
await self._dispatch_voice_update()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data=data or None,
data=data,
)
if self._log:
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
async def get_tracks(
@ -396,7 +359,7 @@ class Player(VoiceProtocol):
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType | None = SearchType.ytsearch,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
@ -413,21 +376,8 @@ class Player(VoiceProtocol):
"""
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
"""
Builds a track using a valid track identifier
You can also pass in a discord.py Context object to get a
Context object on the track it builds.
"""
return await self._node.build_track(identifier, ctx=ctx)
async def get_recommendations(
self,
*,
track: Track,
ctx: Optional[commands.Context] = None,
self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Optional[Union[List[Track], Playlist]]:
"""
Gets recommendations from either YouTube or Spotify.
@ -437,12 +387,7 @@ class Player(VoiceProtocol):
return await self._node.get_recommendations(track=track, ctx=ctx)
async def connect(
self,
*,
timeout: float,
reconnect: bool,
self_deaf: bool = False,
self_mute: bool = False,
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
) -> None:
await self.guild.change_voice_state(
channel=self.channel,
@ -462,7 +407,6 @@ class Player(VoiceProtocol):
data={"encodedTrack": None},
)
if self._log:
self._log.debug(f"Player has been stopped.")
async def disconnect(self, *, force: bool = False) -> None:
@ -481,34 +425,24 @@ class Player(VoiceProtocol):
except AttributeError:
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
# assume we're already disconnected and cleaned up
assert self.channel is None and not self.is_connected
assert not self.is_connected and not self.channel
self._node._players.pop(self.guild.id)
if self.node.is_connected:
await self._node.send(
method="DELETE",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
)
if self._log:
self._log.debug("Player has been destroyed.")
async def play(
self,
track: Track,
*,
start: int = 0,
end: int = 0,
ignore_if_playing: bool = False,
self, track: Track, *, start: int = 0, end: int = 0, ignore_if_playing: bool = False
) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
if not track._search_type:
track.original = track
# Make sure we've never searched the track before
if track._search_type and track.original is None:
if track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully)
try:
if not track.isrc:
@ -572,7 +506,7 @@ class Player(VoiceProtocol):
for filter in track.filters:
await self.add_filter(_filter=filter)
# Lavalink v3.7.5 changed the way the end time parameter works
# Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero.
# If it isnt zero, it'll be set to None.
# Otherwise, it'll be set here:
@ -588,7 +522,6 @@ class Player(VoiceProtocol):
query=f"noReplace={ignore_if_playing}",
)
if self._log:
self._log.debug(
f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
)
@ -612,7 +545,6 @@ class Player(VoiceProtocol):
data={"position": position},
)
if self._log:
self._log.debug(f"Seeking to {position}.")
return self.position
@ -626,7 +558,6 @@ class Player(VoiceProtocol):
)
self._paused = pause
if self._log:
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
return self._paused
@ -640,19 +571,9 @@ class Player(VoiceProtocol):
)
self._volume = volume
if self._log:
self._log.debug(f"Player volume has been adjusted to {volume}")
return self._volume
async def move_to(self, channel: VoiceChannel) -> None:
"""Moves the player to a new voice channel."""
await self.guild.change_voice_state(channel=channel)
self.channel = channel
await self._dispatch_voice_update()
async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
"""Adds a filter to the player. Takes a pomice.Filter object.
This will only work if you are using a version of Lavalink that supports filters.
@ -670,10 +591,8 @@ class Player(VoiceProtocol):
data={"filters": payload},
)
if self._log:
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
if fast_apply:
if self._log:
self._log.debug(f"Fast apply passed, now applying filter instantly.")
await self.seek(self.position)
@ -695,48 +614,13 @@ class Player(VoiceProtocol):
guild_id=self._guild.id,
data={"filters": payload},
)
if self._log:
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
if fast_apply:
if self._log:
self._log.debug(f"Fast apply passed, now removing filter instantly.")
await self.seek(self.position)
return self._filters
async def edit_filter(
self,
*,
filter_tag: str,
edited_filter: Filter,
fast_apply: bool = False,
) -> Filters:
"""Edits a filter from the player using its filter tag and a new filter of the same type.
The filter to be replaced must have the same tag as the one you are replacing it with.
This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
(You must have a song playing in order for `fast_apply` to work.)
"""
self._filters.edit_filter(filter_tag=filter_tag, to_apply=edited_filter)
payload = self._filters.get_all_payloads()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload},
)
if self._log:
self._log.debug(f"Filter with tag {filter_tag} has been edited to {edited_filter!r}")
if fast_apply:
if self._log:
self._log.debug(f"Fast apply passed, now editing filter instantly.")
await self.seek(self.position)
return self._filters
async def reset_filters(self, *, fast_apply: bool = False) -> None:
"""Resets all currently applied filters to their default parameters.
You must have filters applied in order for this to work.
@ -756,10 +640,8 @@ class Player(VoiceProtocol):
guild_id=self._guild.id,
data={"filters": {}},
)
if self._log:
self._log.debug(f"All filters have been removed from player.")
if fast_apply:
if self._log:
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
await self.seek(self.position)

View File

@ -5,8 +5,6 @@ import logging
import random
import re
import time
from os import path
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
@ -17,24 +15,15 @@ from typing import Union
from urllib.parse import quote
import aiohttp
import orjson as json
from discord import Client
from discord.ext import commands
from discord.utils import MISSING
try:
from websockets.legacy import client # websockets >= 10.0
except ImportError:
import websockets.client as client # websockets < 10.0 # type: ignore
from websockets import exceptions
from websockets import typing as wstype
from . import __version__
from . import applemusic
from . import spotify
from .enums import *
from .enums import LogLevel
from .exceptions import AppleMusicNotEnabled
from .exceptions import InvalidSpotifyClientAuthorization
from .exceptions import LavalinkVersionIncompatible
from .exceptions import NodeConnectionFailure
@ -60,8 +49,6 @@ __all__ = (
"NodePool",
)
VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?")
class Node:
"""The base class for a node.
@ -79,8 +66,6 @@ class Node:
"_password",
"_identifier",
"_heartbeat",
"_resume_key",
"_resume_timeout",
"_secure",
"_fallback",
"_log_level",
@ -115,20 +100,15 @@ class Node:
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 120,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None,
spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False,
fallback: bool = False,
logger: Optional[logging.Logger] = None,
log_level: LogLevel = LogLevel.INFO,
):
if not isinstance(port, int):
raise TypeError("Port must be an integer")
self._bot: commands.Bot = bot
self._host: str = host
self._port: int = port
@ -136,25 +116,24 @@ class Node:
self._password: str = password
self._identifier: str = identifier
self._heartbeat: int = heartbeat
self._resume_key: Optional[str] = resume_key
self._resume_timeout: int = resume_timeout
self._secure: bool = secure
self._fallback: bool = fallback
self._log_level: LogLevel = log_level
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: aiohttp.ClientSession = session # type: ignore
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: client.WebSocketClientProtocol
self._websocket: aiohttp.ClientWebSocketResponse
self._task: asyncio.Task = None # type: ignore
self._session_id: Optional[str] = None
self._available: bool = False
self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
self._version: LavalinkVersion = None
self._route_planner = RoutePlanner(self)
self._log = logger
self._log = self._setup_logging(self._log_level)
if not self._bot.user:
raise NodeCreationError("Bot user is not ready yet.")
@ -169,14 +148,13 @@ class Node:
self._players: Dict[int, Player] = {}
self._spotify_client: Optional[spotify.Client] = None
self._apple_music_client: Optional[applemusic.Client] = None
self._spotify_client_id: Optional[str] = spotify_client_id
self._spotify_client_id: Optional[int] = spotify_client_id
self._spotify_client_secret: Optional[str] = spotify_client_secret
self._apple_music_client: Optional[applemusic.Client] = None
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client(
self._spotify_client: spotify.Client = spotify.Client(
self._spotify_client_id,
self._spotify_client_secret,
)
@ -215,7 +193,7 @@ class Node:
@property
def player_count(self) -> int:
"""Property which returns how many players are connected to this node"""
return len(self.players.values())
return len(self.players)
@property
def pool(self) -> Type[NodePool]:
@ -232,44 +210,41 @@ class Node:
"""Alias for `Node.latency`, returns the latency of the node"""
return self.latency
async def _handle_version_check(self, version: str) -> None:
def _setup_logging(self, level: LogLevel) -> logging.Logger:
logger = logging.getLogger("pomice")
handler = logging.StreamHandler()
dt_fmt = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(
"[{asctime}] [{levelname:<8}] {name}: {message}",
dt_fmt,
style="{",
)
handler.setFormatter(formatter)
logger.setLevel(level)
logger.addHandler(handler)
return logger
async def _handle_version_check(self, version: str):
if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4
self._version = LavalinkVersion(major=4, minor=0, fix=0)
return
_version_rx = VERSION_REGEX.match(version)
if not _version_rx:
# this crazy ass line maps the split version string into
# an iterable with ints instead of strings and then
# turns that iterable into a tuple. yeah, i know
split = tuple(map(int, tuple(version.split("."))))
self._version = LavalinkVersion(*split)
if not version.endswith("-SNAPSHOT") and (
self._version.major == 3 and self._version.minor < 7
):
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
_version_groups = _version_rx.groups()
major, minor, fix = (
int(_version_groups[0] or 0),
int(_version_groups[1] or 0),
int(_version_groups[2] or 0),
)
if self._log:
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
if self._version < LavalinkVersion(3, 7, 0):
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
async def _set_ext_client_session(self, session: aiohttp.ClientSession) -> None:
if self._spotify_client:
await self._spotify_client._set_session(session=session)
if self._apple_music_client:
await self._apple_music_client._set_session(session=session)
async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready()
@ -304,59 +279,25 @@ class Node:
await self.disconnect()
async def _configure_resuming(self) -> None:
if not self._resume_key:
return
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
if self._version.major == 3:
data["resumingKey"] = self._resume_key
elif self._version.major == 4:
if self._log:
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
data["resuming"] = True
await self.send(
method="PATCH",
path=f"sessions/{self._session_id}",
include_version=True,
data=data,
)
async def _listen(self) -> None:
while True:
try:
msg = await self._websocket.recv()
data = json.loads(msg)
if self._log:
self._log.debug(f"Recieved raw websocket message {msg}")
self._loop.create_task(self._handle_ws_msg(data=data))
except exceptions.ConnectionClosed:
if self.player_count > 0:
for _player in self.players.values():
self._loop.create_task(_player.destroy())
if self._fallback:
self._loop.create_task(self._handle_node_switch())
self._loop.create_task(self._websocket.close())
backoff = ExponentialBackoff(base=7)
while True:
msg = await self._websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
if self._fallback:
await self._handle_node_switch()
retry = backoff.delay()
if self._log:
self._log.debug(
f"Retrying connection to Node {self._identifier} in {retry} secs",
)
await asyncio.sleep(retry)
if not self.is_connected:
self._loop.create_task(self.connect(reconnect=True))
self._loop.create_task(self.connect())
else:
self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_ws_msg(self, data: dict) -> None:
if self._log:
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
async def _handle_payload(self, data: dict) -> None:
op = data.get("op", None)
if not op:
return
if op == "stats":
self._stats = NodeStats(data)
@ -364,20 +305,21 @@ class Node:
if op == "ready":
self._session_id = data["sessionId"]
await self._configure_resuming()
if not "guildId" in data:
return
player: Optional[Player] = self._players.get(int(data["guildId"]))
player = self._players.get(int(data["guildId"]))
if not player:
return
if op == "event":
return await player._dispatch_event(data)
await player._dispatch_event(data)
return
if op == "playerUpdate":
return await player._update_state(data)
await player._update_state(data)
return
async def send(
self,
@ -402,39 +344,33 @@ class Node:
f'{f"?{query}" if query else ""}'
)
resp = await self._session.request(
async with self._session.request(
method=method,
url=uri,
headers=self._headers,
json=data or {},
)
if self._log:
self._log.debug(
f"Making REST request to Node {self._identifier} with method {method} to {uri}",
)
) as resp:
self._log.debug(f"Making REST request with method {method} to {uri}")
if resp.status >= 300:
resp_data: dict = await resp.json()
raise NodeRestException(
f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
)
if method == "DELETE" or resp.status == 204:
if self._log:
self._log.debug(
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
f"REST request with method {method} to {uri} completed sucessfully and returned no data.",
)
return await resp.json(content_type=None)
if resp.content_type == "text/plain":
if self._log:
self._log.debug(
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
f"REST request with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
)
return await resp.text()
if self._log:
self._log.debug(
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
f"REST request with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
)
return await resp.json()
@ -442,27 +378,16 @@ class Node:
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
return self._players.get(guild_id, None)
async def connect(self, *, reconnect: bool = False) -> Node:
async def connect(self) -> "Node":
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
await self._bot.wait_until_ready()
start = time.perf_counter()
if not self._session:
# Configure connection pooling for optimal concurrent request performance
connector = aiohttp.TCPConnector(
limit=100, # Total connection limit
limit_per_host=30, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL in seconds
)
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
self._session = aiohttp.ClientSession()
try:
if not reconnect:
version: str = await self.send(
method="GET",
path="version",
@ -471,29 +396,17 @@ class Node:
)
await self._handle_version_check(version=version)
await self._set_ext_client_session(session=self._session)
if self._log:
self._log.debug(
f"Version check from Node {self._identifier} successful. Returned version {version}",
)
self._log.debug(f"Version check from node successful. Returned version {version}")
self._websocket = await client.connect( # type: ignore
self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version.major}/websocket",
extra_headers=self._headers,
ping_interval=self._heartbeat,
headers=self._headers,
heartbeat=self._heartbeat,
)
if reconnect:
if self._log:
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
if self.player_count:
for player in self.players.values():
await player._refresh_endpoint_uri(self._session_id)
if self._log:
self._log.debug(
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
f"Connected to node websocket using {self._websocket_uri}/v{self._version.major}/websocket",
)
if not self._task:
@ -503,19 +416,18 @@ class Node:
end = time.perf_counter()
if self._log:
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
return self
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed.",
) from None
except exceptions.InvalidHandshake:
except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid.",
) from None
except exceptions.InvalidURI:
except aiohttp.InvalidURL:
raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid.",
) from None
@ -529,20 +441,25 @@ class Node:
for player in self.players.copy().values():
await player.destroy()
if self._log:
self._log.debug("All players disconnected from node.")
await self._websocket.close()
await self._session.close()
if self._log:
self._log.debug("Websocket and http session closed.")
if self._spotify_client:
await self._spotify_client.close()
self._log.debug("Spotify client session closed.")
if self._apple_music_client:
await self._apple_music_client.close()
self._log.debug("Apple Music client session closed.")
del self._pool._nodes[self._identifier]
self.available = False
self._task.cancel()
end = time.perf_counter()
if self._log:
self._log.info(
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
)
@ -558,16 +475,13 @@ class Node:
data: dict = await self.send(
method="GET",
path="decodetrack",
query=f"encodedTrack={quote(identifier)}",
query=f"encodedTrack={identifier}",
)
track_info = data["info"] if self._version.major >= 4 else data
return Track(
track_id=identifier,
ctx=ctx,
info=track_info,
track_type=TrackType(track_info["sourceName"]),
info=data,
track_type=TrackType(data["sourceName"]),
)
async def get_tracks(
@ -575,7 +489,7 @@ class Node:
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: Optional[SearchType] = SearchType.ytsearch,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
@ -592,17 +506,20 @@ class Node:
timestamp = None
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
if filters:
for filter in filters:
filter.set_preload()
# Due to the inclusion of plugins in the v4 update
# we are doing away with raising an error if pomice detects
# either a Spotify or Apple Music URL and the respective client
# is not enabled. Instead, we will just only parse the URL
# if the client is enabled and the URL is valid.
if URLRegex.AM_URL.match(query):
if not self._apple_music_client:
raise AppleMusicNotEnabled(
"You must have Apple Music functionality enabled in order to play Apple Music tracks."
"Please set apple_music to True in your Node class.",
)
if self._apple_music_client and URLRegex.AM_URL.match(query):
apple_music_results = await self._apple_music_client.search(query=query)
if isinstance(apple_music_results, applemusic.Song):
return [
@ -610,7 +527,7 @@ class Node:
track_id=apple_music_results.id,
ctx=ctx,
track_type=TrackType.APPLE_MUSIC,
search_type=search_type or SearchType.ytsearch,
search_type=search_type,
filters=filters,
info={
"title": apple_music_results.name,
@ -632,7 +549,7 @@ class Node:
track_id=track.id,
ctx=ctx,
track_type=TrackType.APPLE_MUSIC,
search_type=search_type or SearchType.ytsearch,
search_type=search_type,
filters=filters,
info={
"title": track.name,
@ -661,8 +578,15 @@ class Node:
uri=apple_music_results.url,
)
elif self._spotify_client and URLRegex.SPOTIFY_URL.match(query):
spotify_results = await self._spotify_client.search(query=query) # type: ignore
elif URLRegex.SPOTIFY_URL.match(query):
if not self._spotify_client_id and not self._spotify_client_secret:
raise InvalidSpotifyClientAuthorization(
"You did not provide proper Spotify client authorization credentials. "
"If you would like to use the Spotify searching feature, "
"please obtain Spotify API credentials here: https://developer.spotify.com/",
)
spotify_results = await self._spotify_client.search(query=query)
if isinstance(spotify_results, spotify.Track):
return [
@ -670,7 +594,7 @@ class Node:
track_id=spotify_results.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
search_type=search_type or SearchType.ytsearch,
search_type=search_type,
filters=filters,
info={
"title": spotify_results.name,
@ -692,7 +616,7 @@ class Node:
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
search_type=search_type or SearchType.ytsearch,
search_type=search_type,
filters=filters,
info={
"title": track.name,
@ -721,21 +645,45 @@ class Node:
uri=spotify_results.uri,
)
else:
if (
search_type
and not URLRegex.BASE_URL.match(query)
and not re.match(r"(?:[a-z]+?)search:.", query)
and not URLRegex.DISCORD_MP3_URL.match(query)
and not path.exists(path.dirname(query))
):
query = f"{search_type}:{query}"
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0]
info: dict = track["info"]
return [
Track(
track_id=track["track"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": info["length"],
"uri": info["uri"],
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
),
]
else:
# If YouTube url contains a timestamp, capture it for use later.
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):
timestamp = float(match.group("time"))
# If query is a video thats part of a playlist, get the video and queue that instead
# (I can't tell you how much i've wanted to implement this in here)
if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
query = match.group("video")
data = await self.send(
method="GET",
path="loadtracks",
@ -744,31 +692,21 @@ class Node:
load_type = data.get("loadType")
# Lavalink v4 changed the name of the key from "tracks" to "data"
# so lets account for that
data_type = "data" if self._version.major >= 4 else "tracks"
if not load_type:
raise TrackLoadError(
"There was an error while trying to load this track.",
)
elif load_type in ("LOAD_FAILED", "error"):
exception = data["data"] if self._version.major >= 4 else data["exception"]
elif load_type == "LOAD_FAILED":
exception = data["exception"]
raise TrackLoadError(
f"{exception['message']} [{exception['severity']}]",
)
elif load_type in ("NO_MATCHES", "empty"):
elif load_type == "NO_MATCHES":
return None
elif load_type in ("PLAYLIST_LOADED", "playlist"):
if self._version.major >= 4:
track_list = data[data_type]["tracks"]
playlist_info = data[data_type]["info"]
else:
track_list = data[data_type]
playlist_info = data["playlistInfo"]
elif load_type == "PLAYLIST_LOADED":
tracks = [
Track(
track_id=track["encoded"],
@ -776,60 +714,17 @@ class Node:
ctx=ctx,
track_type=TrackType(track["info"]["sourceName"]),
)
for track in track_list
for track in data["tracks"]
]
return Playlist(
playlist_info=playlist_info,
playlist_info=data["playlistInfo"],
tracks=tracks,
playlist_type=PlaylistType(tracks[0].track_type.value),
thumbnail=tracks[0].thumbnail,
uri=query,
)
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
if self._version.major >= 4 and isinstance(data[data_type], dict):
data[data_type] = [data[data_type]]
if path.exists(path.dirname(query)):
local_file = Path(query)
return [
Track(
track_id=track["encoded"],
info={
"title": local_file.name,
"author": "Unknown",
"length": track["info"]["length"],
"uri": quote(local_file.as_uri()),
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
)
for track in data[data_type]
]
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
return [
Track(
track_id=track["encoded"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": track["info"]["length"],
"uri": track["info"]["uri"],
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
)
for track in data[data_type]
]
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
return [
Track(
track_id=track["encoded"],
@ -839,7 +734,7 @@ class Node:
filters=filters,
timestamp=timestamp,
)
for track in data[data_type]
for track in data["tracks"]
]
else:
@ -848,10 +743,7 @@ class Node:
)
async def get_recommendations(
self,
*,
track: Track,
ctx: Optional[commands.Context] = None,
self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Optional[Union[List[Track], Playlist]]:
"""
Gets recommendations from either YouTube or Spotify.
@ -861,7 +753,7 @@ class Node:
Context object on all tracks that get recommended.
"""
if track.track_type == TrackType.SPOTIFY:
results = await self._spotify_client.get_recommendations(query=track.uri) # type: ignore
results = await self._spotify_client.get_recommendations(query=track.uri)
tracks = [
Track(
track_id=track.id,
@ -896,57 +788,6 @@ class Node:
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
)
async def search_spotify_recommendations(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""
Searches for recommendations on Spotify and returns a list of tracks based on the query.
You must have Spotify enabled for this to work.
You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended.
"""
if not self._spotify_client:
raise InvalidSpotifyClientAuthorization(
"You must have Spotify enabled to use this feature.",
)
results = await self._spotify_client.track_search(query=query) # type: ignore
if not results:
raise TrackLoadError(
"Unable to find any tracks based on the query.",
)
tracks = [
Track(
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
info={
"title": track.name,
"author": track.artists,
"length": track.length,
"identifier": track.id,
"uri": track.uri,
"isStream": False,
"isSeekable": True,
"position": 0,
"thumbnail": track.image,
"isrc": track.isrc,
},
requester=self.bot.user,
)
for track in results
]
track = tracks[0]
return await self.get_recommendations(track=track, ctx=ctx)
class NodePool:
"""The base class for the node pool.
@ -1028,16 +869,14 @@ class NodePool:
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 120,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None,
spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False,
fallback: bool = False,
logger: Optional[logging.Logger] = None,
log_level: LogLevel = LogLevel.INFO,
) -> Node:
"""Creates a Node object to be then added into the node pool.
For Spotify searching capabilites, pass in valid Spotify API credentials.
@ -1056,15 +895,13 @@ class NodePool:
identifier=identifier,
secure=secure,
heartbeat=heartbeat,
resume_key=resume_key,
resume_timeout=resume_timeout,
loop=loop,
spotify_client_id=spotify_client_id,
session=session,
spotify_client_secret=spotify_client_secret,
apple_music=apple_music,
fallback=fallback,
logger=logger,
log_level=log_level,
)
await node.connect()

View File

@ -203,13 +203,9 @@ class Queue(Iterable[Track]):
raise QueueEmpty("No items in the queue.")
if self._loop_mode == LoopMode.QUEUE:
# set current item to first track in queue if not set already
# otherwise exception will be raised
if not self._current_item or self._current_item not in self._queue:
if self._queue:
item = self._queue[0]
else:
raise QueueEmpty("No items in the queue.")
# recurse if the item isnt in the queue
if self._current_item not in self._queue:
self.get()
# set current item to first track in queue if not set already
if not self._current_item:
@ -352,23 +348,7 @@ class Queue(Iterable[Track]):
track.filters = None
def jump(self, item: Track) -> None:
"""
Jumps to the item specified in the queue.
If the queue is not looping, the queue will be mutated.
Otherwise, the current item will be adjusted to the item
before the specified track.
The queue is adjusted so that the next item that is retrieved
is the track that is specified, effectively 'jumping' the queue.
"""
if self._loop_mode == LoopMode.TRACK:
raise QueueException("Jumping the queue whilst looping a track is not allowed.")
"""Removes all tracks before the."""
index = self.find_position(item)
if self._loop_mode == LoopMode.QUEUE:
self._current_item = self._queue[index - 1]
else:
new_queue = self._queue[index : self.size]
self._queue = new_queue

View File

@ -1,16 +1,12 @@
from __future__ import annotations
import asyncio
import logging
import re
import time
from base64 import b64encode
from typing import AsyncGenerator
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from urllib.parse import quote
import aiohttp
import orjson as json
@ -24,10 +20,8 @@ __all__ = ("Client",)
GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
# optional trailing slash, and query parameters.
SPOTIFY_URL_REGEX = re.compile(
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
)
@ -37,49 +31,35 @@ class Client:
for any Spotify URL you throw at it.
"""
def __init__(
self,
client_id: str,
client_secret: str,
*,
playlist_concurrency: int = 10,
playlist_page_limit: Optional[int] = None,
) -> None:
self._client_id = client_id
self._client_secret = client_secret
def __init__(self, client_id: int, client_secret: str) -> None:
self._client_id: int = client_id
self._client_secret: str = client_secret
# HTTP session will be injected by Node
self.session: Optional[aiohttp.ClientSession] = None
self.session: aiohttp.ClientSession = None # type: ignore
self._bearer_token: Optional[str] = None
self._expiry: float = 0.0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode(),
)
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__)
# Performance tuning knobs
self._playlist_concurrency = max(1, playlist_concurrency)
self._playlist_page_limit = playlist_page_limit
async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session
async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"}
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
self.session = aiohttp.ClientSession()
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
if self._log:
self._log.debug(f"Fetched Spotify bearer token successfully")
self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10)
@ -100,31 +80,23 @@ class Client:
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
if self._log:
self._log.debug(
f"Made request to Spotify API with status {resp.status} and response {data}",
)
if spotify_type == "track":
return Track(data)
elif spotify_type == "album":
return Album(data)
elif spotify_type == "artist":
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(
async with self.session.get(
f"{request_url}/top-tracks?market=US",
headers=self._bearer_headers,
)
) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
@ -134,177 +106,36 @@ class Client:
tracks = track_data["tracks"]
return Artist(data, tracks)
else:
# For playlists we optionally use a reduced fields payload to shrink response sizes.
# NB: We cannot apply fields filter to initial request because original metadata is needed.
tracks = [
Track(track["track"])
for track in data["tracks"]["items"]
if track["track"] is not None
]
if not tracks:
if not len(tracks):
raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued.",
)
total_tracks = data["tracks"]["total"]
limit = data["tracks"]["limit"]
next_page_url = data["tracks"]["next"]
# Shortcircuit small playlists (single page)
if total_tracks <= limit:
return Playlist(data, tracks)
# Build remaining page URLs; Spotify supports offset-based pagination.
remaining_offsets = range(limit, total_tracks, limit)
page_urls: List[str] = []
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
for idx, offset in enumerate(remaining_offsets):
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
break
page_urls.append(
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
)
if page_urls:
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch_page(url: str) -> Optional[List[Track]]:
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
resp = await self.session.get(url, headers=self._bearer_headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Page fetch failed {resp.status} {resp.reason} for {url}",
)
return None
page_json: dict = await resp.json(loads=json.loads)
return [
Track(item["track"])
for item in page_json.get("items", [])
if item.get("track") is not None
]
# Chunk gather in waves to avoid creating thousands of tasks at once
aggregated: List[Track] = []
wave_size = self._playlist_concurrency * 2
for i in range(0, len(page_urls), wave_size):
wave = page_urls[i : i + wave_size]
results = await asyncio.gather(
*[fetch_page(url) for url in wave],
return_exceptions=False,
)
for result in results:
if result:
aggregated.extend(result)
tracks.extend(aggregated)
return Playlist(data, tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Track], None]:
"""Stream playlist tracks in batches without waiting for full materialization.
Parameters
----------
query: str
Spotify playlist URL.
batch_size: int
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
"""
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
match = SPOTIFY_URL_REGEX.match(query)
if not match or match.group("type") != "playlist":
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
playlist_id = match.group("id")
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
while next_page_url is not None:
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
# Yield first page immediately
first_page_tracks = [
Track(item["track"])
for item in data["tracks"]["items"]
if item.get("track") is not None
next_data: dict = await resp.json(loads=json.loads)
tracks += [
Track(track["track"])
for track in next_data["items"]
if track["track"] is not None
]
# Batch yield
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
next_page_url = next_data["next"]
total = data["tracks"]["total"]
limit = data["tracks"]["limit"]
remaining_offsets = range(limit, total, limit)
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch(offset: int) -> List[Track]:
url = (
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
)
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
r = await self.session.get(url, headers=self._bearer_headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping page offset={offset} due to {r.status} {r.reason}",
)
return []
pj: dict = await r.json(loads=json.loads)
return [
Track(item["track"])
for item in pj.get("items", [])
if item.get("track") is not None
]
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
wave_size = self._playlist_concurrency * 2
for i, offset in enumerate(remaining_offsets):
# Build wave
if i % wave_size == 0:
wave_offsets = list(
o for o in remaining_offsets if o >= offset and o < offset + wave_size
)
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
for page_tracks in results:
if not page_tracks:
continue
for j in range(0, len(page_tracks), batch_size):
yield page_tracks[j : j + batch_size]
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
# Fast-forward the generator manually
for _ in range(len(wave_offsets) - 1):
try:
next(remaining_offsets) # type: ignore
except StopIteration:
break
return Playlist(data, tracks)
async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
@ -327,34 +158,19 @@ class Client:
id=f"?seed_tracks={spotify_id}",
)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
tracks = [Track(track) for track in data["tracks"]]
return tracks
async def track_search(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
tracks = [Track(track) for track in data["tracks"]["items"]]
return tracks
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None # type: ignore

View File

@ -18,7 +18,7 @@ class Track:
self.length: float = data["duration_ms"]
self.id: str = data["id"]
self.isrc: Optional[str] = None
self.issrc: Optional[str] = None
if data.get("external_ids"):
self.isrc = data["external_ids"]["isrc"]

View File

@ -1,10 +1,9 @@
# type: ignore
import re
import setuptools
version = ""
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
with open("pomice/__init__.py") as f:
version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',