Compare commits
120 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
9bffdebe25 | |
|
|
720ba187ab | |
|
|
855bf4e0d7 | |
|
|
cd579becad | |
|
|
3a1ecf9eec | |
|
|
5227962228 | |
|
|
be7106616b | |
|
|
ba9534bc27 | |
|
|
2e0f5b365a | |
|
|
851f00aa97 | |
|
|
817295d321 | |
|
|
8ab3ae9ccd | |
|
|
094f2be181 | |
|
|
b60a6aec18 | |
|
|
80f7b77cd3 | |
|
|
8679d6d125 | |
|
|
ad01407fff | |
|
|
f1609f7049 | |
|
|
5fcfc73901 | |
|
|
86b35106b2 | |
|
|
519a14fbde | |
|
|
9a42093f64 | |
|
|
347a6e0b96 | |
|
|
83d5add134 | |
|
|
bb12e33584 | |
|
|
6817cd8e07 | |
|
|
179472bd6e | |
|
|
ba761743b9 | |
|
|
b3795102b8 | |
|
|
2a492c793f | |
|
|
705ac9feab | |
|
|
a926616028 | |
|
|
4507b50b8b | |
|
|
bd78f47585 | |
|
|
9b18759864 | |
|
|
001b801a15 | |
|
|
db1c66dd40 | |
|
|
341164a0d2 | |
|
|
7829086ae3 | |
|
|
f9cb48c48f | |
|
|
3401b669e8 | |
|
|
d7a7efb051 | |
|
|
0904196979 | |
|
|
7617ecf2d1 | |
|
|
1acc594467 | |
|
|
e48c31b7a9 | |
|
|
f3c5461854 | |
|
|
aa826c7da2 | |
|
|
bc71088092 | |
|
|
7eca4724da | |
|
|
b098b681be | |
|
|
c52a379b87 | |
|
|
50b5eab860 | |
|
|
18fed3a089 | |
|
|
ab432cc8e6 | |
|
|
223be29384 | |
|
|
c5f8ded0b1 | |
|
|
0b1d36cf64 | |
|
|
1f20ebf6c6 | |
|
|
af5418c958 | |
|
|
6670da76e8 | |
|
|
69d3bc9ce1 | |
|
|
e3fe1b52b2 | |
|
|
02d22f20b5 | |
|
|
cbb676e004 | |
|
|
2d8acf7800 | |
|
|
481b2079ed | |
|
|
952a3eff14 | |
|
|
4fc9bd8810 | |
|
|
28db38a00e | |
|
|
00ac166371 | |
|
|
dd3d43e702 | |
|
|
334d74095e | |
|
|
b461b91587 | |
|
|
56843c459c | |
|
|
d5cf16ac63 | |
|
|
394e3a3907 | |
|
|
380266f2c3 | |
|
|
4e720e3dc9 | |
|
|
b91f6ec04e | |
|
|
248cce6656 | |
|
|
d23fe6b8a4 | |
|
|
665d6c13a3 | |
|
|
f823786029 | |
|
|
e69349bca8 | |
|
|
5445661f42 | |
|
|
b75d2f580c | |
|
|
bb835fb173 | |
|
|
b0ef03d2d1 | |
|
|
f9bf268c89 | |
|
|
dfc516f8bd | |
|
|
00370cfbc7 | |
|
|
6ba2ea1d6d | |
|
|
2cab4cb7d0 | |
|
|
77a7246b6a | |
|
|
2461cbb831 | |
|
|
42d886554e | |
|
|
e7c627dcd2 | |
|
|
5c71e9a562 | |
|
|
ab374d4ba8 | |
|
|
02b62d493f | |
|
|
4caaff8b04 | |
|
|
74256dc5ac | |
|
|
bf144a783c | |
|
|
0b0b50f259 | |
|
|
cf3834d5c2 | |
|
|
3a8e622f89 | |
|
|
2ddbb5d91a | |
|
|
b73af37bbf | |
|
|
6ed2fd961b | |
|
|
c88f020280 | |
|
|
206adbd70b | |
|
|
14ba273d35 | |
|
|
6d96a9e53d | |
|
|
45d3e611a5 | |
|
|
9c262c7455 | |
|
|
a8a586bfb1 | |
|
|
b0e0bba27b | |
|
|
0d78b00342 | |
|
|
367a215b05 |
|
|
@ -1,39 +0,0 @@
|
||||||
# This workflow will upload a Python Package using Twine when a release is created
|
|
||||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
|
||||||
|
|
||||||
# This workflow uses actions that are not certified by GitHub.
|
|
||||||
# They are provided by a third-party and are governed by
|
|
||||||
# separate terms of service, privacy policy, and support
|
|
||||||
# documentation.
|
|
||||||
|
|
||||||
name: Upload Python Package
|
|
||||||
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
deploy:
|
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v3
|
|
||||||
with:
|
|
||||||
python-version: '3.x'
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install build
|
|
||||||
- name: Build package
|
|
||||||
run: python -m build
|
|
||||||
- name: Publish package
|
|
||||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
|
||||||
with:
|
|
||||||
user: __token__
|
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
|
|
@ -10,4 +10,8 @@ build/
|
||||||
Pipfile.lock
|
Pipfile.lock
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
.idea/
|
||||||
.venv/
|
.venv/
|
||||||
|
*.code-workspace
|
||||||
|
*.ini
|
||||||
|
.pypirc
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# See https://pre-commit.com/hooks.html for more hooks
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-ast
|
- id: check-ast
|
||||||
- id: check-builtin-literals
|
- id: check-builtin-literals
|
||||||
|
|
@ -11,31 +11,23 @@ repos:
|
||||||
- id: requirements-txt-fixer
|
- id: requirements-txt-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 23.1.0
|
rev: 23.10.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
language_version: python3.8
|
language_version: python3.13
|
||||||
- repo: https://github.com/asottile/blacken-docs
|
|
||||||
rev: 1.13.0
|
|
||||||
hooks:
|
|
||||||
- id: blacken-docs
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v3.3.1
|
rev: v3.15.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
args: [--py37-plus, --keep-runtime-typing]
|
args: [--py37-plus, --keep-runtime-typing]
|
||||||
- repo: https://github.com/asottile/reorder_python_imports
|
- repo: https://github.com/asottile/reorder-python-imports
|
||||||
rev: v3.9.0
|
rev: v3.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: reorder-python-imports
|
- id: reorder-python-imports
|
||||||
- repo: https://github.com/asottile/add-trailing-comma
|
- repo: https://github.com/asottile/add-trailing-comma
|
||||||
rev: v2.4.0
|
rev: v3.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: add-trailing-comma
|
- id: add-trailing-comma
|
||||||
- repo: https://github.com/hadialqattan/pycln
|
|
||||||
rev: v2.1.3
|
|
||||||
hooks:
|
|
||||||
- id: pycln
|
|
||||||
|
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.8
|
python: python3.13
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
3.10.9
|
|
||||||
3
Pipfile
3
Pipfile
|
|
@ -6,6 +6,7 @@ name = "pypi"
|
||||||
[packages]
|
[packages]
|
||||||
orjson = "*"
|
orjson = "*"
|
||||||
"discord.py" = {extras = ["voice"], version = "*"}
|
"discord.py" = {extras = ["voice"], version = "*"}
|
||||||
|
websockets = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
mypy = "*"
|
mypy = "*"
|
||||||
|
|
@ -13,6 +14,8 @@ pre-commit = "*"
|
||||||
furo = "*"
|
furo = "*"
|
||||||
sphinx = "*"
|
sphinx = "*"
|
||||||
myst-parser = "*"
|
myst-parser = "*"
|
||||||
|
black = "*"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.8"
|
python_version = "3.8"
|
||||||
|
|
|
||||||
15
README.md
15
README.md
|
|
@ -3,11 +3,16 @@
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
[](https://github.com/cloudwithax/pomice/blob/main/LICENSE) 
|
[](https://github.com/cloudwithax/pomice/blob/main/LICENSE)  [](https://github.com/psf/black)
|
||||||
[](https://discord.gg/r64qjTSHG8) [](https://pomice.readthedocs.io/en/latest/)
|
[](https://discord.gg/r64qjTSHG8) [](https://pomice.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
|
||||||
Pomice is a fully asynchronous Python library designed for communicating with [Lavalink](https://github.com/freyacodes/Lavalink) seamlessly within the [discord.py](https://github.com/Rapptz/discord.py) library. It features 100% API coverage of the entire [Lavalink](https://github.com/freyacodes/Lavalink) spec that can be accessed with easy-to-understand functions. We also include Spotify and Apple Music querying capabilites using built-in custom clients, making it easier to develop your next big music bot.
|
Pomice is a fully asynchronous Python library designed for communicating with [Lavalink](https://github.com/freyacodes/Lavalink) seamlessly within the [discord.py](https://github.com/Rapptz/discord.py) library. It features 100% coverage of the [Lavalink](https://github.com/freyacodes/Lavalink) spec that can be accessed with easy-to-understand functions along with Spotify and Apple Music querying capabilities using built-in custom clients, making it easier to develop your next big music bot.
|
||||||
|
|
||||||
|
## Quick Links
|
||||||
|
- [Discord Server](https://discord.gg/r64qjTSHG8)
|
||||||
|
- [Read the Docs](https://pomice.readthedocs.io/en/latest/)
|
||||||
|
- [PyPI Homepage](https://pypi.org/project/pomice/)
|
||||||
|
|
||||||
|
|
||||||
# Install
|
# Install
|
||||||
|
|
@ -23,7 +28,7 @@ pip install pomice
|
||||||
pip install git+https://github.com/cloudwithax/pomice
|
pip install git+https://github.com/cloudwithax/pomice
|
||||||
```
|
```
|
||||||
|
|
||||||
# Support
|
# Support And Documentation
|
||||||
|
|
||||||
The official documentation is [here](https://pomice.readthedocs.io/en/latest/)
|
The official documentation is [here](https://pomice.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
|
@ -31,7 +36,7 @@ You can join our support server [here](https://discord.gg/r64qjTSHG8)
|
||||||
|
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
In-depth examples are located in the examples folder
|
In-depth examples are located in the [examples folder](https://github.com/cloudwithax/pomice/tree/main/examples)
|
||||||
|
|
||||||
Here's a quick example:
|
Here's a quick example:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# type: ignore
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,9 @@ The classes listed here are as they appear in Pomice. When you use them within y
|
||||||
the way you use them will be different. Here's an example on how you would use the `TrackStartEvent` within an event listener in a cog:
|
the way you use them will be different. Here's an example on how you would use the `TrackStartEvent` within an event listener in a cog:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
|
||||||
@commands.Cog.listener
|
@commands.Cog.listener
|
||||||
async def on_pomice_track_start(self, player: Player, track: Track):
|
async def on_pomice_track_start(self, player: Player, track: Track):
|
||||||
...
|
...
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Event definitions
|
## Event definitions
|
||||||
|
|
@ -34,7 +32,7 @@ your application. Here are all the definitions:
|
||||||
|
|
||||||
|
|
||||||
All events related to tracks carry a `Player` object so you can access player-specific functions
|
All events related to tracks carry a `Player` object so you can access player-specific functions
|
||||||
and properties for further evaluation. They also carry a `Track` object so you can access track-specific functions and properites for further evaluation as well.
|
and properties for further evaluation. They also carry a `Track` object so you can access track-specific functions and properties for further evaluation as well.
|
||||||
|
|
||||||
`Event.TrackEndEvent()` carries the reason for the track ending. If the track ends suddenly, you can use the reason provided to determine a solution.
|
`Event.TrackEndEvent()` carries the reason for the track ending. If the track ends suddenly, you can use the reason provided to determine a solution.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -361,6 +361,27 @@ await Player.stop()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Moving the player to another channel
|
||||||
|
|
||||||
|
To move the player to another channel, we need to use `Player.move_to()`
|
||||||
|
|
||||||
|
```py
|
||||||
|
|
||||||
|
await Player.move_to(...)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
After you have initialized your function, we need to include the `channel` parameter, which is a `VoiceChannel`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
|
||||||
|
await Player.move_to(channel)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
After running this function, your player should be in the new voice channel. All voice state updates should also be handled.
|
||||||
|
|
||||||
|
|
||||||
## Controlling filters
|
## Controlling filters
|
||||||
|
|
||||||
Pomice has an extensive suite of filter management tools to help you make the most of Lavalink and it's filters.
|
Pomice has an extensive suite of filter management tools to help you make the most of Lavalink and it's filters.
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,10 @@ After you have initialized your function, we need to fill in the proper paramete
|
||||||
- Set this value to `True` if you want Pomice to automatically switch all players to another available node if one disconnects.
|
- Set this value to `True` if you want Pomice to automatically switch all players to another available node if one disconnects.
|
||||||
You must have two or more nodes to be able to do this.
|
You must have two or more nodes to be able to do this.
|
||||||
|
|
||||||
* - `log_level`
|
* - `logger`
|
||||||
- `LogLevel`
|
- `Optional[logging.Logger]`
|
||||||
- The logging level for the node. The default logging level is `LogLevel.INFO`.
|
- If you would like to receive logging information from Pomice, set this to your logger class
|
||||||
|
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
@ -89,13 +90,13 @@ await NodePool.create_node(
|
||||||
spotify_client_secret="<your spotify client secret here>"
|
spotify_client_secret="<your spotify client secret here>"
|
||||||
apple_music=<True/False>,
|
apple_music=<True/False>,
|
||||||
fallback=<True/False>,
|
fallback=<True/False>,
|
||||||
log_level=<optiona LogLevel here>
|
logger=<your logger here>
|
||||||
)
|
)
|
||||||
|
|
||||||
```
|
```
|
||||||
:::{important}
|
:::{important}
|
||||||
|
|
||||||
For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track anyway, they will **not work** because these options are not enabled.
|
For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track, it is **up to you** on how you decide to handle it, whether it be through your own methods or a Lavalink plugin.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,17 @@ import re
|
||||||
|
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
URL_REG = re.compile(r'https?://(?:www\.)?.+')
|
URL_REG = re.compile(r"https?://(?:www\.)?.+")
|
||||||
|
|
||||||
|
|
||||||
class MyBot(commands.Bot):
|
class MyBot(commands.Bot):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__(command_prefix='!', activity=discord.Activity(type=discord.ActivityType.listening, name='to music!'))
|
super().__init__(
|
||||||
|
command_prefix="!",
|
||||||
|
activity=discord.Activity(
|
||||||
|
type=discord.ActivityType.listening, name="to music!"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.add_cog(Music(self))
|
self.add_cog(Music(self))
|
||||||
|
|
||||||
|
|
@ -25,44 +30,47 @@ class MyBot(commands.Bot):
|
||||||
|
|
||||||
|
|
||||||
class Music(commands.Cog):
|
class Music(commands.Cog):
|
||||||
|
|
||||||
def __init__(self, bot) -> None:
|
def __init__(self, bot) -> None:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
self.pomice = pomice.NodePool()
|
self.pomice = pomice.NodePool()
|
||||||
|
|
||||||
async def start_nodes(self):
|
async def start_nodes(self):
|
||||||
await self.pomice.create_node(bot=self.bot, host='127.0.0.1', port='3030',
|
await self.pomice.create_node(
|
||||||
password='youshallnotpass', identifier='MAIN')
|
bot=self.bot,
|
||||||
|
host="127.0.0.1",
|
||||||
|
port="3030",
|
||||||
|
password="youshallnotpass",
|
||||||
|
identifier="MAIN",
|
||||||
|
)
|
||||||
print(f"Node is ready!")
|
print(f"Node is ready!")
|
||||||
|
|
||||||
|
@commands.command(name="join", aliases=["connect"])
|
||||||
|
async def join(
|
||||||
@commands.command(name='join', aliases=['connect'])
|
self, ctx: commands.Context, *, channel: discord.TextChannel = None
|
||||||
async def join(self, ctx: commands.Context, *, channel: discord.TextChannel = None) -> None:
|
) -> None:
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
channel = getattr(ctx.author.voice, 'channel', None)
|
channel = getattr(ctx.author.voice, "channel", None)
|
||||||
if not channel:
|
if not channel:
|
||||||
raise commands.CheckFailure('You must be in a voice channel to use this command'
|
raise commands.CheckFailure(
|
||||||
'without specifying the channel argument.')
|
"You must be in a voice channel to use this command"
|
||||||
|
"without specifying the channel argument."
|
||||||
|
)
|
||||||
|
|
||||||
await ctx.author.voice.channel.connect(cls=pomice.Player)
|
await ctx.author.voice.channel.connect(cls=pomice.Player)
|
||||||
await ctx.send(f'Joined the voice channel `{channel}`')
|
await ctx.send(f"Joined the voice channel `{channel}`")
|
||||||
|
|
||||||
@commands.command(name='play')
|
@commands.command(name="play")
|
||||||
async def play(self, ctx, *, search: str) -> None:
|
async def play(self, ctx, *, search: str) -> None:
|
||||||
|
|
||||||
if not ctx.voice_client:
|
if not ctx.voice_client:
|
||||||
await ctx.invoke(self.join)
|
await ctx.invoke(self.join)
|
||||||
|
|
||||||
player = ctx.voice_client
|
player = ctx.voice_client
|
||||||
|
|
||||||
results = await player.get_tracks(query=f'{search}')
|
results = await player.get_tracks(query=f"{search}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
raise commands.CommandError('No results were found for that search term.')
|
raise commands.CommandError("No results were found for that search term.")
|
||||||
|
|
||||||
if isinstance(results, pomice.Playlist):
|
if isinstance(results, pomice.Playlist):
|
||||||
await player.play(track=results.tracks[0])
|
await player.play(track=results.tracks[0])
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,4 @@ discord.py[voice]
|
||||||
furo
|
furo
|
||||||
myst_parser
|
myst_parser
|
||||||
orjson
|
orjson
|
||||||
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# type: ignore
|
||||||
"""
|
"""
|
||||||
This example aims to show the full capabilities of the library.
|
This example aims to show the full capabilities of the library.
|
||||||
This is in the form of a drop-in cog you can use and modify to your liking.
|
This is in the form of a drop-in cog you can use and modify to your liking.
|
||||||
|
|
@ -100,13 +101,13 @@ class Music(commands.Cog):
|
||||||
await self.pomice.create_node(
|
await self.pomice.create_node(
|
||||||
bot=self.bot,
|
bot=self.bot,
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port="3030",
|
port=3030,
|
||||||
password="youshallnotpass",
|
password="youshallnotpass",
|
||||||
identifier="MAIN",
|
identifier="MAIN",
|
||||||
)
|
)
|
||||||
print(f"Node is ready!")
|
print(f"Node is ready!")
|
||||||
|
|
||||||
async def required(self, ctx: commands.Context):
|
def required(self, ctx: commands.Context):
|
||||||
"""Method which returns required votes based on amount of members in a channel."""
|
"""Method which returns required votes based on amount of members in a channel."""
|
||||||
player: Player = ctx.voice_client
|
player: Player = ctx.voice_client
|
||||||
channel = self.bot.get_channel(int(player.channel.id))
|
channel = self.bot.get_channel(int(player.channel.id))
|
||||||
|
|
@ -118,7 +119,7 @@ class Music(commands.Cog):
|
||||||
|
|
||||||
return required
|
return required
|
||||||
|
|
||||||
async def is_privileged(self, ctx: commands.Context):
|
def is_privileged(self, ctx: commands.Context):
|
||||||
"""Check whether the user is an Admin or DJ."""
|
"""Check whether the user is an Admin or DJ."""
|
||||||
player: Player = ctx.voice_client
|
player: Player = ctx.voice_client
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# type: ignore
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
@ -35,7 +36,7 @@ class Music(commands.Cog):
|
||||||
await self.pomice.create_node(
|
await self.pomice.create_node(
|
||||||
bot=self.bot,
|
bot=self.bot,
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port="3030",
|
port=3030,
|
||||||
password="youshallnotpass",
|
password="youshallnotpass",
|
||||||
identifier="MAIN",
|
identifier="MAIN",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ Pomice
|
||||||
~~~~~~
|
~~~~~~
|
||||||
The modern Lavalink wrapper designed for discord.py.
|
The modern Lavalink wrapper designed for discord.py.
|
||||||
|
|
||||||
Copyright (c) 2023, cloudwithax
|
Copyright (c) 2024, cloudwithax
|
||||||
|
|
||||||
Licensed under GPL-3.0
|
Licensed under GPL-3.0
|
||||||
"""
|
"""
|
||||||
|
|
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
|
||||||
"using 'pip install discord.py'",
|
"using 'pip install discord.py'",
|
||||||
)
|
)
|
||||||
|
|
||||||
__version__ = "2.3"
|
__version__ = "2.10.0"
|
||||||
__title__ = "pomice"
|
__title__ = "pomice"
|
||||||
__author__ = "cloudwithax"
|
__author__ = "cloudwithax"
|
||||||
__license__ = "GPL-3.0"
|
__license__ = "GPL-3.0"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
|
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
|
||||||
from .client import Client
|
from .client import *
|
||||||
from .exceptions import *
|
from .exceptions import *
|
||||||
from .objects import *
|
from .objects import *
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import AsyncGenerator
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -16,11 +20,14 @@ from .objects import *
|
||||||
__all__ = ("Client",)
|
__all__ = ("Client",)
|
||||||
|
|
||||||
AM_URL_REGEX = re.compile(
|
AM_URL_REGEX = re.compile(
|
||||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
||||||
)
|
)
|
||||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
|
||||||
|
|
||||||
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
|
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
|
||||||
AM_BASE_URL = "https://api.music.apple.com"
|
AM_BASE_URL = "https://api.music.apple.com"
|
||||||
|
|
||||||
|
|
@ -31,21 +38,50 @@ class Client:
|
||||||
and translating it to a valid Lavalink track. No client auth is required here.
|
and translating it to a valid Lavalink track. No client auth is required here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, *, playlist_concurrency: int = 6) -> None:
|
||||||
self.expiry: datetime = datetime(1970, 1, 1)
|
self.expiry: datetime = datetime(1970, 1, 1)
|
||||||
self.token: str = ""
|
self.token: str = ""
|
||||||
self.headers: Dict[str, str] = {}
|
self.headers: Dict[str, str] = {}
|
||||||
self.session: aiohttp.ClientSession = None # type: ignore
|
self.session: aiohttp.ClientSession = None # type: ignore
|
||||||
|
self._log = logging.getLogger(__name__)
|
||||||
|
# Concurrency knob for parallel playlist page retrieval
|
||||||
|
self._playlist_concurrency = max(1, playlist_concurrency)
|
||||||
|
|
||||||
|
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
async def request_token(self) -> None:
|
async def request_token(self) -> None:
|
||||||
if not self.session:
|
# First lets get the raw response from the main page
|
||||||
self.session = aiohttp.ClientSession()
|
|
||||||
|
resp = await self.session.get("https://music.apple.com")
|
||||||
|
|
||||||
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise AppleMusicRequestException(
|
raise AppleMusicRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Looking for script tag that fits criteria
|
||||||
|
|
||||||
|
text = await resp.text()
|
||||||
|
match = re.search(AM_SCRIPT_REGEX, text)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
raise AppleMusicRequestException(
|
||||||
|
"Could not find valid script URL in response.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Found the script file, lets grab our token
|
||||||
|
|
||||||
|
result = match.group(1)
|
||||||
|
asset_url = result
|
||||||
|
|
||||||
|
resp = await self.session.get("https://music.apple.com" + asset_url)
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
raise AppleMusicRequestException(
|
||||||
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
|
)
|
||||||
|
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
match = re.search('"(eyJ.+?)"', text)
|
match = re.search('"(eyJ.+?)"', text)
|
||||||
if not match:
|
if not match:
|
||||||
|
|
@ -65,6 +101,8 @@ class Client:
|
||||||
).decode()
|
).decode()
|
||||||
token_data = json.loads(token_json)
|
token_data = json.loads(token_json)
|
||||||
self.expiry = datetime.fromtimestamp(token_data["exp"])
|
self.expiry = datetime.fromtimestamp(token_data["exp"])
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Fetched Apple Music bearer token successfully")
|
||||||
|
|
||||||
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
|
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
|
||||||
if not self.token or datetime.utcnow() > self.expiry:
|
if not self.token or datetime.utcnow() > self.expiry:
|
||||||
|
|
@ -90,35 +128,42 @@ class Client:
|
||||||
else:
|
else:
|
||||||
request_url = AM_REQ_URL.format(country=country, type=type, id=id)
|
request_url = AM_REQ_URL.format(country=country, type=type, id=id)
|
||||||
|
|
||||||
async with self.session.get(request_url, headers=self.headers) as resp:
|
resp = await self.session.get(request_url, headers=self.headers)
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise AppleMusicRequestException(
|
raise AppleMusicRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
|
f"Made request to Apple Music API with status {resp.status} and response {data}",
|
||||||
|
)
|
||||||
|
|
||||||
data = data["data"][0]
|
data = data["data"][0]
|
||||||
|
|
||||||
if type == "song":
|
if type == "song":
|
||||||
return Song(data)
|
return Song(data)
|
||||||
|
|
||||||
if type == "album":
|
elif type == "album":
|
||||||
return Album(data)
|
return Album(data)
|
||||||
|
|
||||||
if type == "artist":
|
elif type == "artist":
|
||||||
async with self.session.get(
|
resp = await self.session.get(
|
||||||
f"{request_url}/view/top-songs",
|
f"{request_url}/view/top-songs",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
) as resp:
|
)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise AppleMusicRequestException(
|
raise AppleMusicRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
top_tracks: dict = await resp.json(loads=json.loads)
|
top_tracks: dict = await resp.json(loads=json.loads)
|
||||||
artist_tracks: dict = top_tracks["data"]
|
artist_tracks: dict = top_tracks["data"]
|
||||||
|
|
||||||
return Artist(data, tracks=artist_tracks)
|
return Artist(data, tracks=artist_tracks)
|
||||||
|
else:
|
||||||
track_data: dict = data["relationships"]["tracks"]
|
track_data: dict = data["relationships"]["tracks"]
|
||||||
album_tracks: List[Song] = [Song(track) for track in track_data["data"]]
|
album_tracks: List[Song] = [Song(track) for track in track_data["data"]]
|
||||||
|
|
||||||
|
|
@ -127,30 +172,127 @@ class Client:
|
||||||
"This playlist is empty and therefore cannot be queued.",
|
"This playlist is empty and therefore cannot be queued.",
|
||||||
)
|
)
|
||||||
|
|
||||||
_next = track_data.get("next")
|
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
|
||||||
if _next:
|
# concurrently by first collecting cursors in rolling waves.
|
||||||
next_page_url = AM_BASE_URL + _next
|
next_cursor = track_data.get("next")
|
||||||
|
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||||
|
|
||||||
while next_page_url is not None:
|
async def fetch_page(url: str) -> List[Song]:
|
||||||
async with self.session.get(next_page_url, headers=self.headers) as resp:
|
async with semaphore:
|
||||||
|
resp = await self.session.get(url, headers=self.headers)
|
||||||
|
if resp.status != 200:
|
||||||
|
if self._log:
|
||||||
|
self._log.warning(
|
||||||
|
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
pj: dict = await resp.json(loads=json.loads)
|
||||||
|
songs = [Song(track) for track in pj.get("data", [])]
|
||||||
|
# Return songs; we will look for pj.get('next') in streaming iterator variant
|
||||||
|
return songs, pj.get("next") # type: ignore
|
||||||
|
|
||||||
|
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
|
||||||
|
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
|
||||||
|
waves: List[List[Song]] = []
|
||||||
|
cursors: List[str] = []
|
||||||
|
if next_cursor:
|
||||||
|
cursors.append(next_cursor)
|
||||||
|
|
||||||
|
# Limit total waves to avoid infinite loops in malformed responses
|
||||||
|
max_waves = 50
|
||||||
|
wave_size = self._playlist_concurrency * 2
|
||||||
|
wave_counter = 0
|
||||||
|
while cursors and wave_counter < max_waves:
|
||||||
|
current = cursors[:wave_size]
|
||||||
|
cursors = cursors[wave_size:]
|
||||||
|
tasks = [
|
||||||
|
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for res in results:
|
||||||
|
if isinstance(res, tuple): # (songs, next)
|
||||||
|
songs, nxt = res
|
||||||
|
if songs:
|
||||||
|
waves.append(songs)
|
||||||
|
if nxt:
|
||||||
|
cursors.append(nxt)
|
||||||
|
wave_counter += 1
|
||||||
|
|
||||||
|
for w in waves:
|
||||||
|
album_tracks.extend(w)
|
||||||
|
|
||||||
|
return Playlist(data, album_tracks)
|
||||||
|
|
||||||
|
async def iter_playlist_tracks(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> AsyncGenerator[List[Song], None]:
|
||||||
|
"""Stream Apple Music playlist tracks in batches.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: str
|
||||||
|
Apple Music playlist URL.
|
||||||
|
batch_size: int
|
||||||
|
Logical grouping size for yielded batches.
|
||||||
|
"""
|
||||||
|
if not self.token or datetime.utcnow() > self.expiry:
|
||||||
|
await self.request_token()
|
||||||
|
|
||||||
|
result = AM_URL_REGEX.match(query)
|
||||||
|
if not result or result.group("type") != "playlist":
|
||||||
|
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
|
||||||
|
|
||||||
|
country = result.group("country")
|
||||||
|
playlist_id = result.group("id")
|
||||||
|
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
|
||||||
|
resp = await self.session.get(request_url, headers=self.headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise AppleMusicRequestException(
|
raise AppleMusicRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
playlist_data = data["data"][0]
|
||||||
|
track_data: dict = playlist_data["relationships"]["tracks"]
|
||||||
|
|
||||||
next_data: dict = await resp.json(loads=json.loads)
|
first_page_tracks = [Song(track) for track in track_data["data"]]
|
||||||
|
for i in range(0, len(first_page_tracks), batch_size):
|
||||||
|
yield first_page_tracks[i : i + batch_size]
|
||||||
|
|
||||||
album_tracks.extend(Song(track) for track in next_data["data"])
|
next_cursor = track_data.get("next")
|
||||||
|
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||||
|
|
||||||
_next = next_data.get("next")
|
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
|
||||||
if _next:
|
url = AM_BASE_URL + cursor
|
||||||
next_page_url = AM_BASE_URL + _next
|
async with semaphore:
|
||||||
else:
|
r = await self.session.get(url, headers=self.headers)
|
||||||
next_page_url = None
|
if r.status != 200:
|
||||||
|
if self._log:
|
||||||
|
self._log.warning(
|
||||||
|
f"Skipping Apple Music page due to {r.status} {r.reason}",
|
||||||
|
)
|
||||||
|
return [], None
|
||||||
|
pj: dict = await r.json(loads=json.loads)
|
||||||
|
songs = [Song(track) for track in pj.get("data", [])]
|
||||||
|
return songs, pj.get("next")
|
||||||
|
|
||||||
return Playlist(data, album_tracks)
|
# Rolling waves of fetches following cursor chain
|
||||||
|
max_waves = 50
|
||||||
async def close(self) -> None:
|
wave_size = self._playlist_concurrency * 2
|
||||||
if self.session:
|
waves = 0
|
||||||
await self.session.close()
|
cursors: List[str] = []
|
||||||
self.session = None # type: ignore
|
if next_cursor:
|
||||||
|
cursors.append(next_cursor)
|
||||||
|
while cursors and waves < max_waves:
|
||||||
|
current = cursors[:wave_size]
|
||||||
|
cursors = cursors[wave_size:]
|
||||||
|
results = await asyncio.gather(*[fetch(c) for c in current])
|
||||||
|
for songs, nxt in results:
|
||||||
|
if songs:
|
||||||
|
for j in range(0, len(songs), batch_size):
|
||||||
|
yield songs[j : j + batch_size]
|
||||||
|
if nxt:
|
||||||
|
cursors.append(nxt)
|
||||||
|
waves += 1
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,11 @@ class SearchType(Enum):
|
||||||
ytsearch = "ytsearch"
|
ytsearch = "ytsearch"
|
||||||
ytmsearch = "ytmsearch"
|
ytmsearch = "ytmsearch"
|
||||||
scsearch = "scsearch"
|
scsearch = "scsearch"
|
||||||
|
other = "other"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
|
||||||
|
return cls.other
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -52,6 +57,10 @@ class TrackType(Enum):
|
||||||
TrackType.APPLE_MUSIC defines that the track is from Apple Music.
|
TrackType.APPLE_MUSIC defines that the track is from Apple Music.
|
||||||
|
|
||||||
TrackType.HTTP defines that the track is from an HTTP source.
|
TrackType.HTTP defines that the track is from an HTTP source.
|
||||||
|
|
||||||
|
TrackType.LOCAL defines that the track is from a local source.
|
||||||
|
|
||||||
|
TrackType.OTHER defines that the track is from an unknown source (possible from 3rd-party plugins).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
|
|
@ -60,6 +69,12 @@ class TrackType(Enum):
|
||||||
SPOTIFY = "spotify"
|
SPOTIFY = "spotify"
|
||||||
APPLE_MUSIC = "apple_music"
|
APPLE_MUSIC = "apple_music"
|
||||||
HTTP = "http"
|
HTTP = "http"
|
||||||
|
LOCAL = "local"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
|
||||||
|
return cls.OTHER
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -76,6 +91,8 @@ class PlaylistType(Enum):
|
||||||
PlaylistType.SPOTIFY defines that the playlist is from Spotify
|
PlaylistType.SPOTIFY defines that the playlist is from Spotify
|
||||||
|
|
||||||
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
|
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
|
||||||
|
|
||||||
|
PlaylistType.OTHER defines that the playlist is from an unknown source (possible from 3rd-party plugins).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't have to define anything special for these, since these just serve as flags
|
# We don't have to define anything special for these, since these just serve as flags
|
||||||
|
|
@ -83,6 +100,11 @@ class PlaylistType(Enum):
|
||||||
SOUNDCLOUD = "soundcloud"
|
SOUNDCLOUD = "soundcloud"
|
||||||
SPOTIFY = "spotify"
|
SPOTIFY = "spotify"
|
||||||
APPLE_MUSIC = "apple_music"
|
APPLE_MUSIC = "apple_music"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
|
||||||
|
return cls.OTHER
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
@ -196,8 +218,12 @@ class URLRegex:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
|
||||||
|
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
|
||||||
|
# the type and id while ignoring extra parameters. This prevents the URL from being
|
||||||
|
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
|
||||||
SPOTIFY_URL = re.compile(
|
SPOTIFY_URL = re.compile(
|
||||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
DISCORD_MP3_URL = re.compile(
|
DISCORD_MP3_URL = re.compile(
|
||||||
|
|
@ -214,22 +240,21 @@ class URLRegex:
|
||||||
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
|
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
|
||||||
)
|
)
|
||||||
|
|
||||||
YOUTUBE_VID_IN_PLAYLIST = re.compile(
|
|
||||||
r"(?P<video>^.*?v.*?)(?P<list>&list.*)",
|
|
||||||
)
|
|
||||||
|
|
||||||
YOUTUBE_TIMESTAMP = re.compile(
|
YOUTUBE_TIMESTAMP = re.compile(
|
||||||
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
|
||||||
|
# Allow arbitrary query parameters so valid links are captured and parsed.
|
||||||
AM_URL = re.compile(
|
AM_URL = re.compile(
|
||||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
|
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
|
||||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
|
||||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
||||||
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
SOUNDCLOUD_URL = re.compile(
|
SOUNDCLOUD_URL = re.compile(
|
||||||
|
|
@ -274,3 +299,10 @@ class LogLevel(IntEnum):
|
||||||
WARN = 30
|
WARN = 30
|
||||||
ERROR = 40
|
ERROR = 40
|
||||||
CRITICAL = 50
|
CRITICAL = 50
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, level_str):
|
||||||
|
try:
|
||||||
|
return cls[level_str.upper()]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"No such log level: {level_str}")
|
||||||
|
|
|
||||||
|
|
@ -59,14 +59,13 @@ class TrackStartEvent(PomiceEvent):
|
||||||
|
|
||||||
def __init__(self, data: dict, player: Player):
|
def __init__(self, data: dict, player: Player):
|
||||||
self.player: Player = player
|
self.player: Player = player
|
||||||
assert self.player._current is not None
|
self.track: Optional[Track] = self.player._current
|
||||||
self.track: Track = self.player._current
|
|
||||||
|
|
||||||
# on_pomice_track_start(player, track)
|
# on_pomice_track_start(player, track)
|
||||||
self.handler_args = self.player, self.track
|
self.handler_args = self.player, self.track
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.TrackStartEvent player={self.player} track_id={self.track.track_id}>"
|
return f"<Pomice.TrackStartEvent player={self.player!r} track={self.track!r}>"
|
||||||
|
|
||||||
|
|
||||||
class TrackEndEvent(PomiceEvent):
|
class TrackEndEvent(PomiceEvent):
|
||||||
|
|
@ -80,8 +79,7 @@ class TrackEndEvent(PomiceEvent):
|
||||||
|
|
||||||
def __init__(self, data: dict, player: Player):
|
def __init__(self, data: dict, player: Player):
|
||||||
self.player: Player = player
|
self.player: Player = player
|
||||||
assert self.player._ending_track is not None
|
self.track: Optional[Track] = self.player._ending_track
|
||||||
self.track: Track = self.player._ending_track
|
|
||||||
self.reason: str = data["reason"]
|
self.reason: str = data["reason"]
|
||||||
|
|
||||||
# on_pomice_track_end(player, track, reason)
|
# on_pomice_track_end(player, track, reason)
|
||||||
|
|
@ -89,8 +87,8 @@ class TrackEndEvent(PomiceEvent):
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"<Pomice.TrackEndEvent player={self.player} track_id={self.track.track_id} "
|
f"<Pomice.TrackEndEvent player={self.player!r} track_id={self.track!r} "
|
||||||
f"reason={self.reason}>"
|
f"reason={self.reason!r}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -106,8 +104,7 @@ class TrackStuckEvent(PomiceEvent):
|
||||||
|
|
||||||
def __init__(self, data: dict, player: Player):
|
def __init__(self, data: dict, player: Player):
|
||||||
self.player: Player = player
|
self.player: Player = player
|
||||||
assert self.player._ending_track is not None
|
self.track: Optional[Track] = self.player._ending_track
|
||||||
self.track: Track = self.player._ending_track
|
|
||||||
self.threshold: float = data["thresholdMs"]
|
self.threshold: float = data["thresholdMs"]
|
||||||
|
|
||||||
# on_pomice_track_stuck(player, track, threshold)
|
# on_pomice_track_stuck(player, track, threshold)
|
||||||
|
|
@ -131,8 +128,7 @@ class TrackExceptionEvent(PomiceEvent):
|
||||||
|
|
||||||
def __init__(self, data: dict, player: Player):
|
def __init__(self, data: dict, player: Player):
|
||||||
self.player: Player = player
|
self.player: Player = player
|
||||||
assert self.player._ending_track is not None
|
self.track: Optional[Track] = self.player._ending_track
|
||||||
self.track: Track = self.player._ending_track
|
|
||||||
# Error is for Lavalink <= 3.3
|
# Error is for Lavalink <= 3.3
|
||||||
self.exception: str = data.get(
|
self.exception: str = data.get(
|
||||||
"error",
|
"error",
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,12 @@ class Equalizer(Filter):
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
|
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Equalizer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.raw == __value.raw
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def flat(cls) -> "Equalizer":
|
def flat(cls) -> "Equalizer":
|
||||||
"""Equalizer preset which represents a flat EQ board,
|
"""Equalizer preset which represents a flat EQ board,
|
||||||
|
|
@ -231,6 +237,16 @@ class Timescale(Filter):
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
|
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Timescale):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.speed == __value.speed
|
||||||
|
and self.pitch == __value.pitch
|
||||||
|
and self.rate == __value.rate
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Karaoke(Filter):
|
class Karaoke(Filter):
|
||||||
"""Filter which filters the vocal track from any song and leaves the instrumental.
|
"""Filter which filters the vocal track from any song and leaves the instrumental.
|
||||||
|
|
@ -270,6 +286,17 @@ class Karaoke(Filter):
|
||||||
f"filter_band={self.filter_band} filter_width={self.filter_width}>"
|
f"filter_band={self.filter_band} filter_width={self.filter_width}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Karaoke):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.level == __value.level
|
||||||
|
and self.mono_level == __value.mono_level
|
||||||
|
and self.filter_band == __value.filter_band
|
||||||
|
and self.filter_width == __value.filter_width
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tremolo(Filter):
|
class Tremolo(Filter):
|
||||||
"""Filter which produces a wavering tone in the music,
|
"""Filter which produces a wavering tone in the music,
|
||||||
|
|
@ -305,6 +332,12 @@ class Tremolo(Filter):
|
||||||
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Tremolo):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.frequency == __value.frequency and self.depth == __value.depth
|
||||||
|
|
||||||
|
|
||||||
class Vibrato(Filter):
|
class Vibrato(Filter):
|
||||||
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter,
|
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter,
|
||||||
|
|
@ -340,6 +373,12 @@ class Vibrato(Filter):
|
||||||
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Vibrato):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.frequency == __value.frequency and self.depth == __value.depth
|
||||||
|
|
||||||
|
|
||||||
class Rotation(Filter):
|
class Rotation(Filter):
|
||||||
"""Filter which produces a stereo-like panning effect, which sounds like
|
"""Filter which produces a stereo-like panning effect, which sounds like
|
||||||
|
|
@ -357,6 +396,12 @@ class Rotation(Filter):
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>"
|
return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>"
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Rotation):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.rotation_hertz == __value.rotation_hertz
|
||||||
|
|
||||||
|
|
||||||
class ChannelMix(Filter):
|
class ChannelMix(Filter):
|
||||||
"""Filter which manually adjusts the panning of the audio, which can make
|
"""Filter which manually adjusts the panning of the audio, which can make
|
||||||
|
|
@ -418,6 +463,17 @@ class ChannelMix(Filter):
|
||||||
f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>"
|
f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, ChannelMix):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.left_to_left == __value.left_to_left
|
||||||
|
and self.left_to_right == __value.left_to_right
|
||||||
|
and self.right_to_left == __value.right_to_left
|
||||||
|
and self.right_to_right == __value.right_to_right
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Distortion(Filter):
|
class Distortion(Filter):
|
||||||
"""Filter which generates a distortion effect. Useful for certain filter implementations where
|
"""Filter which generates a distortion effect. Useful for certain filter implementations where
|
||||||
|
|
@ -479,6 +535,21 @@ class Distortion(Filter):
|
||||||
f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}"
|
f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, Distortion):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.sin_offset == __value.sin_offset
|
||||||
|
and self.sin_scale == __value.sin_scale
|
||||||
|
and self.cos_offset == __value.cos_offset
|
||||||
|
and self.cos_scale == __value.cos_scale
|
||||||
|
and self.tan_offset == __value.tan_offset
|
||||||
|
and self.tan_scale == __value.tan_scale
|
||||||
|
and self.offset == __value.offset
|
||||||
|
and self.scale == __value.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LowPass(Filter):
|
class LowPass(Filter):
|
||||||
"""Filter which supresses higher frequencies and allows lower frequencies to pass.
|
"""Filter which supresses higher frequencies and allows lower frequencies to pass.
|
||||||
|
|
@ -495,3 +566,9 @@ class LowPass(Filter):
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"
|
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not isinstance(__value, LowPass):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.smoothing == __value.smoothing
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class Track:
|
||||||
self.author: str = info.get("author", "Unknown Author")
|
self.author: str = info.get("author", "Unknown Author")
|
||||||
self.uri: str = info.get("uri", "")
|
self.uri: str = info.get("uri", "")
|
||||||
self.identifier: str = info.get("identifier", "")
|
self.identifier: str = info.get("identifier", "")
|
||||||
self.isrc: str = info.get("isrc", "")
|
self.isrc: Optional[str] = info.get("isrc", None)
|
||||||
self.thumbnail: Optional[str] = info.get("thumbnail")
|
self.thumbnail: Optional[str] = info.get("thumbnail")
|
||||||
|
|
||||||
if self.uri and self.track_type is TrackType.YOUTUBE:
|
if self.uri and self.track_type is TrackType.YOUTUBE:
|
||||||
|
|
@ -98,9 +98,6 @@ class Track:
|
||||||
if not isinstance(other, Track):
|
if not isinstance(other, Track):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.ctx and other.ctx:
|
|
||||||
return other.track_id == self.track_id and other.ctx.message.id == self.ctx.message.id
|
|
||||||
|
|
||||||
return other.track_id == self.track_id
|
return other.track_id == self.track_id
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
156
pomice/player.py
156
pomice/player.py
|
|
@ -79,10 +79,35 @@ class Filters:
|
||||||
if filter.tag == filter_tag:
|
if filter.tag == filter_tag:
|
||||||
del self._filters[index]
|
del self._filters[index]
|
||||||
|
|
||||||
|
def edit_filter(self, *, filter_tag: str, to_apply: Filter) -> None:
|
||||||
|
"""Edits a filter in the list of filters applied using its filter tag and replaces it with the new filter."""
|
||||||
|
if not any(f for f in self._filters if f.tag == filter_tag):
|
||||||
|
raise FilterTagInvalid("A filter with that tag was not found.")
|
||||||
|
|
||||||
|
for index, filter in enumerate(self._filters):
|
||||||
|
if filter.tag == filter_tag:
|
||||||
|
if not type(filter) == type(to_apply):
|
||||||
|
raise FilterInvalidArgument(
|
||||||
|
"Edited filter is not the same type as the current filter.",
|
||||||
|
)
|
||||||
|
if self._filters[index] == to_apply:
|
||||||
|
raise FilterInvalidArgument("Edited filter is the same as the current filter.")
|
||||||
|
|
||||||
|
if to_apply.tag != filter_tag:
|
||||||
|
raise FilterInvalidArgument(
|
||||||
|
"Edited filter tag is not the same as the current filter tag.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._filters[index] = to_apply
|
||||||
|
|
||||||
def has_filter(self, *, filter_tag: str) -> bool:
|
def has_filter(self, *, filter_tag: str) -> bool:
|
||||||
"""Checks if a filter exists in the list of filters using its filter tag"""
|
"""Checks if a filter exists in the list of filters using its filter tag"""
|
||||||
return any(f for f in self._filters if f.tag == filter_tag)
|
return any(f for f in self._filters if f.tag == filter_tag)
|
||||||
|
|
||||||
|
def has_filter_type(self, *, filter_type: Filter) -> bool:
|
||||||
|
"""Checks if any filters applied match the specified filter type."""
|
||||||
|
return any(f for f in self._filters if type(f) == type(filter_type))
|
||||||
|
|
||||||
def reset_filters(self) -> None:
|
def reset_filters(self) -> None:
|
||||||
"""Removes all filters from the list"""
|
"""Removes all filters from the list"""
|
||||||
self._filters = []
|
self._filters = []
|
||||||
|
|
@ -131,10 +156,10 @@ class Player(VoiceProtocol):
|
||||||
"_player_endpoint_uri",
|
"_player_endpoint_uri",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, client: Client, channel: VoiceChannel):
|
def __call__(self, client: Client, channel: VoiceChannel) -> Player:
|
||||||
self.client: Client = client
|
self.client = client
|
||||||
self.channel: VoiceChannel = channel
|
self.channel = channel
|
||||||
self._guild: Guild = channel.guild
|
self._guild = channel.guild
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -147,9 +172,9 @@ class Player(VoiceProtocol):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.client: Client = client
|
self.client: Client = client
|
||||||
self.channel: VoiceChannel = channel
|
self.channel: VoiceChannel = channel
|
||||||
|
self._guild = channel.guild
|
||||||
|
|
||||||
self._bot: Client = client
|
self._bot: Client = client
|
||||||
self._guild: Guild = channel.guild
|
|
||||||
self._node: Node = node if node else NodePool.get_node()
|
self._node: Node = node if node else NodePool.get_node()
|
||||||
self._current: Optional[Track] = None
|
self._current: Optional[Track] = None
|
||||||
self._filters: Filters = Filters()
|
self._filters: Filters = Filters()
|
||||||
|
|
@ -188,7 +213,7 @@ class Player(VoiceProtocol):
|
||||||
difference = (time.time() * 1000) - self._last_update
|
difference = (time.time() * 1000) - self._last_update
|
||||||
position = self._last_position + difference
|
position = self._last_position + difference
|
||||||
|
|
||||||
return min(position, current.length)
|
return round(min(position, current.length))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rate(self) -> float:
|
def rate(self) -> float:
|
||||||
|
|
@ -262,7 +287,7 @@ class Player(VoiceProtocol):
|
||||||
"""
|
"""
|
||||||
return self.guild.id not in self._node._players
|
return self.guild.id not in self._node._players
|
||||||
|
|
||||||
def _adjust_end_time(self):
|
def _adjust_end_time(self) -> Optional[str]:
|
||||||
if self._node._version >= LavalinkVersion(3, 7, 5):
|
if self._node._version >= LavalinkVersion(3, 7, 5):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -273,6 +298,7 @@ class Player(VoiceProtocol):
|
||||||
self._last_update = int(state.get("time", 0))
|
self._last_update = int(state.get("time", 0))
|
||||||
self._is_connected = bool(state.get("connected"))
|
self._is_connected = bool(state.get("connected"))
|
||||||
self._last_position = int(state.get("position", 0))
|
self._last_position = int(state.get("position", 0))
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Got player update state with data {state}")
|
self._log.debug(f"Got player update state with data {state}")
|
||||||
|
|
||||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
|
@ -294,7 +320,10 @@ class Player(VoiceProtocol):
|
||||||
data={"voice": data},
|
data={"voice": data},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._log.debug(f"Dispatched voice update to {state['event']['endpoint']} with data {data}")
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
|
f"Dispatched voice update to {state['event']['endpoint']} with data {data}",
|
||||||
|
)
|
||||||
|
|
||||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||||
self._voice_state.update({"event": data})
|
self._voice_state.update({"event": data})
|
||||||
|
|
@ -310,6 +339,10 @@ class Player(VoiceProtocol):
|
||||||
return
|
return
|
||||||
|
|
||||||
channel = self.guild.get_channel(int(channel_id))
|
channel = self.guild.get_channel(int(channel_id))
|
||||||
|
|
||||||
|
if self.channel != channel:
|
||||||
|
self.channel = channel
|
||||||
|
|
||||||
if not channel:
|
if not channel:
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
self._voice_state.clear()
|
self._voice_state.clear()
|
||||||
|
|
@ -324,7 +357,7 @@ class Player(VoiceProtocol):
|
||||||
event_type: str = data["type"]
|
event_type: str = data["type"]
|
||||||
event: PomiceEvent = getattr(events, event_type)(data, self)
|
event: PomiceEvent = getattr(events, event_type)(data, self)
|
||||||
|
|
||||||
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
|
if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"):
|
||||||
self._current = None
|
self._current = None
|
||||||
|
|
||||||
event.dispatch(self._bot)
|
event.dispatch(self._bot)
|
||||||
|
|
@ -332,8 +365,12 @@ class Player(VoiceProtocol):
|
||||||
if isinstance(event, TrackStartEvent):
|
if isinstance(event, TrackStartEvent):
|
||||||
self._ending_track = self._current
|
self._ending_track = self._current
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Dispatched event {data['type']} to player.")
|
self._log.debug(f"Dispatched event {data['type']} to player.")
|
||||||
|
|
||||||
|
async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None:
|
||||||
|
self._player_endpoint_uri = f"sessions/{session_id}/players"
|
||||||
|
|
||||||
async def _swap_node(self, *, new_node: Node) -> None:
|
async def _swap_node(self, *, new_node: Node) -> None:
|
||||||
if self.current:
|
if self.current:
|
||||||
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
|
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
|
||||||
|
|
@ -342,16 +379,16 @@ class Player(VoiceProtocol):
|
||||||
self._node = new_node
|
self._node = new_node
|
||||||
self._node._players[self._guild.id] = self
|
self._node._players[self._guild.id] = self
|
||||||
# reassign uri to update session id
|
# reassign uri to update session id
|
||||||
self._player_endpoint_uri = f"sessions/{self._node._session_id}/players"
|
await self._refresh_endpoint_uri(new_node._session_id)
|
||||||
|
|
||||||
await self._dispatch_voice_update()
|
await self._dispatch_voice_update()
|
||||||
await self._node.send(
|
await self._node.send(
|
||||||
method="PATCH",
|
method="PATCH",
|
||||||
path=self._player_endpoint_uri,
|
path=self._player_endpoint_uri,
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data=data,
|
data=data or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
|
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
|
||||||
|
|
||||||
async def get_tracks(
|
async def get_tracks(
|
||||||
|
|
@ -359,7 +396,7 @@ class Player(VoiceProtocol):
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
ctx: Optional[commands.Context] = None,
|
ctx: Optional[commands.Context] = None,
|
||||||
search_type: SearchType = SearchType.ytsearch,
|
search_type: SearchType | None = SearchType.ytsearch,
|
||||||
filters: Optional[List[Filter]] = None,
|
filters: Optional[List[Filter]] = None,
|
||||||
) -> Optional[Union[List[Track], Playlist]]:
|
) -> Optional[Union[List[Track], Playlist]]:
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
@ -376,8 +413,21 @@ class Player(VoiceProtocol):
|
||||||
"""
|
"""
|
||||||
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
|
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
|
||||||
|
|
||||||
|
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
|
||||||
|
"""
|
||||||
|
Builds a track using a valid track identifier
|
||||||
|
|
||||||
|
You can also pass in a discord.py Context object to get a
|
||||||
|
Context object on the track it builds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await self._node.build_track(identifier, ctx=ctx)
|
||||||
|
|
||||||
async def get_recommendations(
|
async def get_recommendations(
|
||||||
self, *, track: Track, ctx: Optional[commands.Context] = None
|
self,
|
||||||
|
*,
|
||||||
|
track: Track,
|
||||||
|
ctx: Optional[commands.Context] = None,
|
||||||
) -> Optional[Union[List[Track], Playlist]]:
|
) -> Optional[Union[List[Track], Playlist]]:
|
||||||
"""
|
"""
|
||||||
Gets recommendations from either YouTube or Spotify.
|
Gets recommendations from either YouTube or Spotify.
|
||||||
|
|
@ -387,7 +437,12 @@ class Player(VoiceProtocol):
|
||||||
return await self._node.get_recommendations(track=track, ctx=ctx)
|
return await self._node.get_recommendations(track=track, ctx=ctx)
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
|
self,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
reconnect: bool,
|
||||||
|
self_deaf: bool = False,
|
||||||
|
self_mute: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.guild.change_voice_state(
|
await self.guild.change_voice_state(
|
||||||
channel=self.channel,
|
channel=self.channel,
|
||||||
|
|
@ -407,6 +462,7 @@ class Player(VoiceProtocol):
|
||||||
data={"encodedTrack": None},
|
data={"encodedTrack": None},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Player has been stopped.")
|
self._log.debug(f"Player has been stopped.")
|
||||||
|
|
||||||
async def disconnect(self, *, force: bool = False) -> None:
|
async def disconnect(self, *, force: bool = False) -> None:
|
||||||
|
|
@ -425,24 +481,34 @@ class Player(VoiceProtocol):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
|
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
|
||||||
# assume we're already disconnected and cleaned up
|
# assume we're already disconnected and cleaned up
|
||||||
assert not self.is_connected and not self.channel
|
assert self.channel is None and not self.is_connected
|
||||||
|
|
||||||
self._node._players.pop(self.guild.id)
|
self._node._players.pop(self.guild.id)
|
||||||
|
if self.node.is_connected:
|
||||||
await self._node.send(
|
await self._node.send(
|
||||||
method="DELETE",
|
method="DELETE",
|
||||||
path=self._player_endpoint_uri,
|
path=self._player_endpoint_uri,
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug("Player has been destroyed.")
|
self._log.debug("Player has been destroyed.")
|
||||||
|
|
||||||
async def play(
|
async def play(
|
||||||
self, track: Track, *, start: int = 0, end: int = 0, ignore_if_playing: bool = False
|
self,
|
||||||
|
track: Track,
|
||||||
|
*,
|
||||||
|
start: int = 0,
|
||||||
|
end: int = 0,
|
||||||
|
ignore_if_playing: bool = False,
|
||||||
) -> Track:
|
) -> Track:
|
||||||
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
||||||
|
|
||||||
|
if not track._search_type:
|
||||||
|
track.original = track
|
||||||
|
|
||||||
# Make sure we've never searched the track before
|
# Make sure we've never searched the track before
|
||||||
if track.original is None:
|
if track._search_type and track.original is None:
|
||||||
# First lets try using the tracks ISRC, every track has one (hopefully)
|
# First lets try using the tracks ISRC, every track has one (hopefully)
|
||||||
try:
|
try:
|
||||||
if not track.isrc:
|
if not track.isrc:
|
||||||
|
|
@ -506,7 +572,7 @@ class Player(VoiceProtocol):
|
||||||
for filter in track.filters:
|
for filter in track.filters:
|
||||||
await self.add_filter(_filter=filter)
|
await self.add_filter(_filter=filter)
|
||||||
|
|
||||||
# Lavalink v4 changed the way the end time parameter works
|
# Lavalink v3.7.5 changed the way the end time parameter works
|
||||||
# so now the end time cannot be zero.
|
# so now the end time cannot be zero.
|
||||||
# If it isnt zero, it'll be set to None.
|
# If it isnt zero, it'll be set to None.
|
||||||
# Otherwise, it'll be set here:
|
# Otherwise, it'll be set here:
|
||||||
|
|
@ -522,6 +588,7 @@ class Player(VoiceProtocol):
|
||||||
query=f"noReplace={ignore_if_playing}",
|
query=f"noReplace={ignore_if_playing}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(
|
self._log.debug(
|
||||||
f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
|
f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
|
||||||
)
|
)
|
||||||
|
|
@ -545,6 +612,7 @@ class Player(VoiceProtocol):
|
||||||
data={"position": position},
|
data={"position": position},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Seeking to {position}.")
|
self._log.debug(f"Seeking to {position}.")
|
||||||
return self.position
|
return self.position
|
||||||
|
|
||||||
|
|
@ -558,6 +626,7 @@ class Player(VoiceProtocol):
|
||||||
)
|
)
|
||||||
self._paused = pause
|
self._paused = pause
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
|
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
|
||||||
return self._paused
|
return self._paused
|
||||||
|
|
||||||
|
|
@ -571,9 +640,19 @@ class Player(VoiceProtocol):
|
||||||
)
|
)
|
||||||
self._volume = volume
|
self._volume = volume
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Player volume has been adjusted to {volume}")
|
self._log.debug(f"Player volume has been adjusted to {volume}")
|
||||||
return self._volume
|
return self._volume
|
||||||
|
|
||||||
|
async def move_to(self, channel: VoiceChannel) -> None:
|
||||||
|
"""Moves the player to a new voice channel."""
|
||||||
|
|
||||||
|
await self.guild.change_voice_state(channel=channel)
|
||||||
|
|
||||||
|
self.channel = channel
|
||||||
|
|
||||||
|
await self._dispatch_voice_update()
|
||||||
|
|
||||||
async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
|
async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
|
||||||
"""Adds a filter to the player. Takes a pomice.Filter object.
|
"""Adds a filter to the player. Takes a pomice.Filter object.
|
||||||
This will only work if you are using a version of Lavalink that supports filters.
|
This will only work if you are using a version of Lavalink that supports filters.
|
||||||
|
|
@ -591,8 +670,10 @@ class Player(VoiceProtocol):
|
||||||
data={"filters": payload},
|
data={"filters": payload},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
|
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Fast apply passed, now applying filter instantly.")
|
self._log.debug(f"Fast apply passed, now applying filter instantly.")
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
||||||
|
|
@ -614,13 +695,48 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"filters": payload},
|
data={"filters": payload},
|
||||||
)
|
)
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
|
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Fast apply passed, now removing filter instantly.")
|
self._log.debug(f"Fast apply passed, now removing filter instantly.")
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
||||||
return self._filters
|
return self._filters
|
||||||
|
|
||||||
|
async def edit_filter(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
filter_tag: str,
|
||||||
|
edited_filter: Filter,
|
||||||
|
fast_apply: bool = False,
|
||||||
|
) -> Filters:
|
||||||
|
"""Edits a filter from the player using its filter tag and a new filter of the same type.
|
||||||
|
The filter to be replaced must have the same tag as the one you are replacing it with.
|
||||||
|
This will only work if you are using a version of Lavalink that supports filters.
|
||||||
|
|
||||||
|
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
|
||||||
|
|
||||||
|
(You must have a song playing in order for `fast_apply` to work.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._filters.edit_filter(filter_tag=filter_tag, to_apply=edited_filter)
|
||||||
|
payload = self._filters.get_all_payloads()
|
||||||
|
await self._node.send(
|
||||||
|
method="PATCH",
|
||||||
|
path=self._player_endpoint_uri,
|
||||||
|
guild_id=self._guild.id,
|
||||||
|
data={"filters": payload},
|
||||||
|
)
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Filter with tag {filter_tag} has been edited to {edited_filter!r}")
|
||||||
|
if fast_apply:
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Fast apply passed, now editing filter instantly.")
|
||||||
|
await self.seek(self.position)
|
||||||
|
|
||||||
|
return self._filters
|
||||||
|
|
||||||
async def reset_filters(self, *, fast_apply: bool = False) -> None:
|
async def reset_filters(self, *, fast_apply: bool = False) -> None:
|
||||||
"""Resets all currently applied filters to their default parameters.
|
"""Resets all currently applied filters to their default parameters.
|
||||||
You must have filters applied in order for this to work.
|
You must have filters applied in order for this to work.
|
||||||
|
|
@ -640,8 +756,10 @@ class Player(VoiceProtocol):
|
||||||
guild_id=self._guild.id,
|
guild_id=self._guild.id,
|
||||||
data={"filters": {}},
|
data={"filters": {}},
|
||||||
)
|
)
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"All filters have been removed from player.")
|
self._log.debug(f"All filters have been removed from player.")
|
||||||
|
|
||||||
if fast_apply:
|
if fast_apply:
|
||||||
|
if self._log:
|
||||||
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
|
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
|
||||||
await self.seek(self.position)
|
await self.seek(self.position)
|
||||||
|
|
|
||||||
483
pomice/pool.py
483
pomice/pool.py
|
|
@ -5,6 +5,8 @@ import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from os import path
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
@ -15,15 +17,24 @@ from typing import Union
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import orjson as json
|
||||||
from discord import Client
|
from discord import Client
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from discord.utils import MISSING
|
||||||
|
|
||||||
|
try:
|
||||||
|
from websockets.legacy import client # websockets >= 10.0
|
||||||
|
except ImportError:
|
||||||
|
import websockets.client as client # websockets < 10.0 # type: ignore
|
||||||
|
|
||||||
|
from websockets import exceptions
|
||||||
|
from websockets import typing as wstype
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from . import applemusic
|
from . import applemusic
|
||||||
from . import spotify
|
from . import spotify
|
||||||
from .enums import *
|
from .enums import *
|
||||||
from .enums import LogLevel
|
from .enums import LogLevel
|
||||||
from .exceptions import AppleMusicNotEnabled
|
|
||||||
from .exceptions import InvalidSpotifyClientAuthorization
|
from .exceptions import InvalidSpotifyClientAuthorization
|
||||||
from .exceptions import LavalinkVersionIncompatible
|
from .exceptions import LavalinkVersionIncompatible
|
||||||
from .exceptions import NodeConnectionFailure
|
from .exceptions import NodeConnectionFailure
|
||||||
|
|
@ -49,6 +60,8 @@ __all__ = (
|
||||||
"NodePool",
|
"NodePool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?")
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
"""The base class for a node.
|
"""The base class for a node.
|
||||||
|
|
@ -66,6 +79,8 @@ class Node:
|
||||||
"_password",
|
"_password",
|
||||||
"_identifier",
|
"_identifier",
|
||||||
"_heartbeat",
|
"_heartbeat",
|
||||||
|
"_resume_key",
|
||||||
|
"_resume_timeout",
|
||||||
"_secure",
|
"_secure",
|
||||||
"_fallback",
|
"_fallback",
|
||||||
"_log_level",
|
"_log_level",
|
||||||
|
|
@ -100,15 +115,20 @@ class Node:
|
||||||
password: str,
|
password: str,
|
||||||
identifier: str,
|
identifier: str,
|
||||||
secure: bool = False,
|
secure: bool = False,
|
||||||
heartbeat: int = 30,
|
heartbeat: int = 120,
|
||||||
|
resume_key: Optional[str] = None,
|
||||||
|
resume_timeout: int = 60,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
spotify_client_id: Optional[int] = None,
|
spotify_client_id: Optional[str] = None,
|
||||||
spotify_client_secret: Optional[str] = None,
|
spotify_client_secret: Optional[str] = None,
|
||||||
apple_music: bool = False,
|
apple_music: bool = False,
|
||||||
fallback: bool = False,
|
fallback: bool = False,
|
||||||
log_level: LogLevel = LogLevel.INFO,
|
logger: Optional[logging.Logger] = None,
|
||||||
):
|
):
|
||||||
|
if not isinstance(port, int):
|
||||||
|
raise TypeError("Port must be an integer")
|
||||||
|
|
||||||
self._bot: commands.Bot = bot
|
self._bot: commands.Bot = bot
|
||||||
self._host: str = host
|
self._host: str = host
|
||||||
self._port: int = port
|
self._port: int = port
|
||||||
|
|
@ -116,24 +136,25 @@ class Node:
|
||||||
self._password: str = password
|
self._password: str = password
|
||||||
self._identifier: str = identifier
|
self._identifier: str = identifier
|
||||||
self._heartbeat: int = heartbeat
|
self._heartbeat: int = heartbeat
|
||||||
|
self._resume_key: Optional[str] = resume_key
|
||||||
|
self._resume_timeout: int = resume_timeout
|
||||||
self._secure: bool = secure
|
self._secure: bool = secure
|
||||||
self._fallback: bool = fallback
|
self._fallback: bool = fallback
|
||||||
self._log_level: LogLevel = log_level
|
|
||||||
|
|
||||||
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||||
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
||||||
|
|
||||||
self._session: aiohttp.ClientSession = session # type: ignore
|
self._session: aiohttp.ClientSession = session # type: ignore
|
||||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||||
self._websocket: aiohttp.ClientWebSocketResponse
|
self._websocket: client.WebSocketClientProtocol
|
||||||
self._task: asyncio.Task = None # type: ignore
|
self._task: asyncio.Task = None # type: ignore
|
||||||
|
|
||||||
self._session_id: Optional[str] = None
|
self._session_id: Optional[str] = None
|
||||||
self._available: bool = False
|
self._available: bool = False
|
||||||
self._version: LavalinkVersion = None
|
self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
|
||||||
|
|
||||||
self._route_planner = RoutePlanner(self)
|
self._route_planner = RoutePlanner(self)
|
||||||
self._log = self._setup_logging(self._log_level)
|
self._log = logger
|
||||||
|
|
||||||
if not self._bot.user:
|
if not self._bot.user:
|
||||||
raise NodeCreationError("Bot user is not ready yet.")
|
raise NodeCreationError("Bot user is not ready yet.")
|
||||||
|
|
@ -148,13 +169,14 @@ class Node:
|
||||||
|
|
||||||
self._players: Dict[int, Player] = {}
|
self._players: Dict[int, Player] = {}
|
||||||
|
|
||||||
self._spotify_client_id: Optional[int] = spotify_client_id
|
self._spotify_client: Optional[spotify.Client] = None
|
||||||
self._spotify_client_secret: Optional[str] = spotify_client_secret
|
|
||||||
|
|
||||||
self._apple_music_client: Optional[applemusic.Client] = None
|
self._apple_music_client: Optional[applemusic.Client] = None
|
||||||
|
|
||||||
|
self._spotify_client_id: Optional[str] = spotify_client_id
|
||||||
|
self._spotify_client_secret: Optional[str] = spotify_client_secret
|
||||||
|
|
||||||
if self._spotify_client_id and self._spotify_client_secret:
|
if self._spotify_client_id and self._spotify_client_secret:
|
||||||
self._spotify_client: spotify.Client = spotify.Client(
|
self._spotify_client = spotify.Client(
|
||||||
self._spotify_client_id,
|
self._spotify_client_id,
|
||||||
self._spotify_client_secret,
|
self._spotify_client_secret,
|
||||||
)
|
)
|
||||||
|
|
@ -193,7 +215,7 @@ class Node:
|
||||||
@property
|
@property
|
||||||
def player_count(self) -> int:
|
def player_count(self) -> int:
|
||||||
"""Property which returns how many players are connected to this node"""
|
"""Property which returns how many players are connected to this node"""
|
||||||
return len(self.players)
|
return len(self.players.values())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pool(self) -> Type[NodePool]:
|
def pool(self) -> Type[NodePool]:
|
||||||
|
|
@ -210,41 +232,44 @@ class Node:
|
||||||
"""Alias for `Node.latency`, returns the latency of the node"""
|
"""Alias for `Node.latency`, returns the latency of the node"""
|
||||||
return self.latency
|
return self.latency
|
||||||
|
|
||||||
def _setup_logging(self, level: LogLevel) -> logging.Logger:
|
async def _handle_version_check(self, version: str) -> None:
|
||||||
logger = logging.getLogger("pomice")
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
dt_fmt = "%Y-%m-%d %H:%M:%S"
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
"[{asctime}] [{levelname:<8}] {name}: {message}",
|
|
||||||
dt_fmt,
|
|
||||||
style="{",
|
|
||||||
)
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.setLevel(level)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
return logger
|
|
||||||
|
|
||||||
async def _handle_version_check(self, version: str):
|
|
||||||
if version.endswith("-SNAPSHOT"):
|
if version.endswith("-SNAPSHOT"):
|
||||||
# we're just gonna assume all snapshot versions correlate with v4
|
# we're just gonna assume all snapshot versions correlate with v4
|
||||||
self._version = LavalinkVersion(major=4, minor=0, fix=0)
|
self._version = LavalinkVersion(major=4, minor=0, fix=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
# this crazy ass line maps the split version string into
|
_version_rx = VERSION_REGEX.match(version)
|
||||||
# an iterable with ints instead of strings and then
|
if not _version_rx:
|
||||||
# turns that iterable into a tuple. yeah, i know
|
|
||||||
|
|
||||||
split = tuple(map(int, tuple(version.split("."))))
|
|
||||||
self._version = LavalinkVersion(*split)
|
|
||||||
if not version.endswith("-SNAPSHOT") and (
|
|
||||||
self._version.major == 3 and self._version.minor < 7
|
|
||||||
):
|
|
||||||
self._available = False
|
self._available = False
|
||||||
raise LavalinkVersionIncompatible(
|
raise LavalinkVersionIncompatible(
|
||||||
"The Lavalink version you're using is incompatible. "
|
"The Lavalink version you're using is incompatible. "
|
||||||
"Lavalink version 3.7.0 or above is required to use this library.",
|
"Lavalink version 3.7.0 or above is required to use this library.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_version_groups = _version_rx.groups()
|
||||||
|
major, minor, fix = (
|
||||||
|
int(_version_groups[0] or 0),
|
||||||
|
int(_version_groups[1] or 0),
|
||||||
|
int(_version_groups[2] or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
|
||||||
|
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
|
||||||
|
if self._version < LavalinkVersion(3, 7, 0):
|
||||||
|
self._available = False
|
||||||
|
raise LavalinkVersionIncompatible(
|
||||||
|
"The Lavalink version you're using is incompatible. "
|
||||||
|
"Lavalink version 3.7.0 or above is required to use this library.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _set_ext_client_session(self, session: aiohttp.ClientSession) -> None:
|
||||||
|
if self._spotify_client:
|
||||||
|
await self._spotify_client._set_session(session=session)
|
||||||
|
|
||||||
|
if self._apple_music_client:
|
||||||
|
await self._apple_music_client._set_session(session=session)
|
||||||
|
|
||||||
async def _update_handler(self, data: dict) -> None:
|
async def _update_handler(self, data: dict) -> None:
|
||||||
await self._bot.wait_until_ready()
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
|
|
@ -279,47 +304,80 @@ class Node:
|
||||||
|
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
async def _listen(self) -> None:
|
async def _configure_resuming(self) -> None:
|
||||||
backoff = ExponentialBackoff(base=7)
|
if not self._resume_key:
|
||||||
|
|
||||||
while True:
|
|
||||||
msg = await self._websocket.receive()
|
|
||||||
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING):
|
|
||||||
if self._fallback:
|
|
||||||
await self._handle_node_switch()
|
|
||||||
retry = backoff.delay()
|
|
||||||
await asyncio.sleep(retry)
|
|
||||||
if not self.is_connected:
|
|
||||||
self._loop.create_task(self.connect())
|
|
||||||
else:
|
|
||||||
self._loop.create_task(self._handle_payload(msg.json()))
|
|
||||||
|
|
||||||
async def _handle_payload(self, data: dict) -> None:
|
|
||||||
op = data.get("op", None)
|
|
||||||
if not op:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
|
||||||
|
|
||||||
|
if self._version.major == 3:
|
||||||
|
data["resumingKey"] = self._resume_key
|
||||||
|
elif self._version.major == 4:
|
||||||
|
if self._log:
|
||||||
|
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
|
||||||
|
data["resuming"] = True
|
||||||
|
|
||||||
|
await self.send(
|
||||||
|
method="PATCH",
|
||||||
|
path=f"sessions/{self._session_id}",
|
||||||
|
include_version=True,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _listen(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = await self._websocket.recv()
|
||||||
|
data = json.loads(msg)
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Recieved raw websocket message {msg}")
|
||||||
|
self._loop.create_task(self._handle_ws_msg(data=data))
|
||||||
|
except exceptions.ConnectionClosed:
|
||||||
|
if self.player_count > 0:
|
||||||
|
for _player in self.players.values():
|
||||||
|
self._loop.create_task(_player.destroy())
|
||||||
|
|
||||||
|
if self._fallback:
|
||||||
|
self._loop.create_task(self._handle_node_switch())
|
||||||
|
|
||||||
|
self._loop.create_task(self._websocket.close())
|
||||||
|
|
||||||
|
backoff = ExponentialBackoff(base=7)
|
||||||
|
retry = backoff.delay()
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
|
f"Retrying connection to Node {self._identifier} in {retry} secs",
|
||||||
|
)
|
||||||
|
await asyncio.sleep(retry)
|
||||||
|
|
||||||
|
if not self.is_connected:
|
||||||
|
self._loop.create_task(self.connect(reconnect=True))
|
||||||
|
|
||||||
|
async def _handle_ws_msg(self, data: dict) -> None:
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
|
||||||
|
op = data.get("op", None)
|
||||||
|
|
||||||
if op == "stats":
|
if op == "stats":
|
||||||
self._stats = NodeStats(data)
|
self._stats = NodeStats(data)
|
||||||
return
|
return
|
||||||
|
|
||||||
if op == "ready":
|
if op == "ready":
|
||||||
self._session_id = data["sessionId"]
|
self._session_id = data["sessionId"]
|
||||||
|
await self._configure_resuming()
|
||||||
|
|
||||||
if not "guildId" in data:
|
if not "guildId" in data:
|
||||||
return
|
return
|
||||||
|
|
||||||
player = self._players.get(int(data["guildId"]))
|
player: Optional[Player] = self._players.get(int(data["guildId"]))
|
||||||
if not player:
|
if not player:
|
||||||
return
|
return
|
||||||
|
|
||||||
if op == "event":
|
if op == "event":
|
||||||
await player._dispatch_event(data)
|
return await player._dispatch_event(data)
|
||||||
return
|
|
||||||
|
|
||||||
if op == "playerUpdate":
|
if op == "playerUpdate":
|
||||||
await player._update_state(data)
|
return await player._update_state(data)
|
||||||
return
|
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
|
|
@ -344,33 +402,39 @@ class Node:
|
||||||
f'{f"?{query}" if query else ""}'
|
f'{f"?{query}" if query else ""}'
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._session.request(
|
resp = await self._session.request(
|
||||||
method=method,
|
method=method,
|
||||||
url=uri,
|
url=uri,
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
json=data or {},
|
json=data or {},
|
||||||
) as resp:
|
)
|
||||||
self._log.debug(f"Making REST request with method {method} to {uri}")
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
|
f"Making REST request to Node {self._identifier} with method {method} to {uri}",
|
||||||
|
)
|
||||||
if resp.status >= 300:
|
if resp.status >= 300:
|
||||||
resp_data: dict = await resp.json()
|
resp_data: dict = await resp.json()
|
||||||
raise NodeRestException(
|
raise NodeRestException(
|
||||||
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
|
f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
|
||||||
)
|
)
|
||||||
|
|
||||||
if method == "DELETE" or resp.status == 204:
|
if method == "DELETE" or resp.status == 204:
|
||||||
|
if self._log:
|
||||||
self._log.debug(
|
self._log.debug(
|
||||||
f"REST request with method {method} to {uri} completed sucessfully and returned no data.",
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
|
||||||
)
|
)
|
||||||
return await resp.json(content_type=None)
|
return await resp.json(content_type=None)
|
||||||
|
|
||||||
if resp.content_type == "text/plain":
|
if resp.content_type == "text/plain":
|
||||||
|
if self._log:
|
||||||
self._log.debug(
|
self._log.debug(
|
||||||
f"REST request with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
|
||||||
)
|
)
|
||||||
return await resp.text()
|
return await resp.text()
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(
|
self._log.debug(
|
||||||
f"REST request with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
|
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
|
||||||
)
|
)
|
||||||
return await resp.json()
|
return await resp.json()
|
||||||
|
|
||||||
|
|
@ -378,16 +442,27 @@ class Node:
|
||||||
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
|
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
|
||||||
return self._players.get(guild_id, None)
|
return self._players.get(guild_id, None)
|
||||||
|
|
||||||
async def connect(self) -> "Node":
|
async def connect(self, *, reconnect: bool = False) -> Node:
|
||||||
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
|
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
|
||||||
await self._bot.wait_until_ready()
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
if not self._session:
|
if not self._session:
|
||||||
self._session = aiohttp.ClientSession()
|
# Configure connection pooling for optimal concurrent request performance
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=100, # Total connection limit
|
||||||
|
limit_per_host=30, # Per-host connection limit
|
||||||
|
ttl_dns_cache=300, # DNS cache TTL in seconds
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=connector,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not reconnect:
|
||||||
version: str = await self.send(
|
version: str = await self.send(
|
||||||
method="GET",
|
method="GET",
|
||||||
path="version",
|
path="version",
|
||||||
|
|
@ -396,17 +471,29 @@ class Node:
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._handle_version_check(version=version)
|
await self._handle_version_check(version=version)
|
||||||
|
await self._set_ext_client_session(session=self._session)
|
||||||
|
|
||||||
self._log.debug(f"Version check from node successful. Returned version {version}")
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
self._websocket = await self._session.ws_connect(
|
f"Version check from Node {self._identifier} successful. Returned version {version}",
|
||||||
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
|
||||||
headers=self._headers,
|
|
||||||
heartbeat=self._heartbeat,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._websocket = await client.connect( # type: ignore
|
||||||
|
f"{self._websocket_uri}/v{self._version.major}/websocket",
|
||||||
|
extra_headers=self._headers,
|
||||||
|
ping_interval=self._heartbeat,
|
||||||
|
)
|
||||||
|
|
||||||
|
if reconnect:
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
|
||||||
|
if self.player_count:
|
||||||
|
for player in self.players.values():
|
||||||
|
await player._refresh_endpoint_uri(self._session_id)
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.debug(
|
self._log.debug(
|
||||||
f"Connected to node websocket using {self._websocket_uri}/v{self._version.major}/websocket",
|
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self._task:
|
if not self._task:
|
||||||
|
|
@ -416,18 +503,19 @@ class Node:
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
if self._log:
|
||||||
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
|
except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The connection to node '{self._identifier}' failed.",
|
f"The connection to node '{self._identifier}' failed.",
|
||||||
) from None
|
) from None
|
||||||
except aiohttp.WSServerHandshakeError:
|
except exceptions.InvalidHandshake:
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The password for node '{self._identifier}' is invalid.",
|
f"The password for node '{self._identifier}' is invalid.",
|
||||||
) from None
|
) from None
|
||||||
except aiohttp.InvalidURL:
|
except exceptions.InvalidURI:
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
f"The URI for node '{self._identifier}' is invalid.",
|
f"The URI for node '{self._identifier}' is invalid.",
|
||||||
) from None
|
) from None
|
||||||
|
|
@ -441,25 +529,20 @@ class Node:
|
||||||
|
|
||||||
for player in self.players.copy().values():
|
for player in self.players.copy().values():
|
||||||
await player.destroy()
|
await player.destroy()
|
||||||
|
if self._log:
|
||||||
self._log.debug("All players disconnected from node.")
|
self._log.debug("All players disconnected from node.")
|
||||||
|
|
||||||
await self._websocket.close()
|
await self._websocket.close()
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
|
if self._log:
|
||||||
self._log.debug("Websocket and http session closed.")
|
self._log.debug("Websocket and http session closed.")
|
||||||
|
|
||||||
if self._spotify_client:
|
|
||||||
await self._spotify_client.close()
|
|
||||||
self._log.debug("Spotify client session closed.")
|
|
||||||
|
|
||||||
if self._apple_music_client:
|
|
||||||
await self._apple_music_client.close()
|
|
||||||
self._log.debug("Apple Music client session closed.")
|
|
||||||
|
|
||||||
del self._pool._nodes[self._identifier]
|
del self._pool._nodes[self._identifier]
|
||||||
self.available = False
|
self.available = False
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
if self._log:
|
||||||
self._log.info(
|
self._log.info(
|
||||||
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
|
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
|
||||||
)
|
)
|
||||||
|
|
@ -475,13 +558,16 @@ class Node:
|
||||||
data: dict = await self.send(
|
data: dict = await self.send(
|
||||||
method="GET",
|
method="GET",
|
||||||
path="decodetrack",
|
path="decodetrack",
|
||||||
query=f"encodedTrack={identifier}",
|
query=f"encodedTrack={quote(identifier)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
track_info = data["info"] if self._version.major >= 4 else data
|
||||||
|
|
||||||
return Track(
|
return Track(
|
||||||
track_id=identifier,
|
track_id=identifier,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
info=data,
|
info=track_info,
|
||||||
track_type=TrackType(data["sourceName"]),
|
track_type=TrackType(track_info["sourceName"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_tracks(
|
async def get_tracks(
|
||||||
|
|
@ -489,7 +575,7 @@ class Node:
|
||||||
query: str,
|
query: str,
|
||||||
*,
|
*,
|
||||||
ctx: Optional[commands.Context] = None,
|
ctx: Optional[commands.Context] = None,
|
||||||
search_type: SearchType = SearchType.ytsearch,
|
search_type: Optional[SearchType] = SearchType.ytsearch,
|
||||||
filters: Optional[List[Filter]] = None,
|
filters: Optional[List[Filter]] = None,
|
||||||
) -> Optional[Union[Playlist, List[Track]]]:
|
) -> Optional[Union[Playlist, List[Track]]]:
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
@ -506,20 +592,17 @@ class Node:
|
||||||
|
|
||||||
timestamp = None
|
timestamp = None
|
||||||
|
|
||||||
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
|
|
||||||
query = f"{search_type}:{query}"
|
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
filter.set_preload()
|
filter.set_preload()
|
||||||
|
|
||||||
if URLRegex.AM_URL.match(query):
|
# Due to the inclusion of plugins in the v4 update
|
||||||
if not self._apple_music_client:
|
# we are doing away with raising an error if pomice detects
|
||||||
raise AppleMusicNotEnabled(
|
# either a Spotify or Apple Music URL and the respective client
|
||||||
"You must have Apple Music functionality enabled in order to play Apple Music tracks."
|
# is not enabled. Instead, we will just only parse the URL
|
||||||
"Please set apple_music to True in your Node class.",
|
# if the client is enabled and the URL is valid.
|
||||||
)
|
|
||||||
|
|
||||||
|
if self._apple_music_client and URLRegex.AM_URL.match(query):
|
||||||
apple_music_results = await self._apple_music_client.search(query=query)
|
apple_music_results = await self._apple_music_client.search(query=query)
|
||||||
if isinstance(apple_music_results, applemusic.Song):
|
if isinstance(apple_music_results, applemusic.Song):
|
||||||
return [
|
return [
|
||||||
|
|
@ -527,7 +610,7 @@ class Node:
|
||||||
track_id=apple_music_results.id,
|
track_id=apple_music_results.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.APPLE_MUSIC,
|
track_type=TrackType.APPLE_MUSIC,
|
||||||
search_type=search_type,
|
search_type=search_type or SearchType.ytsearch,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": apple_music_results.name,
|
"title": apple_music_results.name,
|
||||||
|
|
@ -549,7 +632,7 @@ class Node:
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.APPLE_MUSIC,
|
track_type=TrackType.APPLE_MUSIC,
|
||||||
search_type=search_type,
|
search_type=search_type or SearchType.ytsearch,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": track.name,
|
"title": track.name,
|
||||||
|
|
@ -578,15 +661,8 @@ class Node:
|
||||||
uri=apple_music_results.url,
|
uri=apple_music_results.url,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif URLRegex.SPOTIFY_URL.match(query):
|
elif self._spotify_client and URLRegex.SPOTIFY_URL.match(query):
|
||||||
if not self._spotify_client_id and not self._spotify_client_secret:
|
spotify_results = await self._spotify_client.search(query=query) # type: ignore
|
||||||
raise InvalidSpotifyClientAuthorization(
|
|
||||||
"You did not provide proper Spotify client authorization credentials. "
|
|
||||||
"If you would like to use the Spotify searching feature, "
|
|
||||||
"please obtain Spotify API credentials here: https://developer.spotify.com/",
|
|
||||||
)
|
|
||||||
|
|
||||||
spotify_results = await self._spotify_client.search(query=query)
|
|
||||||
|
|
||||||
if isinstance(spotify_results, spotify.Track):
|
if isinstance(spotify_results, spotify.Track):
|
||||||
return [
|
return [
|
||||||
|
|
@ -594,7 +670,7 @@ class Node:
|
||||||
track_id=spotify_results.id,
|
track_id=spotify_results.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.SPOTIFY,
|
track_type=TrackType.SPOTIFY,
|
||||||
search_type=search_type,
|
search_type=search_type or SearchType.ytsearch,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": spotify_results.name,
|
"title": spotify_results.name,
|
||||||
|
|
@ -616,7 +692,7 @@ class Node:
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType.SPOTIFY,
|
track_type=TrackType.SPOTIFY,
|
||||||
search_type=search_type,
|
search_type=search_type or SearchType.ytsearch,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
info={
|
info={
|
||||||
"title": track.name,
|
"title": track.name,
|
||||||
|
|
@ -645,45 +721,21 @@ class Node:
|
||||||
uri=spotify_results.uri,
|
uri=spotify_results.uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
|
||||||
data: dict = await self.send(
|
|
||||||
method="GET",
|
|
||||||
path="loadtracks",
|
|
||||||
query=f"identifier={quote(query)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
track: dict = data["tracks"][0]
|
|
||||||
info: dict = track["info"]
|
|
||||||
|
|
||||||
return [
|
|
||||||
Track(
|
|
||||||
track_id=track["track"],
|
|
||||||
info={
|
|
||||||
"title": discord_url.group("file"),
|
|
||||||
"author": "Unknown",
|
|
||||||
"length": info["length"],
|
|
||||||
"uri": info["uri"],
|
|
||||||
"position": info["position"],
|
|
||||||
"identifier": info["identifier"],
|
|
||||||
},
|
|
||||||
ctx=ctx,
|
|
||||||
track_type=TrackType.HTTP,
|
|
||||||
filters=filters,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if (
|
||||||
|
search_type
|
||||||
|
and not URLRegex.BASE_URL.match(query)
|
||||||
|
and not re.match(r"(?:[a-z]+?)search:.", query)
|
||||||
|
and not URLRegex.DISCORD_MP3_URL.match(query)
|
||||||
|
and not path.exists(path.dirname(query))
|
||||||
|
):
|
||||||
|
query = f"{search_type}:{query}"
|
||||||
|
|
||||||
# If YouTube url contains a timestamp, capture it for use later.
|
# If YouTube url contains a timestamp, capture it for use later.
|
||||||
|
|
||||||
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):
|
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):
|
||||||
timestamp = float(match.group("time"))
|
timestamp = float(match.group("time"))
|
||||||
|
|
||||||
# If query is a video thats part of a playlist, get the video and queue that instead
|
|
||||||
# (I can't tell you how much i've wanted to implement this in here)
|
|
||||||
|
|
||||||
if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
|
|
||||||
query = match.group("video")
|
|
||||||
|
|
||||||
data = await self.send(
|
data = await self.send(
|
||||||
method="GET",
|
method="GET",
|
||||||
path="loadtracks",
|
path="loadtracks",
|
||||||
|
|
@ -692,21 +744,31 @@ class Node:
|
||||||
|
|
||||||
load_type = data.get("loadType")
|
load_type = data.get("loadType")
|
||||||
|
|
||||||
|
# Lavalink v4 changed the name of the key from "tracks" to "data"
|
||||||
|
# so lets account for that
|
||||||
|
data_type = "data" if self._version.major >= 4 else "tracks"
|
||||||
|
|
||||||
if not load_type:
|
if not load_type:
|
||||||
raise TrackLoadError(
|
raise TrackLoadError(
|
||||||
"There was an error while trying to load this track.",
|
"There was an error while trying to load this track.",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif load_type == "LOAD_FAILED":
|
elif load_type in ("LOAD_FAILED", "error"):
|
||||||
exception = data["exception"]
|
exception = data["data"] if self._version.major >= 4 else data["exception"]
|
||||||
raise TrackLoadError(
|
raise TrackLoadError(
|
||||||
f"{exception['message']} [{exception['severity']}]",
|
f"{exception['message']} [{exception['severity']}]",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif load_type == "NO_MATCHES":
|
elif load_type in ("NO_MATCHES", "empty"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif load_type == "PLAYLIST_LOADED":
|
elif load_type in ("PLAYLIST_LOADED", "playlist"):
|
||||||
|
if self._version.major >= 4:
|
||||||
|
track_list = data[data_type]["tracks"]
|
||||||
|
playlist_info = data[data_type]["info"]
|
||||||
|
else:
|
||||||
|
track_list = data[data_type]
|
||||||
|
playlist_info = data["playlistInfo"]
|
||||||
tracks = [
|
tracks = [
|
||||||
Track(
|
Track(
|
||||||
track_id=track["encoded"],
|
track_id=track["encoded"],
|
||||||
|
|
@ -714,17 +776,60 @@ class Node:
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
track_type=TrackType(track["info"]["sourceName"]),
|
track_type=TrackType(track["info"]["sourceName"]),
|
||||||
)
|
)
|
||||||
for track in data["tracks"]
|
for track in track_list
|
||||||
]
|
]
|
||||||
return Playlist(
|
return Playlist(
|
||||||
playlist_info=data["playlistInfo"],
|
playlist_info=playlist_info,
|
||||||
tracks=tracks,
|
tracks=tracks,
|
||||||
playlist_type=PlaylistType(tracks[0].track_type.value),
|
playlist_type=PlaylistType(tracks[0].track_type.value),
|
||||||
thumbnail=tracks[0].thumbnail,
|
thumbnail=tracks[0].thumbnail,
|
||||||
uri=query,
|
uri=query,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED":
|
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
|
||||||
|
if self._version.major >= 4 and isinstance(data[data_type], dict):
|
||||||
|
data[data_type] = [data[data_type]]
|
||||||
|
|
||||||
|
if path.exists(path.dirname(query)):
|
||||||
|
local_file = Path(query)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Track(
|
||||||
|
track_id=track["encoded"],
|
||||||
|
info={
|
||||||
|
"title": local_file.name,
|
||||||
|
"author": "Unknown",
|
||||||
|
"length": track["info"]["length"],
|
||||||
|
"uri": quote(local_file.as_uri()),
|
||||||
|
"position": track["info"]["position"],
|
||||||
|
"identifier": track["info"]["identifier"],
|
||||||
|
},
|
||||||
|
ctx=ctx,
|
||||||
|
track_type=TrackType.LOCAL,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
for track in data[data_type]
|
||||||
|
]
|
||||||
|
|
||||||
|
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
||||||
|
return [
|
||||||
|
Track(
|
||||||
|
track_id=track["encoded"],
|
||||||
|
info={
|
||||||
|
"title": discord_url.group("file"),
|
||||||
|
"author": "Unknown",
|
||||||
|
"length": track["info"]["length"],
|
||||||
|
"uri": track["info"]["uri"],
|
||||||
|
"position": track["info"]["position"],
|
||||||
|
"identifier": track["info"]["identifier"],
|
||||||
|
},
|
||||||
|
ctx=ctx,
|
||||||
|
track_type=TrackType.HTTP,
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
for track in data[data_type]
|
||||||
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Track(
|
Track(
|
||||||
track_id=track["encoded"],
|
track_id=track["encoded"],
|
||||||
|
|
@ -734,7 +839,7 @@ class Node:
|
||||||
filters=filters,
|
filters=filters,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
)
|
)
|
||||||
for track in data["tracks"]
|
for track in data[data_type]
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -743,7 +848,10 @@ class Node:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_recommendations(
|
async def get_recommendations(
|
||||||
self, *, track: Track, ctx: Optional[commands.Context] = None
|
self,
|
||||||
|
*,
|
||||||
|
track: Track,
|
||||||
|
ctx: Optional[commands.Context] = None,
|
||||||
) -> Optional[Union[List[Track], Playlist]]:
|
) -> Optional[Union[List[Track], Playlist]]:
|
||||||
"""
|
"""
|
||||||
Gets recommendations from either YouTube or Spotify.
|
Gets recommendations from either YouTube or Spotify.
|
||||||
|
|
@ -753,7 +861,7 @@ class Node:
|
||||||
Context object on all tracks that get recommended.
|
Context object on all tracks that get recommended.
|
||||||
"""
|
"""
|
||||||
if track.track_type == TrackType.SPOTIFY:
|
if track.track_type == TrackType.SPOTIFY:
|
||||||
results = await self._spotify_client.get_recommendations(query=track.uri)
|
results = await self._spotify_client.get_recommendations(query=track.uri) # type: ignore
|
||||||
tracks = [
|
tracks = [
|
||||||
Track(
|
Track(
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
|
|
@ -788,6 +896,57 @@ class Node:
|
||||||
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def search_spotify_recommendations(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
ctx: Optional[commands.Context] = None,
|
||||||
|
filters: Optional[List[Filter]] = None,
|
||||||
|
) -> Optional[Union[List[Track], Playlist]]:
|
||||||
|
"""
|
||||||
|
Searches for recommendations on Spotify and returns a list of tracks based on the query.
|
||||||
|
You must have Spotify enabled for this to work.
|
||||||
|
You can pass in a discord.py Context object to get a
|
||||||
|
Context object on all tracks that get recommended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self._spotify_client:
|
||||||
|
raise InvalidSpotifyClientAuthorization(
|
||||||
|
"You must have Spotify enabled to use this feature.",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self._spotify_client.track_search(query=query) # type: ignore
|
||||||
|
if not results:
|
||||||
|
raise TrackLoadError(
|
||||||
|
"Unable to find any tracks based on the query.",
|
||||||
|
)
|
||||||
|
|
||||||
|
tracks = [
|
||||||
|
Track(
|
||||||
|
track_id=track.id,
|
||||||
|
ctx=ctx,
|
||||||
|
track_type=TrackType.SPOTIFY,
|
||||||
|
info={
|
||||||
|
"title": track.name,
|
||||||
|
"author": track.artists,
|
||||||
|
"length": track.length,
|
||||||
|
"identifier": track.id,
|
||||||
|
"uri": track.uri,
|
||||||
|
"isStream": False,
|
||||||
|
"isSeekable": True,
|
||||||
|
"position": 0,
|
||||||
|
"thumbnail": track.image,
|
||||||
|
"isrc": track.isrc,
|
||||||
|
},
|
||||||
|
requester=self.bot.user,
|
||||||
|
)
|
||||||
|
for track in results
|
||||||
|
]
|
||||||
|
|
||||||
|
track = tracks[0]
|
||||||
|
|
||||||
|
return await self.get_recommendations(track=track, ctx=ctx)
|
||||||
|
|
||||||
|
|
||||||
class NodePool:
|
class NodePool:
|
||||||
"""The base class for the node pool.
|
"""The base class for the node pool.
|
||||||
|
|
@ -869,14 +1028,16 @@ class NodePool:
|
||||||
password: str,
|
password: str,
|
||||||
identifier: str,
|
identifier: str,
|
||||||
secure: bool = False,
|
secure: bool = False,
|
||||||
heartbeat: int = 30,
|
heartbeat: int = 120,
|
||||||
|
resume_key: Optional[str] = None,
|
||||||
|
resume_timeout: int = 60,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
spotify_client_id: Optional[int] = None,
|
spotify_client_id: Optional[str] = None,
|
||||||
spotify_client_secret: Optional[str] = None,
|
spotify_client_secret: Optional[str] = None,
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
apple_music: bool = False,
|
apple_music: bool = False,
|
||||||
fallback: bool = False,
|
fallback: bool = False,
|
||||||
log_level: LogLevel = LogLevel.INFO,
|
logger: Optional[logging.Logger] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Creates a Node object to be then added into the node pool.
|
"""Creates a Node object to be then added into the node pool.
|
||||||
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
||||||
|
|
@ -895,13 +1056,15 @@ class NodePool:
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
secure=secure,
|
secure=secure,
|
||||||
heartbeat=heartbeat,
|
heartbeat=heartbeat,
|
||||||
|
resume_key=resume_key,
|
||||||
|
resume_timeout=resume_timeout,
|
||||||
loop=loop,
|
loop=loop,
|
||||||
spotify_client_id=spotify_client_id,
|
spotify_client_id=spotify_client_id,
|
||||||
session=session,
|
session=session,
|
||||||
spotify_client_secret=spotify_client_secret,
|
spotify_client_secret=spotify_client_secret,
|
||||||
apple_music=apple_music,
|
apple_music=apple_music,
|
||||||
fallback=fallback,
|
fallback=fallback,
|
||||||
log_level=log_level,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
await node.connect()
|
await node.connect()
|
||||||
|
|
|
||||||
|
|
@ -203,9 +203,13 @@ class Queue(Iterable[Track]):
|
||||||
raise QueueEmpty("No items in the queue.")
|
raise QueueEmpty("No items in the queue.")
|
||||||
|
|
||||||
if self._loop_mode == LoopMode.QUEUE:
|
if self._loop_mode == LoopMode.QUEUE:
|
||||||
# recurse if the item isnt in the queue
|
# set current item to first track in queue if not set already
|
||||||
if self._current_item not in self._queue:
|
# otherwise exception will be raised
|
||||||
self.get()
|
if not self._current_item or self._current_item not in self._queue:
|
||||||
|
if self._queue:
|
||||||
|
item = self._queue[0]
|
||||||
|
else:
|
||||||
|
raise QueueEmpty("No items in the queue.")
|
||||||
|
|
||||||
# set current item to first track in queue if not set already
|
# set current item to first track in queue if not set already
|
||||||
if not self._current_item:
|
if not self._current_item:
|
||||||
|
|
@ -348,7 +352,23 @@ class Queue(Iterable[Track]):
|
||||||
track.filters = None
|
track.filters = None
|
||||||
|
|
||||||
def jump(self, item: Track) -> None:
|
def jump(self, item: Track) -> None:
|
||||||
"""Removes all tracks before the."""
|
"""
|
||||||
|
Jumps to the item specified in the queue.
|
||||||
|
|
||||||
|
If the queue is not looping, the queue will be mutated.
|
||||||
|
Otherwise, the current item will be adjusted to the item
|
||||||
|
before the specified track.
|
||||||
|
|
||||||
|
The queue is adjusted so that the next item that is retrieved
|
||||||
|
is the track that is specified, effectively 'jumping' the queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._loop_mode == LoopMode.TRACK:
|
||||||
|
raise QueueException("Jumping the queue whilst looping a track is not allowed.")
|
||||||
|
|
||||||
index = self.find_position(item)
|
index = self.find_position(item)
|
||||||
|
if self._loop_mode == LoopMode.QUEUE:
|
||||||
|
self._current_item = self._queue[index - 1]
|
||||||
|
else:
|
||||||
new_queue = self._queue[index : self.size]
|
new_queue = self._queue[index : self.size]
|
||||||
self._queue = new_queue
|
self._queue = new_queue
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
from typing import AsyncGenerator
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import orjson as json
|
import orjson as json
|
||||||
|
|
@ -20,8 +24,10 @@ __all__ = ("Client",)
|
||||||
|
|
||||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||||
|
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
|
||||||
|
# optional trailing slash, and query parameters.
|
||||||
SPOTIFY_URL_REGEX = re.compile(
|
SPOTIFY_URL_REGEX = re.compile(
|
||||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,35 +37,49 @@ class Client:
|
||||||
for any Spotify URL you throw at it.
|
for any Spotify URL you throw at it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client_id: int, client_secret: str) -> None:
|
def __init__(
|
||||||
self._client_id: int = client_id
|
self,
|
||||||
self._client_secret: str = client_secret
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
*,
|
||||||
|
playlist_concurrency: int = 10,
|
||||||
|
playlist_page_limit: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self._client_id = client_id
|
||||||
|
self._client_secret = client_secret
|
||||||
|
|
||||||
self.session: aiohttp.ClientSession = None # type: ignore
|
# HTTP session will be injected by Node
|
||||||
|
self.session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
self._bearer_token: Optional[str] = None
|
self._bearer_token: Optional[str] = None
|
||||||
self._expiry: float = 0.0
|
self._expiry: float = 0.0
|
||||||
self._auth_token = b64encode(
|
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
|
||||||
f"{self._client_id}:{self._client_secret}".encode(),
|
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
|
||||||
)
|
|
||||||
self._grant_headers = {
|
|
||||||
"Authorization": f"Basic {self._auth_token.decode()}",
|
|
||||||
}
|
|
||||||
self._bearer_headers: Optional[Dict] = None
|
self._bearer_headers: Optional[Dict] = None
|
||||||
|
self._log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Performance tuning knobs
|
||||||
|
self._playlist_concurrency = max(1, playlist_concurrency)
|
||||||
|
self._playlist_page_limit = playlist_page_limit
|
||||||
|
|
||||||
|
async def _set_session(self, session: aiohttp.ClientSession) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
async def _fetch_bearer_token(self) -> None:
|
async def _fetch_bearer_token(self) -> None:
|
||||||
_data = {"grant_type": "client_credentials"}
|
_data = {"grant_type": "client_credentials"}
|
||||||
|
|
||||||
if not self.session:
|
if not self.session:
|
||||||
self.session = aiohttp.ClientSession()
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
|
||||||
|
|
||||||
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
f"Error fetching bearer token: {resp.status} {resp.reason}",
|
f"Error fetching bearer token: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(f"Fetched Spotify bearer token successfully")
|
||||||
|
|
||||||
self._bearer_token = data["access_token"]
|
self._bearer_token = data["access_token"]
|
||||||
self._expiry = time.time() + (int(data["expires_in"]) - 10)
|
self._expiry = time.time() + (int(data["expires_in"]) - 10)
|
||||||
|
|
@ -80,23 +100,31 @@ class Client:
|
||||||
|
|
||||||
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
||||||
|
|
||||||
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
|
if not self.session:
|
||||||
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
if self._log:
|
||||||
|
self._log.debug(
|
||||||
|
f"Made request to Spotify API with status {resp.status} and response {data}",
|
||||||
|
)
|
||||||
|
|
||||||
if spotify_type == "track":
|
if spotify_type == "track":
|
||||||
return Track(data)
|
return Track(data)
|
||||||
elif spotify_type == "album":
|
elif spotify_type == "album":
|
||||||
return Album(data)
|
return Album(data)
|
||||||
elif spotify_type == "artist":
|
elif spotify_type == "artist":
|
||||||
async with self.session.get(
|
if not self.session:
|
||||||
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.get(
|
||||||
f"{request_url}/top-tracks?market=US",
|
f"{request_url}/top-tracks?market=US",
|
||||||
headers=self._bearer_headers,
|
headers=self._bearer_headers,
|
||||||
) as resp:
|
)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
|
|
@ -106,36 +134,177 @@ class Client:
|
||||||
tracks = track_data["tracks"]
|
tracks = track_data["tracks"]
|
||||||
return Artist(data, tracks)
|
return Artist(data, tracks)
|
||||||
else:
|
else:
|
||||||
|
# For playlists we optionally use a reduced fields payload to shrink response sizes.
|
||||||
|
# NB: We cannot apply fields filter to initial request because original metadata is needed.
|
||||||
tracks = [
|
tracks = [
|
||||||
Track(track["track"])
|
Track(track["track"])
|
||||||
for track in data["tracks"]["items"]
|
for track in data["tracks"]["items"]
|
||||||
if track["track"] is not None
|
if track["track"] is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
if not len(tracks):
|
if not tracks:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
"This playlist is empty and therefore cannot be queued.",
|
"This playlist is empty and therefore cannot be queued.",
|
||||||
)
|
)
|
||||||
|
|
||||||
next_page_url = data["tracks"]["next"]
|
total_tracks = data["tracks"]["total"]
|
||||||
|
limit = data["tracks"]["limit"]
|
||||||
|
|
||||||
while next_page_url is not None:
|
# Short‑circuit small playlists (single page)
|
||||||
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
|
if total_tracks <= limit:
|
||||||
|
return Playlist(data, tracks)
|
||||||
|
|
||||||
|
# Build remaining page URLs; Spotify supports offset-based pagination.
|
||||||
|
remaining_offsets = range(limit, total_tracks, limit)
|
||||||
|
page_urls: List[str] = []
|
||||||
|
fields_filter = (
|
||||||
|
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
||||||
|
",next"
|
||||||
|
)
|
||||||
|
for idx, offset in enumerate(remaining_offsets):
|
||||||
|
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
|
||||||
|
break
|
||||||
|
page_urls.append(
|
||||||
|
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if page_urls:
|
||||||
|
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||||
|
|
||||||
|
async def fetch_page(url: str) -> Optional[List[Track]]:
|
||||||
|
async with semaphore:
|
||||||
|
if not self.session:
|
||||||
|
raise SpotifyRequestException(
|
||||||
|
"HTTP session not initialized for Spotify client.",
|
||||||
|
)
|
||||||
|
resp = await self.session.get(url, headers=self._bearer_headers)
|
||||||
|
if resp.status != 200:
|
||||||
|
if self._log:
|
||||||
|
self._log.warning(
|
||||||
|
f"Page fetch failed {resp.status} {resp.reason} for {url}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
page_json: dict = await resp.json(loads=json.loads)
|
||||||
|
return [
|
||||||
|
Track(item["track"])
|
||||||
|
for item in page_json.get("items", [])
|
||||||
|
if item.get("track") is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
# Chunk gather in waves to avoid creating thousands of tasks at once
|
||||||
|
aggregated: List[Track] = []
|
||||||
|
wave_size = self._playlist_concurrency * 2
|
||||||
|
for i in range(0, len(page_urls), wave_size):
|
||||||
|
wave = page_urls[i : i + wave_size]
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[fetch_page(url) for url in wave],
|
||||||
|
return_exceptions=False,
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
if result:
|
||||||
|
aggregated.extend(result)
|
||||||
|
|
||||||
|
tracks.extend(aggregated)
|
||||||
|
|
||||||
|
return Playlist(data, tracks)
|
||||||
|
|
||||||
|
async def iter_playlist_tracks(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> AsyncGenerator[List[Track], None]:
|
||||||
|
"""Stream playlist tracks in batches without waiting for full materialization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: str
|
||||||
|
Spotify playlist URL.
|
||||||
|
batch_size: int
|
||||||
|
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
|
||||||
|
"""
|
||||||
|
if not self._bearer_token or time.time() >= self._expiry:
|
||||||
|
await self._fetch_bearer_token()
|
||||||
|
|
||||||
|
match = SPOTIFY_URL_REGEX.match(query)
|
||||||
|
if not match or match.group("type") != "playlist":
|
||||||
|
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
|
||||||
|
|
||||||
|
playlist_id = match.group("id")
|
||||||
|
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
|
||||||
|
if not self.session:
|
||||||
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
|
||||||
next_data: dict = await resp.json(loads=json.loads)
|
# Yield first page immediately
|
||||||
|
first_page_tracks = [
|
||||||
tracks += [
|
Track(item["track"])
|
||||||
Track(track["track"])
|
for item in data["tracks"]["items"]
|
||||||
for track in next_data["items"]
|
if item.get("track") is not None
|
||||||
if track["track"] is not None
|
|
||||||
]
|
]
|
||||||
next_page_url = next_data["next"]
|
# Batch yield
|
||||||
|
for i in range(0, len(first_page_tracks), batch_size):
|
||||||
|
yield first_page_tracks[i : i + batch_size]
|
||||||
|
|
||||||
return Playlist(data, tracks)
|
total = data["tracks"]["total"]
|
||||||
|
limit = data["tracks"]["limit"]
|
||||||
|
remaining_offsets = range(limit, total, limit)
|
||||||
|
fields_filter = (
|
||||||
|
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
|
||||||
|
",next"
|
||||||
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(self._playlist_concurrency)
|
||||||
|
|
||||||
|
async def fetch(offset: int) -> List[Track]:
|
||||||
|
url = (
|
||||||
|
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
|
||||||
|
)
|
||||||
|
async with semaphore:
|
||||||
|
if not self.session:
|
||||||
|
raise SpotifyRequestException(
|
||||||
|
"HTTP session not initialized for Spotify client.",
|
||||||
|
)
|
||||||
|
r = await self.session.get(url, headers=self._bearer_headers)
|
||||||
|
if r.status != 200:
|
||||||
|
if self._log:
|
||||||
|
self._log.warning(
|
||||||
|
f"Skipping page offset={offset} due to {r.status} {r.reason}",
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
pj: dict = await r.json(loads=json.loads)
|
||||||
|
return [
|
||||||
|
Track(item["track"])
|
||||||
|
for item in pj.get("items", [])
|
||||||
|
if item.get("track") is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
|
||||||
|
wave_size = self._playlist_concurrency * 2
|
||||||
|
for i, offset in enumerate(remaining_offsets):
|
||||||
|
# Build wave
|
||||||
|
if i % wave_size == 0:
|
||||||
|
wave_offsets = list(
|
||||||
|
o for o in remaining_offsets if o >= offset and o < offset + wave_size
|
||||||
|
)
|
||||||
|
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
|
||||||
|
for page_tracks in results:
|
||||||
|
if not page_tracks:
|
||||||
|
continue
|
||||||
|
for j in range(0, len(page_tracks), batch_size):
|
||||||
|
yield page_tracks[j : j + batch_size]
|
||||||
|
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
|
||||||
|
# Fast-forward the generator manually
|
||||||
|
for _ in range(len(wave_offsets) - 1):
|
||||||
|
try:
|
||||||
|
next(remaining_offsets) # type: ignore
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
async def get_recommendations(self, *, query: str) -> List[Track]:
|
async def get_recommendations(self, *, query: str) -> List[Track]:
|
||||||
if not self._bearer_token or time.time() >= self._expiry:
|
if not self._bearer_token or time.time() >= self._expiry:
|
||||||
|
|
@ -158,19 +327,34 @@ class Client:
|
||||||
id=f"?seed_tracks={spotify_id}",
|
id=f"?seed_tracks={spotify_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
|
if not self.session:
|
||||||
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise SpotifyRequestException(
|
raise SpotifyRequestException(
|
||||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict = await resp.json(loads=json.loads)
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
|
||||||
tracks = [Track(track) for track in data["tracks"]]
|
tracks = [Track(track) for track in data["tracks"]]
|
||||||
|
|
||||||
return tracks
|
return tracks
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def track_search(self, *, query: str) -> List[Track]:
|
||||||
if self.session:
|
if not self._bearer_token or time.time() >= self._expiry:
|
||||||
await self.session.close()
|
await self._fetch_bearer_token()
|
||||||
self.session = None # type: ignore
|
|
||||||
|
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
|
||||||
|
|
||||||
|
if not self.session:
|
||||||
|
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
|
||||||
|
resp = await self.session.get(request_url, headers=self._bearer_headers)
|
||||||
|
if resp.status != 200:
|
||||||
|
raise SpotifyRequestException(
|
||||||
|
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||||
|
)
|
||||||
|
|
||||||
|
data: dict = await resp.json(loads=json.loads)
|
||||||
|
tracks = [Track(track) for track in data["tracks"]["items"]]
|
||||||
|
|
||||||
|
return tracks
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class Track:
|
||||||
self.length: float = data["duration_ms"]
|
self.length: float = data["duration_ms"]
|
||||||
self.id: str = data["id"]
|
self.id: str = data["id"]
|
||||||
|
|
||||||
self.issrc: Optional[str] = None
|
self.isrc: Optional[str] = None
|
||||||
if data.get("external_ids"):
|
if data.get("external_ids"):
|
||||||
self.isrc = data["external_ids"]["isrc"]
|
self.isrc = data["external_ids"]["isrc"]
|
||||||
|
|
||||||
|
|
|
||||||
3
setup.py
3
setup.py
|
|
@ -1,9 +1,10 @@
|
||||||
|
# type: ignore
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import setuptools
|
import setuptools
|
||||||
|
|
||||||
version = ""
|
version = ""
|
||||||
requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
|
requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
|
||||||
with open("pomice/__init__.py") as f:
|
with open("pomice/__init__.py") as f:
|
||||||
version = re.search(
|
version = re.search(
|
||||||
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
|
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue