add slots to all classes that require them

This commit is contained in:
cloudwithax 2023-03-10 21:50:23 -05:00
parent 9dc8a9098e
commit 9e0a5e0ad0
10 changed files with 255 additions and 136 deletions

View File

@ -8,7 +8,6 @@ sys.path.insert(0, os.path.abspath('..'))
project = 'Pomice' project = 'Pomice'
copyright = '2023, cloudwithax' copyright = '2023, cloudwithax'
author = 'cloudwithax' author = 'cloudwithax'
@ -86,6 +85,9 @@ html_theme_options: Dict[str, Any] = {
# so theres a point of reference # so theres a point of reference
def linkcode_resolve(domain, info): def linkcode_resolve(domain, info):
# i absolutely MUST add this here or else
# the docs will not build. fuck sphinx
try:
if domain != 'py': if domain != 'py':
return None return None
if not info['module']: if not info['module']:
@ -113,3 +115,6 @@ def linkcode_resolve(domain, info):
start, end = lines[1], lines[1] + len(lines[0]) - 1 start, end = lines[1], lines[1] + len(lines[0]) - 1
return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}" return f"https://github.com/cloudwithax/pomice/blob/main/{file}#L{start}-L{end}"
except:
pass

View File

@ -2,17 +2,6 @@ import re
from enum import Enum from enum import Enum
__all__ = (
'SearchType',
'TrackType',
'PlaylistType',
'NodeAlgorithm',
'LoopMode',
'RouteStrategy',
'RouteIPType',
'URLRegex'
)
class SearchType(Enum): class SearchType(Enum):
""" """

View File

@ -10,17 +10,6 @@ from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
__all__ = (
'PomiceEvent',
'TrackStartEvent',
'TrackEndEvent',
'TrackStuckEvent',
'TrackExceptionEvent',
'WebSocketClosedEvent',
'WebSocketOpenEvent'
)
class PomiceEvent: class PomiceEvent:
"""The base class for all events dispatched by a node. """The base class for all events dispatched by a node.
@ -45,6 +34,12 @@ class TrackStartEvent(PomiceEvent):
name = "track_start" name = "track_start"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._current self.track: Track = self.player._current
@ -62,6 +57,13 @@ class TrackEndEvent(PomiceEvent):
name = "track_end" name = "track_end"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"reason"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
self.reason: str = data["reason"] self.reason: str = data["reason"]
@ -84,6 +86,13 @@ class TrackStuckEvent(PomiceEvent):
name = "track_stuck" name = "track_stuck"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"threshold"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
self.threshold: float = data["thresholdMs"] self.threshold: float = data["thresholdMs"]
@ -103,6 +112,13 @@ class TrackExceptionEvent(PomiceEvent):
name = "track_exception" name = "track_exception"
def __init__(self, data: dict, player: Player): def __init__(self, data: dict, player: Player):
__slots__ = (
"player",
"track",
"exception"
)
self.player: Player = player self.player: Player = player
self.track: Track = self.player._ending_track self.track: Track = self.player._ending_track
if data.get('error'): if data.get('error'):
@ -121,6 +137,14 @@ class TrackExceptionEvent(PomiceEvent):
class WebSocketClosedPayload: class WebSocketClosedPayload:
def __init__(self, data: dict): def __init__(self, data: dict):
__slots__ = (
"guild",
"code",
"reason",
"by_remote"
)
self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"])) self.guild: Guild = NodePool.get_node().bot.get_guild(int(data["guildId"]))
self.code: int = data["code"] self.code: int = data["code"]
self.reason: str = data["code"] self.reason: str = data["code"]
@ -154,6 +178,12 @@ class WebSocketOpenEvent(PomiceEvent):
name = "websocket_open" name = "websocket_open"
def __init__(self, data: dict, _): def __init__(self, data: dict, _):
__slots__ = (
"target",
"ssrc"
)
self.target: str = data["target"] self.target: str = data["target"]
self.ssrc: int = data["ssrc"] self.ssrc: int = data["ssrc"]

View File

