add slots to all classes that require them

This commit is contained in:
cloudwithax 2023-03-10 21:50:23 -05:00
parent 9dc8a9098e
commit 9e0a5e0ad0
10 changed files with 255 additions and 136 deletions

View File

@ -8,7 +8,6 @@ sys.path.insert(0, os.path.abspath('..'))
project = 'Pomice'
copyright = '2023, cloudwithax'
author = 'cloudwithax'
@ -86,30 +85,36 @@ html_theme_options: Dict[str, Any] = {
# so theres a point of reference
def linkcode_resolve(domain, info):
if domain != 'py':
return None
if not info['module']:
return None
mod = importlib.import_module(info["module"])
if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".")
obj = getattr(mod, objname)
try:
obj = getattr(obj, attrname)
except AttributeError:
return None
else:
obj = getattr(mod, info["fullname"])
# i absolutely MUST add this here or else
# the docs will not build. fuck sphinx
try:
file = inspect.getsourcefile(obj)
lines = inspect.getsourcelines(obj)
except TypeError:
# e.g. object is a typing.Union
return None
file = os.path.relpath(file, os.path.abspath(".."))
start, end = lines[1], lines[1] + len(lines[0]) - 1
if domain != 'py':
return None
if not info['module']:
return None
mod = importlib.import_module(info["module"])
if "." in info["fullname"]:
objname, attrname = info["fullname"].split(".")
obj = getattr(mod, objname)
try:
obj = getattr(obj, attrname)
except AttributeError:
return None
else:
obj = getattr(mod, info["fullname"])
try:
file = inspect.getsourcefile(obj)
lines = inspect.getsourcelines(obj)
except TypeError:
# e.g. object is a typing.Union
return None
file = os.path.relpath(file, os.path.abspath(".."))
start, end = lines[1], lines[1] + len(lines[0]) - 1
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"
except:
pass
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"

View File

@ -2,17 +2,6 @@ import re
from enum import Enum
__all__ = (
'SearchType',
'TrackType',
'PlaylistType',
'NodeAlgorithm',
'LoopMode',
'RouteStrategy',
'RouteIPType',
'URLRegex'
)
class SearchType(Enum):
"""

View File

@ -10,17 +10,6 @@ from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from .player import Player
__all__ = (
'PomiceEvent',
'TrackStartEvent',
'TrackEndEvent',
'TrackStuckEvent',
'TrackExceptionEvent',
'WebSocketClosedEvent',
'WebSocketOpenEvent'
)
class PomiceEvent:
"""The base class for all events dispatched by a node.
@ -45,6 +34,12 @@ class TrackStartEvent(PomiceEvent):
name = "track_start"
def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track"
)
self.player: Player = player
self.track: Track = self.player._current
@ -62,6 +57,13 @@ class TrackEndEvent(PomiceEvent):
name = "track_end"
def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"reason"
)
self.player: Player = player
self.track: Track = self.player._ending_track
self.reason: str = data["reason"]
@ -84,6 +86,13 @@ class TrackStuckEvent(PomiceEvent):
name = "track_stuck"
def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"threshold"
)
self.player: Player = player
self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"]
@ -103,6 +112,13 @@ class TrackExceptionEvent(PomiceEvent):
name = "track_exception"
def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"exception"
)
self.player: Player = player
self.track: Track = self.player._ending_track
if data.get('error'):
@ -121,6 +137,14 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload:
def __init__(self, data: dict):
__slots__ = (
"guild",
"code",
"reason",
"by_remote"
)
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.code: int = data["code"]
self.reason: str = data["code"]
@ -154,6 +178,12 @@ class WebSocketOpenEvent(PomiceEvent):
name = "websocket_open"
def __init__(self, data: dict, _):
__slots__ = (
"target",
"ssrc"
)
self.target: str = data["target"]
self.ssrc: int = data["ssrc"]

View File

@ -1,25 +1,3 @@
__all__ = (
'PomiceException',
'NodeException',
'NodeCreationError',
'NodeConnectionFailure',
'NodeConnectionClosed',
'NodeRestException',
'NodeNotAvailable',
'NoNodesAvailable',
'TrackInvalidPosition',
'TrackLoadError',
'FilterInvalidArgument',
'FilterTagInvalid',
'FilterTagAlreadyInUse',
'InvalidSpotifyClientAuthorization',
'AppleMusicNotEnabled',
'QueueException',
'QueueFull',
'QueueEmpty',
'LavalinkVersionIncompatible'
)
class PomiceException(Exception):
"""Base of all Pomice exceptions."""

View File

