fix player channel bug and fix endtime

This commit is contained in:
cloudwithax 2023-03-12 20:51:05 -04:00
parent 8ee1a39cb5
commit c35fd650f7
4 changed files with 59 additions and 29 deletions

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'", "using 'pip install discord.py'",
) )
__version__ = "2.2" __version__ = "2.3a"
__title__ = "pomice" __title__ = "pomice"
__author__ = "cloudwithax" __author__ = "cloudwithax"
__license__ = "GPL-3.0" __license__ = "GPL-3.0"

View File

@ -129,8 +129,11 @@ class Player(VoiceProtocol):
"_player_endpoint_uri", "_player_endpoint_uri",
) )
def __call__(self, client: Client, channel: VoiceChannel) -> Player: def __call__(self, client: Client, channel: VoiceChannel):
self.__init__(client, channel) # type: ignore self.client: Client = client
self.channel: VoiceChannel = channel
self._guild: Guild = channel.guild
return self return self
def __init__( def __init__(
@ -256,12 +259,21 @@ class Player(VoiceProtocol):
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
def _adjust_end_time(self):
version = self._node._version
if version.major == 4:
return None
if version.major == 3:
if version.minor == 7 and version.fix == 5:
return None
else:
return "0"
async def _update_state(self, data: dict) -> None: async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {}) state: dict = data.get("state", {})
self._last_update = time.time() * 1000.0 self._last_update = int(state.get("time", 0))
self._is_connected = bool(state.get("connected")) self._is_connected = bool(state.get("connected"))
position = state.get("position") self._last_position = int(state.get("position", 0))
self._last_position = int(position) if position else 0
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys(): if {"sessionId", "event"} != self._voice_state.keys():
@ -319,11 +331,8 @@ class Player(VoiceProtocol):
self._ending_track = self._current self._ending_track = self._current
async def _swap_node(self, *, new_node: Node) -> None: async def _swap_node(self, *, new_node: Node) -> None:
data: dict = {
"position": self.position,
}
if self.current: if self.current:
data["encodedTrack"] = self.current.track_id data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
del self._node._players[self._guild.id] del self._node._players[self._guild.id]
self._node = new_node self._node = new_node
@ -422,6 +431,9 @@ class Player(VoiceProtocol):
) -> Track: ) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly.""" """Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
end_time = self._adjust_end_time()
print(f"got end time of {end_time}")
# Make sure we've never searched the track before # Make sure we've never searched the track before
if track.original is None: if track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully) # First lets try using the tracks ISRC, every track has one (hopefully)
@ -453,7 +465,7 @@ class Player(VoiceProtocol):
data = { data = {
"encodedTrack": search.track_id, "encodedTrack": search.track_id,
"position": str(start), "position": str(start),
"endTime": str(track.length), "endTime": end_time,
} }
track.original = search track.original = search
track.track_id = search.track_id track.track_id = search.track_id
@ -462,7 +474,7 @@ class Player(VoiceProtocol):
data = { data = {
"encodedTrack": track.track_id, "encodedTrack": track.track_id,
"position": str(start), "position": str(start),
"endTime": str(track.length), "endTime": end_time,
} }
# Lets set the current track before we play it so any # Lets set the current track before we play it so any
@ -489,8 +501,8 @@ class Player(VoiceProtocol):
# Lavalink v4 changed the way the end time parameter works # Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero. # so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track, # If it isnt zero, it'll be set to None.
# otherwise itll be set here: # Otherwise, it'll be set here:
if end > 0: if end > 0:
data["endTime"] = str(end) data["endTime"] = str(end)

View File

@ -34,6 +34,7 @@ from .objects import Playlist
from .objects import Track from .objects import Track
from .routeplanner import RoutePlanner from .routeplanner import RoutePlanner
from .utils import ExponentialBackoff from .utils import ExponentialBackoff
from .utils import LavalinkVersion
from .utils import NodeStats from .utils import NodeStats
from .utils import Ping from .utils import Ping
@ -122,7 +123,7 @@ class Node:
self._session_id: Optional[str] = None self._session_id: Optional[str] = None
self._available: bool = False self._available: bool = False
self._version: int = 0 self._version: LavalinkVersion = None
self._route_planner = RoutePlanner(self) self._route_planner = RoutePlanner(self)
@ -201,6 +202,27 @@ class Node:
"""Alias for `Node.latency`, returns the latency of the node""" """Alias for `Node.latency`, returns the latency of the node"""
return self.latency return self.latency
async def _handle_version_check(self, version: str):
if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4
self._version = LavalinkVersion(major=4, minor=0, fix=0)
return
# this crazy ass line maps the split version string into
# an iterable with ints instead of strings and then
# turns that iterable into a tuple. yeah, i know
split = tuple(map(int, tuple(version.split("."))))
self._version = LavalinkVersion(*split)
if not version.endswith("-SNAPSHOT") and (
self._version.major == 3 and self._version.minor < 7
):
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
async def _update_handler(self, data: dict) -> None: async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready() await self._bot.wait_until_ready()
@ -294,7 +316,7 @@ class Node:
uri: str = ( uri: str = (
f"{self._rest_uri}/" f"{self._rest_uri}/"
f'{f"v{self._version}/" if include_version else ""}' f'{f"v{self._version.major}/" if include_version else ""}'
f"{path}" f"{path}"
f'{f"/{guild_id}" if guild_id else ""}' f'{f"/{guild_id}" if guild_id else ""}'
f'{f"?{query}" if query else ""}' f'{f"?{query}" if query else ""}'
@ -338,22 +360,11 @@ class Node:
ignore_if_available=True, ignore_if_available=True,
include_version=False, include_version=False,
) )
version = version.replace(".", "")
if not version.endswith("-SNAPSHOT") and int(version) < 370:
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
if version.endswith("-SNAPSHOT"): await self._handle_version_check(version=version)
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = int(version[:1])
self._websocket = await self._session.ws_connect( self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket", f"{self._websocket_uri}/v{self._version.major}/websocket",
headers=self._headers, headers=self._headers,
heartbeat=self._heartbeat, heartbeat=self._heartbeat,
) )

View File

@ -8,6 +8,7 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Iterable from typing import Iterable
from typing import NamedTuple
from typing import Optional from typing import Optional
from .enums import RouteIPType from .enums import RouteIPType
@ -224,3 +225,9 @@ class Ping:
s_runtime = 1000 * (cost_time) s_runtime = 1000 * (cost_time)
return s_runtime return s_runtime
class LavalinkVersion(NamedTuple):
major: int
minor: int
fix: int