cleanup + searchtype stuff

This commit is contained in:
vveeps 2021-10-08 17:15:18 +03:00
parent c1ff8d77c0
commit 543e2d9b86
5 changed files with 48 additions and 49 deletions

View File

@ -1,23 +1,21 @@
from enum import Enum from enum import Enum
from typing import Literal
class SearchType(Enum): class SearchType(Enum):
"""The base class for the different search types for Pomice. """The enum for the different search types for Pomice.
This feature is exclusively for the Spotify search feature of Pomice. This feature is exclusively for the Spotify search feature of Pomice.
If you are not using this feature, this class is not necessary. If you are not using this feature, this class is not necessary.
SearchType.YTSEARCH searches for a Spotify track using regular Youtube, which is best for all scenarios SearchType.ytsearch searches using regular Youtube, which is best for all scenarios.
SearchType.YTMSEARCH searches for a Spotify track using YouTube Music, which is best for getting audio-only results. SearchType.ytmsearch searches using YouTube Music, which is best for getting audio-only results.
SearchType.SCSEARCH searches for a Spotify track using SoundCloud, which is an alternative to YouTube or YouTube Music. SearchType.scsearch searches using SoundCloud, which is an alternative to YouTube or YouTube Music.
""" """
YTSEARCH = f'ytsearch:{track.artist} - {track.title}' ytsearch = "ytsearch"
YTMSEARCH = f'ytmsearch:{track.artist} - {track.title}' ytmsearch = "ytmsearch"
SCSEARCH = f'scsearch:{track.artist} - {track.title}' scsearch = "scsearch"
def __str__(self) -> str:
return self.value

View File

@ -11,8 +11,8 @@ import aiohttp
import discord import discord
from discord.ext import commands from discord.ext import commands
from . import __version__, spotify
from . import __version__, spotify, Player, SearchType from .enums import SearchType
from .exceptions import ( from .exceptions import (
InvalidSpotifyClientAuthorization, InvalidSpotifyClientAuthorization,
NodeConnectionFailure, NodeConnectionFailure,
@ -25,6 +25,7 @@ from .exceptions import (
TrackLoadError TrackLoadError
) )
from .objects import Playlist, Track from .objects import Playlist, Track
from .player import Player
from .spotify import SpotifyException from .spotify import SpotifyException
from .utils import ExponentialBackoff, NodeStats from .utils import ExponentialBackoff, NodeStats
@ -239,7 +240,12 @@ class Node:
self.available = False self.available = False
self._task.cancel() self._task.cancel()
async def get_tracks(self, query: str, ctx: commands.Context = None, search_type: SearchType = None): async def get_tracks(
self,
query: str,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch
):
"""Fetches tracks from the node's REST api to parse into Lavalink. """Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a If you passed in Spotify API credentials, you can also pass in a
@ -256,10 +262,10 @@ class Node:
"please obtain Spotify API credentials here: https://developer.spotify.com/" "please obtain Spotify API credentials here: https://developer.spotify.com/"
) )
search_type = spotify_url_check.group("type") spotify_type = spotify_url_check.group("type")
spotify_id = spotify_url_check.group("id") spotify_id = spotify_url_check.group("id")
if search_type == "playlist": if spotify_type == "playlist":
results = spotify.Playlist( results = spotify.Playlist(
client=self._spotify_client, client=self._spotify_client,
data=await self._spotify_http_client.get_playlist(spotify_id) data=await self._spotify_http_client.get_playlist(spotify_id)
@ -303,7 +309,7 @@ class Node:
f"Unable to find results for {query}" f"Unable to find results for {query}"
) )
elif search_type == "album": elif spotify_type == "album":
results = await self._spotify_client.get_album(spotify_id=spotify_id) results = await self._spotify_client.get_album(spotify_id=spotify_id)
try: try:
@ -342,7 +348,7 @@ class Node:
except SpotifyException: except SpotifyException:
raise SpotifyAlbumLoadFailed(f"Unable to find results for {query}") raise SpotifyAlbumLoadFailed(f"Unable to find results for {query}")
elif search_type == 'track': elif spotify_type == 'track':
try: try:
results = await self._spotify_client.get_track(spotify_id=spotify_id) results = await self._spotify_client.get_track(spotify_id=spotify_id)
@ -409,6 +415,7 @@ class Node:
for track in data["tracks"] for track in data["tracks"]
] ]
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.
This holds all the nodes that are to be used by the bot. This holds all the nodes that are to be used by the bot.
@ -445,7 +452,7 @@ class NodePool:
@classmethod @classmethod
async def create_node( async def create_node(
bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]], bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]],
cls, cls,
host: str, host: str,
port: str, port: str,
password: str, password: str,

View File

@ -1,8 +1,9 @@
from typing import Optional from typing import Optional
from . import SearchType
from discord.ext import commands from discord.ext import commands
from .enums import SearchType
class Track: class Track:
"""The base track object. Returns critical track information needed for parsing by Lavalink. """The base track object. Returns critical track information needed for parsing by Lavalink.
@ -14,15 +15,15 @@ class Track:
track_id: str, track_id: str,
info: dict, info: dict,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
spotify: bool = False spotify: bool = False,
search_type: SearchType = SearchType.ytsearch
): ):
self.track_id = track_id self.track_id = track_id
self.info = info self.info = info
self.spotify = spotify self.spotify = spotify
if self.spotify: self.original: Optional[Track] = None if self.spotify else self
self.youtube_result = None self._search_type = search_type
self.search_type: SearchType = None
self.title = info.get("title") self.title = info.get("title")
self.author = info.get("author") self.author = info.get("author")
@ -73,10 +74,9 @@ class Playlist:
self.name = playlist_info.get("name") self.name = playlist_info.get("name")
self.selected_track = playlist_info.get("selectedTrack") self.selected_track = playlist_info.get("selectedTrack")
self._thumbnail = thumbnail self._thumbnail = thumbnail
self._uri = uri self._uri = uri
if self.spotify: if self.spotify:
self.tracks = tracks self.tracks = tracks

View File

@ -5,8 +5,11 @@ import discord
from discord import VoiceChannel, VoiceProtocol, Guild, Member from discord import VoiceChannel, VoiceProtocol, Guild, Member
from discord.ext import commands from discord.ext import commands
from . import events, filters, NodePool, objects, Node from . import events
from .exceptions import TrackInvalidPosition from .exceptions import TrackInvalidPosition
from .filters import Filter
from .node import Node, NodePool
from .objects import Track
class Player(VoiceProtocol): class Player(VoiceProtocol):
@ -27,8 +30,8 @@ class Player(VoiceProtocol):
self._dj: discord.Member = None self._dj: discord.Member = None
self._node = NodePool.get_node() self._node = NodePool.get_node()
self._current: objects.Track = None self._current: Track = None
self._filter: filters.Filter = None self._filter: Filter = None
self._volume = 100 self._volume = 100
self._paused = False self._paused = False
self._is_connected = False self._is_connected = False
@ -74,7 +77,7 @@ class Player(VoiceProtocol):
return self._is_connected and self._paused return self._is_connected and self._paused
@property @property
def current(self) -> objects.Track: def current(self) -> Track:
"""Property which returns the currently playing track""" """Property which returns the currently playing track"""
return self._current return self._current
@ -99,7 +102,7 @@ class Player(VoiceProtocol):
return self._dj return self._dj
@property @property
def filter(self) -> filters.Filter: def filter(self) -> Filter:
"""Property which returns the currently applied filter, if one is applied""" """Property which returns the currently applied filter, if one is applied"""
return self._filter return self._filter
@ -108,7 +111,6 @@ class Player(VoiceProtocol):
"""Property which returns the bot associated with this player instance""" """Property which returns the bot associated with this player instance"""
return self._bot return self._bot
async def _update_state(self, data: dict): async def _update_state(self, data: dict):
state: dict = data.get("state") state: dict = data.get("state")
self._last_update = time.time() * 1000 self._last_update = time.time() * 1000
@ -179,20 +181,19 @@ class Player(VoiceProtocol):
await self.disconnect() await self.disconnect()
await self._node.send(op="destroy", guildId=str(self.guild.id)) await self._node.send(op="destroy", guildId=str(self.guild.id))
async def play(self, track: objects.Track, start_position: int = 0) -> objects.Track: async def play(self, track: Track, start_position: int = 0) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly.""" """Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
if track.spotify: if track.spotify:
search_type = track.search_type or f"ytmsearch:{track.author} - {track.title}" search: Track = (await self._node.get_tracks(
spotify_track: objects.Track = (await self._node.get_tracks( f"{track._search_type}:{track.author} - {track.title}"
search_type
))[0] ))[0]
track.youtube_result = spotify_track track.original = search
await self._node.send( await self._node.send(
op="play", op="play",
guildId=str(self.guild.id), guildId=str(self.guild.id),
track=spotify_track.track_id, track=search.track_id,
startTime=start_position, startTime=start_position,
endTime=spotify_track.length, endTime=search.length,
noReplace=False noReplace=False
) )
else: else:
@ -230,7 +231,7 @@ class Player(VoiceProtocol):
self._volume = volume self._volume = volume
return self._volume return self._volume
async def set_filter(self, filter: filters.Filter) -> filters.Filter: async def set_filter(self, filter: Filter) -> Filter:
"""Sets a filter of the player. Takes a pomice.Filter object. """Sets a filter of the player. Takes a pomice.Filter object.
This will only work if you are using the development version of Lavalink. This will only work if you are using the development version of Lavalink.
""" """

View File

@ -20,7 +20,6 @@ DEALINGS IN THE SOFTWARE.
import random import random
import time import time
from typing import Any
__all__ = [ __all__ = [
'ExponentialBackoff', 'ExponentialBackoff',
@ -79,9 +78,3 @@ class NodeStats:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Pomice.NodeStats total_players={self.players_total} playing_active={self.players_active}>' return f'<Pomice.NodeStats total_players={self.players_total} playing_active={self.players_active}>'