@ -1,19 +1,6 @@
import collections
from .exceptions import FilterInvalidArgument
__all__ = (
'Filter',
'Equalizer',
'Timescale',
'Karaoke',
'Tremolo',
'Vibrato',
'Rotation',
'ChannelMix',
'Distortion',
'LowPass'
)
class Filter:
"""
The base class for all filters.
@ -24,9 +11,15 @@ class Filter:
You must specify a tag for each filter you put on.
This is necessary for the removal of filters.
"""
def __init__(self):
def __init__(self, *, tag: str):
__slots__ = (
"payload",
"tag",
"preload"
)
self.payload: dict = None
self.tag: str = None
self.tag: str = tag
self.preload: bool = False
def set_preload(self) -> bool:
@ -44,13 +37,17 @@ class Equalizer(Filter):
"""
def __init__(self, *, tag: str, levels: list):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"eq",
"raw",
)
self.eq = self._factory(levels)
self.raw = levels
self.payload = {"equalizer": self.eq}
self.tag = tag
def _factory(self, levels: list):
_dict = collections.defaultdict(int)
@ -135,7 +132,13 @@ class Timescale(Filter):
pitch: float = 1.0,
rate: float = 1.0
):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"speed",
"pitch",
"rate"
)
if speed < 0:
raise FilterInvalidArgument("Timescale speed must be more than 0.")
@ -147,7 +150,6 @@ class Timescale(Filter):
self.speed: float = speed
self.pitch: float = pitch
self.rate: float = rate
self.tag: str = tag
self.payload: dict = {"timescale": {"speed": self.speed,
"pitch": self.pitch,
@ -191,13 +193,19 @@ class Karaoke(Filter):
filter_band: float = 220.0,
filter_width: float = 100.0
):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"level",
"mono_level",
"filter_band",
"filter_width"
)
self.level: float = level
self.mono_level: float = mono_level
self.filter_band: float = filter_band
self.filter_width: float = filter_width
self.tag: str = tag
self.payload: dict = {"karaoke": {"level": self.level,
"monoLevel": self.mono_level,
@ -223,7 +231,12 @@ class Tremolo(Filter):
frequency: float = 2.0,
depth: float = 0.5
):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"frequency",
"depth"
)
if frequency < 0:
raise FilterInvalidArgument(
@ -234,7 +247,6 @@ class Tremolo(Filter):
self.frequency: float = frequency
self.depth: float = depth
self.tag: str = tag
self.payload: dict = {"tremolo": {"frequency": self.frequency,
"depth": self.depth}}
@ -255,8 +267,13 @@ class Vibrato(Filter):
frequency: float = 2.0,
depth: float = 0.5
):
super().__init__(tag=tag)
__slots__ = (
"frequency",
"depth"
)
super().__init__()
if frequency < 0 or frequency > 14:
raise FilterInvalidArgument(
"Vibrato frequency must be between 0 and 14.")
@ -266,7 +283,6 @@ class Vibrato(Filter):
self.frequency: float = frequency
self.depth: float = depth
self.tag: str = tag
self.payload: dict = {"vibrato": {"frequency": self.frequency,
"depth": self.depth}}
@ -281,10 +297,11 @@ class Rotation(Filter):
"""
def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__()
super().__init__(tag=tag)
__slots__ = ("rotation_hertz")
self.rotation_hertz: float = rotation_hertz
self.tag: str = tag
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
def __repr__(self) -> str:
@ -305,7 +322,14 @@ class ChannelMix(Filter):
left_to_right: float = 0,
right_to_left: float = 0
):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"left_to_left",
"right_to_right",
"left_to_right",
"right_to_left"
)
if 0 > left_to_left > 1:
raise ValueError(
@ -324,7 +348,6 @@ class ChannelMix(Filter):
self.left_to_right: float = left_to_right
self.right_to_left: float = right_to_left
self.right_to_right: float = right_to_right
self.tag: str = tag
self.payload: dict = {"channelMix": {"leftToLeft": self.left_to_left,
"leftToRight": self.left_to_right,
@ -357,7 +380,18 @@ class Distortion(Filter):
offset: float = 0,
scale: float = 1
):
super().__init__()
super().__init__(tag=tag)
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale"
"offset",
"scale"
)
self.sin_offset: float = sin_offset
self.sin_scale: float = sin_scale
@ -367,7 +401,6 @@ class Distortion(Filter):
self.tan_scale: float = tan_scale
self.offset: float = offset
self.scale: float = scale
self.tag: str = tag
self.payload: dict = {"distortion": {
"sinOffset": self.sin_offset,
@ -393,10 +426,11 @@ class LowPass(Filter):
You can also do this with the Equalizer filter, but this is an easier way to do it.
"""
def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__()
super().__init__(tag=tag)
__slots__ = ('smoothing')
self.smoothing: float = smoothing
self.tag: str = tag
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}
def __repr__(self) -> str:

View File

@ -7,8 +7,6 @@ from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType
from .filters import Filter
__all__ = ('Track', 'Playlist')
class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink.
@ -27,6 +25,29 @@ class Track:
timestamp: Optional[float] = None,
requester: Optional[Union[Member, User]] = None,
):
__slots__ = (
"track_id",
"info",
"track_type",
"filters",
"timestamp",
"original",
"_search_type",
"playlist",
"title",
"author",
"uri",
"identifier",
"isrc",
"thumbnail",
"length",
"ctx",
"requester",
"is_stream",
"is_seekable",
"position"
)
self.track_id: str = track_id
self.info: dict = info
self.track_type: TrackType = track_type
@ -98,6 +119,18 @@ class Playlist:
thumbnail: Optional[str] = None,
uri: Optional[str] = None
):
__slots__ = (
"playlist_info",
"tracks",
"name",
"playlist_type",
"_thumbnail",
"_uri",
"selected_track",
"track_count"
)
self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name")

View File

@ -23,7 +23,6 @@ from .filters import Filter
from .objects import Track
from .pool import Node, NodePool
__all__ = ('Filters', 'Player')
class Filters:
"""Helper class for filters"""
@ -134,10 +133,10 @@ class Player(VoiceProtocol):
node: Node = None
):
self.client: Optional[Client] = client
self._bot: Union[Client, commands.Bot] = client
self.channel: Optional[VoiceChannel] = channel
self._guild: Guild = channel.guild if channel else None
self._bot: Union[Client, commands.Bot] = client
self._guild: Guild = channel.guild if channel else None
self._node: Node = node if node else NodePool.get_node()
self._current: Optional[Track] = None
self._filters: Filters = Filters()

View File

@ -37,8 +37,6 @@ from .routeplanner import RoutePlanner
if TYPE_CHECKING:
from .player import Player
__all__ = ('Node', 'NodePool')
class Node:
"""The base class for a node.
This node object represents a Lavalink node.
@ -65,6 +63,33 @@ class Node:
fallback: bool = False
):
__slots__ = (
"_bot",
"_host",
"_port",
"_pool",
"_password",
"_identifier",
"_heartbeat",
"_secure",
"_fallback",
"_websocket_uri",
"_rest_uri",
"_session",
"_websocket",
"_task",
"_loop",
"_session_id",
"_available",
"_version",
"_headers",
"_players",
"_spotify_client_id",
"_spotify_client_secret",
"_spotify_client",
"_apple_music_client"
)
self._bot: Union[Client, commands.Bot] = bot
self._host: str = host
self._port: int = port
@ -73,9 +98,7 @@ class Node:
self._identifier: str = identifier
self._heartbeat: int = heartbeat
self._secure: bool = secure
self.fallback: bool = fallback
self._fallback: bool = fallback
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"

View File

@ -22,6 +22,15 @@ class Queue(Iterable[Track]):
*,
overflow: bool = True,
):
__slots__ = (
"max_size",
"_queue",
"_overflow",
"_loop_mode",
"_current_item"
)
self.max_size: Optional[int] = max_size
self._queue: List[Track] = [] # type: ignore
self._overflow: bool = overflow
@ -175,12 +184,13 @@ class Queue(Iterable[Track]):
return len(self._queue)
def get_queue(self) -> List:
"""Returns the queue as a List"""
return self._queue
def get(self) -> Track:
def get(self):
"""Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue.
"""
@ -296,7 +306,7 @@ class Queue(Iterable[Track]):
"""Remove all items from the queue."""
self._queue.clear()
def set_loop_mode(self, mode: LoopMode) -> None:
def set_loop_mode(self, mode: LoopMode):
"""
Sets the loop mode of the queue.
Takes the LoopMode enum as an argument.
@ -312,7 +322,7 @@ class Queue(Iterable[Track]):
self._current_item = self._queue[index]
def disable_loop(self) -> None:
def disable_loop(self):
"""
Disables loop mode if set.
Raises QueueException if loop mode is already None.
@ -327,17 +337,17 @@ class Queue(Iterable[Track]):
self._loop_mode = None
def shuffle(self) -> None:
def shuffle(self):
"""Shuffles the queue."""
return random.shuffle(self._queue)
def clear_track_filters(self) -> None:
def clear_track_filters(self):
"""Clears all filters applied to tracks"""
for track in self._queue:
track.filters = None
def jump(self, item: Track) -> None:
"""Mutates the queue so that all tracks before the specified track are removed."""
def jump(self, item: Track):
"""Removes all tracks before the."""
index = self.find_position(item)
new_queue = self._queue[index:self.size]
self._queue = new_queue

View File

@ -8,15 +8,6 @@ from itertools import zip_longest
from datetime import datetime
__all__ = (
'ExponentialBackoff',
'NodeStats',
'FailingIPBlock',
'RouteStats',
'Ping'
)
class ExponentialBackoff:
"""
The MIT License (MIT)
@ -72,6 +63,19 @@ class NodeStats:
def __init__(self, data: dict) -> None:
__slots__ = (
"used",
"free",
"reservable",
"allocated",
"cpu_cores",
"cpu_system_load",
"cpu_process_load",
"players_active",
"players_total",
"uptime"
)
memory: dict = data.get("memory")
self.used = memory.get("used")
self.free = memory.get("free")
@ -97,6 +101,12 @@ class FailingIPBlock:
and the time they failed.
"""
def __init__(self, data: dict) -> None:
__slots__ = (
"address",
"failing_time"
)
self.address = data.get("address")
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp")))
@ -111,6 +121,14 @@ class RouteStats:
"""
def __init__(self, data: dict) -> None:
__slots__ = (
"strategy",
"ip_block_type",
"ip_block_size",
"failing_addresses"
)
self.strategy = RouteStrategy(data.get("class"))
details: dict = data.get("details")