Merge pull request #35 from NiceAesth/typing

feat: add typing; add makefile; add pipfile
This commit is contained in:
Clxud 2023-03-11 10:21:23 -05:00 committed by GitHub
commit 3949c5b1a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 1106 additions and 728 deletions

4
.gitignore vendored
View File

@ -7,3 +7,7 @@ docs/_build/
build/ build/
.gitpod.yml .gitpod.yml
.python-verson .python-verson
Pipfile.lock
.mypy_cache/
.vscode/
.venv/

36
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,36 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-ast
- id: check-builtin-literals
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v2.0.2
hooks:
- id: autopep8
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py37-plus, --keep-runtime-typing]
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.9.0
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.4.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
hooks:
- id: pycln
default_language_version:
python: python3.8

18
Makefile Normal file
View File

@ -0,0 +1,18 @@
prepare:
pipenv install --dev
pipenv run pre-commit install
shell:
pipenv shell
lint:
pipenv run pre-commit run --all-files
test:
pipenv run mypy
serve-docs:
@cd docs;\
make html;\
cd _build/html;\
python -m http.server;\

18
Pipfile Normal file
View File

@ -0,0 +1,18 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
orjson = "*"
"discord.py" = {extras = ["voice"], version = "*"}
[dev-packages]
mypy = "*"
pre-commit = "*"
furo = "*"
sphinx = "*"
myst-parser = "*"
[requires]
python_version = "3.8"

View File

