add local sources, fix queue jumping on loop

This commit is contained in:
cloudwithax 2023-03-27 00:08:46 -04:00
parent 6ed2fd961b
commit b73af37bbf
6 changed files with 68 additions and 7 deletions

View File

@ -1 +1 @@
3.10.9
3.8.10

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import base64
import logging
import re
from datetime import datetime
from typing import Dict
@ -36,6 +37,7 @@ class Client:
self.token: str = ""
self.headers: Dict[str, str] = {}
self.session: aiohttp.ClientSession = None # type: ignore
self._log = logging.getLogger(__name__)
async def request_token(self) -> None:
if not self.session:
@ -65,6 +67,7 @@ class Client:
).decode()
token_data = json.loads(token_json)
self.expiry = datetime.fromtimestamp(token_data["exp"])
self._log.debug(f"Fetched Apple Music bearer token successfully")
async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
if not self.token or datetime.utcnow() > self.expiry:
@ -96,6 +99,9 @@ class Client:
f"Error while fetching results: {resp.status} {resp.reason}",
)
data: dict = await resp.json(loads=json.loads)
self._log.debug(
f"Made request to Apple Music API with status {resp.status} and response {data}",
)
data = data["data"][0]

View File

@ -52,6 +52,8 @@ class TrackType(Enum):
TrackType.APPLE_MUSIC defines that the track is from Apple Music.
TrackType.HTTP defines that the track is from an HTTP source.
TrackType.LOCAL defines that the track is from a local source.
"""
# We don't have to define anything special for these, since these just serve as flags
@ -60,6 +62,7 @@ class TrackType(Enum):
SPOTIFY = "spotify"
APPLE_MUSIC = "apple_music"
HTTP = "http"
LOCAL = "local"
def __str__(self) -> str:
return self.value

View File

@ -5,6 +5,8 @@ import logging
import random
import re
import time
from os import path
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
@ -530,9 +532,6 @@ class Node:
timestamp = None
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
if filters:
for filter in filters:
filter.set_preload()
@ -696,7 +695,38 @@ class Node:
),
]
elif path.exists(path.dirname(query)):
local_file = Path(query)
data: dict = await self.send( # type: ignore
method="GET",
path="loadtracks",
query=f"identifier={quote(query)}",
)
track: dict = data["tracks"][0] # type: ignore
info: dict = track["info"] # type: ignore
return [
Track(
track_id=track["track"],
info={
"title": local_file.name,
"author": "Unknown",
"length": info["length"],
"uri": quote(local_file.as_uri()),
"position": info["position"],
"identifier": info["identifier"],
},
ctx=ctx,
track_type=TrackType.LOCAL,
filters=filters,
),
]
else:
if not URLRegex.BASE_URL.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
query = f"{search_type}:{query}"
# If YouTube url contains a timestamp, capture it for use later.
if match := URLRegex.YOUTUBE_TIMESTAMP.match(query):

View File

@ -348,7 +348,23 @@ class Queue(Iterable[Track]):
track.filters = None
def jump(self, item: Track) -> None:
"""Removes all tracks before the."""
"""
Jumps to the item specified in the queue.
If the queue is not looping, the queue will be mutated.
Otherwise, the current item will be adjusted to the item
before the specified track.
The queue is adjusted so that the next item that is retrieved
is the track that is specified, effectively 'jumping' the queue.
"""
if self._loop_mode == LoopMode.TRACK:
raise QueueException("Jumping the queue whilst looping a track is not allowed.")
index = self.find_position(item)
new_queue = self._queue[index : self.size]
self._queue = new_queue
if self._loop_mode == LoopMode.QUEUE:
self._current_item = self._queue[index - 1]
else:
new_queue = self._queue[index : self.size]
self._queue = new_queue

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import re
import time
from base64 import b64encode
@ -46,6 +47,7 @@ class Client:
"Authorization": f"Basic {self._auth_token.decode()}",
}
self._bearer_headers: Optional[Dict] = None
self._log = logging.getLogger(__name__)
async def _fetch_bearer_token(self) -> None:
_data = {"grant_type": "client_credentials"}
@ -60,6 +62,7 @@ class Client:
)
data: dict = await resp.json(loads=json.loads)
self._log.debug(f"Fetched Spotify bearer token successfully")
self._bearer_token = data["access_token"]
self._expiry = time.time() + (int(data["expires_in"]) - 10)
@ -87,6 +90,9 @@ class Client:
)
data: dict = await resp.json(loads=json.loads)
self._log.debug(
f"Made request to Spotify API with status {resp.status} and response {data}",
)
if spotify_type == "track":
return Track(data)