feat: add typing; add makefile; add pipfile

This commit is contained in:
NiceAesth 2023-03-11 15:44:46 +02:00
parent 481e616414
commit 987de07fc5
52 changed files with 1103 additions and 729 deletions

4
.gitignore vendored
View File

@ -7,3 +7,7 @@ docs/_build/
build/
.gitpod.yml
.python-verson
Pipfile.lock
.mypy_cache/
.vscode/
.venv/

36
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,36 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-ast
- id: check-builtin-literals
- id: debug-statements
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v2.0.2
hooks:
- id: autopep8
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py37-plus, --keep-runtime-typing]
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.9.0
hooks:
- id: reorder-python-imports
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.4.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
hooks:
- id: pycln
default_language_version:
python: python3.8

18
Makefile Normal file
View File

@ -0,0 +1,18 @@
prepare:
pipenv install --dev
pipenv run pre-commit install
shell:
pipenv shell
lint:
pipenv run pre-commit run --all-files
test:
pipenv run mypy
serve-docs:
@cd docs;\
make html;\
cd build/html;\
python -m http.server;\

18
Pipfile Normal file
View File

@ -0,0 +1,18 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
orjson = "*"
"discord.py" = {extras = ["voice"], version = "*"}
[dev-packages]
mypy = "*"
pre-commit = "*"
furo = "*"
sphinx = "*"
myst-parser = "*"
[requires]
python_version = "3.8"

View File

@ -2,12 +2,12 @@ import importlib
import inspect
import os
import sys
from typing import Any, Dict
from typing import Any
from typing import Dict
sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('..'))
project = 'Pomice'
copyright = '2023, cloudwithax'
author = 'cloudwithax'
@ -19,7 +19,7 @@ extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.linkcode',
'myst_parser'
'myst_parser',
]
myst_enable_extensions = [
@ -84,6 +84,7 @@ html_theme_options: Dict[str, Any] = {
# Grab lines from source files and embed into the docs
# so theres a point of reference
def linkcode_resolve(domain, info):
# i absolutely MUST add this here or else
# the docs will not build. fuck sphinx
@ -93,7 +94,6 @@ def linkcode_resolve(domain, info):
if not info['module']:
return None
mod = importlib.import_module(info["module"])
if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".")
@ -117,4 +117,3 @@ def linkcode_resolve(domain, info):
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"
except:
pass

View File

@ -47,6 +47,3 @@ remote, or the node.
`Event.WebsocketOpenEvent()` carries a target, which is usually the node IP, and the SSRC, a 32-bit integer uniquely identifying the source of the RTP packets sent from
Lavalink.

View File

@ -184,4 +184,3 @@ After you have initialized your function, you can optionally include the `fast_a
await Player.reset_filters(fast_apply=<True/False>)
```

View File

@ -14,4 +14,3 @@ filters.md
queue.md
events.md
``

View File

@ -19,7 +19,7 @@ There are also properties the `Node` class has to access certain values:
- Description
* - `Node.bot`
- `Union[Client, Bot]`
- `Client`
- Returns the discord.py client linked to this node.
* - `Node.is_connected`

View File

@ -28,7 +28,7 @@ There are also properties the `Player` class has to access certain values:
- Description
* - `Player.bot`
- `Union[Client, commands.Bot]`
- `Client`
- Returns the bot associated with this player instance.
* - `Player.current`
@ -466,15 +466,3 @@ After you have initialized your function, you can optionally include the `fast_a
await Player.reset_filters(fast_apply=<True/False>)
```

View File

@ -148,15 +148,3 @@ await NodePool.disconnect()
```
After running this function, all nodes in the pool should disconnect and no longer be available to use.

View File

@ -222,13 +222,3 @@ Your `Track` object must be in the queue if you want to jump to it. Make sure yo
:::
After running this function, any items before the specified item will be removed, effectively "jumping" to the specified item in the queue. The next item obtained using `Queue.get()` will be your specified track.

View File

@ -38,7 +38,3 @@ hdi/index.md
:hidden:
api/index.md
```

View File

@ -28,5 +28,3 @@ You are free to use this as a base to add on to for any music features you want
If you want to jump into the library and learn how to do everything you need, refer to the [How Do I?](hdi/index.md) section.
If you want a deeper look into how the library works beyond the [How Do I?](hdi/index.md) guide, refer to the [API Reference](api/index.md) section.

View File

@ -1,5 +1,5 @@
discord.py[voice]
aiohttp
orjson
myst_parser
discord.py[voice]
furo
myst_parser
orjson

View File

@ -4,13 +4,13 @@ This is in the form of a drop-in cog you can use and modify to your liking.
This example aims to include everything you would need to make a fully functioning music bot,
from a queue system, advanced queue control and more.
"""
import math
from contextlib import suppress
import discord
import pomice
import math
from discord.ext import commands
from contextlib import suppress
import pomice
class Player(pomice.Player):
@ -44,7 +44,6 @@ class Player(pomice.Player):
with suppress(discord.HTTPException):
await self.controller.delete()
# Queue up the next track, else teardown the player
try:
track: pomice.Track = self.queue.get()
@ -56,13 +55,16 @@ class Player(pomice.Player):
# Call the controller (a.k.a: The "Now Playing" embed) and check if one exists
if track.is_stream:
embed = discord.Embed(title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]")
embed = discord.Embed(
title="Now playing", description=f":red_circle: **LIVE** [{track.title}]({track.uri}) [{track.requester.mention}]",
)
self.controller = await self.context.send(embed=embed)
else:
embed = discord.Embed(title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]")
embed = discord.Embed(
title=f"Now playing", description=f"[{track.title}]({track.uri}) [{track.requester.mention}]",
)
self.controller = await self.context.send(embed=embed)
async def teardown(self):
"""Clear internal states, remove player controller and disconnect."""
with suppress((discord.HTTPException), (KeyError)):
@ -76,8 +78,6 @@ class Player(pomice.Player):
self.dj = ctx.author
class Music(commands.Cog):
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
@ -100,7 +100,7 @@ class Music(commands.Cog):
host="127.0.0.1",
port="3030",
password="youshallnotpass",
identifier="MAIN"
identifier="MAIN",
)
print(f"Node is ready!")
@ -122,7 +122,6 @@ class Music(commands.Cog):
return player.dj == ctx.author or ctx.author.guild_permissions.kick_members
# The following are events from pomice.events
# We are using these so that if the track either stops or errors,
# we can just skip to the next track
@ -195,8 +194,6 @@ class Music(commands.Cog):
if not player.is_playing:
await player.do_next()
@commands.command(aliases=['pau', 'pa'])
async def pause(self, ctx: commands.Context):
"""Pause the currently playing song."""
@ -345,6 +342,6 @@ class Music(commands.Cog):
await player.set_volume(vol)
await ctx.send(f'Set the volume to **{vol}**%', delete_after=7)
async def setup(bot: commands.Bot):
await bot.add_cog(Music(bot))

View File

@ -1,13 +1,16 @@
import discord
import pomice
from discord.ext import commands
import pomice
class MyBot(commands.Bot):
def __init__(self) -> None:
super().__init__(
command_prefix="!",
activity=discord.Activity(type=discord.ActivityType.listening, name="to music!")
activity=discord.Activity(
type=discord.ActivityType.listening, name="to music!",
),
)
self.add_cog(Music(self))
@ -33,7 +36,7 @@ class Music(commands.Cog):
host="127.0.0.1",
port="3030",
password="youshallnotpass",
identifier="MAIN"
identifier="MAIN",
)
print(f"Node is ready!")
@ -44,7 +47,7 @@ class Music(commands.Cog):
if not channel:
raise commands.CheckFailure(
"You must be in a voice channel to use this command "
"without specifying the channel argument."
"without specifying the channel argument.",
)
# With the release of discord.py 1.7, you can now add a compatible
@ -78,7 +81,9 @@ class Music(commands.Cog):
results = await player.get_tracks(search)
if not results:
raise commands.CommandError("No results were found for that search term.")
raise commands.CommandError(
"No results were found for that search term.",
)
if isinstance(results, pomice.Playlist):
await player.play(track=results.tracks[0])

View File

@ -17,7 +17,7 @@ if not discord.version_info.major >= 2:
raise DiscordPyOutdated(
"You must have discord.py (v2.0 or greater) to use this library. "
"Uninstall your current version and install discord.py 2.0 "
"using 'pip install discord.py'"
"using 'pip install discord.py'",
)
__version__ = "2.2a"

View File

@ -1,5 +1,4 @@
"""Apple Music module for Pomice, made possible by cloudwithax 2023"""
from .client import Client
from .exceptions import *
from .objects import *
from .client import Client

View File

@ -1,19 +1,27 @@
from __future__ import annotations
import base64
import re
from datetime import datetime
from typing import Dict
from typing import List
from typing import Union
import aiohttp
import orjson as json
import base64
from datetime import datetime
from .objects import *
from .exceptions import *
from .objects import *
__all__ = (
"Client",
)
AM_URL_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
)
AM_REQ_URL = "https://api.music.apple.com/v1/catalog/{country}/{type}s/{id}"
AM_BASE_URL = "https://api.music.apple.com"
@ -26,37 +34,49 @@ class Client:
"""
def __init__(self) -> None:
self.token: str = None
self.expiry: datetime = None
self.session: aiohttp.ClientSession = None
self.headers = None
self.expiry: datetime = datetime(1970, 1, 1)
self.token: str = ""
self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore
async def request_token(self):
async def request_token(self) -> None:
if not self.session:
self.session = aiohttp.ClientSession()
async with self.session.get("https://music.apple.com/assets/index.919fe17f.js") as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
text = await resp.text()
result = re.search('"(eyJ.+?)"', text).group(1)
match = re.search('"(eyJ.+?)"', text)
if not match:
raise AppleMusicRequestException(
"Could not find token in response.",
)
result = match.group(1)
self.token = result
self.headers = {
"Authorization": f"Bearer {result}",
"Origin": "https://apple.com",
}
token_split = self.token.split(".")[1]
token_json = base64.b64decode(token_split + "=" * (-len(token_split) % 4)).decode()
token_json = base64.b64decode(
token_split + "=" * (-len(token_split) % 4),
).decode()
token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"])
async def search(self, query: str):
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry:
await self.request_token()
result = AM_URL_REGEX.match(query)
if not result:
raise InvalidAppleMusicURL(
"The Apple Music link provided is not valid.",
)
country = result.group("country")
type = result.group("type")
@ -75,7 +95,7 @@ class Client:
async with self.session.get(request_url, headers=self.headers) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
@ -84,53 +104,57 @@ class Client:
if type == "song":
return Song(data)
elif type == "album":
if type == "album":
return Album(data)
elif type == "artist":
if type == "artist":
async with self.session.get(
f"{request_url}/view/top-songs", headers=self.headers
f"{request_url}/view/top-songs", headers=self.headers,
) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
top_tracks: dict = await resp.json(loads=json.loads)
tracks: dict = top_tracks["data"]
artist_tracks: dict = top_tracks["data"]
return Artist(data, tracks=tracks)
return Artist(data, tracks=artist_tracks)
else:
track_data: dict = data["relationships"]["tracks"]
album_tracks: List[Song] = [
Song(track)
for track in track_data["data"]
]
tracks = [Song(track) for track in track_data.get("data")]
if not len(tracks):
if not len(album_tracks):
raise AppleMusicRequestException(
"This playlist is empty and therefore cannot be queued."
"This playlist is empty and therefore cannot be queued.",
)
if track_data.get("next"):
next_page_url = AM_BASE_URL + track_data.get("next")
_next = track_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
while next_page_url is not None:
async with self.session.get(next_page_url, headers=self.headers) as resp:
if resp.status != 200:
raise AppleMusicRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
next_data: dict = await resp.json(loads=json.loads)
tracks += [Song(track) for track in next_data["data"]]
if next_data.get("next"):
next_page_url = AM_BASE_URL + next_data.get("next")
album_tracks.extend(Song(track) for track in next_data["data"])
_next = next_data.get("next")
if _next:
next_page_url = AM_BASE_URL + _next
else:
next_page_url = None
return Playlist(data, tracks)
return Playlist(data, album_tracks)
async def close(self):
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None
self.session = None # type: ignore

View File

@ -1,3 +1,9 @@
__all__ = (
"AppleMusicRequestException",
"InvalidAppleMusicURL",
)
class AppleMusicRequestException(Exception):
"""An error occurred when making a request to the Apple Music API"""

View File

@ -1,7 +1,13 @@
"""Module for managing Apple Music objects"""
from typing import List
__all__ = (
"Song",
"Playlist",
"Album",
"Artist",
)
class Song:
"""The base class for an Apple Music song"""
@ -55,7 +61,9 @@ class Album:
self.id: str = data["id"]
self.artists: str = data["attributes"]["artistName"]
self.total_tracks: int = data["attributes"]["trackCount"]
self.tracks: List[Song] = [Song(track) for track in data["relationships"]["tracks"]["data"]]
self.tracks: List[Song] = [
Song(track) for track in data["relationships"]["tracks"]["data"]
]
self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}",
f'{data["attributes"]["artwork"]["width"]}x{data["attributes"]["artwork"]["height"]}',
@ -75,7 +83,9 @@ class Artist:
self.name: str = f'Top tracks for {data["attributes"]["name"]}'
self.url: str = data["attributes"]["url"]
self.id: str = data["id"]
self.genres: str = ", ".join(genre for genre in data["attributes"]["genreNames"])
self.genres: str = ", ".join(
genre for genre in data["attributes"]["genreNames"]
)
self.tracks: List[Song] = [Song(track) for track in tracks]
self.image: str = data["attributes"]["artwork"]["url"].replace(
"{w}x{h}",

View File

@ -1,7 +1,17 @@
import re
from enum import Enum
__all__ = (
"SearchType",
"TrackType",
"PlaylistType",
"NodeAlgorithm",
"LoopMode",
"RouteStrategy",
"RouteIPType",
"URLRegex",
)
class SearchType(Enum):
"""
@ -185,43 +195,51 @@ class URLRegex:
"""
SPOTIFY_URL = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)"
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
)
DISCORD_MP3_URL = re.compile(
r"https?://cdn.discordapp.com/attachments/(?P<channel_id>[0-9]+)/"
r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+"
r"(?P<message_id>[0-9]+)/(?P<file>[a-zA-Z0-9_.]+)+",
)
YOUTUBE_URL = re.compile(
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))"
r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$"
r"(\/(?:[\w\-]+\?v=|embed\/|v\/)?)([\w\-]+)(\S+)?$",
)
YOUTUBE_PLAYLIST_URL = re.compile(
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*"
r"^((?:https?:)?\/\/)?((?:www|m)\.)?((?:youtube\.com|youtu.be))/playlist\?list=.*",
)
YOUTUBE_VID_IN_PLAYLIST = re.compile(r"(?P<video>^.*?v.*?)(?P<list>&list.*)")
YOUTUBE_VID_IN_PLAYLIST = re.compile(
r"(?P<video>^.*?v.*?)(?P<list>&list.*)",
)
YOUTUBE_TIMESTAMP = re.compile(r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*")
YOUTUBE_TIMESTAMP = re.compile(
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
)
AM_URL = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)"
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
)
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)"
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
)
SOUNDCLOUD_URL = re.compile(r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*")
SOUNDCLOUD_URL = re.compile(
r"((?:https?:)?\/\/)?((?:www|m)\.)?soundcloud.com\/.*/.*",
)
SOUNDCLOUD_PLAYLIST_URL = re.compile(r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*")
SOUNDCLOUD_PLAYLIST_URL = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/.*/sets/.*",
)
SOUNDCLOUD_TRACK_IN_SET_URL = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)"
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com/[a-zA-Z0-9-._]+/[a-zA-Z0-9-._]+(\?in)",
)
LAVALINK_SEARCH = re.compile(r"(?P<type>ytm?|sc)search:")

View File

@ -1,18 +1,32 @@
from __future__ import annotations
from discord import Client, Guild
from typing import TYPE_CHECKING, Any, Optional, Tuple
from typing import Union
from abc import ABC
from discord import Client
from discord import Guild
from discord.ext import commands
from .pool import NodePool
from .objects import Track
from typing import TYPE_CHECKING, Union
from .pool import NodePool
if TYPE_CHECKING:
from .player import Player
__all__ = (
"PomiceEvent",
"TrackStartEvent",
"TrackEndEvent",
"TrackStuckEvent",
"TrackExceptionEvent",
"WebSocketClosedPayload",
"WebSocketClosedEvent",
"WebSocketOpenEvent",
)
class PomiceEvent:
class PomiceEvent(ABC):
"""The base class for all events dispatched by a node.
Every event must be formatted within your bot's code as a listener.
i.e: If you want to listen for when a track starts, the event would be:
@ -22,10 +36,12 @@ class PomiceEvent:
```
"""
name = "event"
handler_args = ()
__slots__ = ("name", "handler_args")
def dispatch(self, bot: Union[Client, commands.Bot]):
name = "event"
handler_args: Tuple
def dispatch(self, bot: Client) -> None:
bot.dispatch(f"pomice_{self.name}", *self.handler_args)
@ -36,10 +52,15 @@ class TrackStartEvent(PomiceEvent):
name = "track_start"
__slots__ = (
"player",
"track",
)
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track")
self.player: Player = player
assert self.player._current is not None
self.track: Track = self.player._current
# on_pomice_track_start(player, track)
@ -56,10 +77,12 @@ class TrackEndEvent(PomiceEvent):
name = "track_end"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "reason")
def __init__(self, data: dict, player: Player):
self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
self.reason: str = data["reason"]
@ -81,10 +104,12 @@ class TrackStuckEvent(PomiceEvent):
name = "track_stuck"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "threshold")
def __init__(self, data: dict, player: Player):
self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
@ -105,17 +130,16 @@ class TrackExceptionEvent(PomiceEvent):
name = "track_exception"
def __init__(self, data: dict, player: Player):
__slots__ = ("player", "track", "exception")
def __init__(self, data: dict, player: Player):
self.player: Player = player
assert self.player._ending_track is not None
self.track: Track = self.player._ending_track
if data.get("error"):
# User is running Lavalink <= 3.3
self.exception: str = data["error"]
else:
# User is running Lavalink >=3.4
self.exception: str = data["exception"]
# Error is for Lavalink <= 3.3
self.exception: str = data.get(
"error", "") or data.get("exception", "")
# on_pomice_track_exception(player, track, error)
self.handler_args = self.player, self.track, self.exception
@ -125,10 +149,12 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload:
def __init__(self, data: dict):
__slots__ = ("guild", "code", "reason", "by_remote")
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
def __init__(self, data: dict):
self.guild: Optional[Guild] = NodePool.get_node(
).bot.get_guild(int(data["guildId"]))
self.code: int = data["code"]
self.reason: str = data["code"]
self.by_remote: bool = data["byRemote"]
@ -147,7 +173,9 @@ class WebSocketClosedEvent(PomiceEvent):
name = "websocket_closed"
def __init__(self, data: dict, _):
__slots__ = ("payload",)
def __init__(self, data: dict, _: Any) -> None:
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload)
@ -164,9 +192,10 @@ class WebSocketOpenEvent(PomiceEvent):
name = "websocket_open"
def __init__(self, data: dict, _):
__slots__ = ("target", "ssrc")
def __init__(self, data: dict, _: Any) -> None:
self.target: str = data["target"]
self.ssrc: int = data["ssrc"]

View File

@ -1,3 +1,26 @@
__all__ = (
"PomiceException",
"NodeException",
"NodeCreationError",
"NodeConnectionFailure",
"NodeConnectionClosed",
"NodeRestException",
"NodeNotAvailable",
"NoNodesAvailable",
"TrackInvalidPosition",
"TrackLoadError",
"FilterInvalidArgument",
"FilterTagInvalid",
"FilterTagAlreadyInUse",
"InvalidSpotifyClientAuthorization",
"AppleMusicNotEnabled",
"QueueException",
"QueueFull",
"QueueEmpty",
"LavalinkVersionIncompatible",
)
class PomiceException(Exception):
"""Base of all Pomice exceptions."""

View File

