Compare commits

...

71 Commits
2.5.1 ... main

Author SHA1 Message Date
cloudwithax 9bffdebe25 2.10.0 2025-10-04 00:00:57 -04:00
cloudwithax 720ba187ab 2.10.0 2025-10-04 00:00:01 -04:00
cloudwithax 855bf4e0d7 2.9.2 2024-11-21 21:11:24 -05:00
cloudwithax cd579becad fixed file playing and recursion issue in queue looping 2024-11-21 21:06:32 -05:00
cloudwithax 3a1ecf9eec 2.9.1 2024-08-23 21:18:25 -04:00
clxud 5227962228
Merge pull request #69 from ZandercraftGames/fix/other-sources
Fix Support for Other Source and Playlist Types
2024-08-23 21:04:04 -04:00
Zander be7106616b
Typing fix from NiceAesth
Co-authored-by: Andrei Baciu <8437201+NiceAesth@users.noreply.github.com>
2024-08-18 00:34:43 -04:00
Zander ba9534bc27
Typing fix from NiceAesth
Co-authored-by: Andrei Baciu <8437201+NiceAesth@users.noreply.github.com>
2024-08-18 00:34:32 -04:00
pre-commit-ci[bot] 2e0f5b365a [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-08-17 05:44:34 +00:00
Zander M. 851f00aa97
Add support for other unsupported playlist types 2024-08-17 01:23:59 -04:00
Zander M. 817295d321
Add support for other unsupported source types 2024-08-17 00:46:32 -04:00
Zander M. 8ab3ae9ccd
Add websockets dependency to Pipenv 2024-08-17 00:44:51 -04:00
cloudwithax 094f2be181 refactor: set original track if search type is not defined in Player's play_track method 2024-06-10 21:54:21 -04:00
cloudwithax b60a6aec18 refactor: guard check for search type to prevent nulled search types getting searched 2024-06-10 21:30:57 -04:00
cloudwithax 80f7b77cd3 refactor: update search_type handling in Player and Node classes to be nullish to support lavasrc 2024-06-10 21:20:59 -04:00
cloudwithax 8679d6d125 Merge branch 'main' of https://github.com/cloudwithax/pomice 2024-06-10 21:17:57 -04:00
cloudwithax ad01407fff refactor: update query handling in Node class 2024-06-10 21:17:53 -04:00
Clxud f1609f7049
Merge pull request #67 from ZandercraftGames/fix/load-exceptions
Fix KeyError in exception handling on error when loading a track.
2024-03-27 22:10:37 -04:00
pre-commit-ci[bot] 5fcfc73901 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-03-28 02:08:22 +00:00
Zander M. 86b35106b2
Fix KeyError in exception handling on error when loading a track. 2024-03-27 22:04:39 -04:00
Clxud 519a14fbde
Merge pull request #66 from ZandercraftGames/main
Fix build_track failure with Lavalink v4 decodetrack format
2024-03-13 10:16:31 -04:00
Zander M. 9a42093f64
Merge remote-tracking branch 'origin/main' 2024-03-11 13:50:22 -04:00
Zander M. 347a6e0b96
Refactor sourceName to use track_info object. 2024-03-11 13:50:07 -04:00
pre-commit-ci[bot] 83d5add134 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2024-03-11 17:29:47 +00:00
Zander M. bb12e33584
Python Black Format 2024-03-11 13:24:54 -04:00
Zander M. 6817cd8e07
Fix build_track failure with Lavalink v4 decodetrack format. 2024-03-11 13:23:42 -04:00
Clxud 179472bd6e
Merge pull request #63 from NiceAesth/fix-assert
fix: remove unnecessary assert
2024-02-22 11:00:05 -05:00
NiceAesth ba761743b9 fix: remove unnecessary assert
Hit this in production (https://sunnycord.sentry.io/share/issue/e39efaaa16d64b4fbf4e3ec409406971/)
The assert is unnecessary since track is typed as optional either way.
2024-02-19 20:43:58 +02:00
cloudwithax b3795102b8 2.9.0 2024-02-06 17:32:17 -05:00
cloudwithax 2a492c793f fix issues related to loading files/http links, added spotify recommendation querying, changed loglevel enum behavior 2024-02-06 17:31:51 -05:00
cloudwithax 705ac9feab 2.8.1 2024-02-01 22:03:52 -05:00
cloudwithax a926616028 actually make logging optional lol 2024-02-01 22:03:30 -05:00
cloudwithax 4507b50b8b 2.8.0 2024-02-01 21:12:34 -05:00
cloudwithax bd78f47585 fix logging and made spotify + apple music optional because of v4 2024-02-01 21:11:42 -05:00
cloudwithax 9b18759864 update voice channel on state update + dont make logging enabled by default 2024-01-28 15:59:41 -05:00
Clxud 001b801a15
Merge pull request #57 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-11-05 20:33:45 -05:00
pre-commit-ci[bot] db1c66dd40 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2023-10-30 17:50:55 +00:00
pre-commit-ci[bot] 341164a0d2
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0)
- [github.com/psf/black: 23.7.0 → 23.10.1](https://github.com/psf/black/compare/23.7.0...23.10.1)
- [github.com/asottile/blacken-docs: 1.15.0 → 1.16.0](https://github.com/asottile/blacken-docs/compare/1.15.0...1.16.0)
- [github.com/asottile/pyupgrade: v3.10.1 → v3.15.0](https://github.com/asottile/pyupgrade/compare/v3.10.1...v3.15.0)
- [github.com/asottile/reorder-python-imports: v3.10.0 → v3.12.0](https://github.com/asottile/reorder-python-imports/compare/v3.10.0...v3.12.0)
- [github.com/asottile/add-trailing-comma: v3.0.1 → v3.1.0](https://github.com/asottile/add-trailing-comma/compare/v3.0.1...v3.1.0)
- [github.com/hadialqattan/pycln: v2.2.2 → v2.3.0](https://github.com/hadialqattan/pycln/compare/v2.2.2...v2.3.0)
2023-10-30 17:50:45 +00:00
Clxud 7829086ae3
Merge pull request #59 from corpnewt/patch-1
Account for Lavalink v4 changes when loading YT playlists
2023-09-17 11:33:42 -04:00
pre-commit-ci[bot] f9cb48c48f [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2023-09-15 23:28:59 +00:00
CorpNewt 3401b669e8
Fix for YT URL searches on Lavalink v4
Since the prior code for v3 uses list comprehension to build the tracks returned, we can check if we're using v4 and if the data[data_type] is a dictionary, and wrap it in a list to ensure the same behavior.
2023-09-15 18:28:50 -05:00
CorpNewt d7a7efb051
Account for Lavalink v4 changes when loading YT playlists 2023-09-15 10:35:56 -05:00
cloudwithax 0904196979 ver bump 2023-08-23 13:15:54 -04:00
Clxud 7617ecf2d1
Merge pull request #58 from NiceAesth/fix-resume
fix: undefined data in `_configure_resuming`
2023-08-23 13:14:56 -04:00
NiceAesth 1acc594467 fix: undefined data in `_configure_resuming` 2023-08-23 20:00:36 +03:00
cloudwithax e48c31b7a9 Merge branch 'main' of https://github.com/cloudwithax/pomice 2023-08-23 10:45:04 -04:00
cloudwithax f3c5461854 patch load types and track events for v4 2023-08-23 10:44:51 -04:00
Clxud aa826c7da2
Merge pull request #56 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-08-16 01:26:36 -04:00
pre-commit-ci[bot] bc71088092
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/hadialqattan/pycln: v2.2.1 → v2.2.2](https://github.com/hadialqattan/pycln/compare/v2.2.1...v2.2.2)
2023-08-14 21:23:03 +00:00
cloudwithax 7eca4724da Merge branch 'main' of https://github.com/cloudwithax/pomice 2023-08-08 16:28:20 -04:00
cloudwithax b098b681be add missing dep for rtd 2023-08-08 16:28:11 -04:00
Clxud c52a379b87
Merge pull request #55 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-08-07 20:51:22 -04:00
pre-commit-ci[bot] 50b5eab860 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2023-08-07 21:21:47 +00:00
pre-commit-ci[bot] 18fed3a089
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/hadialqattan/pycln: v2.2.0 → v2.2.1](https://github.com/hadialqattan/pycln/compare/v2.2.0...v2.2.1)
2023-08-07 21:21:38 +00:00
Clxud ab432cc8e6
Merge pull request #54 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-07-31 20:11:56 -04:00
pre-commit-ci[bot] 223be29384 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2023-07-31 21:42:43 +00:00
pre-commit-ci[bot] c5f8ded0b1
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/psf/black: 23.3.0 → 23.7.0](https://github.com/psf/black/compare/23.3.0...23.7.0)
- [github.com/asottile/blacken-docs: 1.14.0 → 1.15.0](https://github.com/asottile/blacken-docs/compare/1.14.0...1.15.0)
- [github.com/asottile/pyupgrade: v3.7.0 → v3.10.1](https://github.com/asottile/pyupgrade/compare/v3.7.0...v3.10.1)
- [github.com/asottile/add-trailing-comma: v2.5.1 → v3.0.1](https://github.com/asottile/add-trailing-comma/compare/v2.5.1...v3.0.1)
- [github.com/hadialqattan/pycln: v2.1.5 → v2.2.0](https://github.com/hadialqattan/pycln/compare/v2.1.5...v2.2.0)
2023-07-31 21:42:32 +00:00
Clxud 0b1d36cf64
Merge pull request #52 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-06-19 17:09:32 -04:00
pre-commit-ci[bot] 1f20ebf6c6
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/asottile/blacken-docs: 1.13.0 → 1.14.0](https://github.com/asottile/blacken-docs/compare/1.13.0...1.14.0)
- [github.com/asottile/pyupgrade: v3.6.0 → v3.7.0](https://github.com/asottile/pyupgrade/compare/v3.6.0...v3.7.0)
- [github.com/asottile/reorder-python-imports: v3.9.0 → v3.10.0](https://github.com/asottile/reorder-python-imports/compare/v3.9.0...v3.10.0)
2023-06-19 21:02:01 +00:00
Clxud af5418c958
Merge pull request #51 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-06-12 18:10:23 -04:00
pre-commit-ci[bot] 6670da76e8 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2023-06-12 20:45:28 +00:00
pre-commit-ci[bot] 69d3bc9ce1
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/asottile/pyupgrade: v3.4.0 → v3.6.0](https://github.com/asottile/pyupgrade/compare/v3.4.0...v3.6.0)
- [github.com/asottile/add-trailing-comma: v2.4.0 → v2.5.1](https://github.com/asottile/add-trailing-comma/compare/v2.4.0...v2.5.1)
- [github.com/hadialqattan/pycln: v2.1.3 → v2.1.5](https://github.com/hadialqattan/pycln/compare/v2.1.3...v2.1.5)
2023-06-12 20:45:17 +00:00
cloudwithax e3fe1b52b2
2.7.0 2023-05-21 10:43:09 -04:00
cloudwithax 02d22f20b5
edit filters, log level matches handler, other fixes 2023-05-21 10:42:43 -04:00
cloudwithax cbb676e004
Merge branch 'main' of https://github.com/cloudwithax/pomice 2023-05-19 19:50:32 -04:00
cloudwithax 2d8acf7800
patch build track to escape chars properly 2023-05-19 19:50:10 -04:00
Clxud 481b2079ed
Merge pull request #50 from cloudwithax/pre-commit-ci-update-config
[pre-commit.ci] pre-commit autoupdate
2023-05-08 17:18:17 -04:00
pre-commit-ci[bot] 952a3eff14
[pre-commit.ci] pre-commit autoupdate
updates:
- [github.com/asottile/pyupgrade: v3.3.2 → v3.4.0](https://github.com/asottile/pyupgrade/compare/v3.3.2...v3.4.0)
- https://github.com/asottile/reorder_python_importshttps://github.com/asottile/reorder-python-imports
2023-05-08 20:53:40 +00:00
cloudwithax 4fc9bd8810
remove ctx comparison check from track __eq__ 2023-05-08 10:41:47 -04:00
cloudwithax 28db38a00e
2.6.0 2023-05-07 19:27:44 -04:00
cloudwithax 00ac166371
fix websocket issue and add node resuming 2023-05-07 19:27:11 -04:00
18 changed files with 941 additions and 338 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ build/
Pipfile.lock Pipfile.lock
.mypy_cache/ .mypy_cache/
.vscode/ .vscode/
.idea/
.venv/ .venv/
*.code-workspace *.code-workspace
*.ini *.ini

View File

@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.5.0
hooks: hooks:
- id: check-ast - id: check-ast
- id: check-builtin-literals - id: check-builtin-literals
@ -11,31 +11,23 @@ repos:
- id: requirements-txt-fixer - id: requirements-txt-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 23.10.1
hooks: hooks:
- id: black - id: black
language_version: python3.10 language_version: python3.13
- repo: https://github.com/asottile/blacken-docs
rev: 1.13.0
hooks:
- id: blacken-docs
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.3.2 rev: v3.15.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
args: [--py37-plus, --keep-runtime-typing] args: [--py37-plus, --keep-runtime-typing]
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder-python-imports
rev: v3.9.0 rev: v3.12.0
hooks: hooks:
- id: reorder-python-imports - id: reorder-python-imports
- repo: https://github.com/asottile/add-trailing-comma - repo: https://github.com/asottile/add-trailing-comma
rev: v2.4.0 rev: v3.1.0
hooks: hooks:
- id: add-trailing-comma - id: add-trailing-comma
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
hooks:
- id: pycln
default_language_version: default_language_version:
python: python3.10 python: python3.13

View File

@ -6,6 +6,7 @@ name = "pypi"
[packages] [packages]
orjson = "*" orjson = "*"
"discord.py" = {extras = ["voice"], version = "*"} "discord.py" = {extras = ["voice"], version = "*"}
websockets = "*"
[dev-packages] [dev-packages]
mypy = "*" mypy = "*"

View File

@ -13,11 +13,9 @@ The classes listed here are as they appear in Pomice. When you use them within y
the way you use them will be different. Here's an example on how you would use the `TrackStartEvent` within an event listener in a cog: the way you use them will be different. Here's an example on how you would use the `TrackStartEvent` within an event listener in a cog:
```py ```py
@commands.Cog.listener @commands.Cog.listener
async def on_pomice_track_start(self, player: Player, track: Track): async def on_pomice_track_start(self, player: Player, track: Track):
... ...
``` ```
## Event definitions ## Event definitions

View File

@ -66,13 +66,10 @@ After you have initialized your function, we need to fill in the proper paramete
- Set this value to `True` if you want Pomice to automatically switch all players to another available node if one disconnects. - Set this value to `True` if you want Pomice to automatically switch all players to another available node if one disconnects.
You must have two or more nodes to be able to do this. You must have two or more nodes to be able to do this.
* - `log_level` * - `logger`
- `LogLevel` - `Optional[logging.Logger]`
- The logging level for the node. The default logging level is `LogLevel.INFO`. - If you would like to receive logging information from Pomice, set this to your logger class
* - `log_handler`
- `Optional[logging.Handler]`
- The logging handler for the node. Set to `None` to default to the built-in logging handler.
::: :::
@ -93,13 +90,13 @@ await NodePool.create_node(
spotify_client_secret="<your spotify client secret here>" spotify_client_secret="<your spotify client secret here>"
apple_music=<True/False>, apple_music=<True/False>,
fallback=<True/False>, fallback=<True/False>,
log_level=<optional LogLevel here> logger=<your logger here>
) )
``` ```
:::{important} :::{important}
For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track anyway, they will **not work** because these options are not enabled. For features like Spotify and Apple Music, you are **not required** to fill in anything for them if you do not want to use them. If you do end up queuing a Spotify or Apple Music track, it is **up to you** on how you decide to handle it, whether it be through your own methods or a Lavalink plugin.
::: :::

View File

@ -10,63 +10,71 @@ import re
from discord.ext import commands from discord.ext import commands
URL_REG = re.compile(r'https?://(?:www\.)?.+') URL_REG = re.compile(r"https?://(?:www\.)?.+")
class MyBot(commands.Bot): class MyBot(commands.Bot):
def __init__(self) -> None:
super().__init__(
command_prefix="!",
activity=discord.Activity(
type=discord.ActivityType.listening, name="to music!"
),
)
def __init__(self) -> None: self.add_cog(Music(self))
super().__init__(command_prefix='!', activity=discord.Activity(type=discord.ActivityType.listening, name='to music!'))
self.add_cog(Music(self)) async def on_ready(self) -> None:
print("I'm online!")
async def on_ready(self) -> None: await self.cogs["Music"].start_nodes()
print("I'm online!")
await self.cogs["Music"].start_nodes()
class Music(commands.Cog): class Music(commands.Cog):
def __init__(self, bot) -> None:
self.bot = bot
def __init__(self, bot) -> None: self.pomice = pomice.NodePool()
self.bot = bot
self.pomice = pomice.NodePool() async def start_nodes(self):
await self.pomice.create_node(
bot=self.bot,
host="127.0.0.1",
port="3030",
password="youshallnotpass",
identifier="MAIN",
)
print(f"Node is ready!")
async def start_nodes(self): @commands.command(name="join", aliases=["connect"])
await self.pomice.create_node(bot=self.bot, host='127.0.0.1', port='3030', async def join(
password='youshallnotpass', identifier='MAIN') self, ctx: commands.Context, *, channel: discord.TextChannel = None
print(f"Node is ready!") ) -> None:
if not channel:
channel = getattr(ctx.author.voice, "channel", None)
@commands.command(name='join', aliases=['connect'])
async def join(self, ctx: commands.Context, *, channel: discord.TextChannel = None) -> None:
if not channel:
channel = getattr(ctx.author.voice, 'channel', None)
if not channel: if not channel:
raise commands.CheckFailure('You must be in a voice channel to use this command' raise commands.CheckFailure(
'without specifying the channel argument.') "You must be in a voice channel to use this command"
"without specifying the channel argument."
)
await ctx.author.voice.channel.connect(cls=pomice.Player)
await ctx.send(f"Joined the voice channel `{channel}`")
await ctx.author.voice.channel.connect(cls=pomice.Player) @commands.command(name="play")
await ctx.send(f'Joined the voice channel `{channel}`') async def play(self, ctx, *, search: str) -> None:
if not ctx.voice_client:
@commands.command(name='play')
async def play(self, ctx, *, search: str) -> None:
if not ctx.voice_client:
await ctx.invoke(self.join) await ctx.invoke(self.join)
player = ctx.voice_client player = ctx.voice_client
results = await player.get_tracks(query=f'{search}') results = await player.get_tracks(query=f"{search}")
if not results: if not results:
raise commands.CommandError('No results were found for that search term.') raise commands.CommandError("No results were found for that search term.")
if isinstance(results, pomice.Playlist): if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0]) await player.play(track=results.tracks[0])
else: else:
await player.play(track=results[0]) await player.play(track=results[0])

View File

@ -3,3 +3,4 @@ discord.py[voice]
furo furo
myst_parser myst_parser
orjson orjson
websockets

View File

@ -3,7 +3,7 @@ Pomice
~~~~~~ ~~~~~~
The modern Lavalink wrapper designed for discord.py. The modern Lavalink wrapper designed for discord.py.
Copyright (c) 2023, cloudwithax Copyright (c) 2024, cloudwithax
Licensed under GPL-3.0 Licensed under GPL-3.0
""" """
@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'", "using 'pip install discord.py'",
) )
__version__ = "2.5.1" __version__ = "2.10.0"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0" __license__ = "GPL-3.0"

View File

@ -1,11 +1,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import base64 import base64
import logging import logging
import re import re
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Union from typing import Union
import aiohttp import aiohttp
@ -17,10 +20,10 @@ from .objects import *
__all__ = ("Client",) __all__ = ("Client",)
AM_URL_REGEX = re.compile( AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)", r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
) )
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)", r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
) )
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"') AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
@ -35,12 +38,14 @@ class Client:
and translating it to a valid Lavalink track. No client auth is required here. and translating it to a valid Lavalink track. No client auth is required here.
""" """
def __init__(self) -> None: def __init__(self, *, playlist_concurrency: int = 6) -> None:
self.expiry: datetime = datetime(1970, 1, 1) self.expiry: datetime = datetime(1970, 1, 1)
self.token: str = "" self.token: str = ""
self.headers: Dict[str, str] = {} self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__) self._log = logging.getLogger(__name__)
# Concurrency knob for parallel playlist page retrieval
self._playlist_concurrency = max(1, playlist_concurrency)
async def _set_session(self, session: aiohttp.ClientSession) -> None: async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session self.session = session
@ -96,7 +101,8 @@ class Client:
).decode() ).decode()
token_data = json.loads(token_json) token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"]) self.expiry = datetime.fromtimestamp(token_data["exp"])
self._log.debug(f"Fetched Apple Music bearer token successfully") if self._log:
self._log.debug(f"Fetched Apple Music bearer token successfully")
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]: async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry: if not self.token or datetime.utcnow() > self.expiry:
@ -130,9 +136,10 @@ class Client:
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug( if self._log:
f"Made request to Apple Music API with status {resp.status} and response {data}", self._log.debug(
) f"Made request to Apple Music API with status {resp.status} and response {data}",
)
data = data["data"][0] data = data["data"][0]
@ -165,25 +172,127 @@ class Client:
"This playlist is empty and therefore cannot be queued.", "This playlist is empty and therefore cannot be queued.",
) )
_next = track_data.get("next") # Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
if _next: # concurrently by first collecting cursors in rolling waves.
next_page_url = AM_BASE_URL + _next next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
while next_page_url is not None:
resp = await self.session.get(next_page_url, headers=self.headers)
async def fetch_page(url: str) -> List[Song]:
async with semaphore:
resp = await self.session.get(url, headers=self.headers)
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( if self._log:
f"Error while fetching results: {resp.status} {resp.reason}", self._log.warning(
) f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
)
return []
pj: dict = await resp.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
# Return songs; we will look for pj.get('next') in streaming iterator variant
return songs, pj.get("next") # type: ignore
next_data: dict = await resp.json(loads=json.loads) # We'll implement a wave-based approach similar to Spotify but need to follow cursors.
album_tracks.extend(Song(track) for track in next_data["data"]) # Because we cannot know all cursors upfront, we'll iteratively fetch waves.
waves: List[List[Song]] = []
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
_next = next_data.get("next") # Limit total waves to avoid infinite loops in malformed responses
if _next: max_waves = 50
next_page_url = AM_BASE_URL + _next wave_size = self._playlist_concurrency * 2
else: wave_counter = 0
next_page_url = None while cursors and wave_counter < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
tasks = [
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for res in results:
if isinstance(res, tuple): # (songs, next)
songs, nxt = res
if songs:
waves.append(songs)
if nxt:
cursors.append(nxt)
wave_counter += 1
for w in waves:
album_tracks.extend(w)
return Playlist(data, album_tracks) return Playlist(data, album_tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Song], None]:
"""Stream Apple Music playlist tracks in batches.
Parameters
----------
query: str
Apple Music playlist URL.
batch_size: int
Logical grouping size for yielded batches.
"""
if not self.token or datetime.utcnow() > self.expiry:
await self.request_token()
result = AM_URL_REGEX.match(query)
if not result or result.group("type") != "playlist":
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
country = result.group("country")
playlist_id = result.group("id")
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
resp = await self.session.get(request_url, headers=self.headers)
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
playlist_data = data["data"][0]
track_data: dict = playlist_data["relationships"]["tracks"]
first_page_tracks = [Song(track) for track in track_data["data"]]
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
next_cursor = track_data.get("next")
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
url = AM_BASE_URL + cursor
async with semaphore:
r = await self.session.get(url, headers=self.headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping Apple Music page due to {r.status} {r.reason}",
)
return [], None
pj: dict = await r.json(loads=json.loads)
songs = [Song(track) for track in pj.get("data", [])]
return songs, pj.get("next")
# Rolling waves of fetches following cursor chain
max_waves = 50
wave_size = self._playlist_concurrency * 2
waves = 0
cursors: List[str] = []
if next_cursor:
cursors.append(next_cursor)
while cursors and waves < max_waves:
current = cursors[:wave_size]
cursors = cursors[wave_size:]
results = await asyncio.gather(*[fetch(c) for c in current])
for songs, nxt in results:
if songs:
for j in range(0, len(songs), batch_size):
yield songs[j : j + batch_size]
if nxt:
cursors.append(nxt)
waves += 1

View File

@ -34,6 +34,11 @@ class SearchType(Enum):
ytsearch = "ytsearch" ytsearch = "ytsearch"
ytmsearch = "ytmsearch" ytmsearch = "ytmsearch"
scsearch = "scsearch" scsearch = "scsearch"
other = "other"
@classmethod
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
return cls.other
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -54,6 +59,8 @@ class TrackType(Enum):
TrackType.HTTP defines that the track is from an HTTP source. TrackType.HTTP defines that the track is from an HTTP source.
TrackType.LOCAL defines that the track is from a local source. TrackType.LOCAL defines that the track is from a local source.
TrackType.OTHER defines that the track is from an unknown source (possible from 3rd-party plugins).
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
@ -63,6 +70,11 @@ class TrackType(Enum):
APPLE_MUSIC = "apple_music" APPLE_MUSIC = "apple_music"
HTTP = "http" HTTP = "http"
LOCAL = "local" LOCAL = "local"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -79,6 +91,8 @@ class PlaylistType(Enum):
PlaylistType.SPOTIFY defines that the playlist is from Spotify PlaylistType.SPOTIFY defines that the playlist is from Spotify
PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music. PlaylistType.APPLE_MUSIC defines that the playlist is from Apple Music.
PlaylistType.OTHER defines that the playlist is from an unknown source (possible from 3rd-party plugins).
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
@ -86,6 +100,11 @@ class PlaylistType(Enum):
SOUNDCLOUD = "soundcloud" SOUNDCLOUD = "soundcloud"
SPOTIFY = "spotify" SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music" APPLE_MUSIC = "apple_music"
OTHER = "other"
@classmethod
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
return cls.OTHER
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -199,8 +218,12 @@ class URLRegex:
""" """
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
# the type and id while ignoring extra parameters. This prevents the URL from being
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
SPOTIFY_URL = re.compile( SPOTIFY_URL = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)", r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
) )
DISCORD_MP3_URL = re.compile( DISCORD_MP3_URL = re.compile(
@ -221,14 +244,17 @@ class URLRegex:
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*", r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
) )
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
# Allow arbitrary query parameters so valid links are captured and parsed.
AM_URL = re.compile( AM_URL = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/" r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)", r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
) )
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/" r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)", r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
) )
SOUNDCLOUD_URL = re.compile( SOUNDCLOUD_URL = re.compile(
@ -273,3 +299,10 @@ class LogLevel(IntEnum):
WARN = 30 WARN = 30
ERROR = 40 ERROR = 40
CRITICAL = 50 CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")

View File

@ -128,7 +128,6 @@ class TrackExceptionEvent(PomiceEvent):
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
self.player: Player = player self.player: Player = player
assert self.player._ending_track is not None
self.track: Optional[Track] = self.player._ending_track self.track: Optional[Track] = self.player._ending_track
# Error is for Lavalink <= 3.3 # Error is for Lavalink <= 3.3
self.exception: str = data.get( self.exception: str = data.get(

View File

@ -77,6 +77,12 @@ class Equalizer(Filter):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>" return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Equalizer):
return False
return self.raw == __value.raw
@classmethod @classmethod
def flat(cls) -> "Equalizer": def flat(cls) -> "Equalizer":
"""Equalizer preset which represents a flat EQ board, """Equalizer preset which represents a flat EQ board,
@ -231,6 +237,16 @@ class Timescale(Filter):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>" return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Timescale):
return False
return (
self.speed == __value.speed
and self.pitch == __value.pitch
and self.rate == __value.rate
)
class Karaoke(Filter): class Karaoke(Filter):
"""Filter which filters the vocal track from any song and leaves the instrumental. """Filter which filters the vocal track from any song and leaves the instrumental.
@ -270,6 +286,17 @@ class Karaoke(Filter):
f"filter_band={self.filter_band} filter_width={self.filter_width}>" f"filter_band={self.filter_band} filter_width={self.filter_width}>"
) )
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Karaoke):
return False
return (
self.level == __value.level
and self.mono_level == __value.mono_level
and self.filter_band == __value.filter_band
and self.filter_width == __value.filter_width
)
class Tremolo(Filter): class Tremolo(Filter):
"""Filter which produces a wavering tone in the music, """Filter which produces a wavering tone in the music,
@ -305,6 +332,12 @@ class Tremolo(Filter):
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
) )
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Tremolo):
return False
return self.frequency == __value.frequency and self.depth == __value.depth
class Vibrato(Filter): class Vibrato(Filter):
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter, """Filter which produces a wavering tone in the music, similar to the Tremolo filter,
@ -340,6 +373,12 @@ class Vibrato(Filter):
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
) )
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Vibrato):
return False
return self.frequency == __value.frequency and self.depth == __value.depth
class Rotation(Filter): class Rotation(Filter):
"""Filter which produces a stereo-like panning effect, which sounds like """Filter which produces a stereo-like panning effect, which sounds like
@ -357,6 +396,12 @@ class Rotation(Filter):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>" return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Rotation):
return False
return self.rotation_hertz == __value.rotation_hertz
class ChannelMix(Filter): class ChannelMix(Filter):
"""Filter which manually adjusts the panning of the audio, which can make """Filter which manually adjusts the panning of the audio, which can make
@ -418,6 +463,17 @@ class ChannelMix(Filter):
f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>" f"right_to_left={self.right_to_left} right_to_right={self.right_to_right}>"
) )
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ChannelMix):
return False
return (
self.left_to_left == __value.left_to_left
and self.left_to_right == __value.left_to_right
and self.right_to_left == __value.right_to_left
and self.right_to_right == __value.right_to_right
)
class Distortion(Filter): class Distortion(Filter):
"""Filter which generates a distortion effect. Useful for certain filter implementations where """Filter which generates a distortion effect. Useful for certain filter implementations where
@ -479,6 +535,21 @@ class Distortion(Filter):
f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}" f"tan_scale={self.tan_scale} offset={self.offset} scale={self.scale}"
) )
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Distortion):
return False
return (
self.sin_offset == __value.sin_offset
and self.sin_scale == __value.sin_scale
and self.cos_offset == __value.cos_offset
and self.cos_scale == __value.cos_scale
and self.tan_offset == __value.tan_offset
and self.tan_scale == __value.tan_scale
and self.offset == __value.offset
and self.scale == __value.scale
)
class LowPass(Filter): class LowPass(Filter):
"""Filter which supresses higher frequencies and allows lower frequencies to pass. """Filter which supresses higher frequencies and allows lower frequencies to pass.
@ -495,3 +566,9 @@ class LowPass(Filter):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>" return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, LowPass):
return False
return self.smoothing == __value.smoothing

View File

@ -98,9 +98,6 @@ class Track:
if not isinstance(other, Track): if not isinstance(other, Track):
return False return False
if self.ctx and other.ctx:
return other.track_id == self.track_id and other.ctx.message.id == self.ctx.message.id
return other.track_id == self.track_id return other.track_id == self.track_id
def __str__(self) -> str: def __str__(self) -> str:

View File

@ -79,6 +79,27 @@ class Filters:
if filter.tag == filter_tag: if filter.tag == filter_tag:
del self._filters[index] del self._filters[index]
def edit_filter(self, *, filter_tag: str, to_apply: Filter) -> None:
"""Edits a filter in the list of filters applied using its filter tag and replaces it with the new filter."""
if not any(f for f in self._filters if f.tag == filter_tag):
raise FilterTagInvalid("A filter with that tag was not found.")
for index, filter in enumerate(self._filters):
if filter.tag == filter_tag:
if not type(filter) == type(to_apply):
raise FilterInvalidArgument(
"Edited filter is not the same type as the current filter.",
)
if self._filters[index] == to_apply:
raise FilterInvalidArgument("Edited filter is the same as the current filter.")
if to_apply.tag != filter_tag:
raise FilterInvalidArgument(
"Edited filter tag is not the same as the current filter tag.",
)
self._filters[index] = to_apply
def has_filter(self, *, filter_tag: str) -> bool: def has_filter(self, *, filter_tag: str) -> bool:
"""Checks if a filter exists in the list of filters using its filter tag""" """Checks if a filter exists in the list of filters using its filter tag"""
return any(f for f in self._filters if f.tag == filter_tag) return any(f for f in self._filters if f.tag == filter_tag)
@ -192,7 +213,7 @@ class Player(VoiceProtocol):
difference = (time.time() * 1000) - self._last_update difference = (time.time() * 1000) - self._last_update
position = self._last_position + difference position = self._last_position + difference
return min(position, current.length) return round(min(position, current.length))
@property @property
def rate(self) -> float: def rate(self) -> float:
@ -277,7 +298,8 @@ class Player(VoiceProtocol):
self._last_update = int(state.get("time", 0)) self._last_update = int(state.get("time", 0))
self._is_connected = bool(state.get("connected")) self._is_connected = bool(state.get("connected"))
self._last_position = int(state.get("position", 0)) self._last_position = int(state.get("position", 0))
self._log.debug(f"Got player update state with data {state}") if self._log:
self._log.debug(f"Got player update state with data {state}")
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys(): if {"sessionId", "event"} != self._voice_state.keys():
@ -298,7 +320,10 @@ class Player(VoiceProtocol):
data={"voice": data}, data={"voice": data},
) )
self._log.debug(f"Dispatched voice update to {state['event']['endpoint']} with data {data}") if self._log:
self._log.debug(
f"Dispatched voice update to {state['event']['endpoint']} with data {data}",
)
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None: async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data}) self._voice_state.update({"event": data})
@ -314,6 +339,10 @@ class Player(VoiceProtocol):
return return
channel = self.guild.get_channel(int(channel_id)) channel = self.guild.get_channel(int(channel_id))
if self.channel != channel:
self.channel = channel
if not channel: if not channel:
await self.disconnect() await self.disconnect()
self._voice_state.clear() self._voice_state.clear()
@ -328,7 +357,7 @@ class Player(VoiceProtocol):
event_type: str = data["type"] event_type: str = data["type"]
event: PomiceEvent = getattr(events, event_type)(data, self) event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED": if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"):
self._current = None self._current = None
event.dispatch(self._bot) event.dispatch(self._bot)
@ -336,7 +365,8 @@ class Player(VoiceProtocol):
if isinstance(event, TrackStartEvent): if isinstance(event, TrackStartEvent):
self._ending_track = self._current self._ending_track = self._current
self._log.debug(f"Dispatched event {data['type']} to player.") if self._log:
self._log.debug(f"Dispatched event {data['type']} to player.")
async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None: async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None:
self._player_endpoint_uri = f"sessions/{session_id}/players" self._player_endpoint_uri = f"sessions/{session_id}/players"
@ -358,14 +388,15 @@ class Player(VoiceProtocol):
data=data or None, data=data or None,
) )
self._log.debug(f"Swapped all players to new node {new_node._identifier}.") if self._log:
self._log.debug(f"Swapped all players to new node {new_node._identifier}.")
async def get_tracks( async def get_tracks(
self, self,
query: str, query: str,
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType | None = SearchType.ytsearch,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]: ) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
@ -382,8 +413,21 @@ class Player(VoiceProtocol):
""" """
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters) return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
"""
Builds a track using a valid track identifier
You can also pass in a discord.py Context object to get a
Context object on the track it builds.
"""
return await self._node.build_track(identifier, ctx=ctx)
async def get_recommendations( async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None self,
*,
track: Track,
ctx: Optional[commands.Context] = None,
) -> Optional[Union[List[Track], Playlist]]: ) -> Optional[Union[List[Track], Playlist]]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
@ -393,7 +437,12 @@ class Player(VoiceProtocol):
return await self._node.get_recommendations(track=track, ctx=ctx) return await self._node.get_recommendations(track=track, ctx=ctx)
async def connect( async def connect(
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False self,
*,
timeout: float,
reconnect: bool,
self_deaf: bool = False,
self_mute: bool = False,
) -> None: ) -> None:
await self.guild.change_voice_state( await self.guild.change_voice_state(
channel=self.channel, channel=self.channel,
@ -413,7 +462,8 @@ class Player(VoiceProtocol):
data={"encodedTrack": None}, data={"encodedTrack": None},
) )
self._log.debug(f"Player has been stopped.") if self._log:
self._log.debug(f"Player has been stopped.")
async def disconnect(self, *, force: bool = False) -> None: async def disconnect(self, *, force: bool = False) -> None:
"""Disconnects the player from voice.""" """Disconnects the player from voice."""
@ -431,24 +481,34 @@ class Player(VoiceProtocol):
except AttributeError: except AttributeError:
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() -> # 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
# assume we're already disconnected and cleaned up # assume we're already disconnected and cleaned up
assert not self.is_connected and not self.channel assert self.channel is None and not self.is_connected
self._node._players.pop(self.guild.id) self._node._players.pop(self.guild.id)
await self._node.send( if self.node.is_connected:
method="DELETE", await self._node.send(
path=self._player_endpoint_uri, method="DELETE",
guild_id=self._guild.id, path=self._player_endpoint_uri,
) guild_id=self._guild.id,
)
self._log.debug("Player has been destroyed.") if self._log:
self._log.debug("Player has been destroyed.")
async def play( async def play(
self, track: Track, *, start: int = 0, end: int = 0, ignore_if_playing: bool = False self,
track: Track,
*,
start: int = 0,
end: int = 0,
ignore_if_playing: bool = False,
) -> Track: ) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly.""" """Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
if not track._search_type:
track.original = track
# Make sure we've never searched the track before # Make sure we've never searched the track before
if track.original is None: if track._search_type and track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully) # First lets try using the tracks ISRC, every track has one (hopefully)
try: try:
if not track.isrc: if not track.isrc:
@ -528,9 +588,10 @@ class Player(VoiceProtocol):
query=f"noReplace={ignore_if_playing}", query=f"noReplace={ignore_if_playing}",
) )
self._log.debug( if self._log:
f"Playing {track.title} from uri {track.uri} with a length of {track.length}", self._log.debug(
) f"Playing {track.title} from uri {track.uri} with a length of {track.length}",
)
return self._current return self._current
@ -551,7 +612,8 @@ class Player(VoiceProtocol):
data={"position": position}, data={"position": position},
) )
self._log.debug(f"Seeking to {position}.") if self._log:
self._log.debug(f"Seeking to {position}.")
return self.position return self.position
async def set_pause(self, pause: bool) -> bool: async def set_pause(self, pause: bool) -> bool:
@ -564,7 +626,8 @@ class Player(VoiceProtocol):
) )
self._paused = pause self._paused = pause
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.") if self._log:
self._log.debug(f"Player has been {'paused' if pause else 'resumed'}.")
return self._paused return self._paused
async def set_volume(self, volume: int) -> int: async def set_volume(self, volume: int) -> int:
@ -577,7 +640,8 @@ class Player(VoiceProtocol):
) )
self._volume = volume self._volume = volume
self._log.debug(f"Player volume has been adjusted to {volume}") if self._log:
self._log.debug(f"Player volume has been adjusted to {volume}")
return self._volume return self._volume
async def move_to(self, channel: VoiceChannel) -> None: async def move_to(self, channel: VoiceChannel) -> None:
@ -606,9 +670,11 @@ class Player(VoiceProtocol):
data={"filters": payload}, data={"filters": payload},
) )
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}") if self._log:
self._log.debug(f"Filter has been applied to player with tag {_filter.tag}")
if fast_apply: if fast_apply:
self._log.debug(f"Fast apply passed, now applying filter instantly.") if self._log:
self._log.debug(f"Fast apply passed, now applying filter instantly.")
await self.seek(self.position) await self.seek(self.position)
return self._filters return self._filters
@ -629,9 +695,44 @@ class Player(VoiceProtocol):
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": payload}, data={"filters": payload},
) )
self._log.debug(f"Filter has been removed from player with tag {filter_tag}") if self._log:
self._log.debug(f"Filter has been removed from player with tag {filter_tag}")
if fast_apply: if fast_apply:
self._log.debug(f"Fast apply passed, now removing filter instantly.") if self._log:
self._log.debug(f"Fast apply passed, now removing filter instantly.")
await self.seek(self.position)
return self._filters
async def edit_filter(
self,
*,
filter_tag: str,
edited_filter: Filter,
fast_apply: bool = False,
) -> Filters:
"""Edits a filter from the player using its filter tag and a new filter of the same type.
The filter to be replaced must have the same tag as the one you are replacing it with.
This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
(You must have a song playing in order for `fast_apply` to work.)
"""
self._filters.edit_filter(filter_tag=filter_tag, to_apply=edited_filter)
payload = self._filters.get_all_payloads()
await self._node.send(
method="PATCH",
path=self._player_endpoint_uri,
guild_id=self._guild.id,
data={"filters": payload},
)
if self._log:
self._log.debug(f"Filter with tag {filter_tag} has been edited to {edited_filter!r}")
if fast_apply:
if self._log:
self._log.debug(f"Fast apply passed, now editing filter instantly.")
await self.seek(self.position) await self.seek(self.position)
return self._filters return self._filters
@ -655,8 +756,10 @@ class Player(VoiceProtocol):
guild_id=self._guild.id, guild_id=self._guild.id,
data={"filters": {}}, data={"filters": {}},
) )
self._log.debug(f"All filters have been removed from player.") if self._log:
self._log.debug(f"All filters have been removed from player.")
if fast_apply: if fast_apply:
self._log.debug(f"Fast apply passed, now removing all filters instantly.") if self._log:
self._log.debug(f"Fast apply passed, now removing all filters instantly.")
await self.seek(self.position) await self.seek(self.position)

View File

@ -17,16 +17,24 @@ from typing import Union
from urllib.parse import quote from urllib.parse import quote
import aiohttp import aiohttp
import orjson as json
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from discord.utils import MISSING from discord.utils import MISSING
try:
from websockets.legacy import client # websockets >= 10.0
except ImportError:
import websockets.client as client # websockets < 10.0 # type: ignore
from websockets import exceptions
from websockets import typing as wstype
from . import __version__ from . import __version__
from . import applemusic from . import applemusic
from . import spotify from . import spotify
from .enums import * from .enums import *
from .enums import LogLevel from .enums import LogLevel
from .exceptions import AppleMusicNotEnabled
from .exceptions import InvalidSpotifyClientAuthorization from .exceptions import InvalidSpotifyClientAuthorization
from .exceptions import LavalinkVersionIncompatible from .exceptions import LavalinkVersionIncompatible
from .exceptions import NodeConnectionFailure from .exceptions import NodeConnectionFailure
@ -71,6 +79,8 @@ class Node:
"_password", "_password",
"_identifier", "_identifier",
"_heartbeat", "_heartbeat",
"_resume_key",
"_resume_timeout",
"_secure", "_secure",
"_fallback", "_fallback",
"_log_level", "_log_level",
@ -91,7 +101,6 @@ class Node:
"_apple_music_client", "_apple_music_client",
"_route_planner", "_route_planner",
"_log", "_log",
"_log_handler",
"_stats", "_stats",
"available", "available",
) )
@ -106,15 +115,16 @@ class Node:
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30, heartbeat: int = 120,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False, fallback: bool = False,
log_level: LogLevel = LogLevel.INFO, logger: Optional[logging.Logger] = None,
log_handler: Optional[logging.Handler] = None,
): ):
if not isinstance(port, int): if not isinstance(port, int):
raise TypeError("Port must be an integer") raise TypeError("Port must be an integer")
@ -126,17 +136,17 @@ class Node:
self._password: str = password self._password: str = password
self._identifier: str = identifier self._identifier: str = identifier
self._heartbeat: int = heartbeat self._heartbeat: int = heartbeat
self._resume_key: Optional[str] = resume_key
self._resume_timeout: int = resume_timeout
self._secure: bool = secure self._secure: bool = secure
self._fallback: bool = fallback self._fallback: bool = fallback
self._log_level: LogLevel = log_level
self._log_handler = log_handler
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: aiohttp.ClientSession = session # type: ignore self._session: aiohttp.ClientSession = session # type: ignore
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: aiohttp.ClientWebSocketResponse self._websocket: client.WebSocketClientProtocol
self._task: asyncio.Task = None # type: ignore self._task: asyncio.Task = None # type: ignore
self._session_id: Optional[str] = None self._session_id: Optional[str] = None
@ -144,7 +154,7 @@ class Node:
self._version: LavalinkVersion = LavalinkVersion(0, 0, 0) self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
self._route_planner = RoutePlanner(self) self._route_planner = RoutePlanner(self)
self._log = self._setup_logging(self._log_level) self._log = logger
if not self._bot.user: if not self._bot.user:
raise NodeCreationError("Bot user is not ready yet.") raise NodeCreationError("Bot user is not ready yet.")
@ -205,7 +215,7 @@ class Node:
@property @property
def player_count(self) -> int: def player_count(self) -> int:
"""Property which returns how many players are connected to this node""" """Property which returns how many players are connected to this node"""
return len(self.players) return len(self.players.values())
@property @property
def pool(self) -> Type[NodePool]: def pool(self) -> Type[NodePool]:
@ -222,29 +232,6 @@ class Node:
"""Alias for `Node.latency`, returns the latency of the node""" """Alias for `Node.latency`, returns the latency of the node"""
return self.latency return self.latency
def _setup_logging(self, level: LogLevel) -> logging.Logger:
logger = logging.getLogger("pomice")
logger.setLevel(level)
handler = None
if self._log_handler:
handler = self._log_handler
else:
handler = logging.StreamHandler()
dt_fmt = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter(
"[{asctime}] [{levelname:<8}] {name}: {message}",
dt_fmt,
style="{",
)
handler.setFormatter(formatter)
if handler:
logger.handlers.clear()
logger.addHandler(handler)
return logger
async def _handle_version_check(self, version: str) -> None: async def _handle_version_check(self, version: str) -> None:
if version.endswith("-SNAPSHOT"): if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4 # we're just gonna assume all snapshot versions correlate with v4
@ -266,7 +253,8 @@ class Node:
int(_version_groups[2] or 0), int(_version_groups[2] or 0),
) )
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}") if self._log:
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
self._version = LavalinkVersion(major=major, minor=minor, fix=fix) self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
if self._version < LavalinkVersion(3, 7, 0): if self._version < LavalinkVersion(3, 7, 0):
self._available = False self._available = False
@ -316,25 +304,59 @@ class Node:
await self.disconnect() await self.disconnect()
async def _listen(self) -> None: async def _configure_resuming(self) -> None:
backoff = ExponentialBackoff(base=7) if not self._resume_key:
return
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
if self._version.major == 3:
data["resumingKey"] = self._resume_key
elif self._version.major == 4:
if self._log:
self._log.warning("Using a resume key with Lavalink v4 is deprecated.")
data["resuming"] = True
await self.send(
method="PATCH",
path=f"sessions/{self._session_id}",
include_version=True,
data=data,
)
async def _listen(self) -> None:
while True: while True:
msg = await self._websocket.receive() try:
if msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING): msg = await self._websocket.recv()
data = json.loads(msg)
if self._log:
self._log.debug(f"Recieved raw websocket message {msg}")
self._loop.create_task(self._handle_ws_msg(data=data))
except exceptions.ConnectionClosed:
if self.player_count > 0:
for _player in self.players.values():
self._loop.create_task(_player.destroy())
if self._fallback: if self._fallback:
await self._handle_node_switch() self._loop.create_task(self._handle_node_switch())
self._loop.create_task(self._websocket.close())
backoff = ExponentialBackoff(base=7)
retry = backoff.delay() retry = backoff.delay()
if self._log:
self._log.debug(
f"Retrying connection to Node {self._identifier} in {retry} secs",
)
await asyncio.sleep(retry) await asyncio.sleep(retry)
if not self.is_connected: if not self.is_connected:
self._loop.create_task(self.connect(reconnect=True)) self._loop.create_task(self.connect(reconnect=True))
else:
self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict) -> None: async def _handle_ws_msg(self, data: dict) -> None:
if self._log:
self._log.debug(f"Recieved raw payload from Node {self._identifier} with data {data}")
op = data.get("op", None) op = data.get("op", None)
if not op:
return
if op == "stats": if op == "stats":
self._stats = NodeStats(data) self._stats = NodeStats(data)
@ -342,21 +364,20 @@ class Node:
if op == "ready": if op == "ready":
self._session_id = data["sessionId"] self._session_id = data["sessionId"]
await self._configure_resuming()
if not "guildId" in data: if not "guildId" in data:
return return
player = self._players.get(int(data["guildId"])) player: Optional[Player] = self._players.get(int(data["guildId"]))
if not player: if not player:
return return
if op == "event": if op == "event":
await player._dispatch_event(data) return await player._dispatch_event(data)
return
if op == "playerUpdate": if op == "playerUpdate":
await player._update_state(data) return await player._update_state(data)
return
async def send( async def send(
self, self,
@ -387,9 +408,10 @@ class Node:
headers=self._headers, headers=self._headers,
json=data or {}, json=data or {},
) )
self._log.debug( if self._log:
f"Making REST request to Node {self._identifier} with method {method} to {uri}", self._log.debug(
) f"Making REST request to Node {self._identifier} with method {method} to {uri}",
)
if resp.status >= 300: if resp.status >= 300:
resp_data: dict = await resp.json() resp_data: dict = await resp.json()
raise NodeRestException( raise NodeRestException(
@ -397,34 +419,47 @@ class Node:
) )
if method == "DELETE" or resp.status == 204: if method == "DELETE" or resp.status == 204:
self._log.debug( if self._log:
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.", self._log.debug(
) f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.",
)
return await resp.json(content_type=None) return await resp.json(content_type=None)
if resp.content_type == "text/plain": if resp.content_type == "text/plain":
self._log.debug( if self._log:
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", self._log.debug(
) f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}",
)
return await resp.text() return await resp.text()
self._log.debug( if self._log:
f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", self._log.debug(
) f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}",
)
return await resp.json() return await resp.json()
def get_player(self, guild_id: int) -> Optional[Player]: def get_player(self, guild_id: int) -> Optional[Player]:
"""Takes a guild ID as a parameter. Returns a pomice Player object or None.""" """Takes a guild ID as a parameter. Returns a pomice Player object or None."""
return self._players.get(guild_id, None) return self._players.get(guild_id, None)
async def connect(self, *, reconnect: bool = False) -> "Node": async def connect(self, *, reconnect: bool = False) -> Node:
"""Initiates a connection with a Lavalink node and adds it to the node pool.""" """Initiates a connection with a Lavalink node and adds it to the node pool."""
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
start = time.perf_counter() start = time.perf_counter()
if not self._session: if not self._session:
self._session = aiohttp.ClientSession() # Configure connection pooling for optimal concurrent request performance
connector = aiohttp.TCPConnector(
limit=100, # Total connection limit
limit_per_host=30, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL in seconds
)
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
)
try: try:
if not reconnect: if not reconnect:
@ -438,23 +473,28 @@ class Node:
await self._handle_version_check(version=version) await self._handle_version_check(version=version)
await self._set_ext_client_session(session=self._session) await self._set_ext_client_session(session=self._session)
self._log.debug( if self._log:
f"Version check from Node {self._identifier} successful. Returned version {version}", self._log.debug(
) f"Version check from Node {self._identifier} successful. Returned version {version}",
)
self._websocket = await self._session.ws_connect( self._websocket = await client.connect( # type: ignore
f"{self._websocket_uri}/v{self._version.major}/websocket", f"{self._websocket_uri}/v{self._version.major}/websocket",
headers=self._headers, extra_headers=self._headers,
heartbeat=self._heartbeat, ping_interval=self._heartbeat,
) )
if reconnect: if reconnect:
for player in self.players.values(): if self._log:
await player._refresh_endpoint_uri(self._session_id) self._log.debug(f"Trying to reconnect to Node {self._identifier}...")
if self.player_count:
for player in self.players.values():
await player._refresh_endpoint_uri(self._session_id)
self._log.debug( if self._log:
f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket", self._log.debug(
) f"Node {self._identifier} successfully connected to websocket using {self._websocket_uri}/v{self._version.major}/websocket",
)
if not self._task: if not self._task:
self._task = self._loop.create_task(self._listen()) self._task = self._loop.create_task(self._listen())
@ -463,18 +503,19 @@ class Node:
end = time.perf_counter() end = time.perf_counter()
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s") if self._log:
self._log.info(f"Connected to node {self._identifier}. Took {end - start:.3f}s")
return self return self
except (aiohttp.ClientConnectorError, ConnectionRefusedError): except (aiohttp.ClientConnectorError, OSError, ConnectionRefusedError):
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed.", f"The connection to node '{self._identifier}' failed.",
) from None ) from None
except aiohttp.WSServerHandshakeError: except exceptions.InvalidHandshake:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid.", f"The password for node '{self._identifier}' is invalid.",
) from None ) from None
except aiohttp.InvalidURL: except exceptions.InvalidURI:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid.", f"The URI for node '{self._identifier}' is invalid.",
) from None ) from None
@ -488,20 +529,23 @@ class Node:
for player in self.players.copy().values(): for player in self.players.copy().values():
await player.destroy() await player.destroy()
self._log.debug("All players disconnected from node.") if self._log:
self._log.debug("All players disconnected from node.")
await self._websocket.close() await self._websocket.close()
await self._session.close() await self._session.close()
self._log.debug("Websocket and http session closed.") if self._log:
self._log.debug("Websocket and http session closed.")
del self._pool._nodes[self._identifier] del self._pool._nodes[self._identifier]
self.available = False self.available = False
self._task.cancel() self._task.cancel()
end = time.perf_counter() end = time.perf_counter()
self._log.info( if self._log:
f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s", self._log.info(
) f"Successfully disconnected from node {self._identifier} and closed all sessions. Took {end - start:.3f}s",
)
async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track: async def build_track(self, identifier: str, ctx: Optional[commands.Context] = None) -> Track:
""" """
@ -514,13 +558,16 @@ class Node:
data: dict = await self.send( data: dict = await self.send(
method="GET", method="GET",
path="decodetrack", path="decodetrack",
query=f"encodedTrack={identifier}", query=f"encodedTrack={quote(identifier)}",
) )
track_info = data["info"] if self._version.major >= 4 else data
return Track( return Track(
track_id=identifier, track_id=identifier,
ctx=ctx, ctx=ctx,
info=data, info=track_info,
track_type=TrackType(data["sourceName"]), track_type=TrackType(track_info["sourceName"]),
) )
async def get_tracks( async def get_tracks(
@ -528,7 +575,7 @@ class Node:
query: str, query: str,
*, *,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: Optional[SearchType] = SearchType.ytsearch,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
) -> Optional[Union[Playlist, List[Track]]]: ) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
@ -549,13 +596,13 @@ class Node:
for filter in filters: for filter in filters:
filter.set_preload() filter.set_preload()
if URLRegex.AM_URL.match(query): # Due to the inclusion of plugins in the v4 update
if not self._apple_music_client: # we are doing away with raising an error if pomice detects
raise AppleMusicNotEnabled( # either a Spotify or Apple Music URL and the respective client
"You must have Apple Music functionality enabled in order to play Apple Music tracks." # is not enabled. Instead, we will just only parse the URL
"Please set apple_music to True in your Node class.", # if the client is enabled and the URL is valid.
)
if self._apple_music_client and URLRegex.AM_URL.match(query):
apple_music_results = await self._apple_music_client.search(query=query) apple_music_results = await self._apple_music_client.search(query=query)
if isinstance(apple_music_results, applemusic.Song): if isinstance(apple_music_results, applemusic.Song):
return [ return [
@ -563,7 +610,7 @@ class Node:
track_id=apple_music_results.id, track_id=apple_music_results.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.APPLE_MUSIC, track_type=TrackType.APPLE_MUSIC,
search_type=search_type, search_type=search_type or SearchType.ytsearch,
filters=filters, filters=filters,
info={ info={
"title": apple_music_results.name, "title": apple_music_results.name,
@ -585,7 +632,7 @@ class Node:
track_id=track.id, track_id=track.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.APPLE_MUSIC, track_type=TrackType.APPLE_MUSIC,
search_type=search_type, search_type=search_type or SearchType.ytsearch,
filters=filters, filters=filters,
info={ info={
"title": track.name, "title": track.name,
@ -614,14 +661,7 @@ class Node:
uri=apple_music_results.url, uri=apple_music_results.url,
) )
elif URLRegex.SPOTIFY_URL.match(query): elif self._spotify_client and URLRegex.SPOTIFY_URL.match(query):
if not self._spotify_client_id and not self._spotify_client_secret:
raise InvalidSpotifyClientAuthorization(
"You did not provide proper Spotify client authorization credentials. "
"If you would like to use the Spotify searching feature, "
"please obtain Spotify API credentials here: https://developer.spotify.com/",
)
spotify_results = await self._spotify_client.search(query=query) # type: ignore spotify_results = await self._spotify_client.search(query=query) # type: ignore
if isinstance(spotify_results, spotify.Track): if isinstance(spotify_results, spotify.Track):
@ -630,7 +670,7 @@ class Node:
track_id=spotify_results.id, track_id=spotify_results.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.SPOTIFY, track_type=TrackType.SPOTIFY,
search_type=search_type, search_type=search_type or SearchType.ytsearch,
filters=filters, filters=filters,
info={ info={
"title": spotify_results.name, "title": spotify_results.name,
@ -652,7 +692,7 @@ class Node:
track_id=track.id, track_id=track.id,
ctx=ctx, ctx=ctx,
track_type=TrackType.SPOTIFY, track_type=TrackType.SPOTIFY,
search_type=search_type, search_type=search_type or SearchType.ytsearch,
filters=filters, filters=filters,
info={ info={
"title": track.name, "title": track.name,
@ -681,63 +721,14 @@ class Node:
uri=spotify_results.uri, uri=spotify_results.uri,
) )
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0]
info: dict = track["info"]
return [
Track(
track_id=track["track"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": info["length"],
"uri": info["uri"],
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
),
]
elif path.exists(path.dirname(query)):
local_file = Path(query)
data: dict = await self.send( # type: ignore
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0] # type: ignore
info: dict = track["info"] # type: ignore
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": info["length"],
"uri": quote(local_file.as_uri()),
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
),
]
else: else:
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): if (
search_type
and not URLRegex.BASE_URL.match(query)
and not re.match(r"(?:[a-z]+?)search:.", query)
and not URLRegex.DISCORD_MP3_URL.match(query)
and not path.exists(path.dirname(query))
):
query = f"{search_type}:{query}" query = f"{search_type}:{query}"
# If YouTube url contains a timestamp, capture it for use later. # If YouTube url contains a timestamp, capture it for use later.
@ -753,21 +744,31 @@ class Node:
load_type = data.get("loadType") load_type = data.get("loadType")
# Lavalink v4 changed the name of the key from "tracks" to "data"
# so lets account for that
data_type = "data" if self._version.major >= 4 else "tracks"
if not load_type: if not load_type:
raise TrackLoadError( raise TrackLoadError(
"There was an error while trying to load this track.", "There was an error while trying to load this track.",
) )
elif load_type == "LOAD_FAILED": elif load_type in ("LOAD_FAILED", "error"):
exception = data["exception"] exception = data["data"] if self._version.major >= 4 else data["exception"]
raise TrackLoadError( raise TrackLoadError(
f"{exception['message']} [{exception['severity']}]", f"{exception['message']} [{exception['severity']}]",
) )
elif load_type == "NO_MATCHES": elif load_type in ("NO_MATCHES", "empty"):
return None return None
elif load_type == "PLAYLIST_LOADED": elif load_type in ("PLAYLIST_LOADED", "playlist"):
if self._version.major >= 4:
track_list = data[data_type]["tracks"]
playlist_info = data[data_type]["info"]
else:
track_list = data[data_type]
playlist_info = data["playlistInfo"]
tracks = [ tracks = [
Track( Track(
track_id=track["encoded"], track_id=track["encoded"],
@ -775,17 +776,60 @@ class Node:
ctx=ctx, ctx=ctx,
track_type=TrackType(track["info"]["sourceName"]), track_type=TrackType(track["info"]["sourceName"]),
) )
for track in data["tracks"] for track in track_list
] ]
return Playlist( return Playlist(
playlist_info=data["playlistInfo"], playlist_info=playlist_info,
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType(tracks[0].track_type.value), playlist_type=PlaylistType(tracks[0].track_type.value),
thumbnail=tracks[0].thumbnail, thumbnail=tracks[0].thumbnail,
uri=query, uri=query,
) )
elif load_type == "SEARCH_RESULT" or load_type == "TRACK_LOADED": elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
if self._version.major >= 4 and isinstance(data[data_type], dict):
data[data_type] = [data[data_type]]
if path.exists(path.dirname(query)):
local_file = Path(query)
return [
Track(
track_id=track["encoded"],
info={
"title": local_file.name,
"author": "Unknown",
"length": track["info"]["length"],
"uri": quote(local_file.as_uri()),
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
)
for track in data[data_type]
]
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
return [
Track(
track_id=track["encoded"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": track["info"]["length"],
"uri": track["info"]["uri"],
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
)
for track in data[data_type]
]
return [ return [
Track( Track(
track_id=track["encoded"], track_id=track["encoded"],
@ -795,7 +839,7 @@ class Node:
filters=filters, filters=filters,
timestamp=timestamp, timestamp=timestamp,
) )
for track in data["tracks"] for track in data[data_type]
] ]
else: else:
@ -804,7 +848,10 @@ class Node:
) )
async def get_recommendations( async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None self,
*,
track: Track,
ctx: Optional[commands.Context] = None,
) -> Optional[Union[List[Track], Playlist]]: ) -> Optional[Union[List[Track], Playlist]]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
@ -849,6 +896,57 @@ class Node:
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.", "The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
) )
async def search_spotify_recommendations(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""
Searches for recommendations on Spotify and returns a list of tracks based on the query.
You must have Spotify enabled for this to work.
You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended.
"""
if not self._spotify_client:
raise InvalidSpotifyClientAuthorization(
"You must have Spotify enabled to use this feature.",
)
results = await self._spotify_client.track_search(query=query) # type: ignore
if not results:
raise TrackLoadError(
"Unable to find any tracks based on the query.",
)
tracks = [
Track(
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
info={
"title": track.name,
"author": track.artists,
"length": track.length,
"identifier": track.id,
"uri": track.uri,
"isStream": False,
"isSeekable": True,
"position": 0,
"thumbnail": track.image,
"isrc": track.isrc,
},
requester=self.bot.user,
)
for track in results
]
track = tracks[0]
return await self.get_recommendations(track=track, ctx=ctx)
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.
@ -930,15 +1028,16 @@ class NodePool:
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30, heartbeat: int = 120,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False, fallback: bool = False,
log_level: LogLevel = LogLevel.INFO, logger: Optional[logging.Logger] = None,
log_handler: Optional[logging.Handler] = None,
) -> Node: ) -> Node:
"""Creates a Node object to be then added into the node pool. """Creates a Node object to be then added into the node pool.
For Spotify searching capabilites, pass in valid Spotify API credentials. For Spotify searching capabilites, pass in valid Spotify API credentials.
@ -957,14 +1056,15 @@ class NodePool:
identifier=identifier, identifier=identifier,
secure=secure, secure=secure,
heartbeat=heartbeat, heartbeat=heartbeat,
resume_key=resume_key,
resume_timeout=resume_timeout,
loop=loop, loop=loop,
spotify_client_id=spotify_client_id, spotify_client_id=spotify_client_id,
session=session, session=session,
spotify_client_secret=spotify_client_secret, spotify_client_secret=spotify_client_secret,
apple_music=apple_music, apple_music=apple_music,
fallback=fallback, fallback=fallback,
log_level=log_level, logger=logger,
log_handler=log_handler,
) )
await node.connect() await node.connect()

