From 7d2600ed7f7c0736bc105a52d07c1ae29e9e6aaf Mon Sep 17 00:00:00 2001 From: vveeps <54472340+vveeps@users.noreply.github.com> Date: Sun, 10 Oct 2021 00:13:48 +0300 Subject: [PATCH] cleanup + make some params kw only --- pomice/objects.py | 3 ++- pomice/player.py | 31 ++++++++++++++++++++----------- pomice/pool.py | 3 +++ 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/pomice/objects.py b/pomice/objects.py index 8a8bdcf..dc81c8a 100644 --- a/pomice/objects.py +++ b/pomice/objects.py @@ -13,6 +13,7 @@ class Track: def __init__( self, + *, track_id: str, info: dict, ctx: Optional[commands.Context] = None, @@ -61,6 +62,7 @@ class Playlist: def __init__( self, + *, playlist_info: dict, tracks: list, ctx: Optional[commands.Context] = None, @@ -80,7 +82,6 @@ class Playlist: if self.spotify: self.tracks = tracks - else: self.tracks = [ Track(track_id=track["track"], info=track["info"], ctx=ctx) diff --git a/pomice/player.py b/pomice/player.py index cf8c0c7..1bfea85 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -1,10 +1,12 @@ import time -from typing import Any, Dict, Type, Union +from typing import Any, Dict, Optional, Type, Union import discord -from discord import VoiceChannel, VoiceProtocol, Guild +from discord import Client, Guild, VoiceChannel, VoiceProtocol from discord.ext import commands +from pomice.enums import SearchType + from . import events from .exceptions import TrackInvalidPosition from .filters import Filter @@ -20,13 +22,13 @@ class Player(VoiceProtocol): ``` """ - def __init__(self, client: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]], channel: VoiceChannel): + def __init__(self, client: Type[Client], channel: VoiceChannel): super().__init__(client=client, channel=channel) self.client = client - self._bot: Union[discord.Client, commands.Bot, commands.AutoShardedBot] = client + self._bot = client self.channel = channel - self._guild: discord.Guild = self.channel.guild + self._guild: Guild = self.channel.guild self._node = NodePool.get_node() self._current: Track = None @@ -106,7 +108,7 @@ class Player(VoiceProtocol): return self._filter @property - def bot(self) -> Type[discord.Client]: + def bot(self) -> Type[Client]: """Property which returns the bot associated with this player instance""" return self._bot @@ -146,16 +148,23 @@ class Player(VoiceProtocol): event = event(data) self.bot.dispatch(f"pomice_{event.name}", event) - async def get_tracks(self, query: str, ctx: commands.Context = None): + async def get_tracks( + self, + query: str, + *, + ctx: Optional[commands.Context] = None, + search_type: SearchType = SearchType.ytsearch + ): """Fetches tracks from the node's REST api to parse into Lavalink. - If you passed in Spotify API credentials when you created the node, you can also pass in a Spotify URL of a playlist, - album or track and it will be parsed accordingly. + If you passed in Spotify API credentials when you created the node, + you can also pass in a Spotify URL of a playlist, album or track and it will be parsed + accordingly. You can also pass in a discord.py Context object to get a Context object on any track you search. """ - return await self._node.get_tracks(query, ctx) + return await self._node.get_tracks(query, ctx=ctx, search_type=search_type) async def connect(self, *, timeout: float, reconnect: bool): await self.guild.change_voice_state(channel=self.channel) @@ -180,7 +189,7 @@ class Player(VoiceProtocol): await self.disconnect() await self._node.send(op="destroy", guildId=str(self.guild.id)) - async def play(self, track: Track, start_position: int = 0) -> Track: + async def play(self, track: Track, *, start_position: int = 0) -> Track: """Plays a track. If a Spotify track is passed in, it will be handled accordingly.""" if track.spotify: search: Track = (await self._node.get_tracks( diff --git a/pomice/pool.py b/pomice/pool.py index 60d4dae..bca5b00 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -50,6 +50,7 @@ class Node: def __init__( self, + *, pool, bot: Type[discord.Client], host: str, @@ -252,6 +253,7 @@ class Node: async def get_tracks( self, query: str, + *, ctx: Optional[commands.Context] = None, search_type: SearchType = SearchType.ytsearch ): @@ -483,6 +485,7 @@ class NodePool: @classmethod async def create_node( cls, + *, bot: Type[discord.Client], host: str, port: str,