add per-track filter applying

This commit is contained in:
cloudwithax 2022-10-10 23:25:53 -04:00
parent c9bba65f48
commit 1e12d79d68
3 changed files with 35 additions and 9 deletions

View File

@ -1,9 +1,10 @@
import re
from typing import Optional
from typing import List, Optional
from discord.ext import commands
from .enums import SearchType
from .filters import Filter
SOUNDCLOUD_URL_REGEX = re.compile(
r"^(https?:\/\/)?(www.)?(m\.)?soundcloud\.com\/[\w\-\.]+(\/)+[\w\-\.]+/?$"
@ -24,10 +25,12 @@ class Track:
spotify: bool = False,
search_type: SearchType = SearchType.ytsearch,
spotify_track = None,
filters: Optional[List[Filter]] = None
):
self.track_id = track_id
self.info = info
self.spotify = spotify
self.filters: List[Filter] = filters
self.original: Optional[Track] = None if spotify else self
self._search_type = search_type

View File

@ -238,7 +238,8 @@ class Player(VoiceProtocol):
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None
):
"""Fetches tracks from the node's REST api to parse into Lavalink.
@ -246,10 +247,13 @@ class Player(VoiceProtocol):
you can also pass in a Spotify URL of a playlist, album or track and it will be parsed
accordingly.
You can also pass in a discord.py Context object to get a
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
to be applied to your track once it plays.
"""
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type)
return await self._node.get_tracks(query, ctx=ctx, search_type=search_type, filters=filters)
async def connect(self, *, timeout: float, reconnect: bool, self_deaf: bool = False, self_mute: bool = False):
await self.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
@ -291,6 +295,7 @@ class Player(VoiceProtocol):
ignore_if_playing: bool = False
) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
# Make sure we've never searched the track before
if track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully)
@ -329,6 +334,15 @@ class Player(VoiceProtocol):
"noReplace": ignore_if_playing
}
# Apply track filters
if track.filters:
# First lets remove all filters quickly
await self.reset_filters()
# Now apply all filters
for filter in track.filters:
await self.add_filter(filter=filter)
if end > 0:
data["endTime"] = str(end)

View File

@ -4,7 +4,7 @@ import asyncio
import json
import random
import re
from typing import Dict, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING
from urllib.parse import quote
import aiohttp
@ -27,6 +27,7 @@ from .exceptions import (
NoNodesAvailable,
TrackLoadError
)
from .filters import Filter
from .objects import Playlist, Track
from .utils import ExponentialBackoff, NodeStats, Ping
@ -286,15 +287,19 @@ class Node:
query: str,
*,
ctx: Optional[commands.Context] = None,
search_type: SearchType = SearchType.ytsearch
search_type: SearchType = SearchType.ytsearch,
filters: Optional[List[Filter]] = None
):
"""Fetches tracks from the node's REST api to parse into Lavalink.
If you passed in Spotify API credentials, you can also pass in a
Spotify URL of a playlist, album or track and it will be parsed accordingly.
You can also pass in a discord.py Context object to get a
You can pass in a discord.py Context object to get a
Context object on any track you search.
You may also pass in a List of filters
to be applied to your track once it plays.
"""
if not URL_REGEX.match(query) and not re.match(r"(?:ytm?|sc)search:.", query):
@ -318,6 +323,7 @@ class Node:
search_type=search_type,
spotify=True,
spotify_track=spotify_results,
filters=filters,
info={
"title": spotify_results.name,
"author": spotify_results.artists,
@ -340,6 +346,7 @@ class Node:
search_type=search_type,
spotify=True,
spotify_track=track,
filters=filters,
info={
"title": track.name,
"author": track.artists,
@ -384,7 +391,8 @@ class Node:
"position": info.get("position"),
"identifier": info.get("identifier")
},
ctx=ctx
ctx=ctx,
filters=filters
)
]
@ -420,7 +428,8 @@ class Node:
Track(
track_id=track["track"],
info=track["info"],
ctx=ctx
ctx=ctx,
filters=filters
)
for track in data["tracks"]
]