Fixed some bugs and added new SearchType enum

This commit is contained in:
cloudwithax 2021-10-07 20:34:32 -04:00
parent bdf0a75055
commit b53fe52331
6 changed files with 144 additions and 36 deletions

View File

@ -10,3 +10,4 @@ from .filters import *
from .node import Node, NodePool
from .objects import Track, Playlist
from .player import Player
from .enums import SearchType

23
pomice/enums.py Normal file
View File

@ -0,0 +1,23 @@
from enum import Enum
class SearchType(Enum):
"""The base class for the different search types for Pomice.
This feature is exclusively for the Spotify search feature of Pomice.
If you are not using this feature, this class is not necessary.
SearchType.YTSEARCH searches for a Spotify track using regular Youtube, which is best for all scenarios
SearchType.YTMSEARCH searches for a Spotify track using YouTube Music, which is best for getting audio-only results.
SearchType.SCSEARCH searches for a Spotify track using SoundCloud, which is an alternative to YouTube or YouTube Music.
"""
YTSEARCH = f'ytsearch:{track.artist} - {track.title}'
YTMSEARCH = f'ytmsearch:{track.artist} - {track.title}'
SCSEARCH = f'scsearch:{track.artist} - {track.title}'

View File

@ -4,14 +4,15 @@ import random
import re
import socket
import time
from typing import Optional, Type
from typing import Optional, Type, Union
from urllib.parse import quote
import aiohttp
import discord
from discord.ext import commands
from . import __version__, spotify
from . import __version__, spotify, Player, SearchType
from .exceptions import (
InvalidSpotifyClientAuthorization,
NodeConnectionFailure,
@ -41,24 +42,26 @@ class Node:
def __init__(
self,
pool,
bot: Type[commands.Bot],
bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]],
host: str,
port: int,
password: str,
identifier: str,
spotify_client_id: Optional[str],
spotify_client_secret: Optional[str]
spotify_client_secret: Optional[str],
session: Optional[aiohttp.ClientSession]
):
self._bot = bot
self._host = host
self._port = port
self._password = password
self._identifier = identifier
self._bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]] = bot
self._host: str = host
self._port: int = port
self._pool: NodePool = pool
self._password: str = password
self._identifier: str = identifier
self._websocket_uri = f"ws://{self._host}:{self._port}"
self._rest_uri = f"http://{self._host}:{self._port}"
self._websocket_uri: str = f"ws://{self._host}:{self._port}"
self._rest_uri: str = f"http://{self._host}:{self._port}"
self._session = aiohttp.ClientSession()
self._session: aiohttp.ClientSession = session or aiohttp.ClientSession()
self._websocket: aiohttp.ClientWebSocketResponse = None
self._task: asyncio.Task = None
@ -72,10 +75,10 @@ class Node:
"Client-Name": f"Pomice/{__version__}"
}
self._players = {}
self._players: dict = {}
self._spotify_client_id = spotify_client_id
self._spotify_client_secret = spotify_client_secret
self._spotify_client_id: str = spotify_client_id
self._spotify_client_secret: str = spotify_client_secret
if self._spotify_client_id and self._spotify_client_secret:
self._spotify_client = spotify.Client(
@ -119,14 +122,20 @@ class Node:
return self._players
@property
def bot(self) -> commands.Bot:
def bot(self) -> Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]]:
"""Property which returns the discord.py client linked to this node"""
return self._bot
@property
def player_count(self) -> int:
"""Property which returns how many players are connected to this node"""
return len(self.players)
@property
def pool(self):
"""Property which returns the pool this node is apart of"""
return self._pool
async def _update_handler(self, data: dict):
await self._bot.wait_until_ready()
@ -190,7 +199,7 @@ class Node:
await self._websocket.send_str(json.dumps(data))
def get_player(self, guild_id: int):
def get_player(self, guild_id: int) -> Player:
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
return self._players.get(guild_id, None)
@ -203,7 +212,7 @@ class Node:
self._websocket_uri, headers=self._headers, heartbeat=60
)
self._task = self._bot.loop.create_task(self._listen())
self.available = True
self._available = True
return self
except aiohttp.WSServerHandshakeError:
raise NodeConnectionFailure(
@ -230,7 +239,7 @@ class Node:
self.available = False
self._task.cancel()
async def get_tracks(self, query: str, ctx: commands.Context = None):
async def get_tracks(self, query: str, ctx: commands.Context = None, search_type: SearchType = None):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a
@ -262,6 +271,7 @@ class Node:
Track(
track_id=track.id,
ctx=ctx,
search_type=search_type,
spotify=True,
info={
"title": track.name or "Unknown",
@ -302,6 +312,7 @@ class Node:
Track(
track_id=track.id,
ctx=ctx,
search_type=search_type,
spotify=True,
info={
"title": track.name or "Unknown",
@ -339,6 +350,7 @@ class Node:
Track(
track_id=results.id,
ctx=ctx,
search_type=search_type,
spotify=True,
info={
"title": results.name or "Unknown",
@ -397,7 +409,6 @@ class Node:
for track in data["tracks"]
]
class NodePool:
"""The base class for the node pool.
This holds all the nodes that are to be used by the bot.
@ -409,7 +420,7 @@ class NodePool:
return f"<Pomice.NodePool node_count={self.node_count}>"
@property
def nodes(self):
def nodes(self) -> dict:
"""Property which returns a dict with the node identifier and the Node object."""
return self._nodes
@ -433,8 +444,8 @@ class NodePool:
@classmethod
async def create_node(
cls,
bot: Type[discord.Client],
bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]],
cls,
host: str,
port: str,
password: str,
@ -449,7 +460,7 @@ class NodePool:
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
node = Node(
pool=cls, bot=bot, host=host, port=port, password=password,
bot=bot, pool=cls, host=host, port=port, password=password,
identifier=identifier, spotify_client_id=spotify_client_id,
spotify_client_secret=spotify_client_secret
)

View File

@ -1,5 +1,6 @@
from typing import Optional
from . import SearchType
from discord.ext import commands
@ -19,6 +20,10 @@ class Track:
self.info = info
self.spotify = spotify
if self.spotify:
self.youtube_result = None
self.search_type: SearchType = None
self.title = info.get("title")
self.author = info.get("author")
self.length = info.get("length")
@ -68,11 +73,13 @@ class Playlist:
self.name = playlist_info.get("name")
self.selected_track = playlist_info.get("selectedTrack")
self._thumbnail = thumbnail
self._uri = uri
if self.spotify:
self.tracks = tracks
else:
self.tracks = [
Track(track_id=track["track"], info=track["info"], ctx=ctx)

View File

@ -1,11 +1,11 @@
import time
from typing import Any, Dict, Type
from typing import Any, Dict, Type, Union
import discord
from discord import VoiceChannel, VoiceProtocol
from discord import VoiceChannel, VoiceProtocol, Guild, Member
from discord.ext import commands
from . import events, filters, NodePool, objects
from . import events, filters, NodePool, objects, Node
from .exceptions import TrackInvalidPosition
@ -17,17 +17,17 @@ class Player(VoiceProtocol):
```
"""
def __init__(self, client: Type[commands.Bot], channel: VoiceChannel):
def __init__(self, client: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]], channel: VoiceChannel):
super().__init__(client=client, channel=channel)
self.client = client
self.bot = client
self._bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]] = client
self.channel = channel
self.guild: discord.Guild = self.channel.guild
self._guild: discord.Guild = self.channel.guild
self._dj: discord.Member = None
self._node = NodePool.get_node()
self.current: objects.Track = None
self._current: objects.Track = None
self._filter: filters.Filter = None
self._volume = 100
self._paused = False
@ -73,6 +73,42 @@ class Player(VoiceProtocol):
"""Property which returns whether or not the player has a track which is paused or not."""
return self._is_connected and self._paused
@property
def current(self) -> objects.Track:
"""Property which returns the currently playing track"""
return self._current
@property
def node(self) -> Node:
"""Property which returns the node the player is connected to"""
return self._node
@property
def guild(self) -> Guild:
"""Property which returns the guild associated with the player"""
return self._guild
@property
def volume(self) -> int:
"""Property which returns the players current volume"""
return self._volume
@property
def dj(self) -> Member:
"""Property which returns the DJ for the player session"""
return self._dj
@property
def filter(self) -> filters.Filter:
"""Property which returns the currently applied filter, if one is applied"""
return self._filter
@property
def bot(self) -> Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]]:
"""Property which returns the bot associated with this player instance"""
return self._bot
async def _update_state(self, data: dict):
state: dict = data.get("state")
self._last_update = time.time() * 1000
@ -112,7 +148,7 @@ class Player(VoiceProtocol):
async def get_tracks(self, query: str, ctx: commands.Context = None):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a Spotify URL of a playlist,
If you passed in Spotify API credentials when you created the node, you can also pass in a Spotify URL of a playlist,
album or track and it will be parsed accordingly.
You can also pass in a discord.py Context object to get a
@ -146,9 +182,11 @@ class Player(VoiceProtocol):
async def play(self, track: objects.Track, start_position: int = 0) -> objects.Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
if track.spotify:
search_type = track.search_type or f"ytmsearch:{track.author} - {track.title}"
spotify_track: objects.Track = (await self._node.get_tracks(
f"ytmsearch:{track.author} - {track.title}"
search_type
))[0]
track.youtube_result = spotify_track
await self._node.send(
op="play",
guildId=str(self.guild.id),
@ -166,8 +204,8 @@ class Player(VoiceProtocol):
endTime=track.length,
noReplace=False
)
self.current = track
return self.current
self._current = track
return self._current
async def seek(self, position: float) -> float:
"""Seeks to a position in the currently playing track milliseconds"""

View File

@ -20,6 +20,7 @@ DEALINGS IN THE SOFTWARE.
import random
import time
from typing import Any
__all__ = [
'ExponentialBackoff',
@ -78,3 +79,30 @@ class NodeStats:
def __repr__(self) -> str:
return f'<Pomice.NodeStats total_players={self.players_total} playing_active={self.players_active}>'
class Queue:
"""Pomice's very own queue implementation with some added features like:
- Toggleable shuffle
- Loop queue functionality
- Music player style queue, which doesn't remove tracks, allowing for playback of previously played tracks.
"""
def __init__(self) -> None:
self._queue = []
self._shuffle = False
if self._shuffle is True:
self._original_queue = []
self._looping = False
def put(self, item: Any):
"""Puts an item into the queue"""
return self._queue.extend(item)
def remove(self, item: Any):
"""Removes an item from the queue"""
if type(item) == int:
return self._queue.remove(self._queue[item])