@ -1,25 +1,3 @@
__all__ = (
'PomiceException',
'NodeException',
'NodeCreationError',
'NodeConnectionFailure',
'NodeConnectionClosed',
'NodeRestException',
'NodeNotAvailable',
'NoNodesAvailable',
'TrackInvalidPosition',
'TrackLoadError',
'FilterInvalidArgument',
'FilterTagInvalid',
'FilterTagAlreadyInUse',
'InvalidSpotifyClientAuthorization',
'AppleMusicNotEnabled',
'QueueException',
'QueueFull',
'QueueEmpty',
'LavalinkVersionIncompatible'
)
class PomiceException(Exception): class PomiceException(Exception):
"""Base of all Pomice exceptions.""" """Base of all Pomice exceptions."""

View File

@ -1,19 +1,6 @@
import collections import collections
from .exceptions import FilterInvalidArgument from .exceptions import FilterInvalidArgument
__all__ = (
'Filter',
'Equalizer',
'Timescale',
'Karaoke',
'Tremolo',
'Vibrato',
'Rotation',
'ChannelMix',
'Distortion',
'LowPass'
)
class Filter: class Filter:
""" """
The base class for all filters. The base class for all filters.
@ -24,9 +11,15 @@ class Filter:
You must specify a tag for each filter you put on. You must specify a tag for each filter you put on.
This is necessary for the removal of filters. This is necessary for the removal of filters.
""" """
def __init__(self): def __init__(self, *, tag: str):
__slots__ = (
"payload",
"tag",
"preload"
)
self.payload: dict = None self.payload: dict = None
self.tag: str = None self.tag: str = tag
self.preload: bool = False self.preload: bool = False
def set_preload(self) -> bool: def set_preload(self) -> bool:
@ -44,13 +37,17 @@ class Equalizer(Filter):
""" """
def __init__(self, *, tag: str, levels: list): def __init__(self, *, tag: str, levels: list):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"eq",
"raw",
)
self.eq = self._factory(levels) self.eq = self._factory(levels)
self.raw = levels self.raw = levels
self.payload = {"equalizer": self.eq} self.payload = {"equalizer": self.eq}
self.tag = tag
def _factory(self, levels: list): def _factory(self, levels: list):
_dict = collections.defaultdict(int) _dict = collections.defaultdict(int)
@ -135,7 +132,13 @@ class Timescale(Filter):
pitch: float = 1.0, pitch: float = 1.0,
rate: float = 1.0 rate: float = 1.0
): ):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"speed",
"pitch",
"rate"
)
if speed < 0: if speed < 0:
raise FilterInvalidArgument("Timescale speed must be more than 0.") raise FilterInvalidArgument("Timescale speed must be more than 0.")
@ -147,7 +150,6 @@ class Timescale(Filter):
self.speed: float = speed self.speed: float = speed
self.pitch: float = pitch self.pitch: float = pitch
self.rate: float = rate self.rate: float = rate
self.tag: str = tag
self.payload: dict = {"timescale": {"speed": self.speed, self.payload: dict = {"timescale": {"speed": self.speed,
"pitch": self.pitch, "pitch": self.pitch,
@ -191,13 +193,19 @@ class Karaoke(Filter):
filter_band: float = 220.0, filter_band: float = 220.0,
filter_width: float = 100.0 filter_width: float = 100.0
): ):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"level",
"mono_level",
"filter_band",
"filter_width"
)
self.level: float = level self.level: float = level
self.mono_level: float = mono_level self.mono_level: float = mono_level
self.filter_band: float = filter_band self.filter_band: float = filter_band
self.filter_width: float = filter_width self.filter_width: float = filter_width
self.tag: str = tag
self.payload: dict = {"karaoke": {"level": self.level, self.payload: dict = {"karaoke": {"level": self.level,
"monoLevel": self.mono_level, "monoLevel": self.mono_level,
@ -223,7 +231,12 @@ class Tremolo(Filter):
frequency: float = 2.0, frequency: float = 2.0,
depth: float = 0.5 depth: float = 0.5
): ):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"frequency",
"depth"
)
if frequency < 0: if frequency < 0:
raise FilterInvalidArgument( raise FilterInvalidArgument(
@ -234,7 +247,6 @@ class Tremolo(Filter):
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.tag: str = tag
self.payload: dict = {"tremolo": {"frequency": self.frequency, self.payload: dict = {"tremolo": {"frequency": self.frequency,
"depth": self.depth}} "depth": self.depth}}
@ -255,8 +267,13 @@ class Vibrato(Filter):
frequency: float = 2.0, frequency: float = 2.0,
depth: float = 0.5 depth: float = 0.5
): ):
super().__init__(tag=tag)
__slots__ = (
"frequency",
"depth"
)
super().__init__()
if frequency < 0 or frequency > 14: if frequency < 0 or frequency > 14:
raise FilterInvalidArgument( raise FilterInvalidArgument(
"Vibrato frequency must be between 0 and 14.") "Vibrato frequency must be between 0 and 14.")
@ -266,7 +283,6 @@ class Vibrato(Filter):
self.frequency: float = frequency self.frequency: float = frequency
self.depth: float = depth self.depth: float = depth
self.tag: str = tag
self.payload: dict = {"vibrato": {"frequency": self.frequency, self.payload: dict = {"vibrato": {"frequency": self.frequency,
"depth": self.depth}} "depth": self.depth}}
@ -281,10 +297,11 @@ class Rotation(Filter):
""" """
def __init__(self, *, tag: str, rotation_hertz: float = 5): def __init__(self, *, tag: str, rotation_hertz: float = 5):
super().__init__() super().__init__(tag=tag)
__slots__ = ("rotation_hertz")
self.rotation_hertz: float = rotation_hertz self.rotation_hertz: float = rotation_hertz
self.tag: str = tag
self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}} self.payload: dict = {"rotation": {"rotationHz": self.rotation_hertz}}
def __repr__(self) -> str: def __repr__(self) -> str:
@ -305,7 +322,14 @@ class ChannelMix(Filter):
left_to_right: float = 0, left_to_right: float = 0,
right_to_left: float = 0 right_to_left: float = 0
): ):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"left_to_left",
"right_to_right",
"left_to_right",
"right_to_left"
)
if 0 > left_to_left > 1: if 0 > left_to_left > 1:
raise ValueError( raise ValueError(
@ -324,7 +348,6 @@ class ChannelMix(Filter):
self.left_to_right: float = left_to_right self.left_to_right: float = left_to_right
self.right_to_left: float = right_to_left self.right_to_left: float = right_to_left
self.right_to_right: float = right_to_right self.right_to_right: float = right_to_right
self.tag: str = tag
self.payload: dict = {"channelMix": {"leftToLeft": self.left_to_left, self.payload: dict = {"channelMix": {"leftToLeft": self.left_to_left,
"leftToRight": self.left_to_right, "leftToRight": self.left_to_right,
@ -357,7 +380,18 @@ class Distortion(Filter):
offset: float = 0, offset: float = 0,
scale: float = 1 scale: float = 1
): ):
super().__init__() super().__init__(tag=tag)
__slots__ = (
"sin_offset",
"sin_scale",
"cos_offset",
"cos_scale",
"tan_offset",
"tan_scale"
"offset",
"scale"
)
self.sin_offset: float = sin_offset self.sin_offset: float = sin_offset
self.sin_scale: float = sin_scale self.sin_scale: float = sin_scale
@ -367,7 +401,6 @@ class Distortion(Filter):
self.tan_scale: float = tan_scale self.tan_scale: float = tan_scale
self.offset: float = offset self.offset: float = offset
self.scale: float = scale self.scale: float = scale
self.tag: str = tag
self.payload: dict = {"distortion": { self.payload: dict = {"distortion": {
"sinOffset": self.sin_offset, "sinOffset": self.sin_offset,
@ -393,10 +426,11 @@ class LowPass(Filter):
You can also do this with the Equalizer filter, but this is an easier way to do it. You can also do this with the Equalizer filter, but this is an easier way to do it.
""" """
def __init__(self, *, tag: str, smoothing: float = 20): def __init__(self, *, tag: str, smoothing: float = 20):
super().__init__() super().__init__(tag=tag)
__slots__ = ('smoothing')
self.smoothing: float = smoothing self.smoothing: float = smoothing
self.tag: str = tag
self.payload: dict = {"lowPass": {"smoothing": self.smoothing}} self.payload: dict = {"lowPass": {"smoothing": self.smoothing}}
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -7,8 +7,6 @@ from discord.ext import commands
from .enums import SearchType, TrackType, PlaylistType from .enums import SearchType, TrackType, PlaylistType
from .filters import Filter from .filters import Filter
__all__ = ('Track', 'Playlist')
class Track: class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink. """The base track object. Returns critical track information needed for parsing by Lavalink.
@ -27,6 +25,29 @@ class Track:
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
requester: Optional[Union[Member, User]] = None, requester: Optional[Union[Member, User]] = None,
): ):
__slots__ = (
"track_id",
"info",
"track_type",
"filters",
"timestamp",
"original",
"_search_type",
"playlist",
"title",
"author",
"uri",
"identifier",
"isrc",
"thumbnail",
"length",
"ctx",
"requester",
"is_stream",
"is_seekable",
"position"
)
self.track_id: str = track_id self.track_id: str = track_id
self.info: dict = info self.info: dict = info
self.track_type: TrackType = track_type self.track_type: TrackType = track_type
@ -98,6 +119,18 @@ class Playlist:
thumbnail: Optional[str] = None, thumbnail: Optional[str] = None,
uri: Optional[str] = None uri: Optional[str] = None
): ):
__slots__ = (
"playlist_info",
"tracks",
"name",
"playlist_type",
"_thumbnail",
"_uri",
"selected_track",
"track_count"
)
self.playlist_info: dict = playlist_info self.playlist_info: dict = playlist_info
self.tracks: List[Track] = tracks self.tracks: List[Track] = tracks
self.name: str = playlist_info.get("name") self.name: str = playlist_info.get("name")

View File

@ -23,7 +23,6 @@ from .filters import Filter
from .objects import Track from .objects import Track
from .pool import Node, NodePool from .pool import Node, NodePool
__all__ = ('Filters', 'Player')
class Filters: class Filters:
"""Helper class for filters""" """Helper class for filters"""
@ -134,10 +133,10 @@ class Player(VoiceProtocol):
node: Node = None node: Node = None
): ):
self.client: Optional[Client] = client self.client: Optional[Client] = client
self._bot: Union[Client, commands.Bot] = client
self.channel: Optional[VoiceChannel] = channel self.channel: Optional[VoiceChannel] = channel
self._guild: Guild = channel.guild if channel else None
self._bot: Union[Client, commands.Bot] = client
self._guild: Guild = channel.guild if channel else None
self._node: Node = node if node else NodePool.get_node() self._node: Node = node if node else NodePool.get_node()
self._current: Optional[Track] = None self._current: Optional[Track] = None
self._filters: Filters = Filters() self._filters: Filters = Filters()

View File

@ -37,8 +37,6 @@ from .routeplanner import RoutePlanner
if TYPE_CHECKING: if TYPE_CHECKING:
from .player import Player from .player import Player
__all__ = ('Node', 'NodePool')
class Node: class Node:
"""The base class for a node. """The base class for a node.
This node object represents a Lavalink node. This node object represents a Lavalink node.
@ -65,6 +63,33 @@ class Node:
fallback: bool = False fallback: bool = False
): ):
__slots__ = (
"_bot",
"_host",
"_port",
"_pool",
"_password",
"_identifier",
"_heartbeat",
"_secure",
"_fallback",
"_websocket_uri",
"_rest_uri",
"_session",
"_websocket",
"_task",
"_loop",
"_session_id",
"_available",
"_version",
"_headers",
"_players",
"_spotify_client_id",
"_spotify_client_secret",
"_spotify_client",
"_apple_music_client"
)
self._bot: Union[Client, commands.Bot] = bot self._bot: Union[Client, commands.Bot] = bot
self._host: str = host self._host: str = host
self._port: int = port self._port: int = port
@ -73,9 +98,7 @@ class Node:
self._identifier: str = identifier self._identifier: str = identifier
self._heartbeat: int = heartbeat self._heartbeat: int = heartbeat
self._secure: bool = secure self._secure: bool = secure
self.fallback: bool = fallback self._fallback: bool = fallback
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"

View File

@ -22,6 +22,15 @@ class Queue(Iterable[Track]):
*, *,
overflow: bool = True, overflow: bool = True,
): ):
__slots__ = (
"max_size",
"_queue",
"_overflow",
"_loop_mode",
"_current_item"
)
self.max_size: Optional[int] = max_size self.max_size: Optional[int] = max_size
self._queue: List[Track] = [] # type: ignore self._queue: List[Track] = [] # type: ignore
self._overflow: bool = overflow self._overflow: bool = overflow
@ -175,12 +184,13 @@ class Queue(Iterable[Track]):
return len(self._queue) return len(self._queue)
def get_queue(self) -> List: def get_queue(self) -> List:
"""Returns the queue as a List""" """Returns the queue as a List"""
return self._queue return self._queue
def get(self) -> Track: def get(self):
"""Return next immediately available item in queue if any. """Return next immediately available item in queue if any.
Raises QueueEmpty if no items in queue. Raises QueueEmpty if no items in queue.
""" """
@ -296,7 +306,7 @@ class Queue(Iterable[Track]):
"""Remove all items from the queue.""" """Remove all items from the queue."""
self._queue.clear() self._queue.clear()
def set_loop_mode(self, mode: LoopMode) -> None: def set_loop_mode(self, mode: LoopMode):
""" """
Sets the loop mode of the queue. Sets the loop mode of the queue.
Takes the LoopMode enum as an argument. Takes the LoopMode enum as an argument.
@ -312,7 +322,7 @@ class Queue(Iterable[Track]):
self._current_item = self._queue[index] self._current_item = self._queue[index]
def disable_loop(self) -> None: def disable_loop(self):
""" """
Disables loop mode if set. Disables loop mode if set.
Raises QueueException if loop mode is already None. Raises QueueException if loop mode is already None.
@ -327,17 +337,17 @@ class Queue(Iterable[Track]):
self._loop_mode = None self._loop_mode = None
def shuffle(self) -> None: def shuffle(self):
"""Shuffles the queue.""" """Shuffles the queue."""
return random.shuffle(self._queue) return random.shuffle(self._queue)
def clear_track_filters(self) -> None: def clear_track_filters(self):
"""Clears all filters applied to tracks""" """Clears all filters applied to tracks"""
for track in self._queue: for track in self._queue:
track.filters = None track.filters = None
def jump(self, item: Track) -> None: def jump(self, item: Track):
"""Mutates the queue so that all tracks before the specified track are removed.""" """Removes all tracks before the."""
index = self.find_position(item) index = self.find_position(item)
new_queue = self._queue[index:self.size] new_queue = self._queue[index:self.size]
self._queue = new_queue self._queue = new_queue

View File

@ -8,15 +8,6 @@ from itertools import zip_longest
from datetime import datetime from datetime import datetime
__all__ = (
'ExponentialBackoff',
'NodeStats',
'FailingIPBlock',
'RouteStats',
'Ping'
)
class ExponentialBackoff: class ExponentialBackoff:
""" """
The MIT License (MIT) The MIT License (MIT)
@ -72,6 +63,19 @@ class NodeStats:
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
__slots__ = (
"used",
"free",
"reservable",
"allocated",
"cpu_cores",
"cpu_system_load",
"cpu_process_load",
"players_active",
"players_total",
"uptime"
)
memory: dict = data.get("memory") memory: dict = data.get("memory")
self.used = memory.get("used") self.used = memory.get("used")
self.free = memory.get("free") self.free = memory.get("free")
@ -97,6 +101,12 @@ class FailingIPBlock:
and the time they failed. and the time they failed.
""" """
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
__slots__ = (
"address",
"failing_time"
)
self.address = data.get("address") self.address = data.get("address")
self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp"))) self.failing_time = datetime.fromtimestamp(float(data.get("failingTimestamp")))
@ -111,6 +121,14 @@ class RouteStats:
""" """
def __init__(self, data: dict) -> None: def __init__(self, data: dict) -> None:
__slots__ = (
"strategy",
"ip_block_type",
"ip_block_size",
"failing_addresses"
)
self.strategy = RouteStrategy(data.get("class")) self.strategy = RouteStrategy(data.get("class"))
details: dict = data.get("details") details: dict = data.get("details")