@ -1,6 +1,23 @@
from typing import Any, Dict, Tuple
import collections
from typing import List
from typing import Optional
from .exceptions import FilterInvalidArgument
__all__ = (
"Filter",
"Equalizer",
"Timescale",
"Karaoke",
"Tremolo",
"Vibrato",
"Rotation",
"Distortion",
"ChannelMix",
"LowPass",
)
class Filter:
"""
@ -13,10 +30,10 @@ class Filter:
This is necessary for the removal of filters.
"""
def __init__(self, *, tag: str):
__slots__ = ("payload", "tag", "preload")
self.payload: dict = None
def __init__(self, *, tag: str):
self.payload: Optional[Dict] = None
self.tag: str = tag
self.preload: bool = False
@ -34,32 +51,32 @@ class Equalizer(Filter):
The format for the levels is: List[Tuple[int, float]]
"""
def __init__(self, *, tag: str, levels: list):
super().__init__(tag=tag)
__slots__ = (
"eq",
"raw",
)
def __init__(self, *, tag: str, levels: list):
super().__init__(tag=tag)
self.eq = self._factory(levels)
self.raw = levels
self.payload = {"equalizer": self.eq}
def _factory(self, levels: list):
_dict = collections.defaultdict(int)
def _factory(self, levels: List[Tuple[Any, Any]]) -> List[Dict]:
_dict: Dict = collections.defaultdict(int)
_dict.update(levels)
_dict = [{"band": i, "gain": _dict[i]} for i in range(15)]
data = [{"band": i, "gain": _dict[i]} for i in range(15)]
return _dict
return data
def __repr__(self) -> str:
return f"<Pomice.EqualizerFilter tag={self.tag} eq={self.eq} raw={self.raw}>"
@classmethod
def flat(cls):
def flat(cls) -> "Equalizer":
"""Equalizer preset which represents a flat EQ board,
with all levels set to their default values.
"""
@ -84,7 +101,7 @@ class Equalizer(Filter):
return cls(tag="flat", levels=levels)
@classmethod
def boost(cls):
def boost(cls) -> "Equalizer":
"""Equalizer preset which boosts the sound of a track,
making it sound fun and energetic by increasing the bass
and the highs.
@ -110,7 +127,7 @@ class Equalizer(Filter):
return cls(tag="boost", levels=levels)
@classmethod
def metal(cls):
def metal(cls) -> "Equalizer":
"""Equalizer preset which increases the mids of a track,
preferably one of the metal genre, to make it sound
more full and concert-like.
@ -137,7 +154,7 @@ class Equalizer(Filter):
return cls(tag="metal", levels=levels)
@classmethod
def piano(cls):
def piano(cls) -> "Equalizer":
"""Equalizer preset which increases the mids and highs
of a track, preferably a piano based one, to make it
stand out.
@ -169,11 +186,11 @@ class Timescale(Filter):
a certain amount to produce said effect.
"""
__slots__ = ("speed", "pitch", "rate")
def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: float = 1.0):
super().__init__(tag=tag)
__slots__ = ("speed", "pitch", "rate")
if speed < 0:
raise FilterInvalidArgument("Timescale speed must be more than 0.")
if pitch < 0:
@ -186,11 +203,11 @@ class Timescale(Filter):
self.rate: float = rate
self.payload: dict = {
"timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate}
"timescale": {"speed": self.speed, "pitch": self.pitch, "rate": self.rate},
}
@classmethod
def vaporwave(cls):
def vaporwave(cls) -> "Timescale":
"""Timescale preset which slows down the currently playing track,
giving it the effect of a half-speed record/casette playing.
@ -200,7 +217,7 @@ class Timescale(Filter):
return cls(tag="vaporwave", speed=0.8, pitch=0.8)
@classmethod
def nightcore(cls):
def nightcore(cls) -> "Timescale":
"""Timescale preset which speeds up the currently playing track,
which matches up to nightcore, a genre of sped-up music
@ -209,7 +226,7 @@ class Timescale(Filter):
return cls(tag="nightcore", speed=1.25, pitch=1.3)
def __repr__(self):
def __repr__(self) -> str:
return f"<Pomice.TimescaleFilter tag={self.tag} speed={self.speed} pitch={self.pitch} rate={self.rate}>"
@ -217,6 +234,7 @@ class Karaoke(Filter):
"""Filter which filters the vocal track from any song and leaves the instrumental.
Best for karaoke as the filter implies.
"""
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
def __init__(
self,
@ -229,8 +247,6 @@ class Karaoke(Filter):
):
super().__init__(tag=tag)
__slots__ = ("level", "mono_level", "filter_band", "filter_width")
self.level: float = level
self.mono_level: float = mono_level
self.filter_band: float = filter_band
@ -242,10 +258,10 @@ class Karaoke(Filter):
"monoLevel": self.mono_level,
"filterBand": self.filter_band,
"filterWidth": self.filter_width,
}
},
}
def __repr__(self):
def __repr__(self) -> str:
return (
f"<Pomice.KaraokeFilter tag={self.tag} level={self.level} mono_level={self.mono_level} "
f"filter_band={self.filter_band} filter_width={self.filter_width}>"
@ -256,23 +272,30 @@ class Tremolo(Filter):
"""Filter which produces a wavering tone in the music,
causing it to sound like the music is changing in volume rapidly.
"""
__slots__ = ("frequency", "depth")
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
super().__init__(tag=tag)
__slots__ = ("frequency", "depth")
if frequency < 0:
raise FilterInvalidArgument("Tremolo frequency must be more than 0.")
raise FilterInvalidArgument(
"Tremolo frequency must be more than 0.",
)
if depth < 0 or depth > 1:
raise FilterInvalidArgument("Tremolo depth must be between 0 and 1.")
raise FilterInvalidArgument(
"Tremolo depth must be between 0 and 1.",
)
self.frequency: float = frequency
self.depth: float = depth
self.payload: dict = {"tremolo": {"frequency": self.frequency, "depth": self.depth}}
self.payload: dict = {
"tremolo": {
"frequency": self.frequency, "depth": self.depth,
},
}
def __repr__(self):
def __repr__(self) -> str:
return (
f"<Pomice.TremoloFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
@ -282,23 +305,30 @@ class Vibrato(Filter):
"""Filter which produces a wavering tone in the music, similar to the Tremolo filter,
but changes in pitch rather than volume.
"""
__slots__ = ("frequency", "depth")
def __init__(self, *, tag: str, frequency: float = 2.0, depth: float = 0.5):
super().__init__(tag=tag)
__slots__ = ("frequency", "depth")
if frequency < 0 or frequency > 14:
raise FilterInvalidArgument("Vibrato frequency must be between 0 and 14.")
raise FilterInvalidArgument(
"Vibrato frequency must be between 0 and 14.",
)
if depth < 0 or depth > 1:
raise FilterInvalidArgument("Vibrato depth must be between 0 and 1.")
raise FilterInvalidArgument(
"Vibrato depth must be between 0 and 1.",
)
self.frequency: float = frequency
self.depth: float = depth
self.payload: dict = {"vibrato": {"frequency": self.frequency, "depth": self.depth}}
self.payload: dict = {
"vibrato": {
"frequency": self.frequency, "depth": self.depth,
},
}
def __repr__(self):
def __repr__(self) -> str:
return (
f"<Pomice.VibratoFilter tag={self.tag} frequency={self.frequency} depth={self.depth}>"
)
@ -309,11 +339,11 @@ class Rotation(Filter):
the audio is being rotated around the listener's head
"""
__slots__ = ("rotation_hertz",)
def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__(tag=tag)
__slots__ = "rotation_hertz"
self.rotation_hertz: float = rotation_hertz
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
@ -326,6 +356,13 @@ class ChannelMix(Filter):
for some cool effects when done correctly.
"""
__slots__ = (
"left_to_left",
"right_to_right",
"left_to_right",
"right_to_left",
)
def __init__(
self,
*,
@ -337,23 +374,21 @@ class ChannelMix(Filter):
):
super().__init__(tag=tag)
__slots__ = ("left_to_left", "right_to_right", "left_to_right", "right_to_left")
if 0 > left_to_left > 1:
raise ValueError(
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1."
"'left_to_left' value must be more than or equal to 0 or less than or equal to 1.",
)
if 0 > right_to_right > 1:
raise ValueError(
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1."
"'right_to_right' value must be more than or equal to 0 or less than or equal to 1.",
)
if 0 > left_to_right > 1:
raise ValueError(
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1."
"'left_to_right' value must be more than or equal to 0 or less than or equal to 1.",
)
if 0 > right_to_left > 1:
raise ValueError(
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1."
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1.",
)
self.left_to_left: float = left_to_left
@ -367,7 +402,7 @@ class ChannelMix(Filter):
"leftToRight": self.left_to_right,
"rightToLeft": self.right_to_left,
"rightToRight": self.right_to_right,
}
},
}
def __repr__(self) -> str:
@ -382,6 +417,17 @@ class Distortion(Filter):
distortion is needed.
"""
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale",
"offset",
"scale",
)
def __init__(
self,
*,
@ -397,16 +443,6 @@ class Distortion(Filter):
):
super().__init__(tag=tag)
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale" "offset",
"scale",
)
self.sin_offset: float = sin_offset
self.sin_scale: float = sin_scale
self.cos_offset: float = cos_offset
@ -426,7 +462,7 @@ class Distortion(Filter):
"tanScale": self.tan_scale,
"offset": self.offset,
"scale": self.scale,
}
},
}
def __repr__(self) -> str:
@ -441,12 +477,11 @@ class LowPass(Filter):
"""Filter which supresses higher frequencies and allows lower frequencies to pass.
You can also do this with the Equalizer filter, but this is an easier way to do it.
"""
__slots__ = ("smoothing", "payload")
def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__(tag=tag)
__slots__ = "smoothing"
self.smoothing: float = smoothing
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}

View File

@ -1,30 +1,30 @@
from __future__ import annotations
from typing import List, Optional, Union
from discord import Member, User
from typing import List
from typing import Optional
from typing import Union
from discord import ClientUser
from discord import Member
from discord import User
from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType
from .enums import PlaylistType
from .enums import SearchType
from .enums import TrackType
from .filters import Filter
__all__ = (
"Track",
"Playlist",
)
class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink.
You can also pass in commands.Context to get a discord.py Context object in your track.
"""
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User]] = None,
):
__slots__ = (
"track_id",
"info",
@ -48,6 +48,19 @@ class Track:
"position",
)
def __init__(
self,
*,
track_id: str,
info: dict,
ctx: Optional[commands.Context] = None,
track_type: TrackType,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User, ClientUser]] = None,
):
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
@ -60,35 +73,29 @@ class Track:
self.original = self
self._search_type: SearchType = search_type
self.playlist: Playlist = None
self.playlist: Optional[Playlist] = None
self.title: str = info.get("title")
self.author: str = info.get("author")
self.uri: str = info.get("uri")
self.identifier: str = info.get("identifier")
self.isrc: str = info.get("isrc")
self.title: str = info.get("title", "Unknown Title")
self.author: str = info.get("author", "Unknown Author")
self.uri: str = info.get("uri", "")
self.identifier: str = info.get("identifier", "")
self.isrc: str = info.get("isrc", "")
self.thumbnail: Optional[str] = info.get("thumbnail")
if self.uri:
if info.get("thumbnail"):
self.thumbnail: str = info.get("thumbnail")
elif self.track_type == TrackType.SOUNDCLOUD:
# ok so theres no feasible way of getting a Soundcloud image URL
# so we're just gonna leave it blank for brevity
self.thumbnail = None
else:
self.thumbnail: str = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
if self.uri and self.track_type is TrackType.YOUTUBE:
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length: int = info.get("length")
self.ctx: commands.Context = ctx
if requester:
self.requester: Optional[Union[Member, User]] = requester
else:
self.requester: Optional[Union[Member, User]] = self.ctx.author if ctx else None
self.is_stream: bool = info.get("isStream")
self.is_seekable: bool = info.get("isSeekable")
self.position: int = info.get("position")
self.length: int = info.get("length", 0)
self.is_stream: bool = info.get("isStream", False)
self.is_seekable: bool = info.get("isSeekable", False)
self.position: int = info.get("position", 0)
def __eq__(self, other):
self.ctx: Optional[commands.Context] = ctx
self.requester: Optional[Union[Member, User, ClientUser]] = requester
if not self.requester and self.ctx:
self.requester = self.ctx.author
def __eq__(self, other: object) -> bool:
if not isinstance(other, Track):
return False
@ -97,10 +104,10 @@ class Track:
return other.track_id == self.track_id
def __str__(self):
def __str__(self) -> str:
return self.title
def __repr__(self):
def __repr__(self) -> str:
return f"<Pomice.track title={self.title!r} uri=<{self.uri!r}> length={self.length}>"
@ -110,15 +117,6 @@ class Playlist:
You can also pass in commands.Context to get a discord.py Context object in your tracks.
"""
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
__slots__ = (
"playlist_info",
"tracks",
@ -130,28 +128,37 @@ class Playlist:
"track_count",
)
def __init__(
self,
*,
playlist_info: dict,
tracks: list,
playlist_type: PlaylistType,
thumbnail: Optional[str] = None,
uri: Optional[str] = None,
):
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name")
self.name: str = playlist_info.get("name", "Unknown Playlist")
self.playlist_type: PlaylistType = playlist_type
self._thumbnail: str = thumbnail
self._uri: str = uri
self._thumbnail: Optional[str] = thumbnail
self._uri: Optional[str] = uri
for track in self.tracks:
track.playlist = self
if (index := playlist_info.get("selectedTrack")) == -1:
self.selected_track = None
else:
self.selected_track: Track = self.tracks[index]
self.selected_track: Optional[Track] = None
if (index := playlist_info.get("selectedTrack", -1)) != -1:
self.selected_track = self.tracks[index]
self.track_count: int = len(self.tracks)
def __str__(self):
def __str__(self) -> str:
return self.name
def __repr__(self):
def __repr__(self) -> str:
return f"<Pomice.playlist name={self.name!r} track_count={len(self.tracks)}>"
@property

View File