View File

@ -203,9 +203,13 @@ class Queue(Iterable[Track]):
raise QueueEmpty("No items in the queue.") raise QueueEmpty("No items in the queue.")
if self._loop_mode == LoopMode.QUEUE: if self._loop_mode == LoopMode.QUEUE:
# recurse if the item isnt in the queue # set current item to first track in queue if not set already
if self._current_item not in self._queue: # otherwise exception will be raised
self.get() if not self._current_item or self._current_item not in self._queue:
if self._queue:
item = self._queue[0]
else:
raise QueueEmpty("No items in the queue.")
# set current item to first track in queue if not set already # set current item to first track in queue if not set already
if not self._current_item: if not self._current_item:

View File

@ -1,13 +1,16 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import re import re
import time import time
from base64 import b64encode from base64 import b64encode
from typing import AsyncGenerator
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from urllib.parse import quote
import aiohttp import aiohttp
import orjson as json import orjson as json
@ -21,8 +24,10 @@ __all__ = ("Client",)
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
# Keep this in sync with URLRegex.SPOTIFY_URL (enums.py). Accept intl locale segment,
# optional trailing slash, and query parameters.
SPOTIFY_URL_REGEX = re.compile( SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)", r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
) )
@ -32,29 +37,39 @@ class Client:
for any Spotify URL you throw at it. for any Spotify URL you throw at it.
""" """
def __init__(self, client_id: str, client_secret: str) -> None: def __init__(
self._client_id: str = client_id self,
self._client_secret: str = client_secret client_id: str,
client_secret: str,
*,
playlist_concurrency: int = 10,
playlist_page_limit: Optional[int] = None,
) -> None:
self._client_id = client_id
self._client_secret = client_secret
self.session: aiohttp.ClientSession = None # type: ignore # HTTP session will be injected by Node
self.session: Optional[aiohttp.ClientSession] = None
self._bearer_token: Optional[str] = None self._bearer_token: Optional[str] = None
self._expiry: float = 0.0 self._expiry: float = 0.0
self._auth_token = b64encode( self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
f"{self._client_id}:{self._client_secret}".encode(), self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
)
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__) self._log = logging.getLogger(__name__)
# Performance tuning knobs
self._playlist_concurrency = max(1, playlist_concurrency)
self._playlist_page_limit = playlist_page_limit
async def _set_session(self, session: aiohttp.ClientSession) -> None: async def _set_session(self, session: aiohttp.ClientSession) -> None:
self.session = session self.session = session
async def _fetch_bearer_token(self) -> None: async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"} _data = {"grant_type": "client_credentials"}
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers)
if resp.status != 200: if resp.status != 200:
@ -63,7 +78,8 @@ class Client:
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug(f"Fetched Spotify bearer token successfully") if self._log:
self._log.debug(f"Fetched Spotify bearer token successfully")
self._bearer_token = data["access_token"] self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10) self._expiry = time.time() + (int(data["expires_in"]) - 10)
@ -84,6 +100,8 @@ class Client:
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id) request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers) resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
@ -91,15 +109,18 @@ class Client:
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug( if self._log:
f"Made request to Spotify API with status {resp.status} and response {data}", self._log.debug(
) f"Made request to Spotify API with status {resp.status} and response {data}",
)
if spotify_type == "track": if spotify_type == "track":
return Track(data) return Track(data)
elif spotify_type == "album": elif spotify_type == "album":
return Album(data) return Album(data)
elif spotify_type == "artist": elif spotify_type == "artist":
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get( resp = await self.session.get(
f"{request_url}/top-tracks?market=US", f"{request_url}/top-tracks?market=US",
headers=self._bearer_headers, headers=self._bearer_headers,
@ -113,37 +134,178 @@ class Client:
tracks = track_data["tracks"] tracks = track_data["tracks"]
return Artist(data, tracks) return Artist(data, tracks)
else: else:
# For playlists we optionally use a reduced fields payload to shrink response sizes.
# NB: We cannot apply fields filter to initial request because original metadata is needed.
tracks = [ tracks = [
Track(track["track"]) Track(track["track"])
for track in data["tracks"]["items"] for track in data["tracks"]["items"]
if track["track"] is not None if track["track"] is not None
] ]
if not len(tracks): if not tracks:
raise SpotifyRequestException( raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued.", "This playlist is empty and therefore cannot be queued.",
) )
next_page_url = data["tracks"]["next"] total_tracks = data["tracks"]["total"]
limit = data["tracks"]["limit"]
while next_page_url is not None: # Shortcircuit small playlists (single page)
resp = await self.session.get(next_page_url, headers=self._bearer_headers) if total_tracks <= limit:
if resp.status != 200: return Playlist(data, tracks)
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}", # Build remaining page URLs; Spotify supports offset-based pagination.
remaining_offsets = range(limit, total_tracks, limit)
page_urls: List[str] = []
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
for idx, offset in enumerate(remaining_offsets):
if self._playlist_page_limit is not None and idx >= self._playlist_page_limit:
break
page_urls.append(
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}",
)
if page_urls:
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch_page(url: str) -> Optional[List[Track]]:
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
resp = await self.session.get(url, headers=self._bearer_headers)
if resp.status != 200:
if self._log:
self._log.warning(
f"Page fetch failed {resp.status} {resp.reason} for {url}",
)
return None
page_json: dict = await resp.json(loads=json.loads)
return [
Track(item["track"])
for item in page_json.get("items", [])
if item.get("track") is not None
]
# Chunk gather in waves to avoid creating thousands of tasks at once
aggregated: List[Track] = []
wave_size = self._playlist_concurrency * 2
for i in range(0, len(page_urls), wave_size):
wave = page_urls[i : i + wave_size]
results = await asyncio.gather(
*[fetch_page(url) for url in wave],
return_exceptions=False,
) )
for result in results:
if result:
aggregated.extend(result)
next_data: dict = await resp.json(loads=json.loads) tracks.extend(aggregated)
tracks += [
Track(track["track"])
for track in next_data["items"]
if track["track"] is not None
]
next_page_url = next_data["next"]
return Playlist(data, tracks) return Playlist(data, tracks)
async def iter_playlist_tracks(
self,
*,
query: str,
batch_size: int = 100,
) -> AsyncGenerator[List[Track], None]:
"""Stream playlist tracks in batches without waiting for full materialization.
Parameters
----------
query: str
Spotify playlist URL.
batch_size: int
Number of tracks yielded per batch (logical grouping after fetch). Does not alter API page size.
"""
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
match = SPOTIFY_URL_REGEX.match(query)
if not match or match.group("type") != "playlist":
raise InvalidSpotifyURL("Provided query is not a valid Spotify playlist URL.")
playlist_id = match.group("id")
request_url = REQUEST_URL.format(type="playlist", id=playlist_id)
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
# Yield first page immediately
first_page_tracks = [
Track(item["track"])
for item in data["tracks"]["items"]
if item.get("track") is not None
]
# Batch yield
for i in range(0, len(first_page_tracks), batch_size):
yield first_page_tracks[i : i + batch_size]
total = data["tracks"]["total"]
limit = data["tracks"]["limit"]
remaining_offsets = range(limit, total, limit)
fields_filter = (
"items(track(name,duration_ms,id,is_local,external_urls,external_ids,artists(name),album(images)))"
",next"
)
semaphore = asyncio.Semaphore(self._playlist_concurrency)
async def fetch(offset: int) -> List[Track]:
url = (
f"{request_url}/tracks?offset={offset}&limit={limit}&fields={quote(fields_filter)}"
)
async with semaphore:
if not self.session:
raise SpotifyRequestException(
"HTTP session not initialized for Spotify client.",
)
r = await self.session.get(url, headers=self._bearer_headers)
if r.status != 200:
if self._log:
self._log.warning(
f"Skipping page offset={offset} due to {r.status} {r.reason}",
)
return []
pj: dict = await r.json(loads=json.loads)
return [
Track(item["track"])
for item in pj.get("items", [])
if item.get("track") is not None
]
# Fetch pages in rolling waves; yield promptly as soon as a wave completes.
wave_size = self._playlist_concurrency * 2
for i, offset in enumerate(remaining_offsets):
# Build wave
if i % wave_size == 0:
wave_offsets = list(
o for o in remaining_offsets if o >= offset and o < offset + wave_size
)
results = await asyncio.gather(*[fetch(o) for o in wave_offsets])
for page_tracks in results:
if not page_tracks:
continue
for j in range(0, len(page_tracks), batch_size):
yield page_tracks[j : j + batch_size]
# Skip ahead in iterator by adjusting enumerate drive (consume extras)
# Fast-forward the generator manually
for _ in range(len(wave_offsets) - 1):
try:
next(remaining_offsets) # type: ignore
except StopIteration:
break
async def get_recommendations(self, *, query: str) -> List[Track]: async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token() await self._fetch_bearer_token()
@ -165,6 +327,8 @@ class Client:
id=f"?seed_tracks={spotify_id}", id=f"?seed_tracks={spotify_id}",
) )
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers) resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
@ -175,3 +339,22 @@ class Client:
tracks = [Track(track) for track in data["tracks"]] tracks = [Track(track) for track in data["tracks"]]
return tracks return tracks
async def track_search(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
if not self.session:
raise SpotifyRequestException("HTTP session not initialized for Spotify client.")
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
tracks = [Track(track) for track in data["tracks"]["items"]]
return tracks

View File

@ -4,7 +4,7 @@ import re
import setuptools import setuptools
version = "" version = ""
requirements = ["aiohttp>=3.7.4,<4", "orjson"] requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"]
with open("pomice/__init__.py") as f: with open("pomice/__init__.py") as f:
version = re.search( version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',