@ -2,12 +2,12 @@ import importlib
import inspect import inspect
import os import os
import sys import sys
from typing import Any, Dict from typing import Any
from typing import Dict
sys.path.insert(0, os.path.abspath('.')) sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
project = 'Pomice' project = 'Pomice'
copyright = '2023, cloudwithax' copyright = '2023, cloudwithax'
author = 'cloudwithax' author = 'cloudwithax'
@ -19,7 +19,7 @@ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.autosummary', 'sphinx.ext.autosummary',
'sphinx.ext.linkcode', 'sphinx.ext.linkcode',
'myst_parser' 'myst_parser',
] ]
myst_enable_extensions = [ myst_enable_extensions = [
@ -84,6 +84,7 @@ html_theme_options: Dict[str, Any] = {
# Grab lines from source files and embed into the docs # Grab lines from source files and embed into the docs
# so theres a point of reference # so theres a point of reference
def linkcode_resolve(domain, info): def linkcode_resolve(domain, info):
# i absolutely MUST add this here or else # i absolutely MUST add this here or else
# the docs will not build. fuck sphinx # the docs will not build. fuck sphinx
@ -93,7 +94,6 @@ def linkcode_resolve(domain, info):
if not info['module']: if not info['module']:
return None return None
mod = importlib.import_module(info["module"]) mod = importlib.import_module(info["module"])
if "." in info["fullname"]: if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".") objname, attrname = info["fullname"].split(".")
@ -117,4 +117,3 @@ def linkcode_resolve(domain, info):
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}" return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"
except: except:
pass pass

View File

@ -47,6 +47,3 @@ remote, or the node.
`Event.WebsocketOpenEvent()` carries a target, which is usually the node IP, and the SSRC, a 32-bit integer uniquely identifying the source of the RTP packets sent from `Event.WebsocketOpenEvent()` carries a target, which is usually the node IP, and the SSRC, a 32-bit integer uniquely identifying the source of the RTP packets sent from
Lavalink. Lavalink.

View File

@ -184,4 +184,3 @@ After you have initialized your function, you can optionally include the `fast_a
await Player.reset_filters(fast_apply=<True/False>) await Player.reset_filters(fast_apply=<True/False>)
``` ```

View File

@ -13,5 +13,4 @@ player.md
filters.md filters.md
queue.md queue.md
events.md events.md
`` ```

View File

@ -19,7 +19,7 @@ There are also properties the `Node` class has to access certain values:
- Description - Description
* - `Node.bot` * - `Node.bot`
- `Union[Client, Bot]` - `Client`
- Returns the discord.py client linked to this node. - Returns the discord.py client linked to this node.
* - `Node.is_connected` * - `Node.is_connected`

View File

@ -28,7 +28,7 @@ There are also properties the `Player` class has to access certain values:
- Description - Description
* - `Player.bot` * - `Player.bot`
- `Union[Client, commands.Bot]` - `Client`
- Returns the bot associated with this player instance. - Returns the bot associated with this player instance.
* - `Player.current` * - `Player.current`
@ -466,15 +466,3 @@ After you have initialized your function, you can optionally include the `fast_a
await Player.reset_filters(fast_apply=<True/False>) await Player.reset_filters(fast_apply=<True/False>)
``` ```

View File

@ -148,15 +148,3 @@ await NodePool.disconnect()
``` ```
After running this function, all nodes in the pool should disconnect and no longer be available to use. After running this function, all nodes in the pool should disconnect and no longer be available to use.

View File

@ -222,13 +222,3 @@ Your `Track` object must be in the queue if you want to jump to it. Make sure yo
::: :::
After running this function, any items before the specified item will be removed, effectively "jumping" to the specified item in the queue. The next item obtained using `Queue.get()` will be your specified track. After running this function, any items before the specified item will be removed, effectively "jumping" to the specified item in the queue. The next item obtained using `Queue.get()` will be your specified track.

View File

@ -38,7 +38,3 @@ hdi/index.md
:hidden: :hidden:
api/index.md api/index.md
``` ```

View File

@ -28,5 +28,3 @@ You are free to use this as a base to add on to for any music features you want
If you want to jump into the library and learn how to do everything you need, refer to the [How Do I?](hdi/index.md) section. If you want to jump into the library and learn how to do everything you need, refer to the [How Do I?](hdi/index.md) section.
If you want a deeper look into how the library works beyond the [How Do I?](hdi/index.md) guide, refer to the [API Reference](api/index.md) section. If you want a deeper look into how the library works beyond the [How Do I?](hdi/index.md) guide, refer to the [API Reference](api/index.md) section.

View File

@ -1,5 +1,5 @@
discord.py[voice]
aiohttp aiohttp
orjson discord.py[voice]
myst_parser
furo furo
myst_parser
orjson

View File

@ -4,13 +4,13 @@ This is in the form of a drop-in cog you can use and modify to your liking.
This example aims to include everything you would need to make a fully functioning music bot, This example aims to include everything you would need to make a fully functioning music bot,
from a queue system, advanced queue control and more. from a queue system, advanced queue control and more.
""" """
import math
from contextlib import suppress
import discord import discord
import pomice
import math
from discord.ext import commands from discord.ext import commands
from contextlib import suppress
import pomice
class Player(pomice.Player): class Player(pomice.Player):
@ -44,7 +44,6 @@ class Player(pomice.Player):
with suppress(discord.HTTPException): with suppress(discord.HTTPException):
await self.controller.delete() await self.controller.delete()
# Queue up the next track, else teardown the player # Queue up the next track, else teardown the player
try: try:
track: pomice.Track = self.queue.get() track: pomice.Track = self.queue.get()
@ -56,13 +55,16 @@ class Player(pomice.Player):
# Call the controller (a.k.a: The "Now Playing" embed) and check if one exists # Call the controller (a.k.a: The "Now Playing" embed) and check if one exists
if track.is_stream: if track.is_stream:
embed = discord.Embed(title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]") embed = discord.Embed(
title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]",
)
self.controller = await self.context.send(embed=embed) self.controller = await self.context.send(embed=embed)
else: else:
embed = discord.Embed(title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]") embed = discord.Embed(
title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]",
)
self.controller = await self.context.send(embed=embed) self.controller = await self.context.send(embed=embed)
async def teardown(self): async def teardown(self):
"""Clear internal states, remove player controller and disconnect.""" """Clear internal states, remove player controller and disconnect."""
with suppress((discord.HTTPException), (KeyError)): with suppress((discord.HTTPException), (KeyError)):
@ -76,8 +78,6 @@ class Player(pomice.Player):
self.dj = ctx.author self.dj = ctx.author
class Music(commands.Cog): class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None: def __init__(self, bot: commands.Bot) -> None:
self.bot = bot self.bot = bot
@ -100,7 +100,7 @@ class Music(commands.Cog):
host="127.0.0.1", host="127.0.0.1",
port="3030", port="3030",
password="youshallnotpass", password="youshallnotpass",
identifier="MAIN" identifier="MAIN",
) )
print(f"Node is ready!") print(f"Node is ready!")
@ -122,7 +122,6 @@ class Music(commands.Cog):
return player.dj == ctx.author or ctx.author.guild_permissions.kick_members return player.dj == ctx.author or ctx.author.guild_permissions.kick_members
# The following are events from pomice.events # The following are events from pomice.events
# We are using these so that if the track either stops or errors, # We are using these so that if the track either stops or errors,
# we can just skip to the next track # we can just skip to the next track
@ -195,8 +194,6 @@ class Music(commands.Cog):
if not player.is_playing: if not player.is_playing:
await player.do_next() await player.do_next()
@commands.command(aliases=['pau', 'pa']) @commands.command(aliases=['pau', 'pa'])
async def pause(self, ctx: commands.Context): async def pause(self, ctx: commands.Context):
"""Pause the currently playing song.""" """Pause the currently playing song."""
@ -345,6 +342,6 @@ class Music(commands.Cog):
await player.set_volume(vol) await player.set_volume(vol)
await ctx.send(f'Set the volume to **{vol}**%', delete_after=7) await ctx.send(f'Set the volume to **{vol}**%', delete_after=7)
async def setup(bot: commands.Bot): async def setup(bot: commands.Bot):
await bot.add_cog(Music(bot)) await bot.add_cog(Music(bot))

View File

@ -1,13 +1,16 @@
import discord import discord
import pomice
from discord.ext import commands from discord.ext import commands
import pomice
class MyBot(commands.Bot): class MyBot(commands.Bot):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__( super().__init__(
command_prefix="!", command_prefix="!",
activity=discord.Activity(type=discord.ActivityType.listening, name="to music!") activity=discord.Activity(
type=discord.ActivityType.listening, name="to music!",
),
) )
self.add_cog(Music(self)) self.add_cog(Music(self))
@ -33,7 +36,7 @@ class Music(commands.Cog):
host="127.0.0.1", host="127.0.0.1",
port="3030", port="3030",
password="youshallnotpass", password="youshallnotpass",
identifier="MAIN" identifier="MAIN",
) )
print(f"Node is ready!") print(f"Node is ready!")
@ -44,7 +47,7 @@ class Music(commands.Cog):
if not channel: if not channel:
raise commands.CheckFailure( raise commands.CheckFailure(
"You must be in a voice channel to use this command " "You must be in a voice channel to use this command "
"without specifying the channel argument." "without specifying the channel argument.",
) )
# With the release of discord.py 1.7, you can now add a compatible # With the release of discord.py 1.7, you can now add a compatible
@ -78,7 +81,9 @@ class Music(commands.Cog):
results = await player.get_tracks(search) results = await player.get_tracks(search)
if not results: if not results:
raise commands.CommandError("No results were found for that search term.") raise commands.CommandError(
"No results were found for that search term.",
)
if isinstance(results, pomice.Playlist): if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0]) await player.play(track=results.tracks[0])

View File

@ -17,7 +17,7 @@ if not discord.version_info.major >= 2:
raise DiscordPyOutdated( raise DiscordPyOutdated(
"You must have discord.py (v2.0 or greater) to use this library. " "You must have discord.py (v2.0 or greater) to use this library. "
"Uninstall your current version and install discord.py 2.0 " "Uninstall your current version and install discord.py 2.0 "
"using 'pip install discord.py'" "using 'pip install discord.py'",
) )
__version__ = "2.2a" __version__ = "2.2a"

View File

@ -1,5 +1,4 @@
"""Apple Music module for Pomice, made possible by cloudwithax 2023""" """Apple Music module for Pomice, made possible by cloudwithax 2023"""
from .client import Client
from .exceptions import * from .exceptions import *
from .objects import * from .objects import *
from .client import Client

View File

@ -1,19 +1,27 @@
from __future__ import annotations from __future__ import annotations
import base64
import re import re
from datetime import datetime
from typing import Dict
from typing import List
from typing import Union
import aiohttp import aiohttp
import orjson as json import orjson as json
import base64
from datetime import datetime
from .objects import *
from .exceptions import * from .exceptions import *
from .objects import *
__all__ = (
"Client",
)
AM_URL_REGEX = re.compile( AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
) )
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
) )
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}" AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
AM_BASE_URL = "https://api.music.apple.com" AM_BASE_URL = "https://api.music.apple.com"
@ -26,37 +34,49 @@ class Client:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.token: str = None self.expiry: datetime = datetime(1970, 1, 1)
self.expiry: datetime = None self.token: str = ""
self.session: aiohttp.ClientSession = None self.headers: Dict[str, str] = {}
self.headers = None self.session: aiohttp.ClientSession = None # type: ignore
async def request_token(self): async def request_token(self) -> None:
if not self.session: if not self.session:
self.session = aiohttp.ClientSession() self.session = aiohttp.ClientSession()
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp: async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
text = await resp.text() text = await resp.text()
result = re.search('"(eyJ.+?)"', text).group(1) match = re.search('"(eyJ.+?)"', text)
if not match:
raise AppleMusicRequestException(
"Could not find token in response.",
)
result = match.group(1)
self.token = result self.token = result
self.headers = { self.headers = {
"Authorization": f"Bearer {result}", "Authorization": f"Bearer {result}",
"Origin": "https://apple.com", "Origin": "https://apple.com",
} }
token_split = self.token.split(".")[1] token_split = self.token.split(".")[1]
token_json = base64.b64decode(token_split + "=" * (-len(token_split) % 4)).decode() token_json = base64.b64decode(
token_split + "=" * (-len(token_split) % 4),
).decode()
token_data = json.loads(token_json) token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"]) self.expiry = datetime.fromtimestamp(token_data["exp"])
async def search(self, query: str): async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry: if not self.token or datetime.utcnow() > self.expiry:
await self.request_token() await self.request_token()
result = AM_URL_REGEX.match(query) result = AM_URL_REGEX.match(query)
if not result:
raise InvalidAppleMusicURL(
"The Apple Music link provided is not valid.",
)
country = result.group("country") country = result.group("country")
type = result.group("type") type = result.group("type")
@ -75,7 +95,7 @@ class Client:
async with self.session.get(request_url, headers=self.headers) as resp: async with self.session.get(request_url, headers=self.headers) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
@ -84,53 +104,57 @@ class Client:
if type == "song": if type == "song":
return Song(data) return Song(data)
elif type == "album": if type == "album":
return Album(data) return Album(data)
elif type == "artist": if type == "artist":
async with self.session.get( async with self.session.get(
f"{request_url}/view/top-songs", headers=self.headers f"{request_url}/view/top-songs", headers=self.headers,
) as resp: ) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
top_tracks: dict = await resp.json(loads=json.loads) top_tracks: dict = await resp.json(loads=json.loads)
tracks: dict = top_tracks["data"] artist_tracks: dict = top_tracks["data"]
return Artist(data, tracks=tracks) return Artist(data, tracks=artist_tracks)
else:
track_data: dict = data["relationships"]["tracks"] track_data: dict = data["relationships"]["tracks"]
album_tracks: List[Song] = [
Song(track)
for track in track_data["data"]
]
tracks = [Song(track) for track in track_data.get("data")] if not len(album_tracks):
if not len(tracks):
raise AppleMusicRequestException( raise AppleMusicRequestException(
"This playlist is empty and therefore cannot be queued." "This playlist is empty and therefore cannot be queued.",
) )
if track_data.get("next"): _next = track_data.get("next")
next_page_url = AM_BASE_URL + track_data.get("next") if _next:
next_page_url = AM_BASE_URL + _next
while next_page_url is not None: while next_page_url is not None:
async with self.session.get(next_page_url, headers=self.headers) as resp: async with self.session.get(next_page_url, headers=self.headers) as resp:
if resp.status != 200: if resp.status != 200:
raise AppleMusicRequestException( raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
next_data: dict = await resp.json(loads=json.loads) next_data: dict = await resp.json(loads=json.loads)
tracks += [Song(track) for track in next_data["data"]] album_tracks.extend(Song(track) for track in next_data["data"])
if next_data.get("next"):
next_page_url = AM_BASE_URL + next_data.get("next") _next = next_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
else: else:
next_page_url = None next_page_url = None
return Playlist(data, tracks) return Playlist(data, album_tracks)
async def close(self): async def close(self) -> None:
if self.session: if self.session:
await self.session.close() await self.session.close()
self.session = None self.session = None # type: ignore

View File

@ -1,3 +1,9 @@
__all__ = (
"AppleMusicRequestException",
"InvalidAppleMusicURL",
)
class AppleMusicRequestException(Exception): class AppleMusicRequestException(Exception):
"""An error occurred when making a request to the Apple Music API""" """An error occurred when making a request to the Apple Music API"""

View File

@ -1,7 +1,13 @@
"""Module for managing Apple Music objects""" """Module for managing Apple Music objects"""
from typing import List from typing import List
__all__ = (
"Song",
"Playlist",
"Album",
"Artist",
)
class Song: class Song:
"""The base class for an Apple Music song""" """The base class for an Apple Music song"""
@ -55,7 +61,9 @@ class Album:
self.id: str = data["id"] self.id: str = data["id"]
self.artists: str = data["attributes"]["artistName"] self.artists: str = data["attributes"]["artistName"]
self.total_tracks: int = data["attributes"]["trackCount"] self.total_tracks: int = data["attributes"]["trackCount"]
self.tracks: List[Song] = [Song(track) for track in data["relationships"]["tracks"]["data"]] self.tracks: List[Song] = [
Song(track) for track in data["relationships"]["tracks"]["data"]
]
self.image: str = data["attributes"]["artwork"]["url"].replace( self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}", "{w}x{h}",
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}', f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
@ -75,7 +83,9 @@ class Artist:
self.name: str = f'Top tracks for {data["attributes"]["name"]}' self.name: str = f'Top tracks for {data["attributes"]["name"]}'
self.url: str = data["attributes"]["url"] self.url: str = data["attributes"]["url"]
self.id: str = data["id"] self.id: str = data["id"]
self.genres: str = ", ".join(genre for genre in data["attributes"]["genreNames"]) self.genres: str = ", ".join(
genre for genre in data["attributes"]["genreNames"]
)
self.tracks: List[Song] = [Song(track) for track in tracks] self.tracks: List[Song] = [Song(track) for track in tracks]
self.image: str = data["attributes"]["artwork"]["url"].replace( self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}", "{w}x{h}",

View File

@ -1,7 +1,17 @@
import re import re
from enum import Enum from enum import Enum
__all__ = (
"SearchType",
"TrackType",
"PlaylistType",
"NodeAlgorithm",
"LoopMode",
"RouteStrategy",
"RouteIPType",
"URLRegex",
)
class SearchType(Enum): class SearchType(Enum):
""" """
@ -185,43 +195,51 @@ class URLRegex:
""" """
SPOTIFY_URL = re.compile( SPOTIFY_URL = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)" r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
) )
DISCORD_MP3_URL = re.compile( DISCORD_MP3_URL = re.compile(
r"https?://cdn.discordapp.com/attachments/(?P<channel_id>[0-9]+)/" r"https?://cdn.discordapp.com/attachments/(?P<channel_id>[0-9]+)/"
r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+" r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+",
) )
YOUTUBE_URL = re.compile( YOUTUBE_URL = re.compile(
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))" r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))"
r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$" r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$",
) )
YOUTUBE_PLAYLIST_URL = re.compile( YOUTUBE_PLAYLIST_URL = re.compile(
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*" r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
) )
YOUTUBE_VID_IN_PLAYLIST = re.compile(r"(?P<video>^.*?v.*?)(?P<list>&list.*)") YOUTUBE_VID_IN_PLAYLIST = re.compile(
r"(?P<video>^.*?v.*?)(?P<list>&list.*)",
)
YOUTUBE_TIMESTAMP = re.compile(r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*") YOUTUBE_TIMESTAMP = re.compile(
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
)
AM_URL = re.compile( AM_URL = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)" r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
) )
AM_SINGLE_IN_ALBUM_REGEX = re.compile( AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/" r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)" r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
) )
SOUNDCLOUD_URL = re.compile(r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*") SOUNDCLOUD_URL = re.compile(
r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*",
)
SOUNDCLOUD_PLAYLIST_URL = re.compile(r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*") SOUNDCLOUD_PLAYLIST_URL = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*",
)
SOUNDCLOUD_TRACK_IN_SET_URL = re.compile( SOUNDCLOUD_TRACK_IN_SET_URL = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)" r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)",
) )
LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:") LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")

View File

@ -1,18 +1,34 @@
from __future__ import annotations from __future__ import annotations
from discord import Client, Guild
from abc import ABC
from typing import Any
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from discord import Client
from discord import Guild
from discord.ext import commands from discord.ext import commands
from .pool import NodePool
from .objects import Track from .objects import Track
from .pool import NodePool
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
__all__ = (
"PomiceEvent",
"TrackStartEvent",
"TrackEndEvent",
"TrackStuckEvent",
"TrackExceptionEvent",
"WebSocketClosedPayload",
"WebSocketClosedEvent",
"WebSocketOpenEvent",
)
class PomiceEvent:
class PomiceEvent(ABC):
"""The base class for all events dispatched by a node. """The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener. Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be: i.e: If you want to listen for when a track starts, the event would be:
@ -23,9 +39,9 @@ class PomiceEvent:
""" """
name = "event" name = "event"
handler_args = () handler_args: Tuple
def dispatch(self, bot: Union[Client, commands.Bot]): def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", *self.handler_args) bot.dispatch(f"pomice_{self.name}", *self.handler_args)
@ -36,10 +52,15 @@ class TrackStartEvent(PomiceEvent):
name = "track_start" name = "track_start"
__slots__ = (
"player",
"track",
)
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track")
self.player: Player = player self.player: Player = player
assert self.player._current is not None
self.track: Track = self.player._current self.track: Track = self.player._current
# on_pomice_track_start(player, track) # on_pomice_track_start(player, track)
@ -56,10 +77,12 @@ class TrackEndEvent(PomiceEvent):
name = "track_end" name = "track_end"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "reason") __slots__ = ("player", "track", "reason")
def __init__(self, data: dict, player: Player):
self.player: Player = player self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
self.reason: str = data["reason"] self.reason: str = data["reason"]
@ -81,10 +104,12 @@ class TrackStuckEvent(PomiceEvent):
name = "track_stuck" name = "track_stuck"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "threshold") __slots__ = ("player", "track", "threshold")
def __init__(self, data: dict, player: Player):
self.player: Player = player self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"] self.threshold: float = data["thresholdMs"]
@ -105,17 +130,17 @@ class TrackExceptionEvent(PomiceEvent):
name = "track_exception" name = "track_exception"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "exception") __slots__ = ("player", "track", "exception")
def __init__(self, data: dict, player: Player):
self.player: Player = player self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
if data.get("error"): # Error is for Lavalink <= 3.3
# User is running Lavalink <= 3.3 self.exception: str = data.get(
self.exception: str = data["error"] "error", "",
else: ) or data.get("exception", "")
# User is running Lavalink >=3.4
self.exception: str = data["exception"]
# on_pomice_track_exception(player, track, error) # on_pomice_track_exception(player, track, error)
self.handler_args = self.player, self.track, self.exception self.handler_args = self.player, self.track, self.exception
@ -125,10 +150,12 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload: class WebSocketClosedPayload:
def __init__(self, data: dict):
__slots__ = ("guild", "code", "reason", "by_remote") __slots__ = ("guild", "code", "reason", "by_remote")
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"])) def __init__(self, data: dict):
self.guild: Optional[Guild] = NodePool.get_node(
).bot.get_guild(int(data["guildId"]))
self.code: int = data["code"] self.code: int = data["code"]
self.reason: str = data["code"] self.reason: str = data["code"]
self.by_remote: bool = data["byRemote"] self.by_remote: bool = data["byRemote"]
@ -147,7 +174,9 @@ class WebSocketClosedEvent(PomiceEvent):
name = "websocket_closed" name = "websocket_closed"
def __init__(self, data: dict, _): __slots__ = ("payload",)
def __init__(self, data: dict, _: Any) -> None:
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data) self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload) # on_pomice_websocket_closed(payload)
@ -164,9 +193,10 @@ class WebSocketOpenEvent(PomiceEvent):
name = "websocket_open" name = "websocket_open"
def __init__(self, data: dict, _):
__slots__ = ("target", "ssrc") __slots__ = ("target", "ssrc")
def __init__(self, data: dict, _: Any) -> None:
self.target: str = data["target"] self.target: str = data["target"]
self.ssrc: int = data["ssrc"] self.ssrc: int = data["ssrc"]

View File

@ -1,3 +1,26 @@
__all__ = (
"PomiceException",
"NodeException",
"NodeCreationError",
"NodeConnectionFailure",
"NodeConnectionClosed",
"NodeRestException",
"NodeNotAvailable",
"NoNodesAvailable",
"TrackInvalidPosition",
"TrackLoadError",
"FilterInvalidArgument",
"FilterTagInvalid",
"FilterTagAlreadyInUse",
"InvalidSpotifyClientAuthorization",
"AppleMusicNotEnabled",
"QueueException",
"QueueFull",
"QueueEmpty",
"LavalinkVersionIncompatible",
)
class PomiceException(Exception): class PomiceException(Exception):
"""Base of all Pomice exceptions.""" """Base of all Pomice exceptions."""

View File

@ -1,6 +1,25 @@
import collections import collections
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from .exceptions import FilterInvalidArgument from .exceptions import FilterInvalidArgument
__all__ = (
"Filter",
"Equalizer",
"Timescale",
"Karaoke",
"Tremolo",
"Vibrato",
"Rotation",
"Distortion",
"ChannelMix",
"LowPass",
)
class Filter: class Filter:
""" """
@ -13,10 +32,10 @@ class Filter:
This is necessary for the removal of filters. This is necessary for the removal of filters.
""" """
def __init__(self, *, tag: str):
__slots__ = ("payload", "tag", "preload") __slots__ = ("payload", "tag", "preload")
self.payload: dict = None def __init__(self, *, tag: str):
self.payload: Optional[Dict] = None
self.tag: str = tag self.tag: str = tag
self.preload: bool = False self.preload: bool = False
@ -34,32 +53,32 @@ class Equalizer(Filter):
The format for the levels is: List[Tuple[int, float]] The format for the levels is: List[Tuple[int, float]]
""" """
def __init__(self, *, tag: str, levels: list):
super().__init__(tag=tag)
__slots__ = ( __slots__ = (
"eq", "eq",
"raw", "raw",
) )
def __init__(self, *, tag: str, levels: list):
super().__init__(tag=tag)
self.eq = self._factory(levels) self.eq = self._factory(levels)
self.raw = levels self.raw = levels
self.payload = {"equalizer": self.eq} self.payload = {"equalizer": self.eq}
def _factory(self, levels: list): def _factory(self, levels: List[Tuple[Any, Any]]) -> List[Dict]:
_dict = collections.defaultdict(int) _dict: Dict = collections.defaultdict(int)
_dict.update(levels) _dict.update(levels)
_dict = [{"band": i, "gain": _dict[i]} for i in range(15)] data = [{"band": i, "gain": _dict[i]} for i in range(15)]
return _dict return data
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>" return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
@classmethod @classmethod
def flat(cls): def flat(cls) -> "Equalizer":
"""Equalizer preset which represents a flat EQ board, """Equalizer preset which represents a flat EQ board,
with all levels set to their default values. with all levels set to their default values.
""" """
@ -84,7 +103,7 @@ class Equalizer(Filter):
return cls(tag="flat", levels=levels) return cls(tag="flat", levels=levels)
@classmethod @classmethod
def boost(cls): def boost(cls) -> "Equalizer":
"""Equalizer preset which boosts the sound of a track, """Equalizer preset which boosts the sound of a track,
making it sound fun and energetic by increasing the bass making it sound fun and energetic by increasing the bass
and the highs. and the highs.
@ -110,7 +129,7 @@ class Equalizer(Filter):
return cls(tag="boost", levels=levels) return cls(tag="boost", levels=levels)
@classmethod @classmethod
def metal(cls): def metal(cls) -> "Equalizer":
"""Equalizer preset which increases the mids of a track, """Equalizer preset which increases the mids of a track,
preferably one of the metal genre, to make it sound preferably one of the metal genre, to make it sound
more full and concert-like. more full and concert-like.
@ -137,7 +156,7 @@ class Equalizer(Filter):
return cls(tag="metal", levels=levels) return cls(tag="metal", levels=levels)
@classmethod @classmethod
def piano(cls): def piano(cls) -> "Equalizer":
"""Equalizer preset which increases the mids and highs """Equalizer preset which increases the mids and highs
of a track, preferably a piano based one, to make it of a track, preferably a piano based one, to make it
stand out. stand out.
@ -169,11 +188,11 @@ class Timescale(Filter):
a certain amount to produce said effect. a certain amount to produce said effect.
""" """
__slots__ = ("speed", "pitch", "rate")
def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: float = 1.0): def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: float = 1.0):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("speed", "pitch", "rate")
if speed < 0: if speed < 0:
raise FilterInvalidArgument("Timescale speed must be more than 0.") raise FilterInvalidArgument("Timescale speed must be more than 0.")
if pitch < 0: if pitch < 0:
@ -186,11 +205,11 @@ class Timescale(Filter):
self.rate: float = rate self.rate: float = rate
self.payload: dict = { self.payload: dict = {
"timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate} "timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate},
} }
@classmethod @classmethod
def vaporwave(cls): def vaporwave(cls) -> "Timescale":
"""Timescale preset which slows down the currently playing track, """Timescale preset which slows down the currently playing track,
giving it the effect of a half-speed record/casette playing. giving it the effect of a half-speed record/casette playing.
@ -200,7 +219,7 @@ class Timescale(Filter):
return cls(tag="vaporwave", speed=0.8, pitch=0.8) return cls(tag="vaporwave", speed=0.8, pitch=0.8)
@classmethod @classmethod
def nightcore(cls): def nightcore(cls) -> "Timescale":
"""Timescale preset which speeds up the currently playing track, """Timescale preset which speeds up the currently playing track,
which matches up to nightcore, a genre of sped-up music which matches up to nightcore, a genre of sped-up music
@ -209,7 +228,7 @@ class Timescale(Filter):
return cls(tag="nightcore", speed=1.25, pitch=1.3) return cls(tag="nightcore", speed=1.25, pitch=1.3)
def __repr__(self): def __repr__(self) -> str:
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>" return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
@ -217,6 +236,7 @@ class Karaoke(Filter):
"""Filter which filters the vocal track from any song and leaves the instrumental. """Filter which filters the vocal track from any song and leaves the instrumental.
Best for karaoke as the filter implies. Best for karaoke as the filter implies.
""" """
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
def __init__( def __init__(
self, self,
@ -229,8 +249,6 @@ class Karaoke(Filter):
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
self.level: float = level self.level: float = level
self.mono_level: float = mono_level self.mono_level: float = mono_level
self.filter_band: float = filter_band self.filter_band: float = filter_band
@ -242,10 +260,10 @@ class Karaoke(Filter):
"monoLevel": self.mono_level, "monoLevel": self.mono_level,
"filterBand": self.filter_band, "filterBand": self.filter_band,
"filterWidth": self.filter_width, "filterWidth": self.filter_width,
} },
} }
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"<Pomice.KaraokeFilter tag={self.tag} level={self.level} mono_level={self.mono_level} " f"<Pomice.KaraokeFilter tag={self.tag} level={self.level} mono_level={self.mono_level} "
f"filter_band={self.filter_band} filter_width={self.filter_width}>" f"filter_band={self.filter_band} filter_width={self.filter_width}>"
@ -256,23 +274,30 @@ class Tremolo(Filter):
"""Filter which produces a wavering tone in the music, """Filter which produces a wavering tone in the music,
causing it to sound like the music is changing in volume rapidly. causing it to sound like the music is changing in volume rapidly.
""" """
__slots__ = ("frequency", "depth")
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5): def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("frequency", "depth")
if frequency < 0: if frequency < 0:
raise FilterInvalidArgument("Tremolo frequency must be more than 0.") raise FilterInvalidArgument(
"Tremolo frequency must be more than 0.",
)
if depth < 0 or depth > 1: if depth < 0 or depth > 1:
raise FilterInvalidArgument("Tremolo depth must be between 0 and 1.") raise FilterInvalidArgument(
"Tremolo depth must be between 0 and 1.",
)
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.payload: dict = {"tremolo": {"frequency": self.frequency, "depth": self.depth}} self.payload: dict = {
"tremolo": {
"frequency": self.frequency, "depth": self.depth,
},
}
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
) )
@ -282,23 +307,30 @@ class Vibrato(Filter):
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter, """Filter which produces a wavering tone in the music, similar to the Tremolo filter,
but changes in pitch rather than volume. but changes in pitch rather than volume.
""" """
__slots__ = ("frequency", "depth")
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5): def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("frequency", "depth")
if frequency < 0 or frequency > 14: if frequency < 0 or frequency > 14:
raise FilterInvalidArgument("Vibrato frequency must be between 0 and 14.") raise FilterInvalidArgument(
"Vibrato frequency must be between 0 and 14.",
)
if depth < 0 or depth > 1: if depth < 0 or depth > 1:
raise FilterInvalidArgument("Vibrato depth must be between 0 and 1.") raise FilterInvalidArgument(
"Vibrato depth must be between 0 and 1.",
)
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.payload: dict = {"vibrato": {"frequency": self.frequency, "depth": self.depth}} self.payload: dict = {
"vibrato": {
"frequency": self.frequency, "depth": self.depth,
},
}
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>" f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
) )
@ -309,11 +341,11 @@ class Rotation(Filter):
the audio is being rotated around the listener's head the audio is being rotated around the listener's head
""" """
__slots__ = ("rotation_hertz",)
def __init__(self, *, tag: str, rotation_hertz: float = 5): def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = "rotation_hertz"
self.rotation_hertz: float = rotation_hertz self.rotation_hertz: float = rotation_hertz
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}} self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
@ -326,6 +358,13 @@ class ChannelMix(Filter):
for some cool effects when done correctly. for some cool effects when done correctly.
""" """
__slots__ = (
"left_to_left",
"right_to_right",
"left_to_right",
"right_to_left",
)
def __init__( def __init__(
self, self,
*, *,
@ -337,23 +376,21 @@ class ChannelMix(Filter):
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = ("left_to_left", "right_to_right", "left_to_right", "right_to_left")
if 0 > left_to_left > 1: if 0 > left_to_left > 1:
raise ValueError( raise ValueError(
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1." "'left_to_left' value must be more than or equal to 0 or less than or equal to 1.",
) )
if 0 > right_to_right > 1: if 0 > right_to_right > 1:
raise ValueError( raise ValueError(
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1." "'right_to_right' value must be more than or equal to 0 or less than or equal to 1.",
) )
if 0 > left_to_right > 1: if 0 > left_to_right > 1:
raise ValueError( raise ValueError(
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1." "'left_to_right' value must be more than or equal to 0 or less than or equal to 1.",
) )
if 0 > right_to_left > 1: if 0 > right_to_left > 1:
raise ValueError( raise ValueError(
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1." "'right_to_left' value must be more than or equal to 0 or less than or equal to 1.",
) )
self.left_to_left: float = left_to_left self.left_to_left: float = left_to_left
@ -367,7 +404,7 @@ class ChannelMix(Filter):
"leftToRight": self.left_to_right, "leftToRight": self.left_to_right,
"rightToLeft": self.right_to_left, "rightToLeft": self.right_to_left,
"rightToRight": self.right_to_right, "rightToRight": self.right_to_right,
} },
} }
def __repr__(self) -> str: def __repr__(self) -> str:
@ -382,6 +419,17 @@ class Distortion(Filter):
distortion is needed. distortion is needed.
""" """
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale",
"offset",
"scale",
)
def __init__( def __init__(
self, self,
*, *,
@ -397,16 +445,6 @@ class Distortion(Filter):
): ):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale" "offset",
"scale",
)
self.sin_offset: float = sin_offset self.sin_offset: float = sin_offset
self.sin_scale: float = sin_scale self.sin_scale: float = sin_scale
self.cos_offset: float = cos_offset self.cos_offset: float = cos_offset
@ -426,7 +464,7 @@ class Distortion(Filter):
"tanScale": self.tan_scale, "tanScale": self.tan_scale,
"offset": self.offset, "offset": self.offset,
"scale": self.scale, "scale": self.scale,
} },
} }
def __repr__(self) -> str: def __repr__(self) -> str:
@ -441,12 +479,11 @@ class LowPass(Filter):
"""Filter which supresses higher frequencies and allows lower frequencies to pass. """Filter which supresses higher frequencies and allows lower frequencies to pass.
You can also do this with the Equalizer filter, but this is an easier way to do it. You can also do this with the Equalizer filter, but this is an easier way to do it.
""" """
__slots__ = ("smoothing", "payload")
def __init__(self, *, tag: str, smoothing: float = 20): def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__(tag=tag) super().__init__(tag=tag)
__slots__ = "smoothing"
self.smoothing: float = smoothing self.smoothing: float = smoothing
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}} self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}

View File

@ -1,30 +1,30 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Union
from discord import Member, User
from typing import List
from typing import Optional
from typing import Union
from discord import ClientUser
from discord import Member
from discord import User
from discord.ext import commands from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType from .enums import PlaylistType
from .enums import SearchType
from .enums import TrackType
from .filters import Filter from .filters import Filter
__all__ = (
"Track",
"Playlist",
)
class Track: class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink. """The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track. You can also pass in commands.Context to get a discord.py Context object in your track.
""" """
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User]] = None,
):
__slots__ = ( __slots__ = (
"track_id", "track_id",
"info", "info",
@ -48,6 +48,19 @@ class Track:
"position", "position",
) )
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User, ClientUser]] = None,
):
self.track_id: str = track_id self.track_id: str = track_id
self.info: dict = info self.info: dict = info
self.track_type: TrackType = track_type self.track_type: TrackType = track_type
@ -60,35 +73,29 @@ class Track:
self.original = self self.original = self
self._search_type: SearchType = search_type self._search_type: SearchType = search_type
self.playlist: Playlist = None self.playlist: Optional[Playlist] = None
self.title: str = info.get("title") self.title: str = info.get("title", "Unknown Title")
self.author: str = info.get("author") self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri") self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier") self.identifier: str = info.get("identifier", "")
self.isrc: str = info.get("isrc") self.isrc: str = info.get("isrc", "")
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri: if self.uri and self.track_type is TrackType.YOUTUBE:
if info.get("thumbnail"): self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.thumbnail: str = info.get("thumbnail")
elif self.track_type == TrackType.SOUNDCLOUD:
# ok so theres no feasible way of getting a Soundcloud image URL
# so we're just gonna leave it blank for brevity
self.thumbnail = None
else:
self.thumbnail: str = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length: int = info.get("length") self.length: int = info.get("length", 0)
self.ctx: commands.Context = ctx self.is_stream: bool = info.get("isStream", False)
if requester: self.is_seekable: bool = info.get("isSeekable", False)
self.requester: Optional[Union[Member, User]] = requester self.position: int = info.get("position", 0)
else:
self.requester: Optional[Union[Member, User]] = self.ctx.author if ctx else None
self.is_stream: bool = info.get("isStream")
self.is_seekable: bool = info.get("isSeekable")
self.position: int = info.get("position")
def __eq__(self, other): self.ctx: Optional[commands.Context] = ctx
self.requester: Optional[Union[Member, User, ClientUser]] = requester
if not self.requester and self.ctx:
self.requester = self.ctx.author
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track): if not isinstance(other, Track):
return False return False
@ -97,10 +104,10 @@ class Track:
return other.track_id == self.track_id return other.track_id == self.track_id
def __str__(self): def __str__(self) -> str:
return self.title return self.title
def __repr__(self): def __repr__(self) -> str:
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>" return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
@ -110,15 +117,6 @@ class Playlist:
You can also pass in commands.Context to get a discord.py Context object in your tracks. You can also pass in commands.Context to get a discord.py Context object in your tracks.
""" """
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
__slots__ = ( __slots__ = (
"playlist_info", "playlist_info",
"tracks", "tracks",
@ -130,28 +128,37 @@ class Playlist:
"track_count", "track_count",
) )
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
self.playlist_info: dict = playlist_info self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name") self.name: str = playlist_info.get("name", "Unknown Playlist")
self.playlist_type: PlaylistType = playlist_type self.playlist_type: PlaylistType = playlist_type
self._thumbnail: str = thumbnail self._thumbnail: Optional[str] = thumbnail
self._uri: str = uri self._uri: Optional[str] = uri
for track in self.tracks: for track in self.tracks:
track.playlist = self track.playlist = self
if (index := playlist_info.get("selectedTrack")) == -1: self.selected_track: Optional[Track] = None
self.selected_track = None if (index := playlist_info.get("selectedTrack", -1)) != -1:
else: self.selected_track = self.tracks[index]
self.selected_track: Track = self.tracks[index]
self.track_count: int = len(self.tracks) self.track_count: int = len(self.tracks)
def __str__(self): def __str__(self) -> str:
return self.name return self.name
def __repr__(self): def __repr__(self) -> str:
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>" return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
@property @property

View File

@ -1,54 +1,69 @@
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from discord import Client, Guild, VoiceChannel, VoiceProtocol from discord import Client
from discord import Guild
from discord import VoiceChannel
from discord import VoiceProtocol
from discord.ext import commands from discord.ext import commands
from discord.types.voice import GuildVoiceState
from discord.types.voice import VoiceServerUpdate
from . import events from . import events
from .enums import SearchType from .enums import SearchType
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent from .events import PomiceEvent
from .exceptions import ( from .events import TrackEndEvent
FilterInvalidArgument, from .events import TrackStartEvent
FilterTagAlreadyInUse, from .exceptions import FilterInvalidArgument
FilterTagInvalid, from .exceptions import FilterTagAlreadyInUse
TrackInvalidPosition, from .exceptions import FilterTagInvalid
TrackLoadError, from .exceptions import TrackInvalidPosition
) from .exceptions import TrackLoadError
from .filters import Filter from .filters import Filter
from .objects import Playlist
from .objects import Track from .objects import Track
from .pool import Node, NodePool from .pool import Node
from .pool import NodePool
__all__ = ("Filters", "Player")
class Filters: class Filters:
"""Helper class for filters""" """Helper class for filters"""
__slots__ = "_filters" __slots__ = ("_filters",)
def __init__(self): def __init__(self) -> None:
self._filters: List[Filter] = [] self._filters: List[Filter] = []
@property @property
def has_preload(self): def has_preload(self) -> bool:
"""Property which checks if any applied filters were preloaded""" """Property which checks if any applied filters were preloaded"""
return any(f for f in self._filters if f.preload == True) return any(f for f in self._filters if f.preload == True)
@property @property
def has_global(self): def has_global(self) -> bool:
"""Property which checks if any applied filters are global""" """Property which checks if any applied filters are global"""
return any(f for f in self._filters if f.preload == False) return any(f for f in self._filters if f.preload == False)
@property @property
def empty(self): def empty(self) -> bool:
"""Property which checks if the filter list is empty""" """Property which checks if the filter list is empty"""
return len(self._filters) == 0 return len(self._filters) == 0
def add_filter(self, *, filter: Filter): def add_filter(self, *, filter: Filter) -> None:
"""Adds a filter to the list of filters applied""" """Adds a filter to the list of filters applied"""
if any(f for f in self._filters if f.tag == filter.tag): if any(f for f in self._filters if f.tag == filter.tag):
raise FilterTagAlreadyInUse("A filter with that tag is already in use.") raise FilterTagAlreadyInUse(
"A filter with that tag is already in use.",
)
self._filters.append(filter) self._filters.append(filter)
def remove_filter(self, *, filter_tag: str): def remove_filter(self, *, filter_tag: str) -> None:
"""Removes a filter from the list of filters applied using its filter tag""" """Removes a filter from the list of filters applied using its filter tag"""
if not any(f for f in self._filters if f.tag == filter_tag): if not any(f for f in self._filters if f.tag == filter_tag):
raise FilterTagInvalid("A filter with that tag was not found.") raise FilterTagInvalid("A filter with that tag was not found.")
@ -57,26 +72,27 @@ class Filters:
if filter.tag == filter_tag: if filter.tag == filter_tag:
del self._filters[index] del self._filters[index]
def has_filter(self, *, filter_tag: str): def has_filter(self, *, filter_tag: str) -> bool:
"""Checks if a filter exists in the list of filters using its filter tag""" """Checks if a filter exists in the list of filters using its filter tag"""
return any(f for f in self._filters if f.tag == filter_tag) return any(f for f in self._filters if f.tag == filter_tag)
def reset_filters(self): def reset_filters(self) -> None:
"""Removes all filters from the list""" """Removes all filters from the list"""
self._filters = [] self._filters = []
def get_preload_filters(self): def get_preload_filters(self) -> List[Filter]:
"""Get all preloaded filters""" """Get all preloaded filters"""
return [f for f in self._filters if f.preload == True] return [f for f in self._filters if f.preload == True]
def get_all_payloads(self): def get_all_payloads(self) -> Dict[str, Any]:
"""Returns a formatted dict of all the filter payloads""" """Returns a formatted dict of all the filter payloads"""
payload = {} payload: Dict[str, Any] = {}
for filter in self._filters: for _filter in self._filters:
payload.update(filter.payload) if _filter.payload:
payload.update(_filter.payload)
return payload return payload
def get_filters(self): def get_filters(self) -> List[Filter]:
"""Returns the current list of applied filters""" """Returns the current list of applied filters"""
return self._filters return self._filters
@ -89,20 +105,6 @@ class Player(VoiceProtocol):
``` ```
""" """
def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client
self.channel: VoiceChannel = channel
self._guild: Guild = channel.guild
return self
def __init__(
self,
client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None,
*,
node: Node = None,
):
__slots__ = ( __slots__ = (
"client", "client",
"channel", "channel",
@ -120,14 +122,25 @@ class Player(VoiceProtocol):
"_ending_track", "_ending_track",
"_voice_state", "_voice_state",
"_player_endpoint_uri", "_player_endpoint_uri",
"__dict__",
) )
self.client: Optional[Client] = client def __call__(self, client: Client, channel: VoiceChannel) -> "Player":
self.channel: Optional[VoiceChannel] = channel self.__init__(client, channel) # type: ignore
return self
self._bot: Union[Client, commands.Bot] = client def __init__(
self._guild: Guild = channel.guild if channel else None self,
client: Client,
channel: VoiceChannel,
*,
node: Optional[Node] = None,
) -> None:
self.client: Client = client
self.channel: VoiceChannel = channel
self._bot: Client = client
self._guild: Guild = channel.guild
self._node: Node = node if node else NodePool.get_node() self._node: Node = node if node else NodePool.get_node()
self._current: Optional[Track] = None self._current: Optional[Track] = None
self._filters: Filters = Filters() self._filters: Filters = Filters()
@ -137,7 +150,7 @@ class Player(VoiceProtocol):
self._position: int = 0 self._position: int = 0
self._last_position: int = 0 self._last_position: int = 0
self._last_update: int = 0 self._last_update: float = 0
self._ending_track: Optional[Track] = None self._ending_track: Optional[Track] = None
self._voice_state: dict = {} self._voice_state: dict = {}
@ -153,11 +166,11 @@ class Player(VoiceProtocol):
@property @property
def position(self) -> float: def position(self) -> float:
"""Property which returns the player's position in a track in milliseconds""" """Property which returns the player's position in a track in milliseconds"""
current = self._current.original
if not self.is_playing or not self._current: if not self.is_playing or not self._current:
return 0 return 0
current = getattr(self._current, "original", self._current)
if self.is_paused: if self.is_paused:
return min(self._last_position, current.length) return min(self._last_position, current.length)
@ -185,7 +198,7 @@ class Player(VoiceProtocol):
return self._is_connected and self._paused return self._is_connected and self._paused
@property @property
def current(self) -> Track: def current(self) -> Optional[Track]:
"""Property which returns the currently playing track""" """Property which returns the currently playing track"""
return self._current return self._current
@ -210,7 +223,7 @@ class Player(VoiceProtocol):
return self._filters return self._filters
@property @property
def bot(self) -> Union[Client, commands.Bot]: def bot(self) -> Client:
"""Property which returns the bot associated with this player instance""" """Property which returns the bot associated with this player instance"""
return self._bot return self._bot
@ -221,13 +234,14 @@ class Player(VoiceProtocol):
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
async def _update_state(self, data: dict): async def _update_state(self, data: dict) -> None:
state: dict = data.get("state") state: dict = data.get("state", {})
self._last_update = time.time() * 1000 self._last_update = time.time() * 1000.0
self._is_connected = state.get("connected") self._is_connected = bool(state.get("connected"))
self._last_position = state.get("position") position = state.get("position")
self._last_position = int(position) if position else 0
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None): async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys(): if {"sessionId", "event"} != self._voice_state.keys():
return return
@ -246,27 +260,32 @@ class Player(VoiceProtocol):
data={"voice": data}, data={"voice": data},
) )
async def on_voice_server_update(self, data: dict): async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data}) self._voice_state.update({"event": data})
await self._dispatch_voice_update(self._voice_state) await self._dispatch_voice_update(self._voice_state)
async def on_voice_state_update(self, data: dict): async def on_voice_state_update(self, data: GuildVoiceState) -> None:
self._voice_state.update({"sessionId": data.get("session_id")}) self._voice_state.update({"sessionId": data.get("session_id")})
if not (channel_id := data.get("channel_id")): channel_id = data.get("channel_id")
if not channel_id:
await self.disconnect() await self.disconnect()
self._voice_state.clear() self._voice_state.clear()
return return
self.channel = self.guild.get_channel(int(channel_id)) channel = self.guild.get_channel(int(channel_id))
if not channel:
await self.disconnect()
self._voice_state.clear()
return
if not data.get("token"): if not data.get("token"):
return return
await self._dispatch_voice_update({**self._voice_state, "event": data}) await self._dispatch_voice_update({**self._voice_state, "event": data})
async def _dispatch_event(self, data: dict): async def _dispatch_event(self, data: dict) -> None:
event_type = data.get("type") event_type: str = data["type"]
event: PomiceEvent = getattr(events, event_type)(data, self) event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED": if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
@ -277,11 +296,12 @@ class Player(VoiceProtocol):
if isinstance(event, TrackStartEvent): if isinstance(event, TrackStartEvent):
self._ending_track = self._current self._ending_track = self._current
async def _swap_node(self, *, new_node: Node): async def _swap_node(self, *, new_node: Node) -> None:
data: dict = { data: dict = {
"encodedTrack": self.current.track_id,
"position": self.position, "position": self.position,
} }
if self.current:
data["encodedTrack"] = self.current.track_id
del self._node._players[self._guild.id] del self._node._players[self._guild.id]
self._node = new_node self._node = new_node
@ -304,7 +324,7 @@ class Player(VoiceProtocol):
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
): ) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials when you created the node, If you passed in Spotify API credentials when you created the node,
@ -321,7 +341,7 @@ class Player(VoiceProtocol):
async def get_recommendations( async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Optional[Union[List[Track], Playlist]]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
You can pass in a discord.py Context object to get a You can pass in a discord.py Context object to get a
@ -331,14 +351,14 @@ class Player(VoiceProtocol):
async def connect( async def connect(
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
): ) -> None:
await self.guild.change_voice_state( await self.guild.change_voice_state(
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute channel=self.channel, self_deaf=self_deaf, self_mute=self_mute,
) )
self._node._players[self.guild.id] = self self._node._players[self.guild.id] = self
self._is_connected = True self._is_connected = True
async def stop(self): async def stop(self) -> None:
"""Stops the currently playing track.""" """Stops the currently playing track."""
self._current = None self._current = None
await self._node.send( await self._node.send(
@ -348,27 +368,27 @@ class Player(VoiceProtocol):
data={"encodedTrack": None}, data={"encodedTrack": None},
) )
async def disconnect(self, *, force: bool = False): async def disconnect(self, *, force: bool = False) -> None:
"""Disconnects the player from voice.""" """Disconnects the player from voice."""
try: try:
await self.guild.change_voice_state(channel=None) await self.guild.change_voice_state(channel=None)
finally: finally:
self.cleanup() self.cleanup()
self._is_connected = False self._is_connected = False
self.channel = None del self.channel
async def destroy(self): async def destroy(self) -> None:
"""Disconnects and destroys the player, and runs internal cleanup.""" """Disconnects and destroys the player, and runs internal cleanup."""
try: try:
await self.disconnect() await self.disconnect()
except AttributeError: except AttributeError:
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() -> # 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
# assume we're already disconnected and cleaned up # assume we're already disconnected and cleaned up
assert self.channel is None and not self.is_connected assert not self.is_connected and not self.channel
self._node._players.pop(self.guild.id) self._node._players.pop(self.guild.id)
await self._node.send( await self._node.send(
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id,
) )
async def play( async def play(
@ -383,20 +403,22 @@ class Player(VoiceProtocol):
if not track.isrc: if not track.isrc:
# We have to bare raise here because theres no other way to skip this block feasibly # We have to bare raise here because theres no other way to skip this block feasibly
raise raise
search: Track = ( search = (
await self._node.get_tracks(f"{track._search_type}:{track.isrc}", ctx=track.ctx) await self._node.get_tracks(f"{track._search_type}:{track.isrc}", ctx=track.ctx)
)[0] )[0] # type: ignore
except Exception: except Exception:
# First method didn't work, lets try just searching it up # First method didn't work, lets try just searching it up
try: try:
search: Track = ( search = (
await self._node.get_tracks( await self._node.get_tracks(
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx,
) )
)[0] )[0] # type: ignore
except: except:
# The song wasn't able to be found, raise error # The song wasn't able to be found, raise error
raise TrackLoadError("No equivalent track was able to be found.") raise TrackLoadError(
"No equivalent track was able to be found.",
)
data = { data = {
"encodedTrack": search.track_id, "encodedTrack": search.track_id,
"position": str(start), "position": str(start),
@ -432,7 +454,7 @@ class Player(VoiceProtocol):
if track.filters and not self.filters.has_global: if track.filters and not self.filters.has_global:
# Now apply all filters # Now apply all filters
for filter in track.filters: for filter in track.filters:
await self.add_filter(filter=filter) await self.add_filter(_filter=filter)
# Lavalink v4 changed the way the end time parameter works # Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero. # so now the end time cannot be zero.
@ -454,8 +476,13 @@ class Player(VoiceProtocol):
async def seek(self, position: float) -> float: async def seek(self, position: float) -> float:
"""Seeks to a position in the currently playing track milliseconds""" """Seeks to a position in the currently playing track milliseconds"""
if not self._current or not self._current.original:
return 0.0
if position < 0 or position > self._current.original.length: if position < 0 or position > self._current.original.length:
raise TrackInvalidPosition("Seek position must be between 0 and the track length") raise TrackInvalidPosition(
"Seek position must be between 0 and the track length",
)
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
@ -487,7 +514,7 @@ class Player(VoiceProtocol):
self._volume = volume self._volume = volume
return self._volume return self._volume
async def add_filter(self, filter: Filter, fast_apply: bool = False) -> Filter: async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
"""Adds a filter to the player. Takes a pomice.Filter object. """Adds a filter to the player. Takes a pomice.Filter object.
This will only work if you are using a version of Lavalink that supports filters. This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`. If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
@ -495,7 +522,7 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.) (You must have a song playing in order for `fast_apply` to work.)
""" """
self._filters.add_filter(filter=filter) self._filters.add_filter(filter=_filter)
payload = self._filters.get_all_payloads() payload = self._filters.get_all_payloads()
await self._node.send( await self._node.send(
method="PATCH", method="PATCH",
@ -508,7 +535,7 @@ class Player(VoiceProtocol):
return self._filters return self._filters
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter: async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filters:
"""Removes a filter from the player. Takes a filter tag. """Removes a filter from the player. Takes a filter tag.
This will only work if you are using a version of Lavalink that supports filters. This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`. If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
@ -529,7 +556,7 @@ class Player(VoiceProtocol):
return self._filters return self._filters
async def reset_filters(self, *, fast_apply: bool = False): async def reset_filters(self, *, fast_apply: bool = False) -> None:
"""Resets all currently applied filters to their default parameters. """Resets all currently applied filters to their default parameters.
You must have filters applied in order for this to work. You must have filters applied in order for this to work.
If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`. If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`.
@ -539,7 +566,7 @@ class Player(VoiceProtocol):
if not self._filters: if not self._filters:
raise FilterInvalidArgument( raise FilterInvalidArgument(
"You must have filters applied first in order to use this method." "You must have filters applied first in order to use this method.",
) )
self._filters.reset_filters() self._filters.reset_filters()
await self._node.send( await self._node.send(

View File

@ -3,35 +3,48 @@ from __future__ import annotations
import asyncio import asyncio
import random import random
import re import re
import aiohttp from typing import Any
from typing import Dict
from discord import Client from typing import List
from discord.ext import commands from typing import Optional
from typing import Dict, List, Optional, TYPE_CHECKING, Union from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from urllib.parse import quote from urllib.parse import quote
from . import __version__, spotify, applemusic import aiohttp
from discord import Client
from discord.ext import commands
from . import __version__
from . import applemusic
from . import spotify
from .enums import * from .enums import *
from .exceptions import ( from .exceptions import AppleMusicNotEnabled
AppleMusicNotEnabled, from .exceptions import InvalidSpotifyClientAuthorization
InvalidSpotifyClientAuthorization, from .exceptions import LavalinkVersionIncompatible
LavalinkVersionIncompatible, from .exceptions import NodeConnectionFailure
NodeConnectionFailure, from .exceptions import NodeCreationError
NodeCreationError, from .exceptions import NodeNotAvailable
NodeNotAvailable, from .exceptions import NodeRestException
NoNodesAvailable, from .exceptions import NoNodesAvailable
NodeRestException, from .exceptions import TrackLoadError
TrackLoadError,
)
from .filters import Filter from .filters import Filter
from .objects import Playlist, Track from .objects import Playlist
from .utils import ExponentialBackoff, NodeStats, Ping from .objects import Track
from .routeplanner import RoutePlanner from .routeplanner import RoutePlanner
from .utils import ExponentialBackoff
from .utils import NodeStats
from .utils import Ping
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
__all__ = (
"Node",
"NodePool",
)
class Node: class Node:
"""The base class for a node. """The base class for a node.
@ -40,26 +53,9 @@ class Node:
To enable Apple music, set the "apple_music" parameter to "True" To enable Apple music, set the "apple_music" parameter to "True"
""" """
def __init__(
self,
*,
pool: NodePool,
bot: Union[Client, commands.Bot],
host: str,
port: int,
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False,
fallback: bool = False,
):
__slots__ = ( __slots__ = (
"_bot", "_bot",
"_bot_user",
"_host", "_host",
"_port", "_port",
"_pool", "_pool",
@ -83,12 +79,34 @@ class Node:
"_spotify_client_secret", "_spotify_client_secret",
"_spotify_client", "_spotify_client",
"_apple_music_client", "_apple_music_client",
"_route_planner",
"_stats",
"available",
) )
self._bot: Union[Client, commands.Bot] = bot def __init__(
self,
*,
pool: Type[NodePool],
bot: commands.Bot,
host: str,
port: int,
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False,
fallback: bool = False,
):
self._bot: commands.Bot = bot
self._host: str = host self._host: str = host
self._port: int = port self._port: int = port
self._pool: NodePool = pool self._pool: Type[NodePool] = pool
self._password: str = password self._password: str = password
self._identifier: str = identifier self._identifier: str = identifier
self._heartbeat: int = heartbeat self._heartbeat: int = heartbeat
@ -98,33 +116,38 @@ class Node:
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: Optional[aiohttp.ClientSession] = session self._session: aiohttp.ClientSession = session # type: ignore
self._websocket = None
self._task: asyncio.Task = None
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: aiohttp.ClientWebSocketResponse
self._task: asyncio.Task = None # type: ignore
self._session_id: str = None self._session_id: Optional[str] = None
self._available: bool = False self._available: bool = False
self._version: str = None self._version: int = 0
self._route_planner = RoutePlanner(self) self._route_planner = RoutePlanner(self)
if not self._bot.user:
raise NodeCreationError("Bot user is not ready yet.")
self._bot_user = self._bot.user
self._headers = { self._headers = {
"Authorization": self._password, "Authorization": self._password,
"User-Id": str(self._bot.user.id), "User-Id": str(self._bot_user.id),
"Client-Name": f"Pomice/{__version__}", "Client-Name": f"Pomice/{__version__}",
} }
self._players: Dict[int, Player] = {} self._players: Dict[int, Player] = {}
self._spotify_client_id: str = spotify_client_id self._spotify_client_id: Optional[int] = spotify_client_id
self._spotify_client_secret: str = spotify_client_secret self._spotify_client_secret: Optional[str] = spotify_client_secret
self._apple_music_client: Optional[applemusic.Client] = None self._apple_music_client: Optional[applemusic.Client] = None
if self._spotify_client_id and self._spotify_client_secret: if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client: spotify.Client = spotify.Client( self._spotify_client: spotify.Client = spotify.Client(
self._spotify_client_id, self._spotify_client_secret self._spotify_client_id, self._spotify_client_secret,
) )
if apple_music: if apple_music:
@ -132,7 +155,7 @@ class Node:
self._bot.add_listener(self._update_handler, "on_socket_response") self._bot.add_listener(self._update_handler, "on_socket_response")
def __repr__(self): def __repr__(self) -> str:
return ( return (
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} " f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
f"player_count={len(self._players)}>" f"player_count={len(self._players)}>"
@ -154,7 +177,7 @@ class Node:
return self._players return self._players
@property @property
def bot(self) -> Union[Client, commands.Bot]: def bot(self) -> Client:
"""Property which returns the discord.py client linked to this node""" """Property which returns the discord.py client linked to this node"""
return self._bot return self._bot
@ -164,21 +187,21 @@ class Node:
return len(self.players) return len(self.players)
@property @property
def pool(self): def pool(self) -> Type[NodePool]:
"""Property which returns the pool this node is apart of""" """Property which returns the pool this node is apart of"""
return self._pool return self._pool
@property @property
def latency(self): def latency(self) -> float:
"""Property which returns the latency of the node""" """Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping() return Ping(self._host, port=self._port).get_ping()
@property @property
def ping(self): def ping(self) -> float:
"""Alias for `Node.latency`, returns the latency of the node""" """Alias for `Node.latency`, returns the latency of the node"""
return self.latency return self.latency
async def _update_handler(self, data: dict): async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
if not data: if not data:
@ -193,7 +216,7 @@ class Node:
return return
elif data["t"] == "VOICE_STATE_UPDATE": elif data["t"] == "VOICE_STATE_UPDATE":
if int(data["d"]["user_id"]) != self._bot.user.id: if int(data["d"]["user_id"]) != self._bot_user.id:
return return
guild_id = int(data["d"]["guild_id"]) guild_id = int(data["d"]["guild_id"])
@ -203,8 +226,11 @@ class Node:
except KeyError: except KeyError:
return return
async def _handle_node_switch(self): async def _handle_node_switch(self) -> None:
nodes = [node for node in self.pool.nodes.copy().values() if node.is_connected] nodes = [
node for node in self.pool._nodes.copy().values()
if node.is_connected
]
new_node = random.choice(nodes) new_node = random.choice(nodes)
for player in self.players.copy().values(): for player in self.players.copy().values():
@ -212,7 +238,7 @@ class Node:
await self.disconnect() await self.disconnect()
async def _listen(self): async def _listen(self) -> None:
backoff = ExponentialBackoff(base=7) backoff = ExponentialBackoff(base=7)
while True: while True:
@ -227,7 +253,7 @@ class Node:
else: else:
self._loop.create_task(self._handle_payload(msg.json())) self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict): async def _handle_payload(self, data: dict) -> None:
op = data.get("op", None) op = data.get("op", None)
if not op: if not op:
return return
@ -239,14 +265,20 @@ class Node:
if op == "ready": if op == "ready":
self._session_id = data["sessionId"] self._session_id = data["sessionId"]
if "guildId" in data: if not "guildId" in data:
if not (player := self._players.get(int(data["guildId"]))): return
player = self._players.get(int(data["guildId"]))
if not player:
return return
if op == "event": if op == "event":
await player._dispatch_event(data) await player._dispatch_event(data)
elif op == "playerUpdate": return
if op == "playerUpdate":
await player._update_state(data) await player._update_state(data)
return
async def send( async def send(
self, self,
@ -255,11 +287,13 @@ class Node:
include_version: bool = True, include_version: bool = True,
guild_id: Optional[Union[int, str]] = None, guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None, query: Optional[str] = None,
data: Optional[Union[dict, str]] = None, data: Optional[Union[Dict, str]] = None,
ignore_if_available: bool = False, ignore_if_available: bool = False,
): ) -> Any:
if not ignore_if_available and not self._available: if not ignore_if_available and not self._available:
raise NodeNotAvailable(f"The node '{self._identifier}' is unavailable.") raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable.",
)
uri: str = ( uri: str = (
f"{self._rest_uri}/" f"{self._rest_uri}/"
@ -270,12 +304,12 @@ class Node:
) )
async with self._session.request( async with self._session.request(
method=method, url=uri, headers=self._headers, json=data or {} method=method, url=uri, headers=self._headers, json=data or {},
) as resp: ) as resp:
if resp.status >= 300: if resp.status >= 300:
data: dict = await resp.json() resp_data: dict = await resp.json()
raise NodeRestException( raise NodeRestException(
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}' f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
) )
if method == "DELETE" or resp.status == 204: if method == "DELETE" or resp.status == 204:
@ -286,11 +320,11 @@ class Node:
return await resp.json() return await resp.json()
def get_player(self, guild_id: int): def get_player(self, guild_id: int) -> Optional[Player]:
"""Takes a guild ID as a parameter. Returns a pomice Player object.""" """Takes a guild ID as a parameter. Returns a pomice Player object or None."""
return self._players.get(guild_id, None) return self._players.get(guild_id, None)
async def connect(self): async def connect(self) -> "Node":
"""Initiates a connection with a Lavalink node and adds it to the node pool.""" """Initiates a connection with a Lavalink node and adds it to the node pool."""
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
@ -298,7 +332,7 @@ class Node:
self._session = aiohttp.ClientSession() self._session = aiohttp.ClientSession()
try: try:
version = await self.send( version: str = await self.send(
method="GET", method="GET",
path="version", path="version",
ignore_if_available=True, ignore_if_available=True,
@ -309,14 +343,14 @@ class Node:
self._available = False self._available = False
raise LavalinkVersionIncompatible( raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. " "The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library." "Lavalink version 3.7.0 or above is required to use this library.",
) )
if version.endswith("-SNAPSHOT"): if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4 # we're just gonna assume all snapshot versions correlate with v4
self._version = 4 self._version = 4
else: else:
self._version = version[:1] self._version = int(version[:1])
self._websocket = await self._session.ws_connect( self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket", f"{self._websocket_uri}/v{self._version}/websocket",
@ -332,18 +366,18 @@ class Node:
except (aiohttp.ClientConnectorError, ConnectionRefusedError): except (aiohttp.ClientConnectorError, ConnectionRefusedError):
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed." f"The connection to node '{self._identifier}' failed.",
) from None ) from None
except aiohttp.WSServerHandshakeError: except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid." f"The password for node '{self._identifier}' is invalid.",
) from None ) from None
except aiohttp.InvalidURL: except aiohttp.InvalidURL:
raise NodeConnectionFailure( raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid." f"The URI for node '{self._identifier}' is invalid.",
) from None ) from None
async def disconnect(self): async def disconnect(self) -> None:
"""Disconnects a connected Lavalink node and removes it from the node pool. """Disconnects a connected Lavalink node and removes it from the node pool.
This also destroys any players connected to the node. This also destroys any players connected to the node.
""" """
@ -371,7 +405,7 @@ class Node:
""" """
data: dict = await self.send( data: dict = await self.send(
method="GET", path="decodetrack", query=f"encodedTrack={identifier}" method="GET", path="decodetrack", query=f"encodedTrack={identifier}",
) )
return Track( return Track(
track_id=identifier, track_id=identifier,
@ -387,7 +421,7 @@ class Node:
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch, search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None, filters: Optional[List[Filter]] = None,
): ) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a If you passed in Spotify API credentials, you can also pass in a
@ -413,7 +447,7 @@ class Node:
if not self._apple_music_client: if not self._apple_music_client:
raise AppleMusicNotEnabled( raise AppleMusicNotEnabled(
"You must have Apple Music functionality enabled in order to play Apple Music tracks." "You must have Apple Music functionality enabled in order to play Apple Music tracks."
"Please set apple_music to True in your Node class." "Please set apple_music to True in your Node class.",
) )
apple_music_results = await self._apple_music_client.search(query=query) apple_music_results = await self._apple_music_client.search(query=query)
@ -437,7 +471,7 @@ class Node:
"thumbnail": apple_music_results.image, "thumbnail": apple_music_results.image,
"isrc": apple_music_results.isrc, "isrc": apple_music_results.isrc,
}, },
) ),
] ]
tracks = [ tracks = [
@ -464,7 +498,9 @@ class Node:
] ]
return Playlist( return Playlist(
playlist_info={"name": apple_music_results.name, "selectedTrack": 0}, playlist_info={
"name": apple_music_results.name, "selectedTrack": 0,
},
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType.APPLE_MUSIC, playlist_type=PlaylistType.APPLE_MUSIC,
thumbnail=apple_music_results.image, thumbnail=apple_music_results.image,
@ -476,7 +512,7 @@ class Node:
raise InvalidSpotifyClientAuthorization( raise InvalidSpotifyClientAuthorization(
"You did not provide proper Spotify client authorization credentials. " "You did not provide proper Spotify client authorization credentials. "
"If you would like to use the Spotify searching feature, " "If you would like to use the Spotify searching feature, "
"please obtain Spotify API credentials here: https://developer.spotify.com/" "please obtain Spotify API credentials here: https://developer.spotify.com/",
) )
spotify_results = await self._spotify_client.search(query=query) spotify_results = await self._spotify_client.search(query=query)
@ -501,7 +537,7 @@ class Node:
"thumbnail": spotify_results.image, "thumbnail": spotify_results.image,
"isrc": spotify_results.isrc, "isrc": spotify_results.isrc,
}, },
) ),
] ]
tracks = [ tracks = [
@ -528,7 +564,9 @@ class Node:
] ]
return Playlist( return Playlist(
playlist_info={"name": spotify_results.name, "selectedTrack": 0}, playlist_info={
"name": spotify_results.name, "selectedTrack": 0,
},
tracks=tracks, tracks=tracks,
playlist_type=PlaylistType.SPOTIFY, playlist_type=PlaylistType.SPOTIFY,
thumbnail=spotify_results.image, thumbnail=spotify_results.image,
@ -537,11 +575,11 @@ class Node:
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query): elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send( data: dict = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}" method="GET", path="loadtracks", query=f"identifier={quote(query)}",
) )
track: dict = data["tracks"][0] track: dict = data["tracks"][0]
info: dict = track.get("info") info: dict = track["info"]
return [ return [
Track( Track(
@ -549,15 +587,15 @@ class Node:
info={ info={
"title": discord_url.group("file"), "title": discord_url.group("file"),
"author": "Unknown", "author": "Unknown",
"length": info.get("length"), "length": info["length"],
"uri": info.get("uri"), "uri": info["uri"],
"position": info.get("position"), "position": info["position"],
"identifier": info.get("identifier"), "identifier": info["identifier"],
}, },
ctx=ctx, ctx=ctx,
track_type=TrackType.HTTP, track_type=TrackType.HTTP,
filters=filters, filters=filters,
) ),
] ]
else: else:
@ -572,18 +610,22 @@ class Node:
if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query): if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
query = match.group("video") query = match.group("video")
data: dict = await self.send( data = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}" method="GET", path="loadtracks", query=f"identifier={quote(query)}",
) )
load_type = data.get("loadType") load_type = data.get("loadType")
if not load_type: if not load_type:
raise TrackLoadError("There was an error while trying to load this track.") raise TrackLoadError(
"There was an error while trying to load this track.",
)
elif load_type == "LOAD_FAILED": elif load_type == "LOAD_FAILED":
exception = data["exception"] exception = data["exception"]
raise TrackLoadError(f"{exception['message']} [{exception['severity']}]") raise TrackLoadError(
f"{exception['message']} [{exception['severity']}]",
)
elif load_type == "NO_MATCHES": elif load_type == "NO_MATCHES":
return None return None
@ -619,9 +661,14 @@ class Node:
for track in data["tracks"] for track in data["tracks"]
] ]
else:
raise TrackLoadError(
"There was an error while trying to load this track.",
)
async def get_recommendations( async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]: ) -> Optional[Union[List[Track], Playlist]]:
""" """
Gets recommendations from either YouTube or Spotify. Gets recommendations from either YouTube or Spotify.
The track that is passed in must be either from The track that is passed in must be either from
@ -652,17 +699,17 @@ class Node:
) )
for track in results for track in results
] ]
return tracks return tracks
elif track.track_type == TrackType.YOUTUBE: elif track.track_type == TrackType.YOUTUBE:
tracks = await self.get_tracks( return await self.get_tracks(
query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}", query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}",
ctx=ctx, ctx=ctx,
) )
return tracks
else: else:
raise TrackLoadError( raise TrackLoadError(
"The specfied track must be either a YouTube or Spotify track to recieve recommendations." "The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
) )
@ -671,9 +718,10 @@ class NodePool:
This holds all the nodes that are to be used by the bot. This holds all the nodes that are to be used by the bot.
""" """
__slots__ = ()
_nodes: Dict[str, Node] = {} _nodes: Dict[str, Node] = {}
def __repr__(self): def __repr__(self) -> str:
return f"<Pomice.NodePool node_count={self.node_count}>" return f"<Pomice.NodePool node_count={self.node_count}>"
@property @property
@ -682,7 +730,7 @@ class NodePool:
return self._nodes return self._nodes
@property @property
def node_count(self): def node_count(self) -> int:
return len(self._nodes.values()) return len(self._nodes.values())
@classmethod @classmethod
@ -700,21 +748,31 @@ class NodePool:
based on how players it has. This method will return a node with based on how players it has. This method will return a node with
the least amount of players the least amount of players
""" """
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] available_nodes: List[Node] = [
node for node in cls._nodes.values() if node._available
]
if not available_nodes: if not available_nodes:
raise NoNodesAvailable("There are no nodes available.") raise NoNodesAvailable("There are no nodes available.")
if algorithm == NodeAlgorithm.by_ping: if algorithm == NodeAlgorithm.by_ping:
tested_nodes = {node: node.latency for node in available_nodes} tested_nodes = {node: node.latency for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get) return min(tested_nodes, key=tested_nodes.get) # type: ignore
elif algorithm == NodeAlgorithm.by_players: elif algorithm == NodeAlgorithm.by_players:
tested_nodes = {node: len(node.players.keys()) for node in available_nodes} tested_nodes = {
return min(tested_nodes, key=tested_nodes.get) node: len(node.players.keys())
for node in available_nodes
}
return min(tested_nodes, key=tested_nodes.get) # type: ignore
else:
raise ValueError(
"The algorithm provided is not a valid NodeAlgorithm.",
)
@classmethod @classmethod
def get_node(cls, *, identifier: str = None) -> Node: def get_node(cls, *, identifier: Optional[str] = None) -> Node:
"""Fetches a node from the node pool using it's identifier. """Fetches a node from the node pool using it's identifier.
If no identifier is provided, it will choose a node at random. If no identifier is provided, it will choose a node at random.
""" """
@ -728,21 +786,21 @@ class NodePool:
if identifier is None: if identifier is None:
return random.choice(list(available_nodes.values())) return random.choice(list(available_nodes.values()))
return available_nodes.get(identifier, None) return available_nodes[identifier]
@classmethod @classmethod
async def create_node( async def create_node(
cls, cls,
*, *,
bot: Client, bot: commands.Bot,
host: str, host: str,
port: str, port: int,
password: str, password: str,
identifier: str, identifier: str,
secure: bool = False, secure: bool = False,
heartbeat: int = 30, heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None, spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None, spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False, apple_music: bool = False,
@ -752,7 +810,9 @@ class NodePool:
For Spotify searching capabilites, pass in valid Spotify API credentials. For Spotify searching capabilites, pass in valid Spotify API credentials.
""" """
if identifier in cls._nodes.keys(): if identifier in cls._nodes.keys():
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.") raise NodeCreationError(
f"A node with identifier '{identifier}' already exists.",
)
node = Node( node = Node(
pool=cls, pool=cls,
@ -779,7 +839,9 @@ class NodePool:
async def disconnect(cls) -> None: async def disconnect(cls) -> None:
"""Disconnects all available nodes from the node pool.""" """Disconnects all available nodes from the node pool."""
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available] available_nodes: List[Node] = [
node for node in cls._nodes.values() if node._available
]
for node in available_nodes: for node in available_nodes:
await node.disconnect() await node.disconnect()

View File

@ -1,35 +1,47 @@
from __future__ import annotations from __future__ import annotations
import random import random
from copy import copy from copy import copy
from typing import ( from typing import Iterable
Iterable, from typing import Iterator
Iterator, from typing import List
List, from typing import Optional
Optional, from typing import Union
Union,
)
from .objects import Track
from .enums import LoopMode from .enums import LoopMode
from .exceptions import QueueEmpty, QueueException, QueueFull from .exceptions import QueueEmpty
from .exceptions import QueueException
from .exceptions import QueueFull
from .objects import Track
__all__ = (
"Queue",
)
class Queue(Iterable[Track]): class Queue(Iterable[Track]):
"""Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling.""" """Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling."""
__slots__ = (
"max_size",
"_queue",
"_overflow",
"_loop_mode",
"_current_item",
)
def __init__( def __init__(
self, self,
max_size: Optional[int] = None, max_size: Optional[int] = None,
*, *,
overflow: bool = True, overflow: bool = True,
): ):
__slots__ = ("max_size", "_queue", "_overflow", "_loop_mode", "_current_item")
self.max_size: Optional[int] = max_size self.max_size: Optional[int] = max_size
self._queue: List[Track] = [] # type: ignore self._current_item: Track
self._queue: List[Track] = []
self._overflow: bool = overflow self._overflow: bool = overflow
self._loop_mode: Optional[LoopMode] = None self._loop_mode: Optional[LoopMode] = None
self._current_item: Optional[Track] = None
def __str__(self) -> str: def __str__(self) -> str:
"""String showing all Track objects appearing as a list.""" """String showing all Track objects appearing as a list."""
@ -60,7 +72,7 @@ class Queue(Iterable[Track]):
return self._queue[index] return self._queue[index]
def __setitem__(self, index: int, item: Track): def __setitem__(self, index: int, item: Track) -> None:
"""Inserts an item at given position.""" """Inserts an item at given position."""
if not isinstance(index, int): if not isinstance(index, int):
raise ValueError("'int' type required.'") raise ValueError("'int' type required.'")
@ -90,7 +102,9 @@ class Queue(Iterable[Track]):
The new queue will have the same max_size as the original. The new queue will have the same max_size as the original.
""" """
if not isinstance(other, Iterable): if not isinstance(other, Iterable):
raise TypeError(f"Adding with the '{type(other)}' type is not supported.") raise TypeError(
f"Adding with the '{type(other)}' type is not supported.",
)
new_queue = self.copy() new_queue = self.copy()
new_queue.extend(other) new_queue.extend(other)
@ -106,7 +120,9 @@ class Queue(Iterable[Track]):
self.extend(other) self.extend(other)
return self return self
raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.") raise TypeError(
f"Adding '{type(other)}' type to the queue is not supported.",
)
def _get(self) -> Track: def _get(self) -> Track:
return self._queue.pop(0) return self._queue.pop(0)
@ -165,7 +181,7 @@ class Queue(Iterable[Track]):
return bool(self._loop_mode) return bool(self._loop_mode)
@property @property
def loop_mode(self) -> LoopMode: def loop_mode(self) -> Optional[LoopMode]:
"""Returns the LoopMode enum set in the queue object""" """Returns the LoopMode enum set in the queue object"""
return self._loop_mode return self._loop_mode
@ -178,7 +194,7 @@ class Queue(Iterable[Track]):
"""Returns the queue as a List""" """Returns the queue as a List"""
return self._queue return self._queue
def get(self): def get(self) -> Track:
"""Return next immediately available item in queue if any. """Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue. Raises QueueEmpty if no items in queue.
""" """
@ -239,7 +255,9 @@ class Queue(Iterable[Track]):
"""Put the given item into the back of the queue.""" """Put the given item into the back of the queue."""
if self.is_full: if self.is_full:
if not self._overflow: if not self._overflow:
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.") raise QueueFull(
f"Queue max_size of {self.max_size} has been reached.",
)
self._drop() self._drop()
@ -249,7 +267,9 @@ class Queue(Iterable[Track]):
"""Put the given item into the queue at the specified index.""" """Put the given item into the queue at the specified index."""
if self.is_full: if self.is_full:
if not self._overflow: if not self._overflow:
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.") raise QueueFull(
f"Queue max_size of {self.max_size} has been reached.",
)
self._drop() self._drop()
@ -275,7 +295,7 @@ class Queue(Iterable[Track]):
if (new_len + self.count) > self.max_size: if (new_len + self.count) > self.max_size:
raise QueueFull( raise QueueFull(
f"Queue has {self.count}/{self.max_size} items, " f"Queue has {self.count}/{self.max_size} items, "
f"cannot add {new_len} more." f"cannot add {new_len} more.",
) )
for item in iterable: for item in iterable:
@ -292,7 +312,7 @@ class Queue(Iterable[Track]):
"""Remove all items from the queue.""" """Remove all items from the queue."""
self._queue.clear() self._queue.clear()
def set_loop_mode(self, mode: LoopMode): def set_loop_mode(self, mode: LoopMode) -> None:
""" """
Sets the loop mode of the queue. Sets the loop mode of the queue.
Takes the LoopMode enum as an argument. Takes the LoopMode enum as an argument.
@ -307,7 +327,7 @@ class Queue(Iterable[Track]):
self._queue.insert(index, self._current_item) self._queue.insert(index, self._current_item)
self._current_item = self._queue[index] self._current_item = self._queue[index]
def disable_loop(self): def disable_loop(self) -> None:
""" """
Disables loop mode if set. Disables loop mode if set.
Raises QueueException if loop mode is already None. Raises QueueException if loop mode is already None.
@ -321,17 +341,17 @@ class Queue(Iterable[Track]):
self._loop_mode = None self._loop_mode = None
def shuffle(self): def shuffle(self) -> None:
"""Shuffles the queue.""" """Shuffles the queue."""
return random.shuffle(self._queue) return random.shuffle(self._queue)
def clear_track_filters(self): def clear_track_filters(self) -> None:
"""Clears all filters applied to tracks""" """Clears all filters applied to tracks"""
for track in self._queue: for track in self._queue:
track.filters = None track.filters = None
def jump(self, item: Track): def jump(self, item: Track) -> None:
"""Removes all tracks before the.""" """Removes all tracks before the."""
index = self.find_position(item) index = self.find_position(item)
new_queue = self._queue[index : self.size] new_queue = self._queue[index: self.size]
self._queue = new_queue self._queue = new_queue

View File

@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from .pool import Node from .pool import Node
from .utils import RouteStats from .utils import RouteStats
from aiohttp import ClientSession
__all__ = ("RoutePlanner",)
class RoutePlanner: class RoutePlanner:
@ -16,17 +18,16 @@ class RoutePlanner:
def __init__(self, node: Node) -> None: def __init__(self, node: Node) -> None:
self.node: Node = node self.node: Node = node
self.session: ClientSession = node._session
async def get_status(self) -> RouteStats: async def get_status(self) -> RouteStats:
"""Gets the status of the route planner API.""" """Gets the status of the route planner API."""
data: dict = await self.node.send(method="GET", path="routeplanner/status") data: dict = await self.node.send(method="GET", path="routeplanner/status")
return RouteStats(data) return RouteStats(data)
async def free_address(self, ip: str): async def free_address(self, ip: str) -> None:
"""Frees an address using the route planner API""" """Frees an address using the route planner API"""
await self.node.send(method="POST", path="routeplanner/free/address", data={"address": ip}) await self.node.send(method="POST", path="routeplanner/free/address", data={"address": ip})
async def free_all_addresses(self): async def free_all_addresses(self) -> None:
"""Frees all available addresses using the route planner api""" """Frees all available addresses using the route planner api"""
await self.node.send(method="POST", path="routeplanner/free/address/all") await self.node.send(method="POST", path="routeplanner/free/address/all")

View File

@ -1,5 +1,4 @@
"""Spotify module for Pomice, made possible by cloudwithax 2023""" """Spotify module for Pomice, made possible by cloudwithax 2023"""
from .client import Client
from .exceptions import * from .exceptions import *
from .objects import * from .objects import *
from .client import Client

View File

@ -2,19 +2,28 @@ from __future__ import annotations
import re import re
import time import time
from base64 import b64encode
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import aiohttp import aiohttp
import orjson as json import orjson as json
from base64 import b64encode from .exceptions import InvalidSpotifyURL
from typing import TYPE_CHECKING from .exceptions import SpotifyRequestException
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .objects import * from .objects import *
__all__ = (
"Client",
)
GRANT_URL = "https://accounts.spotify.com/api/token" GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
SPOTIFY_URL_REGEX = re.compile( SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)" r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
) )
@ -24,17 +33,21 @@ class Client:
for any Spotify URL you throw at it. for any Spotify URL you throw at it.
""" """
def __init__(self, client_id: str, client_secret: str) -> None: def __init__(self, client_id: int, client_secret: str) -> None:
self._client_id = client_id self._client_id: int = client_id
self._client_secret = client_secret self._client_secret: str = client_secret
self.session: aiohttp.ClientSession = None self.session: aiohttp.ClientSession = None # type: ignore
self._bearer_token: str = None self._bearer_token: Optional[str] = None
self._expiry = 0 self._expiry: float = 0.0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode()) self._auth_token = b64encode(
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"} f"{self._client_id}:{self._client_secret}".encode(),
self._bearer_headers = None )
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None
async def _fetch_bearer_token(self) -> None: async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"} _data = {"grant_type": "client_credentials"}
@ -45,32 +58,34 @@ class Client:
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp: async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}" f"Error fetching bearer token: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._bearer_token = data["access_token"] self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10) self._expiry = time.time() + (int(data["expires_in"]) - 10)
self._bearer_headers = {"Authorization": f"Bearer {self._bearer_token}"} self._bearer_headers = {
"Authorization": f"Bearer {self._bearer_token}",
}
async def search(self, *, query: str): async def search(self, *, query: str) -> Union[Track, Album, Artist, Playlist]:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token() await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query) result = SPOTIFY_URL_REGEX.match(query)
spotify_type = result.group("type")
spotify_id = result.group("id")
if not result: if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.") raise InvalidSpotifyURL("The Spotify link provided is not valid.")
spotify_type = result.group("type")
spotify_id = result.group("id")
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id) request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
async with self.session.get(request_url, headers=self._bearer_headers) as resp: async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
@ -81,11 +96,11 @@ class Client:
return Album(data) return Album(data)
elif spotify_type == "artist": elif spotify_type == "artist":
async with self.session.get( async with self.session.get(
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers f"{request_url}/top-tracks?market=US", headers=self._bearer_headers,
) as resp: ) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
track_data: dict = await resp.json(loads=json.loads) track_data: dict = await resp.json(loads=json.loads)
@ -100,7 +115,7 @@ class Client:
if not len(tracks): if not len(tracks):
raise SpotifyRequestException( raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued." "This playlist is empty and therefore cannot be queued.",
) )
next_page_url = data["tracks"]["next"] next_page_url = data["tracks"]["next"]
@ -109,7 +124,7 @@ class Client:
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp: async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
next_data: dict = await resp.json(loads=json.loads) next_data: dict = await resp.json(loads=json.loads)
@ -123,26 +138,30 @@ class Client:
return Playlist(data, tracks) return Playlist(data, tracks)
async def get_recommendations(self, *, query: str): async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry: if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token() await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query) result = SPOTIFY_URL_REGEX.match(query)
spotify_type = result.group("type")
spotify_id = result.group("id")
if not result: if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.") raise InvalidSpotifyURL("The Spotify link provided is not valid.")
if not spotify_type == "track": spotify_type = result.group("type")
raise InvalidSpotifyURL("The provided query is not a Spotify track.") spotify_id = result.group("id")
request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}") if not spotify_type == "track":
raise InvalidSpotifyURL(
"The provided query is not a Spotify track.",
)
request_url = REQUEST_URL.format(
type="recommendation", id=f"?seed_tracks={spotify_id}",
)
async with self.session.get(request_url, headers=self._bearer_headers) as resp: async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200: if resp.status != 200:
raise SpotifyRequestException( raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}" f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
@ -154,4 +173,4 @@ class Client:
async def close(self) -> None: async def close(self) -> None:
if self.session: if self.session:
await self.session.close() await self.session.close()
self.session = None self.session = None # type: ignore

View File

@ -1,3 +1,9 @@
__all__ = (
"SpotifyRequestException",
"InvalidSpotifyURL",
)
class SpotifyRequestException(Exception): class SpotifyRequestException(Exception):
"""An error occurred when making a request to the Spotify API""" """An error occurred when making a request to the Spotify API"""

View File

@ -1,29 +1,36 @@
from typing import List from typing import List
from typing import Optional
__all__ = (
"Track",
"Playlist",
"Album",
"Artist",
)
class Track: class Track:
"""The base class for a Spotify Track""" """The base class for a Spotify Track"""
def __init__(self, data: dict, image=None) -> None: def __init__(self, data: dict, image: Optional[str] = None) -> None:
self.name: str = data["name"] self.name: str = data["name"]
self.artists: str = ", ".join(artist["name"] for artist in data["artists"]) self.artists: str = ", ".join(
artist["name"] for artist in data["artists"]
)
self.length: float = data["duration_ms"] self.length: float = data["duration_ms"]
self.id: str = data["id"] self.id: str = data["id"]
self.issrc: Optional[str] = None
if data.get("external_ids"): if data.get("external_ids"):
self.isrc: str = data["external_ids"]["isrc"] self.isrc = data["external_ids"]["isrc"]
else:
self.isrc = None
self.image: Optional[str] = image
if data.get("album") and data["album"].get("images"): if data.get("album") and data["album"].get("images"):
self.image: str = data["album"]["images"][0]["url"] self.image = data["album"]["images"][0]["url"]
else:
self.image: str = image
if data["is_local"]: self.uri: Optional[str] = None
self.uri = None if not data["is_local"]:
else: self.uri = data["external_urls"]["spotify"]
self.uri: str = data["external_urls"]["spotify"]
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -42,7 +49,7 @@ class Playlist:
self.total_tracks: int = data["tracks"]["total"] self.total_tracks: int = data["tracks"]["total"]
self.id: str = data["id"] self.id: str = data["id"]
if data.get("images") and len(data["images"]): if data.get("images") and len(data["images"]):
self.image: str = data["images"][0]["url"] self.image = data["images"][0]["url"]
else: else:
self.image = self.tracks[0].image self.image = self.tracks[0].image
self.uri = data["external_urls"]["spotify"] self.uri = data["external_urls"]["spotify"]
@ -59,9 +66,14 @@ class Album:
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
self.name: str = data["name"] self.name: str = data["name"]
self.artists: str = ", ".join(artist["name"] for artist in data["artists"]) self.artists: str = ", ".join(
artist["name"] for artist in data["artists"]
)
self.image: str = data["images"][0]["url"] self.image: str = data["images"][0]["url"]
self.tracks = [Track(track, image=self.image) for track in data["tracks"]["items"]] self.tracks = [
Track(track, image=self.image)
for track in data["tracks"]["items"]
]
self.total_tracks: int = data["total_tracks"] self.total_tracks: int = data["total_tracks"]
self.id: str = data["id"] self.id: str = data["id"]
self.uri: str = data["external_urls"]["spotify"] self.uri: str = data["external_urls"]["spotify"]
@ -78,7 +90,8 @@ class Artist:
def __init__(self, data: dict, tracks: dict) -> None: def __init__(self, data: dict, tracks: dict) -> None:
self.name: str = ( self.name: str = (
f"Top tracks for {data['name']}" # Setting that because its only playing top tracks # Setting that because its only playing top tracks
f"Top tracks for {data['name']}"
) )
self.genres: str = ", ".join(genre for genre in data["genres"]) self.genres: str = ", ".join(genre for genre in data["genres"])
self.followers: int = data["followers"]["total"] self.followers: int = data["followers"]["total"]

View File

@ -1,11 +1,25 @@
import random import random
import time
import socket import socket
import time
from .enums import RouteStrategy, RouteIPType
from timeit import default_timer as timer
from itertools import zip_longest
from datetime import datetime from datetime import datetime
from itertools import zip_longest
from timeit import default_timer as timer
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
from .enums import RouteIPType
from .enums import RouteStrategy
__all__ = (
"ExponentialBackoff",
"NodeStats",
"FailingIPBlock",
"RouteStats",
"Ping",
)
class ExponentialBackoff: class ExponentialBackoff:
@ -51,7 +65,7 @@ class ExponentialBackoff:
self._exp = 0 self._exp = 0
self._exp = min(self._exp + 1, self._max) self._exp = min(self._exp + 1, self._max)
return self._randfunc(0, self._base * 2**self._exp) return self._randfunc(0, self._base * 2**self._exp) # type: ignore
class NodeStats: class NodeStats:
@ -59,7 +73,6 @@ class NodeStats:
Gives critical information on the node, which is updated every minute. Gives critical information on the node, which is updated every minute.
""" """
def __init__(self, data: dict) -> None:
__slots__ = ( __slots__ = (
"used", "used",
"free", "free",
@ -73,13 +86,15 @@ class NodeStats:
"uptime", "uptime",
) )
memory: dict = data.get("memory") def __init__(self, data: Dict[str, Any]) -> None:
memory: dict = data.get("memory", {})
self.used = memory.get("used") self.used = memory.get("used")
self.free = memory.get("free") self.free = memory.get("free")
self.reservable = memory.get("reservable") self.reservable = memory.get("reservable")
self.allocated = memory.get("allocated") self.allocated = memory.get("allocated")
cpu: dict = data.get("cpu") cpu: dict = data.get("cpu", {})
self.cpu_cores = cpu.get("cores") self.cpu_cores = cpu.get("cores")
self.cpu_system_load = cpu.get("systemLoad") self.cpu_system_load = cpu.get("systemLoad")
self.cpu_process_load = cpu.get("lavalinkLoad") self.cpu_process_load = cpu.get("lavalinkLoad")
@ -99,11 +114,14 @@ class FailingIPBlock:
and the time they failed. and the time they failed.
""" """
def __init__(self, data: dict) -> None:
__slots__ = ("address", "failing_time") __slots__ = ("address", "failing_time")
def __init__(self, data: dict) -> None:
self.address = data.get("address") self.address = data.get("address")
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp"))) self.failing_time = datetime.fromtimestamp(
float(data.get("failingTimestamp", 0)),
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>" return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>"
@ -115,17 +133,29 @@ class RouteStats:
Gives critical information about the route planner strategy on the node. Gives critical information about the route planner strategy on the node.
""" """
def __init__(self, data: dict) -> None: __slots__ = (
__slots__ = ("strategy", "ip_block_type", "ip_block_size", "failing_addresses") "strategy",
"ip_block_type",
"ip_block_size",
"failing_addresses",
"block_index",
"address_index",
)
def __init__(self, data: Dict[str, Any]) -> None:
self.strategy = RouteStrategy(data.get("class")) self.strategy = RouteStrategy(data.get("class"))
details: dict = data.get("details") details: dict = data.get("details", {})
ip_block: dict = details.get("ipBlock") ip_block: dict = details.get("ipBlock", {})
self.ip_block_type = RouteIPType(ip_block.get("type")) self.ip_block_type = RouteIPType(ip_block.get("type"))
self.ip_block_size = ip_block.get("size") self.ip_block_size = ip_block.get("size")
self.failing_addresses = [FailingIPBlock(data) for data in details.get("failingAddresses")] self.failing_addresses = [
FailingIPBlock(
data,
) for data in details.get("failingAddresses", [])
]
self.block_index = details.get("blockIndex") self.block_index = details.get("blockIndex")
self.address_index = details.get("currentAddressIndex") self.address_index = details.get("currentAddressIndex")
@ -136,7 +166,7 @@ class RouteStats:
class Ping: class Ping:
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl # Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
def __init__(self, host, port, timeout=5): def __init__(self, host: str, port: int, timeout: int = 5) -> None:
self.timer = self.Timer() self.timer = self.Timer()
self._successed = 0 self._successed = 0
@ -146,33 +176,33 @@ class Ping:
self._port = port self._port = port
self._timeout = timeout self._timeout = timeout
class Socket(object): class Socket:
def __init__(self, family, type_, timeout): def __init__(self, family: int, type_: int, timeout: Optional[float]) -> None:
s = socket.socket(family, type_) s = socket.socket(family, type_)
s.settimeout(timeout) s.settimeout(timeout)
self._s = s self._s = s
def connect(self, host, port): def connect(self, host: str, port: int) -> None:
self._s.connect((host, int(port))) self._s.connect((host, port))
def shutdown(self): def shutdown(self) -> None:
self._s.shutdown(socket.SHUT_RD) self._s.shutdown(socket.SHUT_RD)
def close(self): def close(self) -> None:
self._s.close() self._s.close()
class Timer(object): class Timer:
def __init__(self): def __init__(self) -> None:
self._start = 0 self._start: float = 0.0
self._stop = 0 self._stop: float = 0.0
def start(self): def start(self) -> None:
self._start = timer() self._start = timer()
def stop(self): def stop(self) -> None:
self._stop = timer() self._stop = timer()
def cost(self, funcs, args): def cost(self, funcs: Iterable[Callable], args: Any) -> float:
self.start() self.start()
for func, arg in zip_longest(funcs, args): for func, arg in zip_longest(funcs, args):
if arg: if arg:
@ -183,13 +213,15 @@ class Ping:
self.stop() self.stop()
return self._stop - self._start return self._stop - self._start
def _create_socket(self, family, type_): def _create_socket(self, family: int, type_: int) -> Socket:
return self.Socket(family, type_, self._timeout) return self.Socket(family, type_, self._timeout)
def get_ping(self): def get_ping(self) -> float:
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM) s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost((s.connect, s.shutdown), ((self._host, self._port), None)) cost_time = self.timer.cost(
(s.connect, s.shutdown), ((self._host, self._port), None),
)
s_runtime = 1000 * (cost_time) s_runtime = 1000 * (cost_time)
return s_runtime return s_runtime

View File

@ -7,3 +7,13 @@ build-backend = "setuptools.build_meta"
[tool.black] [tool.black]
line-length = 100 line-length = 100
[tool.mypy]
mypy_path = "./"
files = ["pomice"]
disallow_untyped_defs = true
disallow_any_unimported = true
no_implicit_optional = true
check_untyped_defs = true
warn_unused_ignores = true
show_error_codes = true

View File

@ -1,10 +1,13 @@
import setuptools
import re import re
import setuptools
version = "" version = ""
requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"] requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
with open("pomice/__init__.py") as f: with open("pomice/__init__.py") as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1) version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE,
).group(1)
if not version: if not version:
raise RuntimeError("version is not set") raise RuntimeError("version is not set")
@ -15,13 +18,13 @@ if version.endswith(("a", "b", "rc")):
import subprocess import subprocess
p = subprocess.Popen( p = subprocess.Popen(
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE ["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
) )
out, err = p.communicate() out, err = p.communicate()
if out: if out:
version += out.decode("utf-8").strip() version += out.decode("utf-8").strip()
p = subprocess.Popen( p = subprocess.Popen(
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE ["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
) )
out, err = p.communicate() out, err = p.communicate()
if out: if out: