start typehinting and correcting the library

This commit is contained in:
Clxud 2023-03-09 15:40:50 +00:00
parent 31d4e1aca2
commit c9a331b278
14 changed files with 151 additions and 150 deletions

2
.gitignore vendored
View File

@ -5,3 +5,5 @@ dist/
pomice.egg-info/
docs/_build/
build/
.gitpod.yml
.python-verson

View File

@ -1,10 +0,0 @@
# This configuration file was automatically generated by Gitpod.
# Please adjust to your needs (see https://www.gitpod.io/docs/introduction/learn-gitpod/gitpod-yaml)
# and commit this file to your remote git repository to share the goodness with others.
# Learn more from ready-to-use templates: https://www.gitpod.io/docs/introduction/getting-started/quickstart
tasks:
- init: pip install .

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.10.9

View File

@ -1,15 +1,3 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import importlib
import inspect
import os
@ -20,21 +8,14 @@ sys.path.insert(0, os.path.abspath('..'))
# -- Project information -----------------------------------------------------
project = 'Pomice'
copyright = '2023, cloudwithax'
author = 'cloudwithax'
# The full version, including alpha/beta/rc tags
release = '2.1.1'
release = '2.2'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
@ -60,25 +41,23 @@ myst_enable_extensions = [
myst_heading_anchors = 3
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# We need to include this because discord.py has special tags
# they inlcude within their docstrings that dont parse
# right within our docs
# -- Options for HTML output -------------------------------------------------
rst_prolog = """
.. |coro| replace:: This function is a |coroutine_link|_.
.. |maybecoro| replace:: This function *could be a* |coroutine_link|_.
.. |coroutine_link| replace:: *coroutine*
.. _coroutine_link: https://docs.python.org/3/library/asyncio-task.html#coroutine
"""
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'furo'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_title = "Pomice"
@ -103,6 +82,9 @@ html_theme_options: Dict[str, Any] = {
"source_directory": "docs/",
}
# Grab lines from source files and embed into the docs
# so theres a point of reference
def linkcode_resolve(domain, info):
if domain != 'py':
return None

View File

@ -2,6 +2,12 @@ import re
from enum import Enum
__all__ = (
'SearchType',
'TrackType',
'PlaylistType'
)
class SearchType(Enum):
"""

View File

@ -3,6 +3,7 @@ from discord import Client
from discord.ext import commands
from .pool import NodePool
from .objects import Track
from typing import TYPE_CHECKING, Union
@ -34,8 +35,8 @@ class TrackStartEvent(PomiceEvent):
name = "track_start"
def __init__(self, data: dict, player: Player):
self.player = player
self.track = self.player._current
self.player: Player = player
self.track: Track = self.player._current
# on_pomice_track_start(player, track)
self.handler_args = self.player, self.track
@ -51,8 +52,8 @@ class TrackEndEvent(PomiceEvent):
name = "track_end"
def __init__(self, data: dict, player: Player):
self.player = player
self.track = self.player._ending_track
self.player: Player = player
self.track: Track = self.player._ending_track
self.reason: str = data["reason"]
# on_pomice_track_end(player, track, reason)
@ -73,8 +74,8 @@ class TrackStuckEvent(PomiceEvent):
name = "track_stuck"
def __init__(self, data: dict, player: Player):
self.player = player
self.track = self.player._ending_track
self.player: Player = player
self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
# on_pomice_track_stuck(player, track, threshold)
@ -92,8 +93,8 @@ class TrackExceptionEvent(PomiceEvent):
name = "track_exception"
def __init__(self, data: dict, player: Player):
self.player = player
self.track = self.player._ending_track
self.player: Player = player
self.track: Track = self.player._ending_track
if data.get('error'):
# User is running Lavalink <= 3.3
self.exception: str = data["error"]
@ -110,7 +111,7 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload:
def __init__(self, data: dict):
self.guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.code: int = data["code"]
self.reason: str = data["code"]
self.by_remote: bool = data["byRemote"]
@ -127,7 +128,7 @@ class WebSocketClosedEvent(PomiceEvent):
name = "websocket_closed"
def __init__(self, data: dict, _):
self.payload = WebSocketClosedPayload(data)
self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data)
# on_pomice_websocket_closed(payload)
self.handler_args = self.payload,

View File

@ -13,7 +13,7 @@ class Filter:
This is necessary for the removal of filters.
"""
def __init__(self):
self.payload = None
self.payload: dict = None
self.tag: str = None
self.preload: bool = False
@ -132,12 +132,12 @@ class Timescale(Filter):
if rate < 0:
raise FilterInvalidArgument("Timescale rate must be more than 0.")
self.speed = speed
self.pitch = pitch
self.rate = rate
self.tag = tag
self.speed: float = speed
self.pitch: float = pitch
self.rate: float = rate
self.tag: str = tag
self.payload = {"timescale": {"speed": self.speed,
self.payload: dict = {"timescale": {"speed": self.speed,
"pitch": self.pitch,
"rate": self.rate}}
@ -181,13 +181,13 @@ class Karaoke(Filter):
):
super().__init__()
self.level = level
self.mono_level = mono_level
self.filter_band = filter_band
self.filter_width = filter_width
self.tag = tag
self.level: float = level
self.mono_level: float = mono_level
self.filter_band: float = filter_band
self.filter_width: float = filter_width
self.tag: str = tag
self.payload = {"karaoke": {"level": self.level,
self.payload: dict = {"karaoke": {"level": self.level,
"monoLevel": self.mono_level,
"filterBand": self.filter_band,
"filterWidth": self.filter_width}}
@ -220,11 +220,11 @@ class Tremolo(Filter):
raise FilterInvalidArgument(
"Tremolo depth must be between 0 and 1.")
self.frequency = frequency
self.depth = depth
self.tag = tag
self.frequency: float = frequency
self.depth: float = depth
self.tag: str = tag
self.payload = {"tremolo": {"frequency": self.frequency,
self.payload: dict = {"tremolo": {"frequency": self.frequency,
"depth": self.depth}}
def __repr__(self):
@ -252,11 +252,11 @@ class Vibrato(Filter):
raise FilterInvalidArgument(
"Vibrato depth must be between 0 and 1.")
self.frequency = frequency
self.depth = depth
self.tag = tag
self.frequency: float = frequency
self.depth: float = depth
self.tag: str = tag
self.payload = {"vibrato": {"frequency": self.frequency,
self.payload: dict = {"vibrato": {"frequency": self.frequency,
"depth": self.depth}}
def __repr__(self):
@ -271,9 +271,9 @@ class Rotation(Filter):
def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__()
self.rotation_hertz = rotation_hertz
self.tag = tag
self.payload = {"rotation": {"rotationHz": self.rotation_hertz}}
self.rotation_hertz: float = rotation_hertz
self.tag: str = tag
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
def __repr__(self) -> str:
return f"<Pomice.RotationFilter tag={self.tag} rotation_hertz={self.rotation_hertz}>"
@ -308,13 +308,13 @@ class ChannelMix(Filter):
raise ValueError(
"'right_to_left' value must be more than or equal to 0 or less than or equal to 1.")
self.left_to_left = left_to_left
self.left_to_right = left_to_right
self.right_to_left = right_to_left
self.right_to_right = right_to_right
self.tag = tag
self.left_to_left: float = left_to_left
self.left_to_right: float = left_to_right
self.right_to_left: float = right_to_left
self.right_to_right: float = right_to_right
self.tag: str = tag
self.payload = {"channelMix": {"leftToLeft": self.left_to_left,
self.payload: dict = {"channelMix": {"leftToLeft": self.left_to_left,
"leftToRight": self.left_to_right,
"rightToLeft": self.right_to_left,
"rightToRight": self.right_to_right}
@ -347,17 +347,17 @@ class Distortion(Filter):
):
super().__init__()
self.sin_offset = sin_offset
self.sin_scale = sin_scale
self.cos_offset = cos_offset
self.cos_scale = cos_scale
self.tan_offset = tan_offset
self.tan_scale = tan_scale
self.offset = offset
self.scale = scale
self.tag = tag
self.sin_offset: float = sin_offset
self.sin_scale: float = sin_scale
self.cos_offset: float = cos_offset
self.cos_scale: float = cos_scale
self.tan_offset: float = tan_offset
self.tan_scale: float = tan_scale
self.offset: float = offset
self.scale: float = scale
self.tag: str = tag
self.payload = {"distortion": {
self.payload: dict = {"distortion": {
"sinOffset": self.sin_offset,
"sinScale": self.sin_scale,
"cosOffset": self.cos_offset,
@ -383,9 +383,9 @@ class LowPass(Filter):
def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__()
self.smoothing = smoothing
self.tag = tag
self.payload = {"lowPass": {"smoothing": self.smoothing}}
self.smoothing: float = smoothing
self.tag: str = tag
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}
def __repr__(self) -> str:
return f"<Pomice.LowPass tag={self.tag} smoothing={self.smoothing}>"

View File

@ -26,8 +26,8 @@ class Track:
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User]] = None,
):
self.track_id = track_id
self.info = info
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
self.filters: Optional[List[Filter]] = filters
self.timestamp: Optional[float] = timestamp
@ -36,35 +36,35 @@ class Track:
self.original: Optional[Track] = None
else:
self.original = self
self._search_type = search_type
self._search_type: SearchType = search_type
self.playlist: Playlist = None
self.title = info.get("title")
self.author = info.get("author")
self.uri = info.get("uri")
self.identifier = info.get("identifier")
self.isrc = info.get("isrc")
self.title: str = info.get("title")
self.author: str = info.get("author")
self.uri: str = info.get("uri")
self.identifier: str = info.get("identifier")
self.isrc: str = info.get("isrc")
if self.uri:
if info.get("thumbnail"):
self.thumbnail = info.get("thumbnail")
self.thumbnail: str = info.get("thumbnail")
elif self.track_type == TrackType.SOUNDCLOUD:
# ok so theres no feasible way of getting a Soundcloud image URL
# so we're just gonna leave it blank for brevity
self.thumbnail = None
else:
self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.thumbnail: str = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
self.length = info.get("length")
self.ctx = ctx
self.length: int = info.get("length")
self.ctx: commands.Context = ctx
if requester:
self.requester = requester
self.requester: Optional[Union[Member, User]] = requester
else:
self.requester = self.ctx.author if ctx else None
self.is_stream = info.get("isStream")
self.is_seekable = info.get("isSeekable")
self.position = info.get("position")
self.requester: Optional[Union[Member, User]] = self.ctx.author if ctx else None
self.is_stream: bool = info.get("isStream")
self.is_seekable: bool = info.get("isSeekable")
self.position: int = info.get("position")
def __eq__(self, other):
if not isinstance(other, Track):
@ -97,13 +97,13 @@ class Playlist:
thumbnail: Optional[str] = None,
uri: Optional[str] = None
):
self.playlist_info = playlist_info
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name = playlist_info.get("name")
self.playlist_type = playlist_type
self.name: str = playlist_info.get("name")
self.playlist_type: PlaylistType = playlist_type
self._thumbnail = thumbnail
self._uri = uri
self._thumbnail: str = thumbnail
self._uri: str = uri
for track in self.tracks:
track.playlist = self
@ -111,9 +111,9 @@ class Playlist:
if (index := playlist_info.get("selectedTrack")) == -1:
self.selected_track = None
else:
self.selected_track = self.tracks[index]
self.selected_track: Track = self.tracks[index]
self.track_count = len(self.tracks)
self.track_count: int = len(self.tracks)
def __str__(self):
return self.name

View File

@ -25,6 +25,8 @@ from .pool import Node, NodePool
class Filters:
"""Helper class for filters"""
__slots__ = ('_filters')
def __init__(self):
self._filters: List[Filter] = []
@ -97,6 +99,24 @@ class Player(VoiceProtocol):
```
"""
__slots__ = (
'client',
'_bot',
'channel',
'_guild',
'_node',
'_current',
'_filters',
'_volume',
'_paused',
'_is_connected',
'_position',
'_last_position',
'_last_update',
'_ending_track',
'_player_endpoint_uri'
)
def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client
self.channel: VoiceChannel = channel
@ -117,7 +137,7 @@ class Player(VoiceProtocol):
self._guild: Guild = channel.guild if channel else None
self._node: Node = node if node else NodePool.get_node()
self._current: Track = None
self._current: Optional[Track] = None
self._filters: Filters = Filters()
self._volume: int = 100
self._paused: bool = False
@ -130,9 +150,9 @@ class Player(VoiceProtocol):
self._voice_state: dict = {}
self._player_endpoint_uri = f'sessions/{self._node._session_id}/players'
self._player_endpoint_uri: str = f'sessions/{self._node._session_id}/players'
def __repr__(self):
def __repr__(self) -> str:
return (
f"<Pomice.player bot={self.bot} guildId={self.guild.id} "
f"is_connected={self.is_connected} is_playing={self.is_playing}>"

View File

@ -65,27 +65,26 @@ class Node:
fallback: bool = False
):
self._bot = bot
self._host = host
self._port = port
self._pool = pool
self._password = password
self._identifier = identifier
self._heartbeat = heartbeat
self._secure = secure
self.fallback = fallback
self._bot: Union[Client, commands.Bot] = bot
self._host: str = host
self._port: int = port
self._pool: NodePool = pool
self._password: str = password
self._identifier: str = identifier
self._heartbeat: int = heartbeat
self._secure: bool = secure
self.fallback: bool = fallback
self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}/v3/websocket"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
self._session = session or aiohttp.ClientSession()
self._session: ClientSession = session or aiohttp.ClientSession()
self._websocket: aiohttp.ClientWebSocketResponse = None
self._task: asyncio.Task = None
self._session_id: str = None
self._metadata = None
self._available = None
self._available: bool = False
self._version: str = None
self._route_planner = RoutePlanner(self)
@ -97,13 +96,13 @@ class Node:
self._players: Dict[int, Player] = {}
self._spotify_client_id = spotify_client_id
self._spotify_client_secret = spotify_client_secret
self._spotify_client_id: str = spotify_client_id
self._spotify_client_secret: str = spotify_client_secret
self._apple_music_client = None
self._apple_music_client: Optional[applemusic.Client] = None
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client(
self._spotify_client: spotify.Client = spotify.Client(
self, self._spotify_client_id, self._spotify_client_secret
)

0
pomice/py.typed Normal file
View File

View File

@ -175,13 +175,12 @@ class Queue(Iterable[Track]):
return len(self._queue)
def get_queue(self) -> List:
"""Returns the queue as a List"""
return self._queue
def get(self):
def get(self) -> Track:
"""Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue.
"""
@ -297,7 +296,7 @@ class Queue(Iterable[Track]):
"""Remove all items from the queue."""
self._queue.clear()
def set_loop_mode(self, mode: LoopMode):
def set_loop_mode(self, mode: LoopMode) -> None:
"""
Sets the loop mode of the queue.
Takes the LoopMode enum as an argument.
@ -313,7 +312,7 @@ class Queue(Iterable[Track]):
self._current_item = self._queue[index]
def disable_loop(self):
def disable_loop(self) -> None:
"""
Disables loop mode if set.
Raises QueueException if loop mode is already None.
@ -328,17 +327,17 @@ class Queue(Iterable[Track]):
self._loop_mode = None
def shuffle(self):
def shuffle(self) -> None:
"""Shuffles the queue."""
return random.shuffle(self._queue)
def clear_track_filters(self):
def clear_track_filters(self) -> None:
"""Clears all filters applied to tracks"""
for track in self._queue:
track.filters = None
def jump(self, item: Track):
"""Returns a new queue with the specified track at the beginning."""
def jump(self, item: Track) -> None:
"""Mutates the queue so that all tracks before the specified track are removed."""
index = self.find_position(item)
new_queue = self._queue[index:self.size]
self._queue = new_queue

View File

@ -4,6 +4,7 @@ if TYPE_CHECKING:
from .pool import Node
from .utils import RouteStats
from aiohttp import ClientSession
class RoutePlanner:
"""
@ -12,15 +13,14 @@ class RoutePlanner:
"""
def __init__(self, node: Node) -> None:
self.node = node
self.session = node._session
self.node: Node = node
self.session: ClientSession = node._session
async def get_status(self):
async def get_status(self) -> RouteStats:
"""Gets the status of the route planner API."""
data: dict = await self.node.send(method="GET", path="routeplanner/status")
return RouteStats(data)
async def free_address(self, ip: str):
"""Frees an address using the route planner API"""
await self.node.send(method="POST", path="routeplanner/free/address", data={"address": ip})

View File

@ -41,6 +41,7 @@ setuptools.setup(
description="The modern Lavalink wrapper designed for Discord.py",
long_description=readme,
long_description_content_type="text/markdown",
package_data={"pomice": ["py.typed"]},
include_package_data=True,
install_requires=requirements,
extra_require=None,