@ -1,54 +1,67 @@
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from discord import Client, Guild, VoiceChannel, VoiceProtocol
from discord import Client
from discord import Guild
from discord import VoiceChannel
from discord import VoiceProtocol
from discord.ext import commands
from discord.types.voice import VoiceServerUpdate, GuildVoiceState
from . import events
from .enums import SearchType
from .events import PomiceEvent, TrackEndEvent, TrackStartEvent
from .exceptions import (
FilterInvalidArgument,
FilterTagAlreadyInUse,
FilterTagInvalid,
TrackInvalidPosition,
TrackLoadError,
)
from .events import PomiceEvent
from .events import TrackEndEvent
from .events import TrackStartEvent
from .exceptions import FilterInvalidArgument
from .exceptions import FilterTagAlreadyInUse
from .exceptions import FilterTagInvalid
from .exceptions import TrackInvalidPosition
from .exceptions import TrackLoadError
from .filters import Filter
from .objects import Track
from .pool import Node, NodePool
from .objects import Track, Playlist
from .pool import Node
from .pool import NodePool
__all__ = ("Filters", "Player")
class Filters:
"""Helper class for filters"""
__slots__ = "_filters"
__slots__ = ("_filters",)
def __init__(self):
def __init__(self) -> None:
self._filters: List[Filter] = []
@property
def has_preload(self):
def has_preload(self) -> bool:
"""Property which checks if any applied filters were preloaded"""
return any(f for f in self._filters if f.preload == True)
@property
def has_global(self):
def has_global(self) -> bool:
"""Property which checks if any applied filters are global"""
return any(f for f in self._filters if f.preload == False)
@property
def empty(self):
def empty(self) -> bool:
"""Property which checks if the filter list is empty"""
return len(self._filters) == 0
def add_filter(self, *, filter: Filter):
def add_filter(self, *, filter: Filter) -> None:
"""Adds a filter to the list of filters applied"""
if any(f for f in self._filters if f.tag == filter.tag):
raise FilterTagAlreadyInUse("A filter with that tag is already in use.")
raise FilterTagAlreadyInUse(
"A filter with that tag is already in use.",
)
self._filters.append(filter)
def remove_filter(self, *, filter_tag: str):
def remove_filter(self, *, filter_tag: str) -> None:
"""Removes a filter from the list of filters applied using its filter tag"""
if not any(f for f in self._filters if f.tag == filter_tag):
raise FilterTagInvalid("A filter with that tag was not found.")
@ -57,26 +70,27 @@ class Filters:
if filter.tag == filter_tag:
del self._filters[index]
def has_filter(self, *, filter_tag: str):
def has_filter(self, *, filter_tag: str) -> bool:
"""Checks if a filter exists in the list of filters using its filter tag"""
return any(f for f in self._filters if f.tag == filter_tag)
def reset_filters(self):
def reset_filters(self) -> None:
"""Removes all filters from the list"""
self._filters = []
def get_preload_filters(self):
def get_preload_filters(self) -> List[Filter]:
"""Get all preloaded filters"""
return [f for f in self._filters if f.preload == True]
def get_all_payloads(self):
def get_all_payloads(self) -> Dict[str, Any]:
"""Returns a formatted dict of all the filter payloads"""
payload = {}
for filter in self._filters:
payload.update(filter.payload)
payload: Dict[str, Any] = {}
for _filter in self._filters:
if _filter.payload:
payload.update(_filter.payload)
return payload
def get_filters(self):
def get_filters(self) -> List[Filter]:
"""Returns the current list of applied filters"""
return self._filters
@ -89,20 +103,6 @@ class Player(VoiceProtocol):
```
"""
def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client
self.channel: VoiceChannel = channel
self._guild: Guild = channel.guild
return self
def __init__(
self,
client: Optional[Client] = None,
channel: Optional[VoiceChannel] = None,
*,
node: Node = None,
):
__slots__ = (
"client",
"channel",
@ -123,11 +123,23 @@ class Player(VoiceProtocol):
"__dict__",
)
self.client: Optional[Client] = client
self.channel: Optional[VoiceChannel] = channel
def __call__(self, client: Client, channel: VoiceChannel) -> "Player":
self.__init__(client, channel) # type: ignore
return self
self._bot: Union[Client, commands.Bot] = client
self._guild: Guild = channel.guild if channel else None
def __init__(
self,
client: Client,
channel: VoiceChannel,
*,
node: Optional[Node] = None,
) -> None:
self.client: Client = client
self.channel: VoiceChannel = channel
self._bot: Client = client
self._guild: Guild = channel.guild
self._node: Node = node if node else NodePool.get_node()
self._current: Optional[Track] = None
self._filters: Filters = Filters()
@ -137,7 +149,7 @@ class Player(VoiceProtocol):
self._position: int = 0
self._last_position: int = 0
self._last_update: int = 0
self._last_update: float = 0
self._ending_track: Optional[Track] = None
self._voice_state: dict = {}
@ -153,11 +165,13 @@ class Player(VoiceProtocol):
@property
def position(self) -> float:
"""Property which returns the player's position in a track in milliseconds"""
current = self._current.original
if not self.is_playing or not self._current:
return 0
current = self._current.original
if not current:
return 0
if self.is_paused:
return min(self._last_position, current.length)
@ -185,7 +199,7 @@ class Player(VoiceProtocol):
return self._is_connected and self._paused
@property
def current(self) -> Track:
def current(self) -> Optional[Track]:
"""Property which returns the currently playing track"""
return self._current
@ -210,7 +224,7 @@ class Player(VoiceProtocol):
return self._filters
@property
def bot(self) -> Union[Client, commands.Bot]:
def bot(self) -> Client:
"""Property which returns the bot associated with this player instance"""
return self._bot
@ -221,13 +235,14 @@ class Player(VoiceProtocol):
"""
return self.guild.id not in self._node._players
async def _update_state(self, data: dict):
state: dict = data.get("state")
self._last_update = time.time() * 1000
self._is_connected = state.get("connected")
self._last_position = state.get("position")
async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {})
self._last_update = time.time() * 1000.0
self._is_connected = bool(state.get("connected"))
position = state.get("position")
self._position = int(position) if position else 0
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None):
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys():
return
@ -246,27 +261,32 @@ class Player(VoiceProtocol):
data={"voice": data},
)
async def on_voice_server_update(self, data: dict):
async def on_voice_server_update(self, data: VoiceServerUpdate) -> None:
self._voice_state.update({"event": data})
await self._dispatch_voice_update(self._voice_state)
async def on_voice_state_update(self, data: dict):
async def on_voice_state_update(self, data: GuildVoiceState) -> None:
self._voice_state.update({"sessionId": data.get("session_id")})
if not (channel_id := data.get("channel_id")):
channel_id = data.get("channel_id")
if not channel_id:
await self.disconnect()
self._voice_state.clear()
return
self.channel = self.guild.get_channel(int(channel_id))
channel = self.guild.get_channel(int(channel_id))
if not channel:
await self.disconnect()
self._voice_state.clear()
return
if not data.get("token"):
return
await self._dispatch_voice_update({**self._voice_state, "event": data})
async def _dispatch_event(self, data: dict):
event_type = data.get("type")
async def _dispatch_event(self, data: dict) -> None:
event_type: str = data["type"]
event: PomiceEvent = getattr(events, event_type)(data, self)
if isinstance(event, TrackEndEvent) and event.reason != "REPLACED":
@ -277,11 +297,12 @@ class Player(VoiceProtocol):
if isinstance(event, TrackStartEvent):
self._ending_track = self._current
async def _swap_node(self, *, new_node: Node):
async def _swap_node(self, *, new_node: Node) -> None:
data: dict = {
"encodedTrack": self.current.track_id,
"position": self.position,
}
if self.current:
data["encodedTrack"] = self.current.track_id
del self._node._players[self._guild.id]
self._node = new_node
@ -304,7 +325,7 @@ class Player(VoiceProtocol):
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
):
) -> Optional[Union[List[Track], Playlist]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials when you created the node,
@ -321,7 +342,7 @@ class Player(VoiceProtocol):
async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]:
) -> Optional[Union[List[Track], Playlist]]:
"""
Gets recommendations from either YouTube or Spotify.
You can pass in a discord.py Context object to get a
@ -331,14 +352,14 @@ class Player(VoiceProtocol):
async def connect(
self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False
):
) -> None:
await self.guild.change_voice_state(
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute
channel=self.channel, self_deaf=self_deaf, self_mute=self_mute,
)
self._node._players[self.guild.id] = self
self._is_connected = True
async def stop(self):
async def stop(self) -> None:
"""Stops the currently playing track."""
self._current = None
await self._node.send(
@ -348,27 +369,27 @@ class Player(VoiceProtocol):
data={"encodedTrack": None},
)
async def disconnect(self, *, force: bool = False):
async def disconnect(self, *, force: bool = False) -> None:
"""Disconnects the player from voice."""
try:
await self.guild.change_voice_state(channel=None)
finally:
self.cleanup()
self._is_connected = False
self.channel = None
del self.channel
async def destroy(self):
async def destroy(self) -> None:
"""Disconnects and destroys the player, and runs internal cleanup."""
try:
await self.disconnect()
except AttributeError:
# 'NoneType' has no attribute '_get_voice_client_key' raised by self.cleanup() ->
# assume we're already disconnected and cleaned up
assert self.channel is None and not self.is_connected
assert not self.is_connected and not self.channel
self._node._players.pop(self.guild.id)
await self._node.send(
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id
method="DELETE", path=self._player_endpoint_uri, guild_id=self._guild.id,
)
async def play(
@ -383,20 +404,22 @@ class Player(VoiceProtocol):
if not track.isrc:
# We have to bare raise here because theres no other way to skip this block feasibly
raise
search: Track = (
search = (
await self._node.get_tracks(f"{track._search_type}:{track.isrc}", ctx=track.ctx)
)[0]
)[0] # type: ignore
except Exception:
# First method didn't work, lets try just searching it up
try:
search: Track = (
search = (
await self._node.get_tracks(
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx
f"{track._search_type}:{track.title} - {track.author}", ctx=track.ctx,
)
)[0]
)[0] # type: ignore
except:
# The song wasn't able to be found, raise error
raise TrackLoadError("No equivalent track was able to be found.")
raise TrackLoadError(
"No equivalent track was able to be found.",
)
data = {
"encodedTrack": search.track_id,
"position": str(start),
@ -432,7 +455,7 @@ class Player(VoiceProtocol):
if track.filters and not self.filters.has_global:
# Now apply all filters
for filter in track.filters:
await self.add_filter(filter=filter)
await self.add_filter(_filter=filter)
# Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero.
@ -454,8 +477,13 @@ class Player(VoiceProtocol):
async def seek(self, position: float) -> float:
"""Seeks to a position in the currently playing track milliseconds"""
if not self._current or not self._current.original:
return 0.0
if position < 0 or position > self._current.original.length:
raise TrackInvalidPosition("Seek position must be between 0 and the track length")
raise TrackInvalidPosition(
"Seek position must be between 0 and the track length",
)
await self._node.send(
method="PATCH",
@ -487,7 +515,7 @@ class Player(VoiceProtocol):
self._volume = volume
return self._volume
async def add_filter(self, filter: Filter, fast_apply: bool = False) -> Filter:
async def add_filter(self, _filter: Filter, fast_apply: bool = False) -> Filters:
"""Adds a filter to the player. Takes a pomice.Filter object.
This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
@ -495,7 +523,7 @@ class Player(VoiceProtocol):
(You must have a song playing in order for `fast_apply` to work.)
"""
self._filters.add_filter(filter=filter)
self._filters.add_filter(filter=_filter)
payload = self._filters.get_all_payloads()
await self._node.send(
method="PATCH",
@ -508,7 +536,7 @@ class Player(VoiceProtocol):
return self._filters
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filter:
async def remove_filter(self, filter_tag: str, fast_apply: bool = False) -> Filters:
"""Removes a filter from the player. Takes a filter tag.
This will only work if you are using a version of Lavalink that supports filters.
If you would like for the filter to apply instantly, set the `fast_apply` arg to `True`.
@ -529,7 +557,7 @@ class Player(VoiceProtocol):
return self._filters
async def reset_filters(self, *, fast_apply: bool = False):
async def reset_filters(self, *, fast_apply: bool = False) -> None:
"""Resets all currently applied filters to their default parameters.
You must have filters applied in order for this to work.
If you would like the filters to be removed instantly, set the `fast_apply` arg to `True`.
@ -539,7 +567,7 @@ class Player(VoiceProtocol):
if not self._filters:
raise FilterInvalidArgument(
"You must have filters applied first in order to use this method."
"You must have filters applied first in order to use this method.",
)
self._filters.reset_filters()
await self._node.send(

View File

@ -3,35 +3,48 @@ from __future__ import annotations
import asyncio
import random
import re
import aiohttp
from discord import Client
from discord.ext import commands
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from urllib.parse import quote
from . import __version__, spotify, applemusic
import aiohttp
from discord import Client
from discord.ext import commands
from . import __version__
from . import applemusic
from . import spotify
from .enums import *
from .exceptions import (
AppleMusicNotEnabled,
InvalidSpotifyClientAuthorization,
LavalinkVersionIncompatible,
NodeConnectionFailure,
NodeCreationError,
NodeNotAvailable,
NoNodesAvailable,
NodeRestException,
TrackLoadError,
)
from .exceptions import AppleMusicNotEnabled
from .exceptions import InvalidSpotifyClientAuthorization
from .exceptions import LavalinkVersionIncompatible
from .exceptions import NodeConnectionFailure
from .exceptions import NodeCreationError
from .exceptions import NodeNotAvailable
from .exceptions import NodeRestException
from .exceptions import NoNodesAvailable
from .exceptions import TrackLoadError
from .filters import Filter
from .objects import Playlist, Track
from .utils import ExponentialBackoff, NodeStats, Ping
from .objects import Playlist
from .objects import Track
from .routeplanner import RoutePlanner
from .utils import ExponentialBackoff
from .utils import NodeStats
from .utils import Ping
if TYPE_CHECKING:
from .player import Player
__all__ = (
"Node",
"NodePool",
)
class Node:
"""The base class for a node.
@ -40,26 +53,9 @@ class Node:
To enable Apple music, set the "apple_music" parameter to "True"
"""
def __init__(
self,
*,
pool: NodePool,
bot: Union[Client, commands.Bot],
host: str,
port: int,
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[str] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False,
fallback: bool = False,
):
__slots__ = (
"_bot",
"_bot_user",
"_host",
"_port",
"_pool",
@ -83,12 +79,34 @@ class Node:
"_spotify_client_secret",
"_spotify_client",
"_apple_music_client",
"_route_planner",
"_stats",
"available",
)
self._bot: Union[Client, commands.Bot] = bot
def __init__(
self,
*,
pool: Type[NodePool],
bot: commands.Bot,
host: str,
port: int,
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
session: Optional[aiohttp.ClientSession] = None,
spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None,
apple_music: bool = False,
fallback: bool = False,
):
self._bot: commands.Bot = bot
self._host: str = host
self._port: int = port
self._pool: NodePool = pool
self._pool: Type[NodePool] = pool
self._password: str = password
self._identifier: str = identifier
self._heartbeat: int = heartbeat
@ -98,33 +116,38 @@ class Node:
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session: Optional[aiohttp.ClientSession] = session
self._websocket = None
self._task: asyncio.Task = None
self._session: aiohttp.ClientSession = session # type: ignore
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
self._websocket: aiohttp.ClientWebSocketResponse
self._task: asyncio.Task
self._session_id: str = None
self._session_id: Optional[str] = None
self._available: bool = False
self._version: str = None
self._version: int
self._route_planner = RoutePlanner(self)
if not self._bot.user:
raise NodeCreationError("Bot user is not ready yet.")
self._bot_user = self._bot.user
self._headers = {
"Authorization": self._password,
"User-Id": str(self._bot.user.id),
"User-Id": str(self._bot_user.id),
"Client-Name": f"Pomice/{__version__}",
}
self._players: Dict[int, Player] = {}
self._spotify_client_id: str = spotify_client_id
self._spotify_client_secret: str = spotify_client_secret
self._spotify_client_id: Optional[int] = spotify_client_id
self._spotify_client_secret: Optional[str] = spotify_client_secret
self._apple_music_client: Optional[applemusic.Client] = None
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client: spotify.Client = spotify.Client(
self._spotify_client_id, self._spotify_client_secret
self._spotify_client_id, self._spotify_client_secret,
)
if apple_music:
@ -132,7 +155,7 @@ class Node:
self._bot.add_listener(self._update_handler, "on_socket_response")
def __repr__(self):
def __repr__(self) -> str:
return (
f"<Pomice.node ws_uri={self._websocket_uri} rest_uri={self._rest_uri} "
f"player_count={len(self._players)}>"
@ -154,7 +177,7 @@ class Node:
return self._players
@property
def bot(self) -> Union[Client, commands.Bot]:
def bot(self) -> Client:
"""Property which returns the discord.py client linked to this node"""
return self._bot
@ -164,21 +187,21 @@ class Node:
return len(self.players)
@property
def pool(self):
def pool(self) -> Type[NodePool]:
"""Property which returns the pool this node is apart of"""
return self._pool
@property
def latency(self):
def latency(self) -> float:
"""Property which returns the latency of the node"""
return Ping(self._host, port=self._port).get_ping()
@property
def ping(self):
def ping(self) -> float:
"""Alias for `Node.latency`, returns the latency of the node"""
return self.latency
async def _update_handler(self, data: dict):
async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready()
if not data:
@ -193,7 +216,7 @@ class Node:
return
elif data["t"] == "VOICE_STATE_UPDATE":
if int(data["d"]["user_id"]) != self._bot.user.id:
if int(data["d"]["user_id"]) != self._bot_user.id:
return
guild_id = int(data["d"]["guild_id"])
@ -203,8 +226,11 @@ class Node:
except KeyError:
return
async def _handle_node_switch(self):
nodes = [node for node in self.pool.nodes.copy().values() if node.is_connected]
async def _handle_node_switch(self) -> None:
nodes = [
node for node in self.pool._nodes.copy().values()
if node.is_connected
]
new_node = random.choice(nodes)
for player in self.players.copy().values():
@ -212,7 +238,7 @@ class Node:
await self.disconnect()
async def _listen(self):
async def _listen(self) -> None:
backoff = ExponentialBackoff(base=7)
while True:
@ -227,7 +253,7 @@ class Node:
else:
self._loop.create_task(self._handle_payload(msg.json()))
async def _handle_payload(self, data: dict):
async def _handle_payload(self, data: dict) -> None:
op = data.get("op", None)
if not op:
return
@ -239,14 +265,20 @@ class Node:
if op == "ready":
self._session_id = data["sessionId"]
if "guildId" in data:
if not (player := self._players.get(int(data["guildId"]))):
if not "guildId" in data:
return
player = self._players.get(int(data["guildId"]))
if not player:
return
if op == "event":
await player._dispatch_event(data)
elif op == "playerUpdate":
return
if op == "playerUpdate":
await player._update_state(data)
return
async def send(
self,
@ -255,11 +287,13 @@ class Node:
include_version: bool = True,
guild_id: Optional[Union[int, str]] = None,
query: Optional[str] = None,
data: Optional[Union[dict, str]] = None,
data: Optional[Union[Dict, str]] = None,
ignore_if_available: bool = False,
):
) -> Any:
if not ignore_if_available and not self._available:
raise NodeNotAvailable(f"The node '{self._identifier}' is unavailable.")
raise NodeNotAvailable(
f"The node '{self._identifier}' is unavailable.",
)
uri: str = (
f"{self._rest_uri}/"
@ -270,12 +304,12 @@ class Node:
)
async with self._session.request(
method=method, url=uri, headers=self._headers, json=data or {}
method=method, url=uri, headers=self._headers, json=data or {},
) as resp:
if resp.status >= 300:
data: dict = await resp.json()
resp_data: dict = await resp.json()
raise NodeRestException(
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {data["message"]}'
f'Error fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}',
)
if method == "DELETE" or resp.status == 204:
@ -286,11 +320,11 @@ class Node:
return await resp.json()
def get_player(self, guild_id: int):
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
def get_player(self, guild_id: int) -> Optional[Player]:
"""Takes a guild ID as a parameter. Returns a pomice Player object or None."""
return self._players.get(guild_id, None)
async def connect(self):
async def connect(self) -> "Node":
"""Initiates a connection with a Lavalink node and adds it to the node pool."""
await self._bot.wait_until_ready()
@ -298,7 +332,7 @@ class Node:
self._session = aiohttp.ClientSession()
try:
version = await self.send(
version: str = await self.send(
method="GET",
path="version",
ignore_if_available=True,
@ -309,14 +343,14 @@ class Node:
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library."
"Lavalink version 3.7.0 or above is required to use this library.",
)
if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = version[:1]
self._version = int(version[:1])
self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket",
@ -332,18 +366,18 @@ class Node:
except (aiohttp.ClientConnectorError, ConnectionRefusedError):
raise NodeConnectionFailure(
f"The connection to node '{self._identifier}' failed."
f"The connection to node '{self._identifier}' failed.",
) from None
except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure(
f"The password for node '{self._identifier}' is invalid."
f"The password for node '{self._identifier}' is invalid.",
) from None
except aiohttp.InvalidURL:
raise NodeConnectionFailure(
f"The URI for node '{self._identifier}' is invalid."
f"The URI for node '{self._identifier}' is invalid.",
) from None
async def disconnect(self):
async def disconnect(self) -> None:
"""Disconnects a connected Lavalink node and removes it from the node pool.
This also destroys any players connected to the node.
"""
@ -371,7 +405,7 @@ class Node:
"""
data: dict = await self.send(
method="GET", path="decodetrack", query=f"encodedTrack={identifier}"
method="GET", path="decodetrack", query=f"encodedTrack={identifier}",
)
return Track(
track_id=identifier,
@ -387,7 +421,7 @@ class Node:
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None,
):
) -> Optional[Union[Playlist, List[Track]]]:
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a
@ -413,7 +447,7 @@ class Node:
if not self._apple_music_client:
raise AppleMusicNotEnabled(
"You must have Apple Music functionality enabled in order to play Apple Music tracks."
"Please set apple_music to True in your Node class."
"Please set apple_music to True in your Node class.",
)
apple_music_results = await self._apple_music_client.search(query=query)
@ -437,7 +471,7 @@ class Node:
"thumbnail": apple_music_results.image,
"isrc": apple_music_results.isrc,
},
)
),
]
tracks = [
@ -464,7 +498,9 @@ class Node:
]
return Playlist(
playlist_info={"name": apple_music_results.name, "selectedTrack": 0},
playlist_info={
"name": apple_music_results.name, "selectedTrack": 0,
},
tracks=tracks,
playlist_type=PlaylistType.APPLE_MUSIC,
thumbnail=apple_music_results.image,
@ -476,7 +512,7 @@ class Node:
raise InvalidSpotifyClientAuthorization(
"You did not provide proper Spotify client authorization credentials. "
"If you would like to use the Spotify searching feature, "
"please obtain Spotify API credentials here: https://developer.spotify.com/"
"please obtain Spotify API credentials here: https://developer.spotify.com/",
)
spotify_results = await self._spotify_client.search(query=query)
@ -501,7 +537,7 @@ class Node:
"thumbnail": spotify_results.image,
"isrc": spotify_results.isrc,
},
)
),
]
tracks = [
@ -528,7 +564,9 @@ class Node:
]
return Playlist(
playlist_info={"name": spotify_results.name, "selectedTrack": 0},
playlist_info={
"name": spotify_results.name, "selectedTrack": 0,
},
tracks=tracks,
playlist_type=PlaylistType.SPOTIFY,
thumbnail=spotify_results.image,
@ -537,11 +575,11 @@ class Node:
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}"
method="GET", path="loadtracks", query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0]
info: dict = track.get("info")
info: dict = track["info"]
return [
Track(
@ -549,15 +587,15 @@ class Node:
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": info.get("length"),
"uri": info.get("uri"),
"position": info.get("position"),
"identifier": info.get("identifier"),
"length": info["length"],
"uri": info["uri"],
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
)
),
]
else:
@ -572,18 +610,22 @@ class Node:
if match := URLRegex.YOUTUBE_VID_IN_PLAYLIST.match(query):
query = match.group("video")
data: dict = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}"
data = await self.send(
method="GET", path="loadtracks", query=f"identifier={quote(query)}",
)
load_type = data.get("loadType")
if not load_type:
raise TrackLoadError("There was an error while trying to load this track.")
raise TrackLoadError(
"There was an error while trying to load this track.",
)
elif load_type == "LOAD_FAILED":
exception = data["exception"]
raise TrackLoadError(f"{exception['message']} [{exception['severity']}]")
raise TrackLoadError(
f"{exception['message']} [{exception['severity']}]",
)
elif load_type == "NO_MATCHES":
return None
@ -619,9 +661,14 @@ class Node:
for track in data["tracks"]
]
else:
raise TrackLoadError(
"There was an error while trying to load this track.",
)
async def get_recommendations(
self, *, track: Track, ctx: Optional[commands.Context] = None
) -> Union[List[Track], None]:
) -> Optional[Union[List[Track], Playlist]]:
"""
Gets recommendations from either YouTube or Spotify.
The track that is passed in must be either from
@ -652,17 +699,17 @@ class Node:
)
for track in results
]
return tracks
elif track.track_type == TrackType.YOUTUBE:
tracks = await self.get_tracks(
return await self.get_tracks(
query=f"ytsearch:https://www.youtube.com/watch?v={track.identifier}&list=RD{track.identifier}",
ctx=ctx,
)
return tracks
else:
raise TrackLoadError(
"The specfied track must be either a YouTube or Spotify track to recieve recommendations."
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
)
@ -671,9 +718,10 @@ class NodePool:
This holds all the nodes that are to be used by the bot.
"""
__slots__ = ()
_nodes: Dict[str, Node] = {}
def __repr__(self):
def __repr__(self) -> str:
return f"<Pomice.NodePool node_count={self.node_count}>"
@property
@ -682,7 +730,7 @@ class NodePool:
return self._nodes
@property
def node_count(self):
def node_count(self) -> int:
return len(self._nodes.values())
@classmethod
@ -700,21 +748,31 @@ class NodePool:
based on how players it has. This method will return a node with
the least amount of players
"""
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
available_nodes: List[Node] = [
node for node in cls._nodes.values() if node._available
]
if not available_nodes:
raise NoNodesAvailable("There are no nodes available.")
if algorithm == NodeAlgorithm.by_ping:
tested_nodes = {node: node.latency for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get)
return min(tested_nodes, key=tested_nodes.get) # type: ignore
elif algorithm == NodeAlgorithm.by_players:
tested_nodes = {node: len(node.players.keys()) for node in available_nodes}
return min(tested_nodes, key=tested_nodes.get)
tested_nodes = {
node: len(node.players.keys())
for node in available_nodes
}
return min(tested_nodes, key=tested_nodes.get) # type: ignore
else:
raise ValueError(
"The algorithm provided is not a valid NodeAlgorithm.",
)
@classmethod
def get_node(cls, *, identifier: str = None) -> Node:
def get_node(cls, *, identifier: Optional[str] = None) -> Node:
"""Fetches a node from the node pool using it's identifier.
If no identifier is provided, it will choose a node at random.
"""
@ -728,21 +786,21 @@ class NodePool:
if identifier is None:
return random.choice(list(available_nodes.values()))
return available_nodes.get(identifier, None)
return available_nodes[identifier]
@classmethod
async def create_node(
cls,
*,
bot: Client,
bot: commands.Bot,
host: str,
port: str,
port: int,
password: str,
identifier: str,
secure: bool = False,
heartbeat: int = 30,
loop: Optional[asyncio.AbstractEventLoop] = None,
spotify_client_id: Optional[str] = None,
spotify_client_id: Optional[int] = None,
spotify_client_secret: Optional[str] = None,
session: Optional[aiohttp.ClientSession] = None,
apple_music: bool = False,
@ -752,7 +810,9 @@ class NodePool:
For Spotify searching capabilites, pass in valid Spotify API credentials.
"""
if identifier in cls._nodes.keys():
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
raise NodeCreationError(
f"A node with identifier '{identifier}' already exists.",
)
node = Node(
pool=cls,
@ -779,7 +839,9 @@ class NodePool:
async def disconnect(cls) -> None:
"""Disconnects all available nodes from the node pool."""
available_nodes: List[Node] = [node for node in cls._nodes.values() if node._available]
available_nodes: List[Node] = [
node for node in cls._nodes.values() if node._available
]
for node in available_nodes:
await node.disconnect()

View File

@ -1,35 +1,47 @@
from __future__ import annotations
import random
from copy import copy
from typing import (
Iterable,
Iterator,
List,
Optional,
Union,
)
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Union
from .objects import Track
from .enums import LoopMode
from .exceptions import QueueEmpty, QueueException, QueueFull
from .exceptions import QueueEmpty
from .exceptions import QueueException
from .exceptions import QueueFull
from .objects import Track
__all__ = (
"Queue",
)
class Queue(Iterable[Track]):
"""Queue for Pomice. This queue takes pomice.Track as an input and includes looping and shuffling."""
__slots__ = (
"max_size",
"_queue",
"_overflow",
"_loop_mode",
"_current_item",
)
def __init__(
self,
max_size: Optional[int] = None,
*,
overflow: bool = True,
):
__slots__ = ("max_size", "_queue", "_overflow", "_loop_mode", "_current_item")
self.max_size: Optional[int] = max_size
self._queue: List[Track] = [] # type: ignore
self._current_item: Track
self._queue: List[Track] = []
self._overflow: bool = overflow
self._loop_mode: Optional[LoopMode] = None
self._current_item: Optional[Track] = None
def __str__(self) -> str:
"""String showing all Track objects appearing as a list."""
@ -60,7 +72,7 @@ class Queue(Iterable[Track]):
return self._queue[index]
def __setitem__(self, index: int, item: Track):
def __setitem__(self, index: int, item: Track) -> None:
"""Inserts an item at given position."""
if not isinstance(index, int):
raise ValueError("'int' type required.'")
@ -90,7 +102,9 @@ class Queue(Iterable[Track]):
The new queue will have the same max_size as the original.
"""
if not isinstance(other, Iterable):
raise TypeError(f"Adding with the '{type(other)}' type is not supported.")
raise TypeError(
f"Adding with the '{type(other)}' type is not supported.",
)
new_queue = self.copy()
new_queue.extend(other)
@ -106,7 +120,9 @@ class Queue(Iterable[Track]):
self.extend(other)
return self
raise TypeError(f"Adding '{type(other)}' type to the queue is not supported.")
raise TypeError(
f"Adding '{type(other)}' type to the queue is not supported.",
)
def _get(self) -> Track:
return self._queue.pop(0)
@ -165,7 +181,7 @@ class Queue(Iterable[Track]):
return bool(self._loop_mode)
@property
def loop_mode(self) -> LoopMode:
def loop_mode(self) -> Optional[LoopMode]:
"""Returns the LoopMode enum set in the queue object"""
return self._loop_mode
@ -178,7 +194,7 @@ class Queue(Iterable[Track]):
"""Returns the queue as a List"""
return self._queue
def get(self):
def get(self) -> Track:
"""Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue.
"""
@ -239,7 +255,9 @@ class Queue(Iterable[Track]):
"""Put the given item into the back of the queue."""
if self.is_full:
if not self._overflow:
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.")
raise QueueFull(
f"Queue max_size of {self.max_size} has been reached.",
)
self._drop()
@ -249,7 +267,9 @@ class Queue(Iterable[Track]):
"""Put the given item into the queue at the specified index."""
if self.is_full:
if not self._overflow:
raise QueueFull(f"Queue max_size of {self.max_size} has been reached.")
raise QueueFull(
f"Queue max_size of {self.max_size} has been reached.",
)
self._drop()
@ -275,7 +295,7 @@ class Queue(Iterable[Track]):
if (new_len + self.count) > self.max_size:
raise QueueFull(
f"Queue has {self.count}/{self.max_size} items, "
f"cannot add {new_len} more."
f"cannot add {new_len} more.",
)
for item in iterable:
@ -292,7 +312,7 @@ class Queue(Iterable[Track]):
"""Remove all items from the queue."""
self._queue.clear()
def set_loop_mode(self, mode: LoopMode):
def set_loop_mode(self, mode: LoopMode) -> None:
"""
Sets the loop mode of the queue.
Takes the LoopMode enum as an argument.
@ -307,7 +327,7 @@ class Queue(Iterable[Track]):
self._queue.insert(index, self._current_item)
self._current_item = self._queue[index]
def disable_loop(self):
def disable_loop(self) -> None:
"""
Disables loop mode if set.
Raises QueueException if loop mode is already None.
@ -321,17 +341,17 @@ class Queue(Iterable[Track]):
self._loop_mode = None
def shuffle(self):
def shuffle(self) -> None:
"""Shuffles the queue."""
return random.shuffle(self._queue)
def clear_track_filters(self):
def clear_track_filters(self) -> None:
"""Clears all filters applied to tracks"""
for track in self._queue:
track.filters = None
def jump(self, item: Track):
def jump(self, item: Track) -> None:
"""Removes all tracks before the."""
index = self.find_position(item)
new_queue = self._queue[index : self.size]
new_queue = self._queue[index: self.size]
self._queue = new_queue

View File

@ -1,11 +1,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .pool import Node
from .utils import RouteStats
from aiohttp import ClientSession
__all__ = ("RoutePlanner",)
class RoutePlanner:
@ -16,17 +18,16 @@ class RoutePlanner:
def __init__(self, node: Node) -> None:
self.node: Node = node
self.session: ClientSession = node._session
async def get_status(self) -> RouteStats:
"""Gets the status of the route planner API."""
data: dict = await self.node.send(method="GET", path="routeplanner/status")
return RouteStats(data)
async def free_address(self, ip: str):
async def free_address(self, ip: str) -> None:
"""Frees an address using the route planner API"""
await self.node.send(method="POST", path="routeplanner/free/address", data={"address": ip})
async def free_all_addresses(self):
async def free_all_addresses(self) -> None:
"""Frees all available addresses using the route planner api"""
await self.node.send(method="POST", path="routeplanner/free/address/all")

View File

@ -1,5 +1,4 @@
"""Spotify module for Pomice, made possible by cloudwithax 2023"""
from .client import Client
from .exceptions import *
from .objects import *
from .client import Client

View File

@ -2,19 +2,27 @@ from __future__ import annotations
import re
import time
from base64 import b64encode
from typing import Dict, List
from typing import Optional
from typing import Union
import aiohttp
import orjson as json
from base64 import b64encode
from typing import TYPE_CHECKING
from .exceptions import InvalidSpotifyURL, SpotifyRequestException
from .exceptions import InvalidSpotifyURL
from .exceptions import SpotifyRequestException
from .objects import *
__all__ = (
"Client",
)
GRANT_URL = "https://accounts.spotify.com/api/token"
REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}"
SPOTIFY_URL_REGEX = re.compile(
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)"
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
)
@ -24,17 +32,21 @@ class Client:
for any Spotify URL you throw at it.
"""
def __init__(self, client_id: str, client_secret: str) -> None:
self._client_id = client_id
self._client_secret = client_secret
def __init__(self, client_id: int, client_secret: str) -> None:
self._client_id: int = client_id
self._client_secret: str = client_secret
self.session: aiohttp.ClientSession = None
self.session: aiohttp.ClientSession = None # type: ignore
self._bearer_token: str = None
self._expiry = 0
self._auth_token = b64encode(f"{self._client_id}:{self._client_secret}".encode())
self._grant_headers = {"Authorization": f"Basic {self._auth_token.decode()}"}
self._bearer_headers = None
self._bearer_token: Optional[str] = None
self._expiry: float = 0.0
self._auth_token = b64encode(
f"{self._client_id}:{self._client_secret}".encode(),
)
self._grant_headers = {
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None
async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"}
@ -45,32 +57,34 @@ class Client:
async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error fetching bearer token: {resp.status} {resp.reason}"
f"Error fetching bearer token: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10)
self._bearer_headers = {"Authorization": f"Bearer {self._bearer_token}"}
self._bearer_headers = {
"Authorization": f"Bearer {self._bearer_token}",
}
async def search(self, *, query: str):
async def search(self, *, query: str) -> Union[Track, Album, Artist, Playlist]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query)
spotify_type = result.group("type")
spotify_id = result.group("id")
if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
spotify_type = result.group("type")
spotify_id = result.group("id")
request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id)
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
@ -81,11 +95,11 @@ class Client:
return Album(data)
elif spotify_type == "artist":
async with self.session.get(
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers
f"{request_url}/top-tracks?market=US", headers=self._bearer_headers,
) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
track_data: dict = await resp.json(loads=json.loads)
@ -100,7 +114,7 @@ class Client:
if not len(tracks):
raise SpotifyRequestException(
"This playlist is empty and therefore cannot be queued."
"This playlist is empty and therefore cannot be queued.",
)
next_page_url = data["tracks"]["next"]
@ -109,7 +123,7 @@ class Client:
async with self.session.get(next_page_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
next_data: dict = await resp.json(loads=json.loads)
@ -123,26 +137,30 @@ class Client:
return Playlist(data, tracks)
async def get_recommendations(self, *, query: str):
async def get_recommendations(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
result = SPOTIFY_URL_REGEX.match(query)
spotify_type = result.group("type")
spotify_id = result.group("id")
if not result:
raise InvalidSpotifyURL("The Spotify link provided is not valid.")
if not spotify_type == "track":
raise InvalidSpotifyURL("The provided query is not a Spotify track.")
spotify_type = result.group("type")
spotify_id = result.group("id")
request_url = REQUEST_URL.format(type="recommendation", id=f"?seed_tracks={spotify_id}")
if not spotify_type == "track":
raise InvalidSpotifyURL(
"The provided query is not a Spotify track.",
)
request_url = REQUEST_URL.format(
type="recommendation", id=f"?seed_tracks={spotify_id}",
)
async with self.session.get(request_url, headers=self._bearer_headers) as resp:
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}"
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
@ -154,4 +172,4 @@ class Client:
async def close(self) -> None:
if self.session:
await self.session.close()
self.session = None
self.session = None # type: ignore

View File

@ -1,3 +1,9 @@
__all__ = (
"SpotifyRequestException",
"InvalidSpotifyURL",
)
class SpotifyRequestException(Exception):
"""An error occurred when making a request to the Spotify API"""

View File

@ -1,29 +1,36 @@
from typing import List
from typing import Optional
__all__ = (
"Track",
"Playlist",
"Album",
"Artist",
)
class Track:
"""The base class for a Spotify Track"""
def __init__(self, data: dict, image=None) -> None:
def __init__(self, data: dict, image: Optional[str] = None) -> None:
self.name: str = data["name"]
self.artists: str = ", ".join(artist["name"] for artist in data["artists"])
self.artists: str = ", ".join(
artist["name"] for artist in data["artists"]
)
self.length: float = data["duration_ms"]
self.id: str = data["id"]
self.issrc: Optional[str] = None
if data.get("external_ids"):
self.isrc: str = data["external_ids"]["isrc"]
else:
self.isrc = None
self.isrc = data["external_ids"]["isrc"]
self.image: Optional[str] = image
if data.get("album") and data["album"].get("images"):
self.image: str = data["album"]["images"][0]["url"]
else:
self.image: str = image
self.image = data["album"]["images"][0]["url"]
if data["is_local"]:
self.uri = None
else:
self.uri: str = data["external_urls"]["spotify"]
self.uri: Optional[str] = None
if not data["is_local"]:
self.uri = data["external_urls"]["spotify"]
def __repr__(self) -> str:
return (
@ -42,7 +49,7 @@ class Playlist:
self.total_tracks: int = data["tracks"]["total"]
self.id: str = data["id"]
if data.get("images") and len(data["images"]):
self.image: str = data["images"][0]["url"]
self.image = data["images"][0]["url"]
else:
self.image = self.tracks[0].image
self.uri = data["external_urls"]["spotify"]
@ -59,9 +66,14 @@ class Album:
def __init__(self, data: dict) -> None:
self.name: str = data["name"]
self.artists: str = ", ".join(artist["name"] for artist in data["artists"])
self.artists: str = ", ".join(
artist["name"] for artist in data["artists"]
)
self.image: str = data["images"][0]["url"]
self.tracks = [Track(track, image=self.image) for track in data["tracks"]["items"]]
self.tracks = [
Track(track, image=self.image)
for track in data["tracks"]["items"]
]
self.total_tracks: int = data["total_tracks"]
self.id: str = data["id"]
self.uri: str = data["external_urls"]["spotify"]
@ -78,7 +90,8 @@ class Artist:
def __init__(self, data: dict, tracks: dict) -> None:
self.name: str = (
f"Top tracks for {data['name']}" # Setting that because its only playing top tracks
# Setting that because its only playing top tracks
f"Top tracks for {data['name']}"
)
self.genres: str = ", ".join(genre for genre in data["genres"])
self.followers: int = data["followers"]["total"]

View File

@ -1,11 +1,24 @@
import random
import time
import socket
from .enums import RouteStrategy, RouteIPType
from timeit import default_timer as timer
from itertools import zip_longest
import time
from datetime import datetime
from itertools import zip_longest
from timeit import default_timer as timer
from typing import Any, Dict
from typing import Callable
from typing import Iterable
from typing import Optional
from .enums import RouteIPType
from .enums import RouteStrategy
__all__ = (
"ExponentialBackoff",
"NodeStats",
"FailingIPBlock",
"RouteStats",
"Ping",
)
class ExponentialBackoff:
@ -51,7 +64,7 @@ class ExponentialBackoff:
self._exp = 0
self._exp = min(self._exp + 1, self._max)
return self._randfunc(0, self._base * 2**self._exp)
return self._randfunc(0, self._base * 2**self._exp) # type: ignore
class NodeStats:
@ -59,7 +72,6 @@ class NodeStats:
Gives critical information on the node, which is updated every minute.
"""
def __init__(self, data: dict) -> None:
__slots__ = (
"used",
"free",
@ -73,13 +85,15 @@ class NodeStats:
"uptime",
)
memory: dict = data.get("memory")
def __init__(self, data: Dict[str, Any]) -> None:
memory: dict = data.get("memory", {})
self.used = memory.get("used")
self.free = memory.get("free")
self.reservable = memory.get("reservable")
self.allocated = memory.get("allocated")
cpu: dict = data.get("cpu")
cpu: dict = data.get("cpu", {})
self.cpu_cores = cpu.get("cores")
self.cpu_system_load = cpu.get("systemLoad")
self.cpu_process_load = cpu.get("lavalinkLoad")
@ -99,11 +113,14 @@ class FailingIPBlock:
and the time they failed.
"""
def __init__(self, data: dict) -> None:
__slots__ = ("address", "failing_time")
def __init__(self, data: dict) -> None:
self.address = data.get("address")
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp")))
self.failing_time = datetime.fromtimestamp(
float(data.get("failingTimestamp", 0)),
)
def __repr__(self) -> str:
return f"<Pomice.FailingIPBlock address={self.address} failing_time={self.failing_time}>"
@ -115,17 +132,29 @@ class RouteStats:
Gives critical information about the route planner strategy on the node.
"""
def __init__(self, data: dict) -> None:
__slots__ = ("strategy", "ip_block_type", "ip_block_size", "failing_addresses")
__slots__ = (
"strategy",
"ip_block_type",
"ip_block_size",
"failing_addresses",
"block_index",
"address_index",
)
def __init__(self, data: Dict[str, Any]) -> None:
self.strategy = RouteStrategy(data.get("class"))
details: dict = data.get("details")
details: dict = data.get("details", {})
ip_block: dict = details.get("ipBlock")
ip_block: dict = details.get("ipBlock", {})
self.ip_block_type = RouteIPType(ip_block.get("type"))
self.ip_block_size = ip_block.get("size")
self.failing_addresses = [FailingIPBlock(data) for data in details.get("failingAddresses")]
self.failing_addresses = [
FailingIPBlock(
data,
) for data in details.get("failingAddresses", [])
]
self.block_index = details.get("blockIndex")
self.address_index = details.get("currentAddressIndex")
@ -136,7 +165,7 @@ class RouteStats:
class Ping:
# Thanks to https://github.com/zhengxiaowai/tcping for the nice ping impl
def __init__(self, host, port, timeout=5):
def __init__(self, host: str, port: int, timeout: int = 5) -> None:
self.timer = self.Timer()
self._successed = 0
@ -146,33 +175,33 @@ class Ping:
self._port = port
self._timeout = timeout
class Socket(object):
def __init__(self, family, type_, timeout):
class Socket:
def __init__(self, family: int, type_: int, timeout: Optional[float]) -> None:
s = socket.socket(family, type_)
s.settimeout(timeout)
self._s = s
def connect(self, host, port):
self._s.connect((host, int(port)))
def connect(self, host: str, port: int) -> None:
self._s.connect((host, port))
def shutdown(self):
def shutdown(self) -> None:
self._s.shutdown(socket.SHUT_RD)
def close(self):
def close(self) -> None:
self._s.close()
class Timer(object):
def __init__(self):
self._start = 0
self._stop = 0
class Timer:
def __init__(self) -> None:
self._start: float = 0.0
self._stop: float = 0.0
def start(self):
def start(self) -> None:
self._start = timer()
def stop(self):
def stop(self) -> None:
self._stop = timer()
def cost(self, funcs, args):
def cost(self, funcs: Iterable[Callable], args: Any) -> float:
self.start()
for func, arg in zip_longest(funcs, args):
if arg:
@ -183,13 +212,15 @@ class Ping:
self.stop()
return self._stop - self._start
def _create_socket(self, family, type_):
def _create_socket(self, family: int, type_: int) -> Socket:
return self.Socket(family, type_, self._timeout)
def get_ping(self):
def get_ping(self) -> float:
s = self._create_socket(socket.AF_INET, socket.SOCK_STREAM)
cost_time = self.timer.cost((s.connect, s.shutdown), ((self._host, self._port), None))
cost_time = self.timer.cost(
(s.connect, s.shutdown), ((self._host, self._port), None),
)
s_runtime = 1000 * (cost_time)
return s_runtime

View File

@ -7,3 +7,13 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 100
[tool.mypy]
mypy_path = "./"
files = ["pomice"]
disallow_untyped_defs = true
disallow_any_unimported = true
no_implicit_optional = true
check_untyped_defs = true
warn_unused_ignores = true
show_error_codes = true

View File

@ -1,10 +1,13 @@
import setuptools
import re
import setuptools
version = ""
requirements = ["discord.py>=2.0.0", "aiohttp>=3.7.4,<4", "orjson"]
with open("pomice/__init__.py") as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE,
).group(1)
if not version:
raise RuntimeError("version is not set")
@ -15,13 +18,13 @@ if version.endswith(("a", "b", "rc")):
import subprocess
p = subprocess.Popen(
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
["git", "rev-list", "--count", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
out, err = p.communicate()
if out:
version += out.decode("utf-8").strip()
p = subprocess.Popen(
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
)
out, err = p.communicate()
if out: