add local sources, fix queue jumping on loop

This commit is contained in:
cloudwithax 2023-03-27 00:08:46 -04:00
parent 6ed2fd961b
commit b73af37bbf
6 changed files with 68 additions and 7 deletions

View File

@ -1 +1 @@
3.10.9 3.8.10

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import logging
import re import re
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict
@ -36,6 +37,7 @@ class Client:
self.token: str = "" self.token: str = ""
self.headers: Dict[str, str] = {} self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__)
async def request_token(self) -> None: async def request_token(self) -> None:
if not self.session: if not self.session:
@ -65,6 +67,7 @@ class Client:
).decode() ).decode()
token_data = json.loads(token_json) token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"]) self.expiry = datetime.fromtimestamp(token_data["exp"])
self._log.debug(f"Fetched Apple Music bearer token successfully")
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]: async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry: if not self.token or datetime.utcnow() > self.expiry:
@ -96,6 +99,9 @@ class Client:
f"Error while fetching results: {resp.status} {resp.reason}", f"Error while fetching results: {resp.status} {resp.reason}",
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug(
f"Made request to Apple Music API with status {resp.status} and response {data}",
)
data = data["data"][0] data = data["data"][0]

View File

@ -52,6 +52,8 @@ class TrackType(Enum):
TrackType.APPLE_MUSIC defines that the track is from Apple Music. TrackType.APPLE_MUSIC defines that the track is from Apple Music.
TrackType.HTTP defines that the track is from an HTTP source. TrackType.HTTP defines that the track is from an HTTP source.
TrackType.LOCAL defines that the track is from a local source.
""" """
# We don't have to define anything special for these, since these just serve as flags # We don't have to define anything special for these, since these just serve as flags
@ -60,6 +62,7 @@ class TrackType(Enum):
SPOTIFY = "spotify" SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music" APPLE_MUSIC = "apple_music"
HTTP = "http" HTTP = "http"
LOCAL = "local"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value

View File

@ -5,6 +5,8 @@ import logging
import random import random
import re import re
import time import time
from os import path
from pathlib import Path
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -530,9 +532,6 @@ class Node:
timestamp = None timestamp = None
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
if filters: if filters:
for filter in filters: for filter in filters:
filter.set_preload() filter.set_preload()
@ -696,7 +695,38 @@ class Node:
), ),
] ]
elif path.exists(path.dirname(query)):
local_file = Path(query)
data: dict = await self.send( # type: ignore
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0] # type: ignore
info: dict = track["info"] # type: ignore
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": info["length"],
"uri": quote(local_file.as_uri()),
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
),
]
else: else:
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
# If YouTube url contains a timestamp, capture it for use later. # If YouTube url contains a timestamp, capture it for use later.
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query): if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):

View File

@ -348,7 +348,23 @@ class Queue(Iterable[Track]):
track.filters = None track.filters = None
def jump(self, item: Track) -> None: def jump(self, item: Track) -> None:
"""Removes all tracks before the.""" """
Jumps to the item specified in the queue.
If the queue is not looping, the queue will be mutated.
Otherwise, the current item will be adjusted to the item
before the specified track.
The queue is adjusted so that the next item that is retrieved
is the track that is specified, effectively 'jumping' the queue.
"""
if self._loop_mode == LoopMode.TRACK:
raise QueueException("Jumping the queue whilst looping a track is not allowed.")
index = self.find_position(item) index = self.find_position(item)
if self._loop_mode == LoopMode.QUEUE:
self._current_item = self._queue[index - 1]
else:
new_queue = self._queue[index : self.size] new_queue = self._queue[index : self.size]
self._queue = new_queue self._queue = new_queue

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
import re import re
import time import time
from base64 import b64encode from base64 import b64encode
@ -46,6 +47,7 @@ class Client:
"Authorization": f"Basic {self._auth_token.decode()}", "Authorization": f"Basic {self._auth_token.decode()}",
} }
self._bearer_headers: Optional[Dict] = None self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__)
async def _fetch_bearer_token(self) -> None: async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"} _data = {"grant_type": "client_credentials"}
@ -60,6 +62,7 @@ class Client:
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug(f"Fetched Spotify bearer token successfully")
self._bearer_token = data["access_token"] self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10) self._expiry = time.time() + (int(data["expires_in"]) - 10)
@ -87,6 +90,9 @@ class Client:
) )
data: dict = await resp.json(loads=json.loads) data: dict = await resp.json(loads=json.loads)
self._log.debug(
f"Made request to Spotify API with status {resp.status} and response {data}",
)
if spotify_type == "track": if spotify_type == "track":
return Track(data) return Track(data)