feat: add typing; add makefile; add pipfile
This commit is contained in:
parent
481e616414
commit
987de07fc5
|
|
@ -7,3 +7,7 @@ docs/_build/
|
|||
build/
|
||||
.gitpod.yml
|
||||
.python-verson
|
||||
Pipfile.lock
|
||||
.mypy_cache/
|
||||
.vscode/
|
||||
.venv/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: check-builtin-literals
|
||||
- id: debug-statements
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/pre-commit/mirrors-autopep8
|
||||
rev: v2.0.2
|
||||
hooks:
|
||||
- id: autopep8
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.3.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py37-plus, --keep-runtime-typing]
|
||||
- repo: https://github.com/asottile/reorder_python_imports
|
||||
rev: v3.9.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
- repo: https://github.com/asottile/add-trailing-comma
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: add-trailing-comma
|
||||
- repo: https://github.com/hadialqattan/pycln
|
||||
rev: v2.1.3
|
||||
hooks:
|
||||
- id: pycln
|
||||
|
||||
default_language_version:
|
||||
python: python3.8
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
prepare:
|
||||
pipenv install --dev
|
||||
pipenv run pre-commit install
|
||||
|
||||
shell:
|
||||
pipenv shell
|
||||
|
||||
lint:
|
||||
pipenv run pre-commit run --all-files
|
||||
|
||||
test:
|
||||
pipenv run mypy
|
||||
|
||||
serve-docs:
|
||||
@cd docs;\
|
||||
make html;\
|
||||
cd build/html;\
|
||||
python -m http.server;\
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
[[source]]
|
||||
url = "https://pypi.org/simple"
|
||||
verify_ssl = true
|
||||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
orjson = "*"
|
||||
"discord.py" = {extras = ["voice"], version = "*"}
|
||||
|
||||
[dev-packages]
|
||||
mypy = "*"
|
||||
pre-commit = "*"
|
||||
furo = "*"
|
||||
sphinx = "*"
|
||||
myst-parser = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.8"
|
||||
|
|
@ -2,12 +2,12 @@ import importlib
|
|||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
|
||||
|
||||
project = 'Pomice'
|
||||
copyright = '2023, cloudwithax'
|
||||
author = 'cloudwithax'
|
||||
|
|
@ -19,7 +19,7 @@ extensions = [
|
|||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.linkcode',
|
||||
'myst_parser'
|
||||
'myst_parser',
|
||||
]
|
||||
|
||||
myst_enable_extensions = [
|
||||
|
|
@ -84,6 +84,7 @@ html_theme_options: Dict[str, Any] = {
|
|||
# Grab lines from source files and embed into the docs
|
||||
# so theres a point of reference
|
||||
|
||||
|
||||
def linkcode_resolve(domain, info):
|
||||
# i absolutely MUST add this here or else
|
||||
# the docs will not build. fuck sphinx
|
||||
|
|
@ -93,7 +94,6 @@ def linkcode_resolve(domain, info):
|
|||
if not info['module']:
|
||||
return None
|
||||
|
||||
|
||||
mod = importlib.import_module(info["module"])
|
||||
if "." in info["fullname"]:
|
||||
objname, attrname = info["fullname"].split(".")
|
||||
|
|
@ -117,4 +117,3 @@ def linkcode_resolve(domain, info):
|
|||
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,3 @@ remote, or the node.
|
|||
|
||||
`Event.WebsocketOpenEvent()` carries a target, which is usually the node IP, and the SSRC, a 32-bit integer uniquely identifying the source of the RTP packets sent from
|
||||
Lavalink.
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -184,4 +184,3 @@ After you have initialized your function, you can optionally include the `fast_a
|
|||
await Player.reset_filters(fast_apply=<True/False>)
|
||||
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -14,4 +14,3 @@ filters.md
|
|||
queue.md
|
||||
events.md
|
||||
``
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ There are also properties the `Node` class has to access certain values:
|
|||
- Description
|
||||
|
||||
* - `Node.bot`
|
||||
- `Union[Client, Bot]`
|
||||
- `Client`
|
||||
- Returns the discord.py client linked to this node.
|
||||
|
||||
* - `Node.is_connected`
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ There are also properties the `Player` class has to access certain values:
|
|||
- Description
|
||||
|
||||
* - `Player.bot`
|
||||
- `Union[Client, commands.Bot]`
|
||||
- `Client`
|
||||
- Returns the bot associated with this player instance.
|
||||
|
||||
* - `Player.current`
|
||||
|
|
@ -466,15 +466,3 @@ After you have initialized your function, you can optionally include the `fast_a
|
|||
await Player.reset_filters(fast_apply=<True/False>)
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -148,15 +148,3 @@ await NodePool.disconnect()
|
|||
```
|
||||
|
||||
After running this function, all nodes in the pool should disconnect and no longer be available to use.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -222,13 +222,3 @@ Your `Track` object must be in the queue if you want to jump to it. Make sure yo
|
|||
:::
|
||||
|
||||
After running this function, any items before the specified item will be removed, effectively "jumping" to the specified item in the queue. The next item obtained using `Queue.get()` will be your specified track.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,3 @@ hdi/index.md
|
|||
:hidden:
|
||||
api/index.md
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,5 +28,3 @@ You are free to use this as a base to add on to for any music features you want
|
|||
If you want to jump into the library and learn how to do everything you need, refer to the [How Do I?](hdi/index.md) section.
|
||||
|
||||
If you want a deeper look into how the library works beyond the [How Do I?](hdi/index.md) guide, refer to the [API Reference](api/index.md) section.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
discord.py[voice]
|
||||
aiohttp
|
||||
orjson
|
||||
myst_parser
|
||||
discord.py[voice]
|
||||
furo
|
||||
myst_parser
|
||||
orjson
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ This is in the form of a drop-in cog you can use and modify to your liking.
|
|||
This example aims to include everything you would need to make a fully functioning music bot,
|
||||
from a queue system, advanced queue control and more.
|
||||
"""
|
||||
import math
|
||||
from contextlib import suppress
|
||||
|
||||
import discord
|
||||
import pomice
|
||||
import math
|
||||
|
||||
from discord.ext import commands
|
||||
from contextlib import suppress
|
||||
|
||||
import pomice
|
||||
|
||||
|
||||
class Player(pomice.Player):
|
||||
|
|
@ -44,7 +44,6 @@ class Player(pomice.Player):
|
|||
with suppress(discord.HTTPException):
|
||||
await self.controller.delete()
|
||||
|
||||
|
||||
# Queue up the next track, else teardown the player
|
||||
try:
|
||||
track: pomice.Track = self.queue.get()
|
||||
|
|
@ -56,13 +55,16 @@ class Player(pomice.Player):
|
|||
# Call the controller (a.k.a: The "Now Playing" embed) and check if one exists
|
||||
|
||||
if track.is_stream:
|
||||
embed = discord.Embed(title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]")
|
||||
embed = discord.Embed(
|
||||
title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]",
|
||||
)
|
||||
self.controller = await self.context.send(embed=embed)
|
||||
else:
|
||||
embed = discord.Embed(title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]")
|
||||
embed = discord.Embed(
|
||||
title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]",
|
||||
)
|
||||
self.controller = await self.context.send(embed=embed)
|
||||
|
||||
|
||||
async def teardown(self):
|
||||
"""Clear internal states, remove player controller and disconnect."""
|
||||
with suppress((discord.HTTPException), (KeyError)):
|
||||
|
|
@ -76,8 +78,6 @@ class Player(pomice.Player):
|
|||
self.dj = ctx.author
|
||||
|
||||
|
||||
|
||||
|
||||
class Music(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
|
|
@ -100,7 +100,7 @@ class Music(commands.Cog):
|
|||
host="127.0.0.1",
|
||||
port="3030",
|
||||
password="youshallnotpass",
|
||||
identifier="MAIN"
|
||||
identifier="MAIN",
|
||||
)
|
||||
print(f"Node is ready!")
|
||||
|
||||
|
|
@ -122,7 +122,6 @@ class Music(commands.Cog):
|
|||
|
||||
return player.dj == ctx.author or ctx.author.guild_permissions.kick_members
|
||||
|
||||
|
||||
# The following are events from pomice.events
|
||||
# We are using these so that if the track either stops or errors,
|
||||
# we can just skip to the next track
|
||||
|
|
@ -195,8 +194,6 @@ class Music(commands.Cog):
|
|||
if not player.is_playing:
|
||||
await player.do_next()
|
||||
|
||||
|
||||
|
||||
@commands.command(aliases=['pau', 'pa'])
|
||||
async def pause(self, ctx: commands.Context):
|
||||
"""Pause the currently playing song."""
|
||||
|
|
@ -345,6 +342,6 @@ class Music(commands.Cog):
|
|||
await player.set_volume(vol)
|
||||
await ctx.send(f'Set the volume to **{vol}**%', delete_after=7)
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
await bot.add_cog(Music(bot))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import discord
|
||||
import pomice
|
||||
from discord.ext import commands
|
||||
|
||||
import pomice
|
||||
|
||||
|
||||
class MyBot(commands.Bot):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
command_prefix="!",
|
||||
activity=discord.Activity(type=discord.ActivityType.listening, name="to music!")
|
||||
activity=discord.Activity(
|
||||
type=discord.ActivityType.listening, name="to music!",
|
||||
),
|
||||
)
|
||||
|
||||
self.add_cog(Music(self))
|
||||
|
|
@ -33,7 +36,7 @@ class Music(commands.Cog):
|
|||
host="127.0.0.1",
|
||||
port="3030",
|
||||
password="youshallnotpass",
|
||||
identifier="MAIN"
|
||||
identifier="MAIN",
|
||||
)
|
||||
print(f"Node is ready!")
|
||||
|
||||
|
|
@ -44,7 +47,7 @@ class Music(commands.Cog):
|
|||
if not channel:
|
||||
raise commands.CheckFailure(
|
||||
"You must be in a voice channel to use this command "
|
||||
"without specifying the channel argument."
|
||||
"without specifying the channel argument.",
|
||||
)
|
||||
|
||||
# With the release of discord.py 1.7, you can now add a compatible
|
||||
|
|
@ -78,7 +81,9 @@ class Music(commands.Cog):
|
|||
results = await player.get_tracks(search)
|
||||
|
||||
if not results:
|
||||
raise commands.CommandError("No results were found for that search term.")
|
||||
raise commands.CommandError(
|
||||
"No results were found for that search term.",
|
||||
)
|
||||
|
||||
if isinstance(results, pomice.Playlist):
|
||||
await player.play(track=results.tracks[0])
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ if not discord.version_info.major >= 2:
|
|||
raise DiscordPyOutdated(
|
||||
"You must have discord.py (v2.0 or greater) to use this library. "
|
||||
"Uninstall your current version and install discord.py 2.0 "
|
||||
"using 'pip install discord.py'"
|
||||
"using 'pip install discord.py'",
|
||||
)
|
||||
|
||||
__version__ = "2.2a"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
|
||||
|
||||
from .client import Client
|
||||
from .exceptions import *
|
||||
from .objects import *
|
||||
from .client import Client
|
||||
|
|
|
|||
|
|
@ -1,19 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
import orjson as json
|
||||
import base64
|
||||
|
||||
from datetime import datetime
|
||||
from .objects import *
|
||||
from .exceptions import *
|
||||
from .objects import *
|
||||
|
||||
__all__ = (
|
||||
"Client",
|
||||
)
|
||||
|
||||
AM_URL_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||
)
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||
)
|
||||
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
|
||||
AM_BASE_URL = "https://api.music.apple.com"
|
||||
|
|
@ -26,37 +34,49 @@ class Client:
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.token: str = None
|
||||
self.expiry: datetime = None
|
||||
self.session: aiohttp.ClientSession = None
|
||||
self.headers = None
|
||||
self.expiry: datetime = datetime(1970, 1, 1)
|
||||
self.token: str = ""
|
||||
self.headers: Dict[str, str] = {}
|
||||
self.session: aiohttp.ClientSession = None # type: ignore
|
||||
|
||||
async def request_token(self):
|
||||
async def request_token(self) -> None:
|
||||
if not self.session:
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
text = await resp.text()
|
||||
result = re.search('"(eyJ.+?)"', text).group(1)
|
||||
match = re.search('"(eyJ.+?)"', text)
|
||||
if not match:
|
||||
raise AppleMusicRequestException(
|
||||
"Could not find token in response.",
|
||||
)
|
||||
result = match.group(1)
|
||||
|
||||
self.token = result
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {result}",
|
||||
"Origin": "https://apple.com",
|
||||
}
|
||||
token_split = self.token.split(".")[1]
|
||||
token_json = base64.b64decode(token_split + "=" * (-len(token_split) % 4)).decode()
|
||||
token_json = base64.b64decode(
|
||||
token_split + "=" * (-len(token_split) % 4),
|
||||
).decode()
|
||||
token_data = json.loads(token_json)
|
||||
self.expiry = datetime.fromtimestamp(token_data["exp"])
|
||||
|
||||
async def search(self, query: str):
|
||||
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
|
||||
if not self.token or datetime.utcnow() > self.expiry:
|
||||
await self.request_token()
|
||||
|
||||
result = AM_URL_REGEX.match(query)
|
||||
if not result:
|
||||
raise InvalidAppleMusicURL(
|
||||
"The Apple Music link provided is not valid.",
|
||||
)
|
||||
|
||||
country = result.group("country")
|
||||
type = result.group("type")
|
||||
|
|
@ -75,7 +95,7 @@ class Client:
|
|||
async with self.session.get(request_url, headers=self.headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
|
||||
|
|
@ -84,53 +104,57 @@ class Client:
|
|||
if type == "song":
|
||||
return Song(data)
|
||||
|
||||
elif type == "album":
|
||||
if type == "album":
|
||||
return Album(data)
|
||||
|
||||
elif type == "artist":
|
||||
if type == "artist":
|
||||
async with self.session.get(
|
||||
f"{request_url}/view/top-songs", headers=self.headers
|
||||
f"{request_url}/view/top-songs", headers=self.headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
top_tracks: dict = await resp.json(loads=json.loads)
|
||||
tracks: dict = top_tracks["data"]
|
||||
artist_tracks: dict = top_tracks["data"]
|
||||
|
||||
return Artist(data, tracks=tracks)
|
||||
return Artist(data, tracks=artist_tracks)
|
||||
|
||||
else:
|
||||
track_data: dict = data["relationships"]["tracks"]
|
||||
track_data: dict = data["relationships"]["tracks"]
|
||||
album_tracks: List[Song] = [
|
||||
Song(track)
|
||||
for track in track_data["data"]
|
||||
]
|
||||
|
||||
tracks = [Song(track) for track in track_data.get("data")]
|
||||
if not len(album_tracks):
|
||||
raise AppleMusicRequestException(
|
||||
"This playlist is empty and therefore cannot be queued.",
|
||||
)
|
||||
|
||||
if not len(tracks):
|
||||
raise AppleMusicRequestException(
|
||||
"This playlist is empty and therefore cannot be queued."
|
||||
)
|
||||
_next = track_data.get("next")
|
||||
if _next:
|
||||
next_page_url = AM_BASE_URL + _next
|
||||
|
||||
if track_data.get("next"):
|
||||
next_page_url = AM_BASE_URL + track_data.get("next")
|
||||
while next_page_url is not None:
|
||||
async with self.session.get(next_page_url, headers=self.headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
while next_page_url is not None:
|
||||
async with self.session.get(next_page_url, headers=self.headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise AppleMusicRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
)
|
||||
next_data: dict = await resp.json(loads=json.loads)
|
||||
|
||||
next_data: dict = await resp.json(loads=json.loads)
|
||||
album_tracks.extend(Song(track) for track in next_data["data"])
|
||||
|
||||
tracks += [Song(track) for track in next_data["data"]]
|
||||
if next_data.get("next"):
|
||||
next_page_url = AM_BASE_URL + next_data.get("next")
|
||||
else:
|
||||
next_page_url = None
|
||||
_next = next_data.get("next")
|
||||
if _next:
|
||||
next_page_url = AM_BASE_URL + _next
|
||||
else:
|
||||
next_page_url = None
|
||||
|
||||
return Playlist(data, tracks)
|
||||
return Playlist(data, album_tracks)
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self.session = None # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
__all__ = (
|
||||
"AppleMusicRequestException",
|
||||
"InvalidAppleMusicURL",
|
||||
)
|
||||
|
||||
|
||||
class AppleMusicRequestException(Exception):
|
||||
"""An error occurred when making a request to the Apple Music API"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
"""Module for managing Apple Music objects"""
|
||||
|
||||
from typing import List
|
||||
|
||||
__all__ = (
|
||||
"Song",
|
||||
"Playlist",
|
||||
"Album",
|
||||
"Artist",
|
||||
)
|
||||
|
||||
|
||||
class Song:
|
||||
"""The base class for an Apple Music song"""
|
||||
|
|
@ -55,7 +61,9 @@ class Album:
|
|||
self.id: str = data["id"]
|
||||
self.artists: str = data["attributes"]["artistName"]
|
||||
self.total_tracks: int = data["attributes"]["trackCount"]
|
||||
self.tracks: List[Song] = [Song(track) for track in data["relationships"]["tracks"]["data"]]
|
||||
self.tracks: List[Song] = [
|
||||
Song(track) for track in data["relationships"]["tracks"]["data"]
|
||||
]
|
||||
self.image: str = data["attributes"]["artwork"]["url"].replace(
|
||||
"{w}x{h}",
|
||||
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
|
||||
|
|
@ -75,7 +83,9 @@ class Artist:
|
|||
self.name: str = f'Top tracks for {data["attributes"]["name"]}'
|
||||
self.url: str = data["attributes"]["url"]
|
||||
self.id: str = data["id"]
|
||||
self.genres: str = ", ".join(genre for genre in data["attributes"]["genreNames"])
|
||||
self.genres: str = ", ".join(
|
||||
genre for genre in data["attributes"]["genreNames"]
|
||||
)
|
||||
self.tracks: List[Song] = [Song(track) for track in tracks]
|
||||
self.image: str = data["attributes"]["artwork"]["url"].replace(
|
||||
"{w}x{h}",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,17 @@
|
|||
import re
|
||||
|
||||
from enum import Enum
|
||||
|
||||
__all__ = (
|
||||
"SearchType",
|
||||
"TrackType",
|
||||
"PlaylistType",
|
||||
"NodeAlgorithm",
|
||||
"LoopMode",
|
||||
"RouteStrategy",
|
||||
"RouteIPType",
|
||||
"URLRegex",
|
||||
)
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
"""
|
||||
|
|
@ -185,43 +195,51 @@ class URLRegex:
|
|||
"""
|
||||
|
||||
SPOTIFY_URL = re.compile(
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)"
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||
)
|
||||
|
||||
DISCORD_MP3_URL = re.compile(
|
||||
r"https?://cdn.discordapp.com/attachments/(?P<channel_id>[0-9]+)/"
|
||||
r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+"
|
||||
r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+",
|
||||
)
|
||||
|
||||
YOUTUBE_URL = re.compile(
|
||||
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))"
|
||||
r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$"
|
||||
r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$",
|
||||
)
|
||||
|
||||
YOUTUBE_PLAYLIST_URL = re.compile(
|
||||
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*"
|
||||
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
|
||||
)
|
||||
|
||||
YOUTUBE_VID_IN_PLAYLIST = re.compile(r"(?P<video>^.*?v.*?)(?P<list>&list.*)")
|
||||
YOUTUBE_VID_IN_PLAYLIST = re.compile(
|
||||
r"(?P<video>^.*?v.*?)(?P<list>&list.*)",
|
||||
)
|
||||
|
||||
YOUTUBE_TIMESTAMP = re.compile(r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*")
|
||||
YOUTUBE_TIMESTAMP = re.compile(
|
||||
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
|
||||
)
|
||||
|
||||
AM_URL = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
|
||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
|
||||
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
|
||||
)
|
||||
|
||||
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
|
||||
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
|
||||
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
|
||||
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
|
||||
)
|
||||
|
||||
SOUNDCLOUD_URL = re.compile(r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*")
|
||||
SOUNDCLOUD_URL = re.compile(
|
||||
r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*",
|
||||
)
|
||||
|
||||
SOUNDCLOUD_PLAYLIST_URL = re.compile(r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*")
|
||||
SOUNDCLOUD_PLAYLIST_URL = re.compile(
|
||||
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*",
|
||||
)
|
||||
|
||||
SOUNDCLOUD_TRACK_IN_SET_URL = re.compile(
|
||||
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)"
|
||||
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)",
|
||||
)
|
||||
|
||||
LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")
|
||||
|
|
|
|||
|
|
@ -1,18 +1,32 @@
|
|||
from __future__ import annotations
|
||||
from discord import Client, Guild
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
from typing import Union
|
||||
from abc import ABC
|
||||
|
||||
from discord import Client
|
||||
from discord import Guild
|
||||
from discord.ext import commands
|
||||
|
||||
from .pool import NodePool
|
||||
from .objects import Track
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from .pool import NodePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
|
||||
__all__ = (
|
||||
"PomiceEvent",
|
||||
"TrackStartEvent",
|
||||
"TrackEndEvent",
|
||||
"TrackStuckEvent",
|
||||
"TrackExceptionEvent",
|
||||
"WebSocketClosedPayload",
|
||||
"WebSocketClosedEvent",
|
||||
"WebSocketOpenEvent",
|
||||
)
|
||||
|
||||
class PomiceEvent:
|
||||
|
||||
class PomiceEvent(ABC):
|
||||
"""The base class for all events dispatched by a node.
|
||||
Every event must be formatted within your bot's code as a listener.
|
||||
i.e: If you want to listen for when a track starts, the event would be:
|
||||
|
|
@ -22,10 +36,12 @@ class PomiceEvent:
|
|||
```
|
||||
"""
|
||||
|
||||
name = "event"
|
||||
handler_args = ()
|
||||
__slots__ = ("name", "handler_args")
|
||||
|
||||
def dispatch(self, bot: Union[Client, commands.Bot]):
|
||||
name = "event"
|
||||
handler_args: Tuple
|
||||
|
||||
def dispatch(self, bot: Client) -> None:
|
||||
bot.dispatch(f"pomice_{self.name}", *self.handler_args)
|
||||
|
||||
|
||||
|
|
@ -36,10 +52,15 @@ class TrackStartEvent(PomiceEvent):
|
|||
|
||||
name = "track_start"
|
||||
|
||||
__slots__ = (
|
||||
"player",
|
||||
"track",
|
||||
)
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
__slots__ = ("player", "track")
|
||||
|
||||
self.player: Player = player
|
||||
assert self.player._current is not None
|
||||
self.track: Track = self.player._current
|
||||
|
||||
# on_pomice_track_start(player, track)
|
||||
|
|
@ -56,10 +77,12 @@ class TrackEndEvent(PomiceEvent):
|
|||
|
||||
name = "track_end"
|
||||
|
||||
__slots__ = ("player", "track", "reason")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
__slots__ = ("player", "track", "reason")
|
||||
|
||||
self.player: Player = player
|
||||
assert self.player._ending_track is not None
|
||||
self.track: Track = self.player._ending_track
|
||||
self.reason: str = data["reason"]
|
||||
|
||||
|
|
@ -81,10 +104,12 @@ class TrackStuckEvent(PomiceEvent):
|
|||
|
||||
name = "track_stuck"
|
||||
|
||||
__slots__ = ("player", "track", "threshold")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
__slots__ = ("player", "track", "threshold")
|
||||
|
||||
self.player: Player = player
|
||||
assert self.player._ending_track is not None
|
||||
self.track: Track = self.player._ending_track
|
||||
self.threshold: float = data["thresholdMs"]
|
||||
|
||||
|
|
@ -105,17 +130,16 @@ class TrackExceptionEvent(PomiceEvent):
|
|||
|
||||
name = "track_exception"
|
||||
|
||||
__slots__ = ("player", "track", "exception")
|
||||
|
||||
def __init__(self, data: dict, player: Player):
|
||||
__slots__ = ("player", "track", "exception")
|
||||
|
||||
self.player: Player = player
|
||||
assert self.player._ending_track is not None
|
||||
self.track: Track = self.player._ending_track
|
||||
if data.get("error"):
|
||||
# User is running Lavalink <= 3.3
|
||||
self.exception: str = data["error"]
|
||||
else:
|
||||
# User is running Lavalink >=3.4
|
||||
self.exception: str = data["exception"]
|
||||
# Error is for Lavalink <= 3.3
|
||||
self.exception: str = data.get(
|
||||
"error", "") or data.get("exception", "")
|
||||
|
||||
# on_pomice_track_exception(player, track, error)
|
||||
self.handler_args = self.player, self.track, self.exception
|
||||
|
|
@ -125,10 +149,12 @@ class TrackExceptionEvent(PomiceEvent):
|
|||
|
||||
|
||||
class WebSocketClosedPayload:
|
||||
def __init__(self, data: dict):
|
||||
__slots__ = ("guild", "code", "reason", "by_remote")
|
||||
__slots__ = ("guild", "code", "reason", "by_remote")
|
||||
|
||||
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
|
||||
def __init__(self, data: dict):
|
||||
|
||||
self.guild: Optional[Guild] = NodePool.get_node(
|
||||
).bot.get_guild(int(data["guildId"]))
|
||||
self.code: int = data["code"]
|
||||
self.reason: str = data["code"]
|
||||
self.by_remote: bool = data["byRemote"]
|
||||
|
|
@ -147,7 +173,9 @@ class WebSocketClosedEvent(PomiceEvent):
|
|||
|
||||
name = "websocket_closed"
|
||||
|
||||
def __init__(self, data: dict, _):
|
||||
__slots__ = ("payload",)
|
||||
|
||||
def __init__(self, data: dict, _: Any) -> None:
|
||||
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
|
||||
|
||||
# on_pomice_websocket_closed(payload)
|
||||
|
|
@ -164,8 +192,9 @@ class WebSocketOpenEvent(PomiceEvent):
|
|||
|
||||
name = "websocket_open"
|
||||
|
||||
def __init__(self, data: dict, _):
|
||||
__slots__ = ("target", "ssrc")
|
||||
__slots__ = ("target", "ssrc")
|
||||
|
||||
def __init__(self, data: dict, _: Any) -> None:
|
||||
|
||||
self.target: str = data["target"]
|
||||
self.ssrc: int = data["ssrc"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,26 @@
|
|||
__all__ = (
|
||||
"PomiceException",
|
||||
"NodeException",
|
||||
"NodeCreationError",
|
||||
"NodeConnectionFailure",
|
||||
"NodeConnectionClosed",
|
||||
"NodeRestException",
|
||||
"NodeNotAvailable",
|
||||
"NoNodesAvailable",
|
||||
"TrackInvalidPosition",
|
||||
"TrackLoadError",
|
||||
"FilterInvalidArgument",
|
||||
"FilterTagInvalid",
|
||||
"FilterTagAlreadyInUse",
|
||||
"InvalidSpotifyClientAuthorization",
|
||||
"AppleMusicNotEnabled",
|
||||
"QueueException",
|
||||
"QueueFull",
|
||||
"QueueEmpty",
|
||||
"LavalinkVersionIncompatible",
|
||||
)
|
||||
|
||||
|
||||
class PomiceException(Exception):
|
||||
"""Base of all Pomice exceptions."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,23 @@
|
|||
from typing import Any, Dict, Tuple
|
||||
import collections
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from .exceptions import FilterInvalidArgument
|
||||
|
||||
__all__ = (
|
||||
"Filter",
|
||||
"Equalizer",
|
||||
"Timescale",
|
||||
"Karaoke",
|
||||
"Tremolo",
|
||||
"Vibrato",
|
||||
"Rotation",
|
||||
"Distortion",
|
||||
"ChannelMix",
|
||||
"LowPass",
|
||||
)
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
|
|
@ -13,10 +30,10 @@ class Filter:
|
|||
This is necessary for the removal of filters.
|
||||
"""
|
||||
|
||||
def __init__(self, *, tag: str):
|
||||
__slots__ = ("payload", "tag", "preload")
|
||||
__slots__ = ("payload", "tag", "preload")
|
||||
|
||||
self.payload: dict = None
|
||||
def __init__(self, *, tag: str):
|
||||
self.payload: Optional[Dict] = None
|
||||
self.tag: str = tag
|
||||
self.preload: bool = False
|
||||
|
||||
|
|
@ -34,32 +51,32 @@ class Equalizer(Filter):
|
|||
The format for the levels is: List[Tuple[int, float]]
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"eq",
|
||||
"raw",
|
||||
)
|
||||
|
||||
def __init__(self, *, tag: str, levels: list):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = (
|
||||
"eq",
|
||||
"raw",
|
||||
)
|
||||
|
||||
self.eq = self._factory(levels)
|
||||
self.raw = levels
|
||||
|
||||
self.payload = {"equalizer": self.eq}
|
||||
|
||||
def _factory(self, levels: list):
|
||||
_dict = collections.defaultdict(int)
|
||||
def _factory(self, levels: List[Tuple[Any, Any]]) -> List[Dict]:
|
||||
_dict: Dict = collections.defaultdict(int)
|
||||
|
||||
_dict.update(levels)
|
||||
_dict = [{"band": i, "gain": _dict[i]} for i in range(15)]
|
||||
data = [{"band": i, "gain": _dict[i]} for i in range(15)]
|
||||
|
||||
return _dict
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
|
||||
|
||||
@classmethod
|
||||
def flat(cls):
|
||||
def flat(cls) -> "Equalizer":
|
||||
"""Equalizer preset which represents a flat EQ board,
|
||||
with all levels set to their default values.
|
||||
"""
|
||||
|
|
@ -84,7 +101,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="flat", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def boost(cls):
|
||||
def boost(cls) -> "Equalizer":
|
||||
"""Equalizer preset which boosts the sound of a track,
|
||||
making it sound fun and energetic by increasing the bass
|
||||
and the highs.
|
||||
|
|
@ -110,7 +127,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="boost", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def metal(cls):
|
||||
def metal(cls) -> "Equalizer":
|
||||
"""Equalizer preset which increases the mids of a track,
|
||||
preferably one of the metal genre, to make it sound
|
||||
more full and concert-like.
|
||||
|
|
@ -137,7 +154,7 @@ class Equalizer(Filter):
|
|||
return cls(tag="metal", levels=levels)
|
||||
|
||||
@classmethod
|
||||
def piano(cls):
|
||||
def piano(cls) -> "Equalizer":
|
||||
"""Equalizer preset which increases the mids and highs
|
||||
of a track, preferably a piano based one, to make it
|
||||
stand out.
|
||||
|
|
@ -169,11 +186,11 @@ class Timescale(Filter):
|
|||
a certain amount to produce said effect.
|
||||
"""
|
||||
|
||||
__slots__ = ("speed", "pitch", "rate")
|
||||
|
||||
def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: float = 1.0):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = ("speed", "pitch", "rate")
|
||||
|
||||
if speed < 0:
|
||||
raise FilterInvalidArgument("Timescale speed must be more than 0.")
|
||||
if pitch < 0:
|
||||
|
|
@ -186,11 +203,11 @@ class Timescale(Filter):
|
|||
self.rate: float = rate
|
||||
|
||||
self.payload: dict = {
|
||||
"timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate}
|
||||
"timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def vaporwave(cls):
|
||||
def vaporwave(cls) -> "Timescale":
|
||||
"""Timescale preset which slows down the currently playing track,
|
||||
giving it the effect of a half-speed record/casette playing.
|
||||
|
||||
|
|
@ -200,7 +217,7 @@ class Timescale(Filter):
|
|||
return cls(tag="vaporwave", speed=0.8, pitch=0.8)
|
||||
|
||||
@classmethod
|
||||
def nightcore(cls):
|
||||
def nightcore(cls) -> "Timescale":
|
||||
"""Timescale preset which speeds up the currently playing track,
|
||||
which matches up to nightcore, a genre of sped-up music
|
||||
|
||||
|
|
@ -209,7 +226,7 @@ class Timescale(Filter):
|
|||
|
||||
return cls(tag="nightcore", speed=1.25, pitch=1.3)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
|
||||
|
||||
|
||||
|
|
@ -217,6 +234,7 @@ class Karaoke(Filter):
|
|||
"""Filter which filters the vocal track from any song and leaves the instrumental.
|
||||
Best for karaoke as the filter implies.
|
||||
"""
|
||||
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -229,8 +247,6 @@ class Karaoke(Filter):
|
|||
):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
|
||||
|
||||
self.level: float = level
|
||||
self.mono_level: float = mono_level
|
||||
self.filter_band: float = filter_band
|
||||
|
|
@ -242,10 +258,10 @@ class Karaoke(Filter):
|
|||
"monoLevel": self.mono_level,
|
||||
"filterBand": self.filter_band,
|
||||
"filterWidth": self.filter_width,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.KaraokeFilter tag={self.tag} level={self.level} mono_level={self.mono_level} "
|
||||
f"filter_band={self.filter_band} filter_width={self.filter_width}>"
|
||||
|
|
@ -256,23 +272,30 @@ class Tremolo(Filter):
|
|||
"""Filter which produces a wavering tone in the music,
|
||||
causing it to sound like the music is changing in volume rapidly.
|
||||
"""
|
||||
__slots__ = ("frequency", "depth")
|
||||
|
||||
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = ("frequency", "depth")
|
||||
|
||||
if frequency < 0:
|
||||
raise FilterInvalidArgument("Tremolo frequency must be more than 0.")
|
||||
raise FilterInvalidArgument(
|
||||
"Tremolo frequency must be more than 0.",
|
||||
)
|
||||
if depth < 0 or depth > 1:
|
||||
raise FilterInvalidArgument("Tremolo depth must be between 0 and 1.")
|
||||
raise FilterInvalidArgument(
|
||||
"Tremolo depth must be between 0 and 1.",
|
||||
)
|
||||
|
||||
self.frequency: float = frequency
|
||||
self.depth: float = depth
|
||||
|
||||
self.payload: dict = {"tremolo": {"frequency": self.frequency, "depth": self.depth}}
|
||||
self.payload: dict = {
|
||||
"tremolo": {
|
||||
"frequency": self.frequency, "depth": self.depth,
|
||||
},
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
||||
)
|
||||
|
|
@ -282,23 +305,30 @@ class Vibrato(Filter):
|
|||
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter,
|
||||
but changes in pitch rather than volume.
|
||||
"""
|
||||
__slots__ = ("frequency", "depth")
|
||||
|
||||
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = ("frequency", "depth")
|
||||
|
||||
if frequency < 0 or frequency > 14:
|
||||
raise FilterInvalidArgument("Vibrato frequency must be between 0 and 14.")
|
||||
raise FilterInvalidArgument(
|
||||
"Vibrato frequency must be between 0 and 14.",
|
||||
)
|
||||
if depth < 0 or depth > 1:
|
||||
raise FilterInvalidArgument("Vibrato depth must be between 0 and 1.")
|
||||
raise FilterInvalidArgument(
|
||||
"Vibrato depth must be between 0 and 1.",
|
||||
)
|
||||
|
||||
self.frequency: float = frequency
|
||||
self.depth: float = depth
|
||||
|
||||
self.payload: dict = {"vibrato": {"frequency": self.frequency, "depth": self.depth}}
|
||||
self.payload: dict = {
|
||||
"vibrato": {
|
||||
"frequency": self.frequency, "depth": self.depth,
|
||||
},
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
|
||||
)
|
||||
|
|
@ -309,11 +339,11 @@ class Rotation(Filter):
|
|||
the audio is being rotated around the listener's head
|
||||
"""
|
||||
|
||||
__slots__ = ("rotation_hertz",)
|
||||
|
||||
def __init__(self, *, tag: str, rotation_hertz: float = 5):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = "rotation_hertz"
|
||||
|
||||
self.rotation_hertz: float = rotation_hertz
|
||||
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
|
||||
|
||||
|
|
@ -326,6 +356,13 @@ class ChannelMix(Filter):
|
|||
for some cool effects when done correctly.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"left_to_left",
|
||||
"right_to_right",
|
||||
"left_to_right",
|
||||
"right_to_left",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -337,23 +374,21 @@ class ChannelMix(Filter):
|
|||
):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = ("left_to_left", "right_to_right", "left_to_right", "right_to_left")
|
||||
|
||||
if 0 > left_to_left > 1:
|
||||
raise ValueError(
|
||||
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1."
|
||||
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1.",
|
||||
)
|
||||
if 0 > right_to_right > 1:
|
||||
raise ValueError(
|
||||
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1."
|
||||
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1.",
|
||||
)
|
||||
if 0 > left_to_right > 1:
|
||||
raise ValueError(
|
||||
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1."
|
||||
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1.",
|
||||
)
|
||||
if 0 > right_to_left > 1:
|
||||
raise ValueError(
|
||||
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1."
|
||||
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1.",
|
||||
)
|
||||
|
||||
self.left_to_left: float = left_to_left
|
||||
|
|
@ -367,7 +402,7 @@ class ChannelMix(Filter):
|
|||
"leftToRight": self.left_to_right,
|
||||
"rightToLeft": self.right_to_left,
|
||||
"rightToRight": self.right_to_right,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
@ -382,6 +417,17 @@ class Distortion(Filter):
|
|||
distortion is needed.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"sin_offset",
|
||||
"sin_scale",
|
||||
"cos_offset",
|
||||
"cos_scale",
|
||||
"tan_offset",
|
||||
"tan_scale",
|
||||
"offset",
|
||||
"scale",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -397,16 +443,6 @@ class Distortion(Filter):
|
|||
):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = (
|
||||
"sin_offset",
|
||||
"sin_scale",
|
||||
"cos_offset",
|
||||
"cos_scale",
|
||||
"tan_offset",
|
||||
"tan_scale" "offset",
|
||||
"scale",
|
||||
)
|
||||
|
||||
self.sin_offset: float = sin_offset
|
||||
self.sin_scale: float = sin_scale
|
||||
self.cos_offset: float = cos_offset
|
||||
|
|
@ -426,7 +462,7 @@ class Distortion(Filter):
|
|||
"tanScale": self.tan_scale,
|
||||
"offset": self.offset,
|
||||
"scale": self.scale,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
@ -441,12 +477,11 @@ class LowPass(Filter):
|
|||
"""Filter which supresses higher frequencies and allows lower frequencies to pass.
|
||||
You can also do this with the Equalizer filter, but this is an easier way to do it.
|
||||
"""
|
||||
__slots__ = ("smoothing", "payload")
|
||||
|
||||
def __init__(self, *, tag: str, smoothing: float = 20):
|
||||
super().__init__(tag=tag)
|
||||
|
||||
__slots__ = "smoothing"
|
||||
|
||||
self.smoothing: float = smoothing
|
||||
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,53 @@
|
|||
from __future__ import annotations
|
||||
from typing import List, Optional, Union
|
||||
from discord import Member, User
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from discord import ClientUser
|
||||
from discord import Member
|
||||
from discord import User
|
||||
from discord.ext import commands
|
||||
|
||||
from .enums import SearchType, TrackType, PlaylistType
|
||||
from .enums import PlaylistType
|
||||
from .enums import SearchType
|
||||
from .enums import TrackType
|
||||
from .filters import Filter
|
||||
|
||||
__all__ = (
|
||||
"Track",
|
||||
"Playlist",
|
||||
)
|
||||
|
||||
|
||||
class Track:
|
||||
"""The base track object. Returns critical track information needed for parsing by Lavalink.
|
||||
You can also pass in commands.Context to get a discord.py Context object in your track.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"track_id",
|
||||
"info",
|
||||
"track_type",
|
||||
"filters",
|
||||
"timestamp",
|
||||
"original",
|
||||
"_search_type",
|
||||
"playlist",
|
||||
"title",
|
||||
"author",
|
||||
"uri",
|
||||
"identifier",
|
||||
"isrc",
|
||||
"thumbnail",
|
||||
"length",
|
||||
"ctx",
|
||||
"requester",
|
||||
"is_stream",
|
||||
"is_seekable",
|
||||
"position",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -23,30 +58,8 @@ class Track:
|
|||
search_type: SearchType = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
requester: Optional[Union[Member, User]] = None,
|
||||
requester: Optional[Union[Member, User, ClientUser]] = None,
|
||||
):
|
||||
__slots__ = (
|
||||
"track_id",
|
||||
"info",
|
||||
"track_type",
|
||||
"filters",
|
||||
"timestamp",
|
||||
"original",
|
||||
"_search_type",
|
||||
"playlist",
|
||||
"title",
|
||||
"author",
|
||||
"uri",
|
||||
"identifier",
|
||||
"isrc",
|
||||
"thumbnail",
|
||||
"length",
|
||||
"ctx",
|
||||
"requester",
|
||||
"is_stream",
|
||||
"is_seekable",
|
||||
"position",
|
||||
)
|
||||
|
||||
self.track_id: str = track_id
|
||||
self.info: dict = info
|
||||
|
|
@ -60,35 +73,29 @@ class Track:
|
|||
self.original = self
|
||||
self._search_type: SearchType = search_type
|
||||
|
||||
self.playlist: Playlist = None
|
||||
self.playlist: Optional[Playlist] = None
|
||||
|
||||
self.title: str = info.get("title")
|
||||
self.author: str = info.get("author")
|
||||
self.uri: str = info.get("uri")
|
||||
self.identifier: str = info.get("identifier")
|
||||
self.isrc: str = info.get("isrc")
|
||||
self.title: str = info.get("title", "Unknown Title")
|
||||
self.author: str = info.get("author", "Unknown Author")
|
||||
self.uri: str = info.get("uri", "")
|
||||
self.identifier: str = info.get("identifier", "")
|
||||
self.isrc: str = info.get("isrc", "")
|
||||
self.thumbnail: Optional[str] = info.get("thumbnail")
|
||||
|
||||
if self.uri:
|
||||
if info.get("thumbnail"):
|
||||
self.thumbnail: str = info.get("thumbnail")
|
||||
elif self.track_type == TrackType.SOUNDCLOUD:
|
||||
# ok so theres no feasible way of getting a Soundcloud image URL
|
||||
# so we're just gonna leave it blank for brevity
|
||||
self.thumbnail = None
|
||||
else:
|
||||
self.thumbnail: str = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
|
||||
if self.uri and self.track_type is TrackType.YOUTUBE:
|
||||
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
|
||||
|
||||
self.length: int = info.get("length")
|
||||
self.ctx: commands.Context = ctx
|
||||
if requester:
|
||||
self.requester: Optional[Union[Member, User]] = requester
|
||||
else:
|
||||
self.requester: Optional[Union[Member, User]] = self.ctx.author if ctx else None
|
||||
self.is_stream: bool = info.get("isStream")
|
||||
self.is_seekable: bool = info.get("isSeekable")
|
||||
self.position: int = info.get("position")
|
||||
self.length: int = info.get("length", 0)
|
||||
self.is_stream: bool = info.get("isStream", False)
|
||||
self.is_seekable: bool = info.get("isSeekable", False)
|
||||
self.position: int = info.get("position", 0)
|
||||
|
||||
def __eq__(self, other):
|
||||
self.ctx: Optional[commands.Context] = ctx
|
||||
self.requester: Optional[Union[Member, User, ClientUser]] = requester
|
||||
if not self.requester and self.ctx:
|
||||
self.requester = self.ctx.author
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Track):
|
||||
return False
|
||||
|
||||
|
|
@ -97,10 +104,10 @@ class Track:
|
|||
|
||||
return other.track_id == self.track_id
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.title
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
|
||||
|
||||
|
||||
|
|
@ -110,6 +117,17 @@ class Playlist:
|
|||
You can also pass in commands.Context to get a discord.py Context object in your tracks.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"playlist_info",
|
||||
"tracks",
|
||||
"name",
|
||||
"playlist_type",
|
||||
"_thumbnail",
|
||||
"_uri",
|
||||
"selected_track",
|
||||
"track_count",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -119,39 +137,28 @@ class Playlist:
|
|||
thumbnail: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
):
|
||||
__slots__ = (
|
||||
"playlist_info",
|
||||
"tracks",
|
||||
"name",
|
||||
"playlist_type",
|
||||
"_thumbnail",
|
||||
"_uri",
|
||||
"selected_track",
|
||||
"track_count",
|
||||
)
|
||||
|
||||
self.playlist_info: dict = playlist_info
|
||||
self.tracks: List[Track] = tracks
|
||||
self.name: str = playlist_info.get("name")
|
||||
self.name: str = playlist_info.get("name", "Unknown Playlist")
|
||||
self.playlist_type: PlaylistType = playlist_type
|
||||
|
||||
self._thumbnail: str = thumbnail
|
||||
self._uri: str = uri
|
||||
self._thumbnail: Optional[str] = thumbnail
|
||||
self._uri: Optional[str] = uri
|
||||
|
||||
for track in self.tracks:
|
||||
track.playlist = self
|
||||
|
||||
if (index := playlist_info.get("selectedTrack")) == -1:
|
||||
self.selected_track = None
|
||||
else:
|
||||
self.selected_track: Track = self.tracks[index]
|
||||
self.selected_track: Optional[Track] = None
|
||||
if (index := playlist_info.get("selectedTrack", -1)) != -1:
|
||||
self.selected_track = self.tracks[index]
|
||||
|
||||
self.track_count: int = len(self.tracks)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
|
||||
|
||||
@property
|
||||
|
|
|
|||
230
pomice/player.py
230
pomice/player.py
|
|
@ -1,54 +1,67 @@
|
|||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from discord import Client, Guild, VoiceChannel, VoiceProtocol
|
||||
from discord import Client
|
||||
from discord import Guild
|
||||
from discord import VoiceChannel
|
||||
from discord import VoiceProtocol
|
||||
from discord.ext import commands
|
||||
from discord.types.voice import VoiceServerUpdate, GuildVoiceState
|
||||
|
||||
from . import events
|
||||
from .enums import SearchType
|
||||
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
|
||||
from .exceptions import (
|
||||
FilterInvalidArgument,
|
||||
FilterTagAlreadyInUse,
|
||||
FilterTagInvalid,
|
||||
TrackInvalidPosition,
|
||||
TrackLoadError,
|
||||
)
|
||||
from .events import PomiceEvent
|
||||
from .events import TrackEndEvent
|
||||
from .events import TrackStartEvent
|
||||
from .exceptions import FilterInvalidArgument
|
||||
from .exceptions import FilterTagAlreadyInUse
|
||||
from .exceptions import FilterTagInvalid
|
||||
from .exceptions import TrackInvalidPosition
|
||||
from .exceptions import TrackLoadError
|
||||
from .filters import Filter
|
||||
from .objects import Track
|
||||
from .pool import Node, NodePool
|
||||
from .objects import Track, Playlist
|
||||
from .pool import Node
|
||||
from .pool import NodePool
|
||||
|
||||
__all__ = ("Filters", "Player")
|
||||
|
||||
|
||||
class Filters:
|
||||
"""Helper class for filters"""
|
||||
|
||||
__slots__ = "_filters"
|
||||
__slots__ = ("_filters",)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._filters: List[Filter] = []
|
||||
|
||||
@property
|
||||
def has_preload(self):
|
||||
def has_preload(self) -> bool:
|
||||
"""Property which checks if any applied filters were preloaded"""
|
||||
return any(f for f in self._filters if f.preload == True)
|
||||
|
||||
@property
|
||||
def has_global(self):
|
||||
def has_global(self) -> bool:
|
||||
"""Property which checks if any applied filters are global"""
|
||||
return any(f for f in self._filters if f.preload == False)
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
def empty(self) -> bool:
|
||||
"""Property which checks if the filter list is empty"""
|
||||
return len(self._filters) == 0
|
||||
|
||||
def add_filter(self, *, filter: Filter):
|
||||
def add_filter(self, *, filter: Filter) -> None:
|
||||
"""Adds a filter to the list of filters applied"""
|
||||
if any(f for f in self._filters if f.tag == filter.tag):
|
||||
raise FilterTagAlreadyInUse("A filter with that tag is already in use.")
|
||||
raise FilterTagAlreadyInUse(
|
||||
"A filter with that tag is already in use.",
|
||||
)
|
||||
self._filters.append(filter)
|
||||
|
||||
def remove_filter(self, *, filter_tag: str):
|
||||
def remove_filter(self, *, filter_tag: str) -> None:
|
||||
"""Removes a filter from the list of filters applied using its filter tag"""
|
||||
if not any(f for f in self._filters if f.tag == filter_tag):
|
||||
raise FilterTagInvalid("A filter with that tag was not found.")
|
||||
|
|
@ -57,26 +70,27 @@ class Filters:
|
|||
if filter.tag == filter_tag:
|
||||
del self._filters[index]
|
||||
|
||||
def has_filter(self, *, filter_tag: str):
|
||||
def has_filter(self, *, filter_tag: str) -> bool:
|
||||
"""Checks if a filter exists in the list of filters using its filter tag"""
|
||||
return any(f for f in self._filters if f.tag == filter_tag)
|
||||
|
||||
def reset_filters(self):
|
||||
def reset_filters(self) -> None:
|
||||
"""Removes all filters from the list"""
|
||||
self._filters = []
|
||||
|
||||
def get_preload_filters(self):
|
||||
def get_preload_filters(self) -> List[Filter]:
|
||||
"""Get all preloaded filters"""
|
||||
return [f for f in self._filters if f.preload == True]
|
||||
|
||||
def get_all_payloads(self):
|
||||
def get_all_payloads(self) -> Dict[str, Any]:
|
||||
"""Returns a formatted dict of all the filter payloads"""
|
||||
payload = {}
|
||||
for filter in self._filters:
|
||||
payload.update(filter.payload)
|
||||
payload: Dict[str, Any] = {}
|
||||
for _filter in self._filters:
|
||||
if _filter.payload:
|
||||
payload.update(_filter.payload)
|
||||
return payload
|
||||
|
||||
def get_filters(self):
|
||||
def get_filters(self) -> List[Filter]:
|
||||
"""Returns the current list of applied filters"""
|
||||
return self._filters
|
||||
|
||||
|
|
@ -89,45 +103,43 @@ class Player(VoiceProtocol):
|
|||
```
|
||||
"""
|
||||
|
||||
def __call__(self, client: Client, channel: VoiceChannel):
|
||||
self.client: Client = client
|
||||
self.channel: VoiceChannel = channel
|
||||
self._guild: Guild = channel.guild
|
||||
__slots__ = (
|
||||
"client",
|
||||
"channel",
|
||||
"_bot",
|
||||
"_guild",
|
||||
"_node",
|
||||
"_current",
|
||||
"_filters",
|
||||
"_volume",
|
||||
"_paused",
|
||||
"_is_connected",
|
||||
"_position",
|
||||
"_last_position",
|
||||
"_last_update",
|
||||
"_ending_track",
|
||||
"_voice_state",
|
||||
"_player_endpoint_uri",
|
||||
"__dict__",
|
||||
)
|
||||
|
||||
def __call__(self, client: Client, channel: VoiceChannel) -> "Player":
|
||||
self.__init__(client, channel) # type: ignore
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[Client] = None,
|
||||
channel: Optional[VoiceChannel] = None,
|
||||
client: Client,
|
||||
channel: VoiceChannel,
|
||||
*,
|
||||
node: Node = None,
|
||||
):
|
||||
__slots__ = (
|
||||
"client",
|
||||
"channel",
|
||||
"_bot",
|
||||
"_guild",
|
||||
"_node",
|
||||
"_current",
|
||||
"_filters",
|
||||
"_volume",
|
||||
"_paused",
|
||||
"_is_connected",
|
||||
"_position",
|
||||
"_last_position",
|
||||
"_last_update",
|
||||
"_ending_track",
|
||||
"_voice_state",
|
||||
"_player_endpoint_uri",
|
||||
"__dict__",
|
||||
)
|
||||
node: Optional[Node] = None,
|
||||
) -> None:
|
||||
|
||||
self.client: Optional[Client] = client
|
||||
self.channel: Optional[VoiceChannel] = channel
|
||||
self.client: Client = client
|
||||
self.channel: VoiceChannel = channel
|
||||
|
||||
self._bot: Union[Client, commands.Bot] = client
|
||||
self._guild: Guild = channel.guild if channel else None
|
||||
self._bot: Client = client
|
||||
self._guild: Guild = channel.guild
|
||||
self._node: Node = node if node else NodePool.get_node()
|
||||
self._current: Optional[Track] = None
|
||||
self._filters: Filters = Filters()
|
||||
|
|
@ -137,7 +149,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
self._position: int = 0
|
||||
self._last_position: int = 0
|
||||
self._last_update: int = 0
|
||||
self._last_update: float = 0
|
||||
self._ending_track: Optional[Track] = None
|
||||
|
||||
self._voice_state: dict = {}
|
||||
|
|
@ -153,11 +165,13 @@ class Player(VoiceProtocol):
|
|||
@property
|
||||
def position(self) -> float:
|
||||
"""Property which returns the player's position in a track in milliseconds"""
|
||||
current = self._current.original
|
||||
|
||||
if not self.is_playing or not self._current:
|
||||
return 0
|
||||
|
||||
current = self._current.original
|
||||
if not current:
|
||||
return 0
|
||||
|
||||
if self.is_paused:
|
||||
return min(self._last_position, current.length)
|
||||
|
||||
|
|
@ -185,7 +199,7 @@ class Player(VoiceProtocol):
|
|||
return self._is_connected and self._paused
|
||||
|
||||
@property
|
||||
def current(self) -> Track:
|
||||
def current(self) -> Optional[Track]:
|
||||
"""Property which returns the currently playing track"""
|
||||
return self._current
|
||||
|
||||
|
|
@ -210,7 +224,7 @@ class Player(VoiceProtocol):
|
|||
return self._filters
|
||||
|
||||
@property
|
||||
def bot(self) -> Union[Client, commands.Bot]:
|
||||
def bot(self) -> Client:
|
||||
"""Property which returns the bot associated with this player instance"""
|
||||
return self._bot
|
||||
|
||||
|
|
@ -221,13 +235,14 @@ class Player(VoiceProtocol):
|
|||
"""
|
||||
return self.guild.id not in self._node._players
|
||||
|
||||
async def _update_state(self, data: dict):
|
||||
state: dict = data.get("state")
|
||||
self._last_update = time.time() * 1000
|
||||
self._is_connected = state.get("connected")
|
||||
self._last_position = state.get("position")
|
||||
async def _update_state(self, data: dict) -> None:
|
||||
state: dict = data.get("state", {})
|
||||
self._last_update = time.time() * 1000.0
|
||||
self._is_connected = bool(state.get("connected"))
|
||||
position = state.get("position")
|
||||
self._position = int(position) if position else 0
|
||||
|
||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None):
|
||||
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
if {"sessionId", "event"} != self._voice_state.keys():
|
||||
return
|
||||
|
||||
|
|
@ -246,27 +261,32 @@ class Player(VoiceProtocol):
|
|||
data={"voice": data},
|
||||
)
|
||||
|
||||
async def on_voice_server_update(self, data: dict):
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
|
||||
self._voice_state.update({"event": data})
|
||||
await self._dispatch_voice_update(self._voice_state)
|
||||
|
||||
async def on_voice_state_update(self, data: dict):
|
||||
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
|
||||
self._voice_state.update({"sessionId": data.get("session_id")})
|
||||
|
||||
if not (channel_id := data.get("channel_id")):
|
||||
channel_id = data.get("channel_id")
|
||||
if not channel_id:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
|
||||
self.channel = self.guild.get_channel(int(channel_id))
|
||||
channel = self.guild.get_channel(int(channel_id))
|
||||
if not channel:
|
||||
await self.disconnect()
|
||||
self._voice_state.clear()
|
||||
return
|
||||
|
||||
if not data.get("token"):
|
||||
return
|
||||
|
||||
await self._dispatch_voice_update({**self._voice_state, "event": data})
|
||||
|
||||
async def _dispatch_event(self, data: dict):
|
||||
event_type = data.get("type")
|
||||
async def _dispatch_event(self, data: dict) -> None:
|
||||
event_type: str = data["type"]
|
||||
event: PomiceEvent = getattr(events, event_type)(data, self)
|
||||
|
||||
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
|
||||
|
|
@ -277,11 +297,12 @@ class Player(VoiceProtocol):
|
|||
if isinstance(event, TrackStartEvent):
|
||||
self._ending_track = self._current
|
||||
|
||||
async def _swap_node(self, *, new_node: Node):
|
||||
async def _swap_node(self, *, new_node: Node) -> None:
|
||||
data: dict = {
|
||||
"encodedTrack": self.current.track_id,
|
||||
"position": self.position,
|
||||
}
|
||||
if self.current:
|
||||
data["encodedTrack"] = self.current.track_id
|
||||
|
||||
del self._node._players[self._guild.id]
|
||||
self._node = new_node
|
||||
|
|
@ -304,7 +325,7 @@ class Player(VoiceProtocol):
|
|||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
):
|
||||
) -> Optional[Union[List[Track], Playlist]]:
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
||||
If you passed in Spotify API credentials when you created the node,
|
||||
|
|
@ -321,7 +342,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def get_recommendations(
|
||||
self, *, track: Track, ctx: Optional[commands.Context] = None
|
||||
) -> Union[List[Track], None]:
|
||||
) -> Optional[Union[List[Track], Playlist]]:
|
||||
"""
|
||||
Gets recommendations from either YouTube or Spotify.
|
||||
You can pass in a discord.py Context object to get a
|
||||
|
|
@ -331,14 +352,14 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def connect(
|
||||
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
|
||||
):
|
||||
) -> None:
|
||||
await self.guild.change_voice_state(
|
||||
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute
|
||||
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute,
|
||||
)
|
||||
self._node._players[self.guild.id] = self
|
||||
self._is_connected = True
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stops the currently playing track."""
|
||||
self._current = None
|
||||
await self._node.send(
|
||||
|
|
@ -348,27 +369,27 @@ class Player(VoiceProtocol):
|
|||
data={"encodedTrack": None},
|
||||
)
|
||||
|
||||
async def disconnect(self, *, force: bool = False):
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
"""Disconnects the player from voice."""
|
||||
try:
|
||||
await self.guild.change_voice_state(channel=None)
|
||||
finally:
|
||||
self.cleanup()
|
||||
self._is_connected = False
|
||||
self.channel = None
|
||||
del self.channel
|
||||
|
||||
async def destroy(self):
|
||||
async def destroy(self) -> None:
|
||||
"""Disconnects and destroys the player, and runs internal cleanup."""
|
||||
try:
|
||||
await self.disconnect()
|
||||
except AttributeError:
|
||||
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
|
||||
# assume we're already disconnected and cleaned up
|
||||
assert self.channel is None and not self.is_connected
|
||||
assert not self.is_connected and not self.channel
|
||||
|
||||
self._node._players.pop(self.guild.id)
|
||||
await self._node.send(
|
||||
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id
|
||||
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id,
|
||||
)
|
||||
|
||||
async def play(
|
||||
|
|
@ -383,20 +404,22 @@ class Player(VoiceProtocol):
|
|||
if not track.isrc:
|
||||
# We have to bare raise here because theres no other way to skip this block feasibly
|
||||
raise
|
||||
search: Track = (
|
||||
search = (
|
||||
await self._node.get_tracks(f"{track._search_type}:{track.isrc}", ctx=track.ctx)
|
||||
)[0]
|
||||
)[0] # type: ignore
|
||||
except Exception:
|
||||
# First method didn't work, lets try just searching it up
|
||||
try:
|
||||
search: Track = (
|
||||
search = (
|
||||
await self._node.get_tracks(
|
||||
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx
|
||||
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx,
|
||||
)
|
||||
)[0]
|
||||
)[0] # type: ignore
|
||||
except:
|
||||
# The song wasn't able to be found, raise error
|
||||
raise TrackLoadError("No equivalent track was able to be found.")
|
||||
raise TrackLoadError(
|
||||
"No equivalent track was able to be found.",
|
||||
)
|
||||
data = {
|
||||
"encodedTrack": search.track_id,
|
||||
"position": str(start),
|
||||
|
|
@ -432,7 +455,7 @@ class Player(VoiceProtocol):
|
|||
if track.filters and not self.filters.has_global:
|
||||
# Now apply all filters
|
||||
for filter in track.filters:
|
||||
await self.add_filter(filter=filter)
|
||||
await self.add_filter(_filter=filter)
|
||||
|
||||
# Lavalink v4 changed the way the end time parameter works
|
||||
# so now the end time cannot be zero.
|
||||
|
|
@ -454,8 +477,13 @@ class Player(VoiceProtocol):
|
|||
|
||||
async def seek(self, position: float) -> float:
|
||||
"""Seeks to a position in the currently playing track milliseconds"""
|
||||
if not self._current or not self._current.original:
|
||||
return 0.0
|
||||
|
||||
if position < 0 or position > self._current.original.length:
|
||||
raise TrackInvalidPosition("Seek position must be between 0 and the track length")
|
||||
raise TrackInvalidPosition(
|
||||
"Seek position must be between 0 and the track length",
|
||||
)
|
||||
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
|
|
@ -487,7 +515,7 @@ class Player(VoiceProtocol):
|
|||
self._volume = volume
|
||||
return self._volume
|
||||
|
||||
async def add_filter(self, filter: Filter, fast_apply: bool = False) -> Filter:
|
||||
async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
|
||||
"""Adds a filter to the player. Takes a pomice.Filter object.
|
||||
This will only work if you are using a version of Lavalink that supports filters.
|
||||
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
|
||||
|
|
@ -495,7 +523,7 @@ class Player(VoiceProtocol):
|
|||
(You must have a song playing in order for `fast_apply` to work.)
|
||||
"""
|
||||
|
||||
self._filters.add_filter(filter=filter)
|
||||
self._filters.add_filter(filter=_filter)
|
||||
payload = self._filters.get_all_payloads()
|
||||
await self._node.send(
|
||||
method="PATCH",
|
||||
|
|
@ -508,7 +536,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
return self._filters
|
||||
|
||||
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter:
|
||||
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filters:
|
||||
"""Removes a filter from the player. Takes a filter tag.
|
||||
This will only work if you are using a version of Lavalink that supports filters.
|
||||
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
|
||||
|
|
@ -529,7 +557,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
return self._filters
|
||||
|
||||
async def reset_filters(self, *, fast_apply: bool = False):
|
||||
async def reset_filters(self, *, fast_apply: bool = False) -> None:
|
||||
"""Resets all currently applied filters to their default parameters.
|
||||
You must have filters applied in order for this to work.
|
||||
If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`.
|
||||
|
|
@ -539,7 +567,7 @@ class Player(VoiceProtocol):
|
|||
|
||||
if not self._filters:
|
||||
raise FilterInvalidArgument(
|
||||
"You must have filters applied first in order to use this method."
|
||||
"You must have filters applied first in order to use this method.",
|
||||
)
|
||||
self._filters.reset_filters()
|
||||
await self._node.send(
|
||||
|
|
|
|||
316
pomice/pool.py
316
pomice/pool.py
|
|
@ -3,35 +3,48 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import aiohttp
|
||||
|
||||
from discord import Client
|
||||
from discord.ext import commands
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from . import __version__, spotify, applemusic
|
||||
import aiohttp
|
||||
from discord import Client
|
||||
from discord.ext import commands
|
||||
|
||||
from . import __version__
|
||||
from . import applemusic
|
||||
from . import spotify
|
||||
from .enums import *
|
||||
from .exceptions import (
|
||||
AppleMusicNotEnabled,
|
||||
InvalidSpotifyClientAuthorization,
|
||||
LavalinkVersionIncompatible,
|
||||
NodeConnectionFailure,
|
||||
NodeCreationError,
|
||||
NodeNotAvailable,
|
||||
NoNodesAvailable,
|
||||
NodeRestException,
|
||||
TrackLoadError,
|
||||
)
|
||||
from .exceptions import AppleMusicNotEnabled
|
||||
from .exceptions import InvalidSpotifyClientAuthorization
|
||||
from .exceptions import LavalinkVersionIncompatible
|
||||
from .exceptions import NodeConnectionFailure
|
||||
from .exceptions import NodeCreationError
|
||||
from .exceptions import NodeNotAvailable
|
||||
from .exceptions import NodeRestException
|
||||
from .exceptions import NoNodesAvailable
|
||||
from .exceptions import TrackLoadError
|
||||
from .filters import Filter
|
||||
from .objects import Playlist, Track
|
||||
from .utils import ExponentialBackoff, NodeStats, Ping
|
||||
from .objects import Playlist
|
||||
from .objects import Track
|
||||
from .routeplanner import RoutePlanner
|
||||
from .utils import ExponentialBackoff
|
||||
from .utils import NodeStats
|
||||
from .utils import Ping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .player import Player
|
||||
|
||||
__all__ = (
|
||||
"Node",
|
||||
"NodePool",
|
||||
)
|
||||
|
||||
|
||||
class Node:
|
||||
"""The base class for a node.
|
||||
|
|
@ -40,11 +53,42 @@ class Node:
|
|||
To enable Apple music, set the "apple_music" parameter to "True"
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_bot",
|
||||
"_bot_user",
|
||||
"_host",
|
||||
"_port",
|
||||
"_pool",
|
||||
"_password",
|
||||
"_identifier",
|
||||
"_heartbeat",
|
||||
"_secure",
|
||||
"_fallback",
|
||||
"_websocket_uri",
|
||||
"_rest_uri",
|
||||
"_session",
|
||||
"_websocket",
|
||||
"_task",
|
||||
"_loop",
|
||||
"_session_id",
|
||||
"_available",
|
||||
"_version",
|
||||
"_headers",
|
||||
"_players",
|
||||
"_spotify_client_id",
|
||||
"_spotify_client_secret",
|
||||
"_spotify_client",
|
||||
"_apple_music_client",
|
||||
"_route_planner",
|
||||
"_stats",
|
||||
"available",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pool: NodePool,
|
||||
bot: Union[Client, commands.Bot],
|
||||
pool: Type[NodePool],
|
||||
bot: commands.Bot,
|
||||
host: str,
|
||||
port: int,
|
||||
password: str,
|
||||
|
|
@ -53,42 +97,16 @@ class Node:
|
|||
heartbeat: int = 30,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
spotify_client_id: Optional[str] = None,
|
||||
spotify_client_id: Optional[int] = None,
|
||||
spotify_client_secret: Optional[str] = None,
|
||||
apple_music: bool = False,
|
||||
fallback: bool = False,
|
||||
):
|
||||
__slots__ = (
|
||||
"_bot",
|
||||
"_host",
|
||||
"_port",
|
||||
"_pool",
|
||||
"_password",
|
||||
"_identifier",
|
||||
"_heartbeat",
|
||||
"_secure",
|
||||
"_fallback",
|
||||
"_websocket_uri",
|
||||
"_rest_uri",
|
||||
"_session",
|
||||
"_websocket",
|
||||
"_task",
|
||||
"_loop",
|
||||
"_session_id",
|
||||
"_available",
|
||||
"_version",
|
||||
"_headers",
|
||||
"_players",
|
||||
"_spotify_client_id",
|
||||
"_spotify_client_secret",
|
||||
"_spotify_client",
|
||||
"_apple_music_client",
|
||||
)
|
||||
|
||||
self._bot: Union[Client, commands.Bot] = bot
|
||||
self._bot: commands.Bot = bot
|
||||
self._host: str = host
|
||||
self._port: int = port
|
||||
self._pool: NodePool = pool
|
||||
self._pool: Type[NodePool] = pool
|
||||
self._password: str = password
|
||||
self._identifier: str = identifier
|
||||
self._heartbeat: int = heartbeat
|
||||
|
|
@ -98,33 +116,38 @@ class Node:
|
|||
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = session
|
||||
self._websocket = None
|
||||
self._task: asyncio.Task = None
|
||||
self._session: aiohttp.ClientSession = session # type: ignore
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
|
||||
self._websocket: aiohttp.ClientWebSocketResponse
|
||||
self._task: asyncio.Task
|
||||
|
||||
self._session_id: str = None
|
||||
self._session_id: Optional[str] = None
|
||||
self._available: bool = False
|
||||
self._version: str = None
|
||||
self._version: int
|
||||
|
||||
self._route_planner = RoutePlanner(self)
|
||||
|
||||
if not self._bot.user:
|
||||
raise NodeCreationError("Bot user is not ready yet.")
|
||||
|
||||
self._bot_user = self._bot.user
|
||||
|
||||
self._headers = {
|
||||
"Authorization": self._password,
|
||||
"User-Id": str(self._bot.user.id),
|
||||
"User-Id": str(self._bot_user.id),
|
||||
"Client-Name": f"Pomice/{__version__}",
|
||||
}
|
||||
|
||||
self._players: Dict[int, Player] = {}
|
||||
|
||||
self._spotify_client_id: str = spotify_client_id
|
||||
self._spotify_client_secret: str = spotify_client_secret
|
||||
self._spotify_client_id: Optional[int] = spotify_client_id
|
||||
self._spotify_client_secret: Optional[str] = spotify_client_secret
|
||||
|
||||
self._apple_music_client: Optional[applemusic.Client] = None
|
||||
|
||||
if self._spotify_client_id and self._spotify_client_secret:
|
||||
self._spotify_client: spotify.Client = spotify.Client(
|
||||
self._spotify_client_id, self._spotify_client_secret
|
||||
self._spotify_client_id, self._spotify_client_secret,
|
||||
)
|
||||
|
||||
if apple_music:
|
||||
|
|
@ -132,7 +155,7 @@ class Node:
|
|||
|
||||
self._bot.add_listener(self._update_handler, "on_socket_response")
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
|
||||
f"player_count={len(self._players)}>"
|
||||
|
|
@ -154,7 +177,7 @@ class Node:
|
|||
return self._players
|
||||
|
||||
@property
|
||||
def bot(self) -> Union[Client, commands.Bot]:
|
||||
def bot(self) -> Client:
|
||||
"""Property which returns the discord.py client linked to this node"""
|
||||
return self._bot
|
||||
|
||||
|
|
@ -164,21 +187,21 @@ class Node:
|
|||
return len(self.players)
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
def pool(self) -> Type[NodePool]:
|
||||
"""Property which returns the pool this node is apart of"""
|
||||
return self._pool
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
"""Property which returns the latency of the node"""
|
||||
return Ping(self._host, port=self._port).get_ping()
|
||||
|
||||
@property
|
||||
def ping(self):
|
||||
def ping(self) -> float:
|
||||
"""Alias for `Node.latency`, returns the latency of the node"""
|
||||
return self.latency
|
||||
|
||||
async def _update_handler(self, data: dict):
|
||||
async def _update_handler(self, data: dict) -> None:
|
||||
await self._bot.wait_until_ready()
|
||||
|
||||
if not data:
|
||||
|
|
@ -193,7 +216,7 @@ class Node:
|
|||
return
|
||||
|
||||
elif data["t"] == "VOICE_STATE_UPDATE":
|
||||
if int(data["d"]["user_id"]) != self._bot.user.id:
|
||||
if int(data["d"]["user_id"]) != self._bot_user.id:
|
||||
return
|
||||
|
||||
guild_id = int(data["d"]["guild_id"])
|
||||
|
|
@ -203,8 +226,11 @@ class Node:
|
|||
except KeyError:
|
||||
return
|
||||
|
||||
async def _handle_node_switch(self):
|
||||
nodes = [node for node in self.pool.nodes.copy().values() if node.is_connected]
|
||||
async def _handle_node_switch(self) -> None:
|
||||
nodes = [
|
||||
node for node in self.pool._nodes.copy().values()
|
||||
if node.is_connected
|
||||
]
|
||||
new_node = random.choice(nodes)
|
||||
|
||||
for player in self.players.copy().values():
|
||||
|
|
@ -212,7 +238,7 @@ class Node:
|
|||
|
||||
await self.disconnect()
|
||||
|
||||
async def _listen(self):
|
||||
async def _listen(self) -> None:
|
||||
backoff = ExponentialBackoff(base=7)
|
||||
|
||||
while True:
|
||||
|
|
@ -227,7 +253,7 @@ class Node:
|
|||
else:
|
||||
self._loop.create_task(self._handle_payload(msg.json()))
|
||||
|
||||
async def _handle_payload(self, data: dict):
|
||||
async def _handle_payload(self, data: dict) -> None:
|
||||
op = data.get("op", None)
|
||||
if not op:
|
||||
return
|
||||
|
|
@ -239,14 +265,20 @@ class Node:
|
|||
if op == "ready":
|
||||
self._session_id = data["sessionId"]
|
||||
|
||||
if "guildId" in data:
|
||||
if not (player := self._players.get(int(data["guildId"]))):
|
||||
return
|
||||
if not "guildId" in data:
|
||||
return
|
||||
|
||||
player = self._players.get(int(data["guildId"]))
|
||||
if not player:
|
||||
return
|
||||
|
||||
if op == "event":
|
||||
await player._dispatch_event(data)
|
||||
elif op == "playerUpdate":
|
||||
return
|
||||
|
||||
if op == "playerUpdate":
|
||||
await player._update_state(data)
|
||||
return
|
||||
|
||||
async def send(
|
||||
self,
|
||||
|
|
@ -255,11 +287,13 @@ class Node:
|
|||
include_version: bool = True,
|
||||
guild_id: Optional[Union[int, str]] = None,
|
||||
query: Optional[str] = None,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
data: Optional[Union[Dict, str]] = None,
|
||||
ignore_if_available: bool = False,
|
||||
):
|
||||
) -> Any:
|
||||
if not ignore_if_available and not self._available:
|
||||
raise NodeNotAvailable(f"The node '{self._identifier}' is unavailable.")
|
||||
raise NodeNotAvailable(
|
||||
f"The node '{self._identifier}' is unavailable.",
|
||||
)
|
||||
|
||||
uri: str = (
|
||||
f"{self._rest_uri}/"
|
||||
|
|
@ -270,12 +304,12 @@ class Node:
|
|||
)
|
||||
|
||||
async with self._session.request(
|
||||
method=method, url=uri, headers=self._headers, json=data or {}
|
||||
method=method, url=uri, headers=self._headers, json=data or {},
|
||||
) as resp:
|
||||
if resp.status >= 300:
|
||||
data: dict = await resp.json()
|
||||
resp_data: dict = await resp.json()
|
||||
raise NodeRestException(
|
||||
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}'
|
||||
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
|
||||
)
|
||||
|
||||
if method == "DELETE" or resp.status == 204:
|
||||
|
|
@ -286,11 +320,11 @@ class Node:
|
|||
|
||||
return await resp.json()
|
||||
|
||||
def get_player(self, guild_id: int):
|
||||
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
|
||||
def get_player(self, guild_id: int) -> Optional[Player]:
|
||||
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
|
||||
return self._players.get(guild_id, None)
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> "Node":
|
||||
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
|
||||
await self._bot.wait_until_ready()
|
||||
|
||||
|
|
@ -298,7 +332,7 @@ class Node:
|
|||
self._session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
version = await self.send(
|
||||
version: str = await self.send(
|
||||
method="GET",
|
||||
path="version",
|
||||
ignore_if_available=True,
|
||||
|
|
@ -309,14 +343,14 @@ class Node:
|
|||
self._available = False
|
||||
raise LavalinkVersionIncompatible(
|
||||
"The Lavalink version you're using is incompatible. "
|
||||
"Lavalink version 3.7.0 or above is required to use this library."
|
||||
"Lavalink version 3.7.0 or above is required to use this library.",
|
||||
)
|
||||
|
||||
if version.endswith("-SNAPSHOT"):
|
||||
# we're just gonna assume all snapshot versions correlate with v4
|
||||
self._version = 4
|
||||
else:
|
||||
self._version = version[:1]
|
||||
self._version = int(version[:1])
|
||||
|
||||
self._websocket = await self._session.ws_connect(
|
||||
f"{self._websocket_uri}/v{self._version}/websocket",
|
||||
|
|
@ -332,18 +366,18 @@ class Node:
|
|||
|
||||
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
|
||||
raise NodeConnectionFailure(
|
||||
f"The connection to node '{self._identifier}' failed."
|
||||
f"The connection to node '{self._identifier}' failed.",
|
||||
) from None
|
||||
except aiohttp.WSServerHandshakeError:
|
||||
raise NodeConnectionFailure(
|
||||
f"The password for node '{self._identifier}' is invalid."
|
||||
f"The password for node '{self._identifier}' is invalid.",
|
||||
) from None
|
||||
except aiohttp.InvalidURL:
|
||||
raise NodeConnectionFailure(
|
||||
f"The URI for node '{self._identifier}' is invalid."
|
||||
f"The URI for node '{self._identifier}' is invalid.",
|
||||
) from None
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnects a connected Lavalink node and removes it from the node pool.
|
||||
This also destroys any players connected to the node.
|
||||
"""
|
||||
|
|
@ -371,7 +405,7 @@ class Node:
|
|||
"""
|
||||
|
||||
data: dict = await self.send(
|
||||
method="GET", path="decodetrack", query=f"encodedTrack={identifier}"
|
||||
method="GET", path="decodetrack", query=f"encodedTrack={identifier}",
|
||||
)
|
||||
return Track(
|
||||
track_id=identifier,
|
||||
|
|
@ -387,7 +421,7 @@ class Node:
|
|||
ctx: Optional[commands.Context] = None,
|
||||
search_type: SearchType = SearchType.ytsearch,
|
||||
filters: Optional[List[Filter]] = None,
|
||||
):
|
||||
) -> Optional[Union[Playlist, List[Track]]]:
|
||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||
|
||||
If you passed in Spotify API credentials, you can also pass in a
|
||||
|
|
@ -413,7 +447,7 @@ class Node:
|
|||
if not self._apple_music_client:
|
||||
raise AppleMusicNotEnabled(
|
||||
"You must have Apple Music functionality enabled in order to play Apple Music tracks."
|
||||
"Please set apple_music to True in your Node class."
|
||||
"Please set apple_music to True in your Node class.",
|
||||
)
|
||||
|
||||
apple_music_results = await self._apple_music_client.search(query=query)
|
||||
|
|
@ -437,7 +471,7 @@ class Node:
|
|||
"thumbnail": apple_music_results.image,
|
||||
"isrc": apple_music_results.isrc,
|
||||
},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tracks = [
|
||||
|
|
@ -464,7 +498,9 @@ class Node:
|
|||
]
|
||||
|
||||
return Playlist(
|
||||
playlist_info={"name": apple_music_results.name, "selectedTrack": 0},
|
||||
playlist_info={
|
||||
"name": apple_music_results.name, "selectedTrack": 0,
|
||||
},
|
||||
tracks=tracks,
|
||||
playlist_type=PlaylistType.APPLE_MUSIC,
|
||||
thumbnail=apple_music_results.image,
|
||||
|
|
@ -476,7 +512,7 @@ class Node:
|
|||
raise InvalidSpotifyClientAuthorization(
|
||||
"You did not provide proper Spotify client authorization credentials. "
|
||||
"If you would like to use the Spotify searching feature, "
|
||||
"please obtain Spotify API credentials here: https://developer.spotify.com/"
|
||||
"please obtain Spotify API credentials here: https://developer.spotify.com/",
|
||||
)
|
||||
|
||||
spotify_results = await self._spotify_client.search(query=query)
|
||||
|
|
@ -501,7 +537,7 @@ class Node:
|
|||
"thumbnail": spotify_results.image,
|
||||
"isrc": spotify_results.isrc,
|
||||
},
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tracks = [
|
||||
|
|
@ -528,7 +564,9 @@ class Node:
|
|||
]
|
||||
|
||||
return Playlist(
|
||||
playlist_info={"name": spotify_results.name, "selectedTrack": 0},
|
||||
playlist_info={
|
||||
"name": spotify_results.name, "selectedTrack": 0,
|
||||
},
|
||||
tracks=tracks,
|
||||
playlist_type=PlaylistType.SPOTIFY,
|
||||
thumbnail=spotify_results.image,
|
||||
|
|
@ -537,11 +575,11 @@ class Node:
|
|||
|
||||
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
|
||||
data: dict = await self.send(
|
||||
method="GET", path="loadtracks", query=f"identifier={quote(query)}"
|
||||
method="GET", path="loadtracks", query=f"identifier={quote(query)}",
|
||||
)
|
||||
|
||||
track: dict = data["tracks"][0]
|
||||
info: dict = track.get("info")
|
||||
info: dict = track["info"]
|
||||
|
||||
return [
|
||||
Track(
|
||||
|
|
@ -549,15 +587,15 @@ class Node:
|
|||
info={
|
||||
"title": discord_url.group("file"),
|
||||
"author": "Unknown",
|
||||
"length": info.get("length"),
|
||||
"uri": info.get("uri"),
|
||||
"position": info.get("position"),
|
||||
"identifier": info.get("identifier"),
|
||||
"length": info["length"],
|
||||
"uri": info["uri"],
|
||||
"position": info["position"],
|
||||
"identifier": info["identifier"],
|
||||
},
|
||||
ctx=ctx,
|
||||
track_type=TrackType.HTTP,
|
||||
filters=filters,
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
else:
|
||||
|
|
@ -572,18 +610,22 @@ class Node:
|
|||
if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
|
||||
query = match.group("video")
|
||||
|
||||
data: dict = await self.send(
|
||||
method="GET", path="loadtracks", query=f"identifier={quote(query)}"
|
||||
data = await self.send(
|
||||
method="GET", path="loadtracks", query=f"identifier={quote(query)}",
|
||||
)
|
||||
|
||||
load_type = data.get("loadType")
|
||||
|
||||
if not load_type:
|
||||
raise TrackLoadError("There was an error while trying to load this track.")
|
||||
raise TrackLoadError(
|
||||
"There was an error while trying to load this track.",
|
||||
)
|
||||
|
||||
elif load_type == "LOAD_FAILED":
|
||||
exception = data["exception"]
|
||||
raise TrackLoadError(f"{exception['message']} [{exception['severity']}]")
|
||||
raise TrackLoadError(
|
||||
f"{exception['message']} [{exception['severity']}]",
|
||||
)
|
||||
|
||||
elif load_type == "NO_MATCHES":
|
||||
return None
|
||||
|
|
@ -619,9 +661,14 @@ class Node:
|
|||
for track in data["tracks"]
|
||||
]
|
||||
|
||||
else:
|
||||
raise TrackLoadError(
|
||||
"There was an error while trying to load this track.",
|
||||
)
|
||||
|
||||
async def get_recommendations(
|
||||
self, *, track: Track, ctx: Optional[commands.Context] = None
|
||||
) -> Union[List[Track], None]:
|
||||
) -> Optional[Union[List[Track], Playlist]]:
|
||||
"""
|
||||
Gets recommendations from either YouTube or Spotify.
|
||||
The track that is passed in must be either from
|
||||
|
|
@ -652,17 +699,17 @@ class Node:
|
|||
)
|
||||
for track in results
|
||||
]
|
||||
|
||||
return tracks
|
||||
|
||||
elif track.track_type == TrackType.YOUTUBE:
|
||||
tracks = await self.get_tracks(
|
||||
return await self.get_tracks(
|
||||
query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}",
|
||||
ctx=ctx,
|
||||
)
|
||||
return tracks
|
||||
|
||||
else:
|
||||
raise TrackLoadError(
|
||||
"The specfied track must be either a YouTube or Spotify track to recieve recommendations."
|
||||
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -671,9 +718,10 @@ class NodePool:
|
|||
This holds all the nodes that are to be used by the bot.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
_nodes: Dict[str, Node] = {}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.NodePool node_count={self.node_count}>"
|
||||
|
||||
@property
|
||||
|
|
@ -682,7 +730,7 @@ class NodePool:
|
|||
return self._nodes
|
||||
|
||||
@property
|
||||
def node_count(self):
|
||||
def node_count(self) -> int:
|
||||
return len(self._nodes.values())
|
||||
|
||||
@classmethod
|
||||
|
|
@ -700,21 +748,31 @@ class NodePool:
|
|||
based on how players it has. This method will return a node with
|
||||
the least amount of players
|
||||
"""
|
||||
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
||||
available_nodes: List[Node] = [
|
||||
node for node in cls._nodes.values() if node._available
|
||||
]
|
||||
|
||||
if not available_nodes:
|
||||
raise NoNodesAvailable("There are no nodes available.")
|
||||
|
||||
if algorithm == NodeAlgorithm.by_ping:
|
||||
tested_nodes = {node: node.latency for node in available_nodes}
|
||||
return min(tested_nodes, key=tested_nodes.get)
|
||||
return min(tested_nodes, key=tested_nodes.get) # type: ignore
|
||||
|
||||
elif algorithm == NodeAlgorithm.by_players:
|
||||
tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
|
||||
return min(tested_nodes, key=tested_nodes.get)
|
||||
tested_nodes = {
|
||||
node: len(node.players.keys())
|
||||
for node in available_nodes
|
||||
}
|
||||
return min(tested_nodes, key=tested_nodes.get) # type: ignore
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"The algorithm provided is not a valid NodeAlgorithm.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, *, identifier: str = None) -> Node:
|
||||
def get_node(cls, *, identifier: Optional[str] = None) -> Node:
|
||||
"""Fetches a node from the node pool using it's identifier.
|
||||
If no identifier is provided, it will choose a node at random.
|
||||
"""
|
||||
|
|
@ -728,21 +786,21 @@ class NodePool:
|
|||
if identifier is None:
|
||||
return random.choice(list(available_nodes.values()))
|
||||
|
||||
return available_nodes.get(identifier, None)
|
||||
return available_nodes[identifier]
|
||||
|
||||
@classmethod
|
||||
async def create_node(
|
||||
cls,
|
||||
*,
|
||||
bot: Client,
|
||||
bot: commands.Bot,
|
||||
host: str,
|
||||
port: str,
|
||||
port: int,
|
||||
password: str,
|
||||
identifier: str,
|
||||
secure: bool = False,
|
||||
heartbeat: int = 30,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
spotify_client_id: Optional[str] = None,
|
||||
spotify_client_id: Optional[int] = None,
|
||||
spotify_client_secret: Optional[str] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
apple_music: bool = False,
|
||||
|
|
@ -752,7 +810,9 @@ class NodePool:
|
|||
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
||||
"""
|
||||
if identifier in cls._nodes.keys():
|
||||
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
|
||||
raise NodeCreationError(
|
||||
f"A node with identifier '{identifier}' already exists.",
|
||||
)
|
||||
|
||||
node = Node(
|
||||
pool=cls,
|
||||
|
|
@ -779,7 +839,9 @@ class NodePool:
|
|||
async def disconnect(cls) -> None:
|
||||
"""Disconnects all available nodes from the node pool."""
|
||||
|
||||
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
|
||||
available_nodes: List[Node] = [
|
||||
node for node in cls._nodes.values() if node._available
|
||||
]
|
||||
|
||||
for node in available_nodes:
|
||||
await node.disconnect()
|
||||
|
|
|
|||
|
|
@ -1,35 +1,47 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from .objects import Track
|
||||
from .enums import LoopMode
|
||||
from .exceptions import QueueEmpty, QueueException, QueueFull
|
||||
from .exceptions import QueueEmpty
|
||||
from .exceptions import QueueException
|
||||
from .exceptions import QueueFull
|
||||
from .objects import Track
|
||||
|
||||
__all__ = (
|
||||
"Queue",
|
||||
)
|
||||
|
||||
|
||||
class Queue(Iterable[Track]):
|
||||
"""Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling."""
|
||||
|
||||
__slots__ = (
|
||||
"max_size",
|
||||
"_queue",
|
||||
"_overflow",
|
||||
"_loop_mode",
|
||||
"_current_item",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: Optional[int] = None,
|
||||
*,
|
||||
overflow: bool = True,
|
||||
):
|
||||
__slots__ = ("max_size", "_queue", "_overflow", "_loop_mode", "_current_item")
|
||||
|
||||
self.max_size: Optional[int] = max_size
|
||||
self._queue: List[Track] = [] # type: ignore
|
||||
self._current_item: Track
|
||||
self._queue: List[Track] = []
|
||||
self._overflow: bool = overflow
|
||||
self._loop_mode: Optional[LoopMode] = None
|
||||
self._current_item: Optional[Track] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String showing all Track objects appearing as a list."""
|
||||
|
|
@ -60,7 +72,7 @@ class Queue(Iterable[Track]):
|
|||
|
||||
return self._queue[index]
|
||||
|
||||
def __setitem__(self, index: int, item: Track):
|
||||
def __setitem__(self, index: int, item: Track) -> None:
|
||||
"""Inserts an item at given position."""
|
||||
if not isinstance(index, int):
|
||||
raise ValueError("'int' type required.'")
|
||||
|
|
@ -90,7 +102,9 @@ class Queue(Iterable[Track]):
|
|||
The new queue will have the same max_size as the original.
|
||||
"""
|
||||
if not isinstance(other, Iterable):
|
||||
raise TypeError(f"Adding with the '{type(other)}' type is not supported.")
|
||||
raise TypeError(
|
||||
f"Adding with the '{type(other)}' type is not supported.",
|
||||
)
|
||||
|
||||
new_queue = self.copy()
|
||||
new_queue.extend(other)
|
||||
|
|
@ -106,7 +120,9 @@ class Queue(Iterable[Track]):
|
|||
self.extend(other)
|
||||
return self
|
||||
|
||||
raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.")
|
||||
raise TypeError(
|
||||
f"Adding '{type(other)}' type to the queue is not supported.",
|
||||
)
|
||||
|
||||
def _get(self) -> Track:
|
||||
return self._queue.pop(0)
|
||||
|
|
@ -165,7 +181,7 @@ class Queue(Iterable[Track]):
|
|||
return bool(self._loop_mode)
|
||||
|
||||
@property
|
||||
def loop_mode(self) -> LoopMode:
|
||||
def loop_mode(self) -> Optional[LoopMode]:
|
||||
"""Returns the LoopMode enum set in the queue object"""
|
||||
return self._loop_mode
|
||||
|
||||
|
|
@ -178,7 +194,7 @@ class Queue(Iterable[Track]):
|
|||
"""Returns the queue as a List"""
|
||||
return self._queue
|
||||
|
||||
def get(self):
|
||||
def get(self) -> Track:
|
||||
"""Return next immediately available item in queue if any.
|
||||
Raises QueueEmpty if no items in queue.
|
||||
"""
|
||||
|
|
@ -239,7 +255,9 @@ class Queue(Iterable[Track]):
|
|||
"""Put the given item into the back of the queue."""
|
||||
if self.is_full:
|
||||
if not self._overflow:
|
||||
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.")
|
||||
raise QueueFull(
|
||||
f"Queue max_size of {self.max_size} has been reached.",
|
||||
)
|
||||
|
||||
self._drop()
|
||||
|
||||
|
|
@ -249,7 +267,9 @@ class Queue(Iterable[Track]):
|
|||
"""Put the given item into the queue at the specified index."""
|
||||
if self.is_full:
|
||||
if not self._overflow:
|
||||
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.")
|
||||
raise QueueFull(
|
||||
f"Queue max_size of {self.max_size} has been reached.",
|
||||
)
|
||||
|
||||
self._drop()
|
||||
|
||||
|
|
@ -275,7 +295,7 @@ class Queue(Iterable[Track]):
|
|||
if (new_len + self.count) > self.max_size:
|
||||
raise QueueFull(
|
||||
f"Queue has {self.count}/{self.max_size} items, "
|
||||
f"cannot add {new_len} more."
|
||||
f"cannot add {new_len} more.",
|
||||
)
|
||||
|
||||
for item in iterable:
|
||||
|
|
@ -292,7 +312,7 @@ class Queue(Iterable[Track]):
|
|||
"""Remove all items from the queue."""
|
||||
self._queue.clear()
|
||||
|
||||
def set_loop_mode(self, mode: LoopMode):
|
||||
def set_loop_mode(self, mode: LoopMode) -> None:
|
||||
"""
|
||||
Sets the loop mode of the queue.
|
||||
Takes the LoopMode enum as an argument.
|
||||
|
|
@ -307,7 +327,7 @@ class Queue(Iterable[Track]):
|
|||
self._queue.insert(index, self._current_item)
|
||||
self._current_item = self._queue[index]
|
||||
|
||||
def disable_loop(self):
|
||||
def disable_loop(self) -> None:
|
||||
"""
|
||||
Disables loop mode if set.
|
||||
Raises QueueException if loop mode is already None.
|
||||
|
|
@ -321,17 +341,17 @@ class Queue(Iterable[Track]):
|
|||
|
||||
self._loop_mode = None
|
||||
|
||||
def shuffle(self):
|
||||
def shuffle(self) -> None:
|
||||
"""Shuffles the queue."""
|
||||
return random.shuffle(self._queue)
|
||||
|
||||
def clear_track_filters(self):
|
||||
def clear_track_filters(self) -> None:
|
||||
"""Clears all filters applied to tracks"""
|
||||
for track in self._queue:
|
||||
track.filters = None
|
||||
|
||||
def jump(self, item: Track):
|
||||
def jump(self, item: Track) -> None:
|
||||
"""Removes all tracks before the."""
|
||||
index = self.find_position(item)
|
||||
new_queue = self._queue[index : self.size]
|
||||
new_queue = self._queue[index: self.size]
|
||||
self._queue = new_queue
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pool import Node
|
||||
|
||||
from .utils import RouteStats
|
||||
from aiohttp import ClientSession
|
||||
|
||||
__all__ = ("RoutePlanner",)
|
||||
|
||||
|
||||
class RoutePlanner:
|
||||
|
|
@ -16,17 +18,16 @@ class RoutePlanner:
|
|||
|
||||
def __init__(self, node: Node) -> None:
|
||||
self.node: Node = node
|
||||
self.session: ClientSession = node._session
|
||||
|
||||
async def get_status(self) -> RouteStats:
|
||||
"""Gets the status of the route planner API."""
|
||||
data: dict = await self.node.send(method="GET", path="routeplanner/status")
|
||||
return RouteStats(data)
|
||||
|
||||
async def free_address(self, ip: str):
|
||||
async def free_address(self, ip: str) -> None:
|
||||
"""Frees an address using the route planner API"""
|
||||
await self.node.send(method="POST", path="routeplanner/free/address", data={"address": ip})
|
||||
|
||||
async def free_all_addresses(self):
|
||||
async def free_all_addresses(self) -> None:
|
||||
"""Frees all available addresses using the route planner api"""
|
||||
await self.node.send(method="POST", path="routeplanner/free/address/all")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""Spotify module for Pomice, made possible by cloudwithax 2023"""
|
||||
|
||||
from .client import Client
|
||||
from .exceptions import *
|
||||
from .objects import *
|
||||
from .client import Client
|
||||
|
|
|
|||
|
|
@ -2,19 +2,27 @@ from __future__ import annotations
|
|||
|
||||
import re
|
||||
import time
|
||||
from base64 import b64encode
|
||||
from typing import Dict, List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import aiohttp
|
||||
import orjson as json
|
||||
|
||||
from base64 import b64encode
|
||||
from typing import TYPE_CHECKING
|
||||
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
|
||||
from .exceptions import InvalidSpotifyURL
|
||||
from .exceptions import SpotifyRequestException
|
||||
from .objects import *
|
||||
|
||||
__all__ = (
|
||||
"Client",
|
||||
)
|
||||
|
||||
|
||||
GRANT_URL = "https://accounts.spotify.com/api/token"
|
||||
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
|
||||
SPOTIFY_URL_REGEX = re.compile(
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)"
|
||||
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -24,17 +32,21 @@ class Client:
|
|||
for any Spotify URL you throw at it.
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
def __init__(self, client_id: int, client_secret: str) -> None:
|
||||
self._client_id: int = client_id
|
||||
self._client_secret: str = client_secret
|
||||
|
||||
self.session: aiohttp.ClientSession = None
|
||||
self.session: aiohttp.ClientSession = None # type: ignore
|
||||
|
||||
self._bearer_token: str = None
|
||||
self._expiry = 0
|
||||
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
|
||||
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
|
||||
self._bearer_headers = None
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._expiry: float = 0.0
|
||||
self._auth_token = b64encode(
|
||||
f"{self._client_id}:{self._client_secret}".encode(),
|
||||
)
|
||||
self._grant_headers = {
|
||||
"Authorization": f"Basic {self._auth_token.decode()}",
|
||||
}
|
||||
self._bearer_headers: Optional[Dict] = None
|
||||
|
||||
async def _fetch_bearer_token(self) -> None:
|
||||
_data = {"grant_type": "client_credentials"}
|
||||
|
|
@ -45,32 +57,34 @@ class Client:
|
|||
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error fetching bearer token: {resp.status} {resp.reason}"
|
||||
f"Error fetching bearer token: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
|
||||
self._bearer_token = data["access_token"]
|
||||
self._expiry = time.time() + (int(data["expires_in"]) - 10)
|
||||
self._bearer_headers = {"Authorization": f"Bearer {self._bearer_token}"}
|
||||
self._bearer_headers = {
|
||||
"Authorization": f"Bearer {self._bearer_token}",
|
||||
}
|
||||
|
||||
async def search(self, *, query: str):
|
||||
async def search(self, *, query: str) -> Union[Track, Album, Artist, Playlist]:
|
||||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
await self._fetch_bearer_token()
|
||||
|
||||
result = SPOTIFY_URL_REGEX.match(query)
|
||||
spotify_type = result.group("type")
|
||||
spotify_id = result.group("id")
|
||||
|
||||
if not result:
|
||||
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
|
||||
|
||||
spotify_type = result.group("type")
|
||||
spotify_id = result.group("id")
|
||||
|
||||
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
|
||||
|
||||
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
|
|
@ -81,11 +95,11 @@ class Client:
|
|||
return Album(data)
|
||||
elif spotify_type == "artist":
|
||||
async with self.session.get(
|
||||
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers
|
||||
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
track_data: dict = await resp.json(loads=json.loads)
|
||||
|
|
@ -100,7 +114,7 @@ class Client:
|
|||
|
||||
if not len(tracks):
|
||||
raise SpotifyRequestException(
|
||||
"This playlist is empty and therefore cannot be queued."
|
||||
"This playlist is empty and therefore cannot be queued.",
|
||||
)
|
||||
|
||||
next_page_url = data["tracks"]["next"]
|
||||
|
|
@ -109,7 +123,7 @@ class Client:
|
|||
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
next_data: dict = await resp.json(loads=json.loads)
|
||||
|
|
@ -123,26 +137,30 @@ class Client:
|
|||
|
||||
return Playlist(data, tracks)
|
||||
|
||||
async def get_recommendations(self, *, query: str):
|
||||
async def get_recommendations(self, *, query: str) -> List[Track]:
|
||||
if not self._bearer_token or time.time() >= self._expiry:
|
||||
await self._fetch_bearer_token()
|
||||
|
||||
result = SPOTIFY_URL_REGEX.match(query)
|
||||
spotify_type = result.group("type")
|
||||
spotify_id = result.group("id")
|
||||
|
||||
if not result:
|
||||
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
|
||||
|
||||
if not spotify_type == "track":
|
||||
raise InvalidSpotifyURL("The provided query is not a Spotify track.")
|
||||
spotify_type = result.group("type")
|
||||
spotify_id = result.group("id")
|
||||
|
||||
request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}")
|
||||
if not spotify_type == "track":
|
||||
raise InvalidSpotifyURL(
|
||||
"The provided query is not a Spotify track.",
|
||||
)
|
||||
|
||||
request_url = REQUEST_URL.format(
|
||||
type="recommendation", id=f"?seed_tracks={spotify_id}",
|
||||
)
|
||||
|
||||
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
|
||||
if resp.status != 200:
|
||||
raise SpotifyRequestException(
|
||||
f"Error while fetching results: {resp.status} {resp.reason}"
|
||||
f"Error while fetching results: {resp.status} {resp.reason}",
|
||||
)
|
||||
|
||||
data: dict = await resp.json(loads=json.loads)
|
||||
|
|
@ -154,4 +172,4 @@ class Client:
|
|||
async def close(self) -> None:
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self.session = None # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
__all__ = (
|
||||
"SpotifyRequestException",
|
||||
"InvalidSpotifyURL",
|
||||
)
|
||||
|
||||
|
||||
class SpotifyRequestException(Exception):
|
||||
"""An error occurred when making a request to the Spotify API"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,36 @@
|
|||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
__all__ = (
|
||||
"Track",
|
||||
"Playlist",
|
||||
"Album",
|
||||
"Artist",
|
||||
)
|
||||
|
||||
|
||||
class Track:
|
||||
"""The base class for a Spotify Track"""
|
||||
|
||||
def __init__(self, data: dict, image=None) -> None:
|
||||
def __init__(self, data: dict, image: Optional[str] = None) -> None:
|
||||
self.name: str = data["name"]
|
||||
self.artists: str = ", ".join(artist["name"] for artist in data["artists"])
|
||||
self.artists: str = ", ".join(
|
||||
artist["name"] for artist in data["artists"]
|
||||
)
|
||||
self.length: float = data["duration_ms"]
|
||||
self.id: str = data["id"]
|
||||
|
||||
self.issrc: Optional[str] = None
|
||||
if data.get("external_ids"):
|
||||
self.isrc: str = data["external_ids"]["isrc"]
|
||||
else:
|
||||
self.isrc = None
|
||||
self.isrc = data["external_ids"]["isrc"]
|
||||
|
||||
self.image: Optional[str] = image
|
||||
if data.get("album") and data["album"].get("images"):
|
||||
self.image: str = data["album"]["images"][0]["url"]
|
||||
else:
|
||||
self.image: str = image
|
||||
self.image = data["album"]["images"][0]["url"]
|
||||
|
||||
if data["is_local"]:
|
||||
self.uri = None
|
||||
else:
|
||||
self.uri: str = data["external_urls"]["spotify"]
|
||||
self.uri: Optional[str] = None
|
||||
if not data["is_local"]:
|
||||
self.uri = data["external_urls"]["spotify"]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
|
@ -42,7 +49,7 @@ class Playlist:
|
|||
self.total_tracks: int = data["tracks"]["total"]
|
||||
self.id: str = data["id"]
|
||||
if data.get("images") and len(data["images"]):
|
||||
self.image: str = data["images"][0]["url"]
|
||||
self.image = data["images"][0]["url"]
|
||||
else:
|
||||
self.image = self.tracks[0].image
|
||||
self.uri = data["external_urls"]["spotify"]
|
||||
|
|
@ -59,9 +66,14 @@ class Album:
|
|||
|
||||
def __init__(self, data: dict) -> None:
|
||||
self.name: str = data["name"]
|
||||
self.artists: str = ", ".join(artist["name"] for artist in data["artists"])
|
||||
self.artists: str = ", ".join(
|
||||
artist["name"] for artist in data["artists"]
|
||||
)
|
||||
self.image: str = data["images"][0]["url"]
|
||||
self.tracks = [Track(track, image=self.image) for track in data["tracks"]["items"]]
|
||||
self.tracks = [
|
||||
Track(track, image=self.image)
|
||||
for track in data["tracks"]["items"]
|
||||
]
|
||||
self.total_tracks: int = data["total_tracks"]
|
||||
self.id: str = data["id"]
|
||||
self.uri: str = data["external_urls"]["spotify"]
|
||||
|
|
@ -78,7 +90,8 @@ class Artist:
|
|||
|
||||
def __init__(self, data: dict, tracks: dict) -> None:
|
||||
self.name: str = (
|
||||
f"Top tracks for {data['name']}" # Setting that because its only playing top tracks
|
||||
# Setting that because its only playing top tracks
|
||||
f"Top tracks for {data['name']}"
|
||||
)
|
||||
self.genres: str = ", ".join(genre for genre in data["genres"])
|
||||
self.followers: int = data["followers"]["total"]
|
||||
|
|
|
|||
121
pomice/utils.py
121
pomice/utils.py
|
|
@ -1,11 +1,24 @@
|
|||
import random
|
||||
import time
|
||||
import socket
|
||||
|
||||
from .enums import RouteStrategy, RouteIPType
|
||||
from timeit import default_timer as timer
|
||||
from itertools import zip_longest
|
||||
import time
|
||||
from datetime import datetime
|
||||
from itertools import zip_longest
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, Dict
|
||||
from typing import Callable
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from .enums import RouteIPType
|
||||
from .enums import RouteStrategy
|
||||
|
||||
__all__ = (
|
||||
"ExponentialBackoff",
|
||||
"NodeStats",
|
||||
"FailingIPBlock",
|
||||
"RouteStats",
|
||||
"Ping",
|
||||
)
|
||||
|
||||
|
||||
class ExponentialBackoff:
|
||||
|
|
@ -51,7 +64,7 @@ class ExponentialBackoff:
|
|||
self._exp = 0
|
||||
|
||||
self._exp = min(self._exp + 1, self._max)
|
||||
return self._randfunc(0, self._base * 2**self._exp)
|
||||
return self._randfunc(0, self._base * 2**self._exp) # type: ignore
|
||||
|
||||
|
||||
class NodeStats:
|
||||
|
|
@ -59,27 +72,28 @@ class NodeStats:
|
|||
Gives critical information on the node, which is updated every minute.
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
__slots__ = (
|
||||
"used",
|
||||
"free",
|
||||
"reservable",
|
||||
"allocated",
|
||||
"cpu_cores",
|
||||
"cpu_system_load",
|
||||
"cpu_process_load",
|
||||
"players_active",
|
||||
"players_total",
|
||||
"uptime",
|
||||
)
|
||||
__slots__ = (
|
||||
"used",
|
||||
"free",
|
||||
"reservable",
|
||||
"allocated",
|
||||
"cpu_cores",
|
||||
"cpu_system_load",
|
||||
"cpu_process_load",
|
||||
"players_active",
|
||||
"players_total",
|
||||
"uptime",
|
||||
)
|
||||
|
||||
memory: dict = data.get("memory")
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
|
||||
memory: dict = data.get("memory", {})
|
||||
self.used = memory.get("used")
|
||||
self.free = memory.get("free")
|
||||
self.reservable = memory.get("reservable")
|
||||
self.allocated = memory.get("allocated")
|
||||
|
||||
cpu: dict = data.get("cpu")
|
||||
cpu: dict = data.get("cpu", {})
|
||||
self.cpu_cores = cpu.get("cores")
|
||||
self.cpu_system_load = cpu.get("systemLoad")
|
||||
self.cpu_process_load = cpu.get("lavalinkLoad")
|
||||
|
|
@ -99,11 +113,14 @@ class FailingIPBlock:
|
|||
and the time they failed.
|
||||
"""
|
||||
|
||||
__slots__ = ("address", "failing_time")
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
__slots__ = ("address", "failing_time")
|
||||
|
||||
self.address = data.get("address")
|
||||
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp")))
|
||||
self.failing_time = datetime.fromtimestamp(
|
||||
float(data.get("failingTimestamp", 0)),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>"
|
||||
|
|
@ -115,17 +132,29 @@ class RouteStats:
|
|||
Gives critical information about the route planner strategy on the node.
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
__slots__ = ("strategy", "ip_block_type", "ip_block_size", "failing_addresses")
|
||||
__slots__ = (
|
||||
"strategy",
|
||||
"ip_block_type",
|
||||
"ip_block_size",
|
||||
"failing_addresses",
|
||||
"block_index",
|
||||
"address_index",
|
||||
)
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
|
||||
self.strategy = RouteStrategy(data.get("class"))
|
||||
|
||||
details: dict = data.get("details")
|
||||
details: dict = data.get("details", {})
|
||||
|
||||
ip_block: dict = details.get("ipBlock")
|
||||
ip_block: dict = details.get("ipBlock", {})
|
||||
self.ip_block_type = RouteIPType(ip_block.get("type"))
|
||||
self.ip_block_size = ip_block.get("size")
|
||||
self.failing_addresses = [FailingIPBlock(data) for data in details.get("failingAddresses")]
|
||||
self.failing_addresses = [
|
||||
FailingIPBlock(
|
||||
data,
|
||||
) for data in details.get("failingAddresses", [])
|
||||
]
|
||||
|
||||
self.block_index = details.get("blockIndex")
|
||||
self.address_index = details.get("currentAddressIndex")
|
||||
|
|
@ -136,7 +165,7 @@ class RouteStats:
|
|||
|
||||
class Ping:
|
||||
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
|
||||
def __init__(self, host, port, timeout=5):
|
||||
def __init__(self, host: str, port: int, timeout: int = 5) -> None:
|
||||
self.timer = self.Timer()
|
||||
|
||||
self._successed = 0
|
||||
|
|
@ -146,33 +175,33 @@ class Ping:
|
|||
self._port = port
|
||||
self._timeout = timeout
|
||||
|
||||
class Socket(object):
|
||||
def __init__(self, family, type_, timeout):
|
||||
class Socket:
|
||||
def __init__(self, family: int, type_: int, timeout: Optional[float]) -> None:
|
||||
s = socket.socket(family, type_)
|
||||
s.settimeout(timeout)
|
||||
self._s = s
|
||||
|
||||
def connect(self, host, port):
|
||||
self._s.connect((host, int(port)))
|
||||
def connect(self, host: str, port: int) -> None:
|
||||
self._s.connect((host, port))
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
self._s.shutdown(socket.SHUT_RD)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._s.close()
|
||||
|
||||
class Timer(object):
|
||||
def __init__(self):
|
||||
self._start = 0
|
||||
self._stop = 0
|
||||
class Timer:
|
||||
def __init__(self) -> None:
|
||||
self._start: float = 0.0
|
||||
self._stop: float = 0.0
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
self._start = timer()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self._stop = timer()
|
||||
|
||||
def cost(self, funcs, args):
|
||||
def cost(self, funcs: Iterable[Callable], args: Any) -> float:
|
||||
self.start()
|
||||
for func, arg in zip_longest(funcs, args):
|
||||
if arg:
|
||||
|
|
@ -183,13 +212,15 @@ class Ping:
|
|||
self.stop()
|
||||
return self._stop - self._start
|
||||
|
||||
def _create_socket(self, family, type_):
|
||||
def _create_socket(self, family: int, type_: int) -> Socket:
|
||||
return self.Socket(family, type_, self._timeout)
|
||||
|
||||
def get_ping(self):
|
||||
def get_ping(self) -> float:
|
||||
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
cost_time = self.timer.cost((s.connect, s.shutdown), ((self._host, self._port), None))
|
||||
cost_time = self.timer.cost(
|
||||
(s.connect, s.shutdown), ((self._host, self._port), None),
|
||||
)
|
||||
s_runtime = 1000 * (cost_time)
|
||||
|
||||
return s_runtime
|
||||
|
|
|
|||
|
|
@ -7,3 +7,13 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
[tool.mypy]
|
||||
mypy_path = "./"
|
||||
files = ["pomice"]
|
||||
disallow_untyped_defs = true
|
||||
disallow_any_unimported = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
show_error_codes = true
|
||||
|
|
|
|||
11
setup.py
11
setup.py
|
|
@ -1,10 +1,13 @@
|
|||
import setuptools
|
||||
import re
|
||||
|
||||
import setuptools
|
||||
|
||||
version = ""
|
||||
requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
|
||||
with open("pomice/__init__.py") as f:
|
||||
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
|
||||
version = re.search(
|
||||
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE,
|
||||
).group(1)
|
||||
|
||||
if not version:
|
||||
raise RuntimeError("version is not set")
|
||||
|
|
@ -15,13 +18,13 @@ if version.endswith(("a", "b", "rc")):
|
|||
import subprocess
|
||||
|
||||
p = subprocess.Popen(
|
||||
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
)
|
||||
out, err = p.communicate()
|
||||
if out:
|
||||
version += out.decode("utf-8").strip()
|
||||
p = subprocess.Popen(
|
||||
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
)
|
||||
out, err = p.communicate()
|
||||
if out:
|
||||
|
|
|
|||
Loading…
Reference in New Issue