fix issues related to loading files/http links, added spotify recommendation querying, changed loglevel enum behavior

This commit is contained in:
cloudwithax 2024-02-06 17:31:51 -05:00
parent 705ac9feab
commit 2a492c793f
4 changed files with 119 additions and 56 deletions

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'", "using 'pip install discord.py'",
) )
__version__ = "2.8.1" __version__ = "2.9.0a"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0" __license__ = "GPL-3.0"

View File

@ -273,3 +273,10 @@ class LogLevel(IntEnum):
WARN = 30 WARN = 30
ERROR = 40 ERROR = 40
CRITICAL = 50 CRITICAL = 50
@classmethod
def from_str(cls, level_str):
try:
return cls[level_str.upper()]
except KeyError:
raise ValueError(f"No such log level: {level_str}")

View File

@ -30,6 +30,7 @@ from . import applemusic
from . import spotify from . import spotify
from .enums import * from .enums import *
from .enums import LogLevel from .enums import LogLevel
from .exceptions import InvalidSpotifyClientAuthorization
from .exceptions import LavalinkVersionIncompatible from .exceptions import LavalinkVersionIncompatible
from .exceptions import NodeConnectionFailure from .exceptions import NodeConnectionFailure
from .exceptions import NodeCreationError from .exceptions import NodeCreationError
@ -702,61 +703,6 @@ class Node:
uri=spotify_results.uri, uri=spotify_results.uri,
) )
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
data: dict = await self.send(
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0]
info: dict = track["info"]
return [
Track(
track_id=track["track"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": info["length"],
"uri": info["uri"],
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
),
]
elif path.exists(path.dirname(query)):
local_file = Path(query)
data: dict = await self.send( # type: ignore
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0] # type: ignore
info: dict = track["info"] # type: ignore
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": info["length"],
"uri": quote(local_file.as_uri()),
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
),
]
else: else:
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query): if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}" query = f"{search_type}:{query}"
@ -819,6 +765,47 @@ class Node:
elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"): elif load_type in ("SEARCH_RESULT", "TRACK_LOADED", "track", "search"):
if self._version.major >= 4 and isinstance(data[data_type], dict): if self._version.major >= 4 and isinstance(data[data_type], dict):
data[data_type] = [data[data_type]] data[data_type] = [data[data_type]]
if path.exists(path.dirname(query)):
local_file = Path(query)
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": track["info"]["length"],
"uri": quote(local_file.as_uri()),
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
)
for track in data[data_type]
]
elif discord_url := URLRegex.DISCORD_MP3_URL.match(query):
return [
Track(
track_id=track["encoded"],
info={
"title": discord_url.group("file"),
"author": "Unknown",
"length": track["info"]["length"],
"uri": track["info"]["uri"],
"position": track["info"]["position"],
"identifier": track["info"]["identifier"],
},
ctx=ctx,
track_type=TrackType.HTTP,
filters=filters,
)
for track in data[data_type]
]
return [ return [
Track( Track(
track_id=track["encoded"], track_id=track["encoded"],
@ -885,6 +872,57 @@ class Node:
"The specfied track must be either a YouTube or Spotify track to recieve recommendations.", "The specfied track must be either a YouTube or Spotify track to recieve recommendations.",
) )
async def search_spotify_recommendations(
self,
query: str,
*,
ctx: Optional[commands.Context] = None,
filters: Optional[List[Filter]] = None,
) -> Optional[Union[List[Track], Playlist]]:
"""
Searches for recommendations on Spotify and returns a list of tracks based on the query.
You must have Spotify enabled for this to work.
You can pass in a discord.py Context object to get a
Context object on all tracks that get recommended.
"""
if not self._spotify_client:
raise InvalidSpotifyClientAuthorization(
"You must have Spotify enabled to use this feature.",
)
results = await self._spotify_client.track_search(query=query) # type: ignore
if not results:
raise TrackLoadError(
"Unable to find any tracks based on the query.",
)
tracks = [
Track(
track_id=track.id,
ctx=ctx,
track_type=TrackType.SPOTIFY,
info={
"title": track.name,
"author": track.artists,
"length": track.length,
"identifier": track.id,
"uri": track.uri,
"isStream": False,
"isSeekable": True,
"position": 0,
"thumbnail": track.image,
"isrc": track.isrc,
},
requester=self.bot.user,
)
for track in results
]
track = tracks[0]
return await self.get_recommendations(track=track, ctx=ctx)
class NodePool: class NodePool:
"""The base class for the node pool. """The base class for the node pool.

View File

@ -8,6 +8,7 @@ from typing import Dict
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
from urllib.parse import quote
import aiohttp import aiohttp
import orjson as json import orjson as json
@ -177,3 +178,20 @@ class Client:
tracks = [Track(track) for track in data["tracks"]] tracks = [Track(track) for track in data["tracks"]]
return tracks return tracks
async def track_search(self, *, query: str) -> List[Track]:
if not self._bearer_token or time.time() >= self._expiry:
await self._fetch_bearer_token()
request_url = f"https://api.spotify.com/v1/search?q={quote(query)}&type=track"
resp = await self.session.get(request_url, headers=self._bearer_headers)
if resp.status != 200:
raise SpotifyRequestException(
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
tracks = [Track(track) for track in data["tracks"]["items"]]
return tracks