Fixed some bugs and added new SearchType enum
This commit is contained in:
parent
bdf0a75055
commit
b53fe52331
|
|
@ -10,3 +10,4 @@ from .filters import *
|
||||||
from .node import Node, NodePool
|
from .node import Node, NodePool
|
||||||
from .objects import Track, Playlist
|
from .objects import Track, Playlist
|
||||||
from .player import Player
|
from .player import Player
|
||||||
|
from .enums import SearchType
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class SearchType(Enum):
|
||||||
|
"""The base class for the different search types for Pomice.
|
||||||
|
This feature is exclusively for the Spotify search feature of Pomice.
|
||||||
|
If you are not using this feature, this class is not necessary.
|
||||||
|
|
||||||
|
SearchType.YTSEARCH searches for a Spotify track using regular Youtube, which is best for all scenarios
|
||||||
|
|
||||||
|
SearchType.YTMSEARCH searches for a Spotify track using YouTube Music, which is best for getting audio-only results.
|
||||||
|
|
||||||
|
SearchType.SCSEARCH searches for a Spotify track using SoundCloud, which is an alternative to YouTube or YouTube Music.
|
||||||
|
"""
|
||||||
|
YTSEARCH = f'ytsearch:{track.artist} - {track.title}'
|
||||||
|
YTMSEARCH = f'ytmsearch:{track.artist} - {track.title}'
|
||||||
|
SCSEARCH = f'scsearch:{track.artist} - {track.title}'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -4,14 +4,15 @@ import random
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type, Union
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from . import __version__, spotify
|
|
||||||
|
from . import __version__, spotify, Player, SearchType
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
InvalidSpotifyClientAuthorization,
|
InvalidSpotifyClientAuthorization,
|
||||||
NodeConnectionFailure,
|
NodeConnectionFailure,
|
||||||
|
|
@ -41,24 +42,26 @@ class Node:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pool,
|
pool,
|
||||||
bot: Type[commands.Bot],
|
bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]],
|
||||||
host: str,
|
host: str,
|
||||||
port: int,
|
port: int,
|
||||||
password: str,
|
password: str,
|
||||||
identifier: str,
|
identifier: str,
|
||||||
spotify_client_id: Optional[str],
|
spotify_client_id: Optional[str],
|
||||||
spotify_client_secret: Optional[str]
|
spotify_client_secret: Optional[str],
|
||||||
|
session: Optional[aiohttp.ClientSession]
|
||||||
):
|
):
|
||||||
self._bot = bot
|
self._bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]] = bot
|
||||||
self._host = host
|
self._host: str = host
|
||||||
self._port = port
|
self._port: int = port
|
||||||
self._password = password
|
self._pool: NodePool = pool
|
||||||
self._identifier = identifier
|
self._password: str = password
|
||||||
|
self._identifier: str = identifier
|
||||||
|
|
||||||
self._websocket_uri = f"ws://{self._host}:{self._port}"
|
self._websocket_uri: str = f"ws://{self._host}:{self._port}"
|
||||||
self._rest_uri = f"http://{self._host}:{self._port}"
|
self._rest_uri: str = f"http://{self._host}:{self._port}"
|
||||||
|
|
||||||
self._session = aiohttp.ClientSession()
|
self._session: aiohttp.ClientSession = session or aiohttp.ClientSession()
|
||||||
self._websocket: aiohttp.ClientWebSocketResponse = None
|
self._websocket: aiohttp.ClientWebSocketResponse = None
|
||||||
self._task: asyncio.Task = None
|
self._task: asyncio.Task = None
|
||||||
|
|
||||||
|
|
@ -72,10 +75,10 @@ class Node:
|
||||||
"Client-Name": f"Pomice/{__version__}"
|
"Client-Name": f"Pomice/{__version__}"
|
||||||
}
|
}
|
||||||
|
|
||||||
self._players = {}
|
self._players: dict = {}
|
||||||
|
|
||||||
self._spotify_client_id = spotify_client_id
|
self._spotify_client_id: str = spotify_client_id
|
||||||
self._spotify_client_secret = spotify_client_secret
|
self._spotify_client_secret: str = spotify_client_secret
|
||||||
|
|
||||||
if self._spotify_client_id and self._spotify_client_secret:
|
if self._spotify_client_id and self._spotify_client_secret:
|
||||||
self._spotify_client = spotify.Client(
|
self._spotify_client = spotify.Client(
|
||||||
|
|
@ -119,14 +122,20 @@ class Node:
|
||||||
return self._players
|
return self._players
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bot(self) -> commands.Bot:
|
def bot(self) -> Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]]:
|
||||||
"""Property which returns the discord.py client linked to this node"""
|
"""Property which returns the discord.py client linked to this node"""
|
||||||
return self._bot
|
return self._bot
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def player_count(self) -> int:
|
def player_count(self) -> int:
|
||||||
|
"""Property which returns how many players are connected to this node"""
|
||||||
return len(self.players)
|
return len(self.players)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pool(self):
|
||||||
|
"""Property which returns the pool this node is apart of"""
|
||||||
|
return self._pool
|
||||||
|
|
||||||
async def _update_handler(self, data: dict):
|
async def _update_handler(self, data: dict):
|
||||||
await self._bot.wait_until_ready()
|
await self._bot.wait_until_ready()
|
||||||
|
|
||||||
|
|
@ -190,7 +199,7 @@ class Node:
|
||||||
|
|
||||||
await self._websocket.send_str(json.dumps(data))
|
await self._websocket.send_str(json.dumps(data))
|
||||||
|
|
||||||
def get_player(self, guild_id: int):
|
def get_player(self, guild_id: int) -> Player:
|
||||||
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
|
"""Takes a guild ID as a parameter. Returns a pomice Player object."""
|
||||||
return self._players.get(guild_id, None)
|
return self._players.get(guild_id, None)
|
||||||
|
|
||||||
|
|
@ -203,7 +212,7 @@ class Node:
|
||||||
self._websocket_uri, headers=self._headers, heartbeat=60
|
self._websocket_uri, headers=self._headers, heartbeat=60
|
||||||
)
|
)
|
||||||
self._task = self._bot.loop.create_task(self._listen())
|
self._task = self._bot.loop.create_task(self._listen())
|
||||||
self.available = True
|
self._available = True
|
||||||
return self
|
return self
|
||||||
except aiohttp.WSServerHandshakeError:
|
except aiohttp.WSServerHandshakeError:
|
||||||
raise NodeConnectionFailure(
|
raise NodeConnectionFailure(
|
||||||
|
|
@ -230,7 +239,7 @@ class Node:
|
||||||
self.available = False
|
self.available = False
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
async def get_tracks(self, query: str, ctx: commands.Context = None):
|
async def get_tracks(self, query: str, ctx: commands.Context = None, search_type: SearchType = None):
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
||||||
If you passed in Spotify API credentials, you can also pass in a
|
If you passed in Spotify API credentials, you can also pass in a
|
||||||
|
|
@ -262,6 +271,7 @@ class Node:
|
||||||
Track(
|
Track(
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
search_type=search_type,
|
||||||
spotify=True,
|
spotify=True,
|
||||||
info={
|
info={
|
||||||
"title": track.name or "Unknown",
|
"title": track.name or "Unknown",
|
||||||
|
|
@ -302,6 +312,7 @@ class Node:
|
||||||
Track(
|
Track(
|
||||||
track_id=track.id,
|
track_id=track.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
search_type=search_type,
|
||||||
spotify=True,
|
spotify=True,
|
||||||
info={
|
info={
|
||||||
"title": track.name or "Unknown",
|
"title": track.name or "Unknown",
|
||||||
|
|
@ -339,6 +350,7 @@ class Node:
|
||||||
Track(
|
Track(
|
||||||
track_id=results.id,
|
track_id=results.id,
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
search_type=search_type,
|
||||||
spotify=True,
|
spotify=True,
|
||||||
info={
|
info={
|
||||||
"title": results.name or "Unknown",
|
"title": results.name or "Unknown",
|
||||||
|
|
@ -397,7 +409,6 @@ class Node:
|
||||||
for track in data["tracks"]
|
for track in data["tracks"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class NodePool:
|
class NodePool:
|
||||||
"""The base class for the node pool.
|
"""The base class for the node pool.
|
||||||
This holds all the nodes that are to be used by the bot.
|
This holds all the nodes that are to be used by the bot.
|
||||||
|
|
@ -409,7 +420,7 @@ class NodePool:
|
||||||
return f"<Pomice.NodePool node_count={self.node_count}>"
|
return f"<Pomice.NodePool node_count={self.node_count}>"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self):
|
def nodes(self) -> dict:
|
||||||
"""Property which returns a dict with the node identifier and the Node object."""
|
"""Property which returns a dict with the node identifier and the Node object."""
|
||||||
return self._nodes
|
return self._nodes
|
||||||
|
|
||||||
|
|
@ -433,8 +444,8 @@ class NodePool:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_node(
|
async def create_node(
|
||||||
cls,
|
bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]],
|
||||||
bot: Type[discord.Client],
|
cls,
|
||||||
host: str,
|
host: str,
|
||||||
port: str,
|
port: str,
|
||||||
password: str,
|
password: str,
|
||||||
|
|
@ -449,7 +460,7 @@ class NodePool:
|
||||||
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
|
raise NodeCreationError(f"A node with identifier '{identifier}' already exists.")
|
||||||
|
|
||||||
node = Node(
|
node = Node(
|
||||||
pool=cls, bot=bot, host=host, port=port, password=password,
|
bot=bot, pool=cls, host=host, port=port, password=password,
|
||||||
identifier=identifier, spotify_client_id=spotify_client_id,
|
identifier=identifier, spotify_client_id=spotify_client_id,
|
||||||
spotify_client_secret=spotify_client_secret
|
spotify_client_secret=spotify_client_secret
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from . import SearchType
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,6 +20,10 @@ class Track:
|
||||||
self.info = info
|
self.info = info
|
||||||
self.spotify = spotify
|
self.spotify = spotify
|
||||||
|
|
||||||
|
if self.spotify:
|
||||||
|
self.youtube_result = None
|
||||||
|
self.search_type: SearchType = None
|
||||||
|
|
||||||
self.title = info.get("title")
|
self.title = info.get("title")
|
||||||
self.author = info.get("author")
|
self.author = info.get("author")
|
||||||
self.length = info.get("length")
|
self.length = info.get("length")
|
||||||
|
|
@ -68,11 +73,13 @@ class Playlist:
|
||||||
self.name = playlist_info.get("name")
|
self.name = playlist_info.get("name")
|
||||||
self.selected_track = playlist_info.get("selectedTrack")
|
self.selected_track = playlist_info.get("selectedTrack")
|
||||||
|
|
||||||
|
|
||||||
self._thumbnail = thumbnail
|
self._thumbnail = thumbnail
|
||||||
self._uri = uri
|
self._uri = uri
|
||||||
|
|
||||||
if self.spotify:
|
if self.spotify:
|
||||||
self.tracks = tracks
|
self.tracks = tracks
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.tracks = [
|
self.tracks = [
|
||||||
Track(track_id=track["track"], info=track["info"], ctx=ctx)
|
Track(track_id=track["track"], info=track["info"], ctx=ctx)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Type
|
from typing import Any, Dict, Type, Union
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord import VoiceChannel, VoiceProtocol
|
from discord import VoiceChannel, VoiceProtocol, Guild, Member
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from . import events, filters, NodePool, objects
|
from . import events, filters, NodePool, objects, Node
|
||||||
from .exceptions import TrackInvalidPosition
|
from .exceptions import TrackInvalidPosition
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,17 +17,17 @@ class Player(VoiceProtocol):
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, client: Type[commands.Bot], channel: VoiceChannel):
|
def __init__(self, client: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]], channel: VoiceChannel):
|
||||||
super().__init__(client=client, channel=channel)
|
super().__init__(client=client, channel=channel)
|
||||||
|
|
||||||
self.client = client
|
self.client = client
|
||||||
self.bot = client
|
self._bot: Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]] = client
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.guild: discord.Guild = self.channel.guild
|
self._guild: discord.Guild = self.channel.guild
|
||||||
self._dj: discord.Member = None
|
self._dj: discord.Member = None
|
||||||
|
|
||||||
self._node = NodePool.get_node()
|
self._node = NodePool.get_node()
|
||||||
self.current: objects.Track = None
|
self._current: objects.Track = None
|
||||||
self._filter: filters.Filter = None
|
self._filter: filters.Filter = None
|
||||||
self._volume = 100
|
self._volume = 100
|
||||||
self._paused = False
|
self._paused = False
|
||||||
|
|
@ -73,6 +73,42 @@ class Player(VoiceProtocol):
|
||||||
"""Property which returns whether or not the player has a track which is paused or not."""
|
"""Property which returns whether or not the player has a track which is paused or not."""
|
||||||
return self._is_connected and self._paused
|
return self._is_connected and self._paused
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current(self) -> objects.Track:
|
||||||
|
"""Property which returns the currently playing track"""
|
||||||
|
return self._current
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self) -> Node:
|
||||||
|
"""Property which returns the node the player is connected to"""
|
||||||
|
return self._node
|
||||||
|
|
||||||
|
@property
|
||||||
|
def guild(self) -> Guild:
|
||||||
|
"""Property which returns the guild associated with the player"""
|
||||||
|
return self._guild
|
||||||
|
|
||||||
|
@property
|
||||||
|
def volume(self) -> int:
|
||||||
|
"""Property which returns the players current volume"""
|
||||||
|
return self._volume
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dj(self) -> Member:
|
||||||
|
"""Property which returns the DJ for the player session"""
|
||||||
|
return self._dj
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter(self) -> filters.Filter:
|
||||||
|
"""Property which returns the currently applied filter, if one is applied"""
|
||||||
|
return self._filter
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bot(self) -> Type[Union[discord.Client, commands.Bot, commands.AutoShardedBot]]:
|
||||||
|
"""Property which returns the bot associated with this player instance"""
|
||||||
|
return self._bot
|
||||||
|
|
||||||
|
|
||||||
async def _update_state(self, data: dict):
|
async def _update_state(self, data: dict):
|
||||||
state: dict = data.get("state")
|
state: dict = data.get("state")
|
||||||
self._last_update = time.time() * 1000
|
self._last_update = time.time() * 1000
|
||||||
|
|
@ -112,7 +148,7 @@ class Player(VoiceProtocol):
|
||||||
async def get_tracks(self, query: str, ctx: commands.Context = None):
|
async def get_tracks(self, query: str, ctx: commands.Context = None):
|
||||||
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
"""Fetches tracks from the node's REST api to parse into Lavalink.
|
||||||
|
|
||||||
If you passed in Spotify API credentials, you can also pass in a Spotify URL of a playlist,
|
If you passed in Spotify API credentials when you created the node, you can also pass in a Spotify URL of a playlist,
|
||||||
album or track and it will be parsed accordingly.
|
album or track and it will be parsed accordingly.
|
||||||
|
|
||||||
You can also pass in a discord.py Context object to get a
|
You can also pass in a discord.py Context object to get a
|
||||||
|
|
@ -146,9 +182,11 @@ class Player(VoiceProtocol):
|
||||||
async def play(self, track: objects.Track, start_position: int = 0) -> objects.Track:
|
async def play(self, track: objects.Track, start_position: int = 0) -> objects.Track:
|
||||||
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
|
||||||
if track.spotify:
|
if track.spotify:
|
||||||
|
search_type = track.search_type or f"ytmsearch:{track.author} - {track.title}"
|
||||||
spotify_track: objects.Track = (await self._node.get_tracks(
|
spotify_track: objects.Track = (await self._node.get_tracks(
|
||||||
f"ytmsearch:{track.author} - {track.title}"
|
search_type
|
||||||
))[0]
|
))[0]
|
||||||
|
track.youtube_result = spotify_track
|
||||||
await self._node.send(
|
await self._node.send(
|
||||||
op="play",
|
op="play",
|
||||||
guildId=str(self.guild.id),
|
guildId=str(self.guild.id),
|
||||||
|
|
@ -166,8 +204,8 @@ class Player(VoiceProtocol):
|
||||||
endTime=track.length,
|
endTime=track.length,
|
||||||
noReplace=False
|
noReplace=False
|
||||||
)
|
)
|
||||||
self.current = track
|
self._current = track
|
||||||
return self.current
|
return self._current
|
||||||
|
|
||||||
async def seek(self, position: float) -> float:
|
async def seek(self, position: float) -> float:
|
||||||
"""Seeks to a position in the currently playing track milliseconds"""
|
"""Seeks to a position in the currently playing track milliseconds"""
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ExponentialBackoff',
|
'ExponentialBackoff',
|
||||||
|
|
@ -78,3 +79,30 @@ class NodeStats:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'<Pomice.NodeStats total_players={self.players_total} playing_active={self.players_active}>'
|
return f'<Pomice.NodeStats total_players={self.players_total} playing_active={self.players_active}>'
|
||||||
|
|
||||||
|
|
||||||
|
class Queue:
|
||||||
|
"""Pomice's very own queue implementation with some added features like:
|
||||||
|
- Toggleable shuffle
|
||||||
|
- Loop queue functionality
|
||||||
|
- Music player style queue, which doesn't remove tracks, allowing for playback of previously played tracks.
|
||||||
|
"""
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._queue = []
|
||||||
|
self._shuffle = False
|
||||||
|
if self._shuffle is True:
|
||||||
|
self._original_queue = []
|
||||||
|
self._looping = False
|
||||||
|
|
||||||
|
def put(self, item: Any):
|
||||||
|
"""Puts an item into the queue"""
|
||||||
|
return self._queue.extend(item)
|
||||||
|
|
||||||
|
def remove(self, item: Any):
|
||||||
|
"""Removes an item from the queue"""
|
||||||
|
if type(item) == int:
|
||||||
|
return self._queue.remove(self._queue[item])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue