fix player channel bug and fix endtime

This commit is contained in:
cloudwithax 2023-03-12 20:51:05 -04:00
parent 8ee1a39cb5
commit c35fd650f7
4 changed files with 59 additions and 29 deletions

View File

@ -20,7 +20,7 @@ if not discord.version_info.major >= 2:
"using 'pip install discord.py'",
)
__version__ = "2.2"
__version__ = "2.3a"
__title__ = "pomice"
__author__ = "cloudwithax"
__license__ = "GPL-3.0"

View File

@ -129,8 +129,11 @@ class Player(VoiceProtocol):
"_player_endpoint_uri",
)
def __call__(self, client: Client, channel: VoiceChannel) -> Player:
self.__init__(client, channel) # type: ignore
def __call__(self, client: Client, channel: VoiceChannel):
self.client: Client = client
self.channel: VoiceChannel = channel
self._guild: Guild = channel.guild
return self
def __init__(
@ -256,12 +259,21 @@ class Player(VoiceProtocol):
"""
return self.guild.id not in self._node._players
def _adjust_end_time(self):
version = self._node._version
if version.major == 4:
return None
if version.major == 3:
if version.minor == 7 and version.fix == 5:
return None
else:
return "0"
async def _update_state(self, data: dict) -> None:
state: dict = data.get("state", {})
self._last_update = time.time() * 1000.0
self._last_update = int(state.get("time", 0))
self._is_connected = bool(state.get("connected"))
position = state.get("position")
self._last_position = int(position) if position else 0
self._last_position = int(state.get("position", 0))
async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None:
if {"sessionId", "event"} != self._voice_state.keys():
@ -319,11 +331,8 @@ class Player(VoiceProtocol):
self._ending_track = self._current
async def _swap_node(self, *, new_node: Node) -> None:
data: dict = {
"position": self.position,
}
if self.current:
data["encodedTrack"] = self.current.track_id
data: dict = {"position": self.position, "encodedTrack": self.current.track_id}
del self._node._players[self._guild.id]
self._node = new_node
@ -422,6 +431,9 @@ class Player(VoiceProtocol):
) -> Track:
"""Plays a track. If a Spotify track is passed in, it will be handled accordingly."""
end_time = self._adjust_end_time()
print(f"got end time of {end_time}")
# Make sure we've never searched the track before
if track.original is None:
# First lets try using the tracks ISRC, every track has one (hopefully)
@ -453,7 +465,7 @@ class Player(VoiceProtocol):
data = {
"encodedTrack": search.track_id,
"position": str(start),
"endTime": str(track.length),
"endTime": end_time,
}
track.original = search
track.track_id = search.track_id
@ -462,7 +474,7 @@ class Player(VoiceProtocol):
data = {
"encodedTrack": track.track_id,
"position": str(start),
"endTime": str(track.length),
"endTime": end_time,
}
# Lets set the current track before we play it so any
@ -489,8 +501,8 @@ class Player(VoiceProtocol):
# Lavalink v4 changed the way the end time parameter works
# so now the end time cannot be zero.
# If it isnt zero, it'll match the length of the track,
# otherwise itll be set here:
# If it isnt zero, it'll be set to None.
# Otherwise, it'll be set here:
if end > 0:
data["endTime"] = str(end)

View File

@ -34,6 +34,7 @@ from .objects import Playlist
from .objects import Track
from .routeplanner import RoutePlanner
from .utils import ExponentialBackoff
from .utils import LavalinkVersion
from .utils import NodeStats
from .utils import Ping
@ -122,7 +123,7 @@ class Node:
self._session_id: Optional[str] = None
self._available: bool = False
self._version: int = 0
self._version: LavalinkVersion = None
self._route_planner = RoutePlanner(self)
@ -201,6 +202,27 @@ class Node:
"""Alias for `Node.latency`, returns the latency of the node"""
return self.latency
async def _handle_version_check(self, version: str):
if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4
self._version = LavalinkVersion(major=4, minor=0, fix=0)
return
# this crazy ass line maps the split version string into
# an iterable with ints instead of strings and then
# turns that iterable into a tuple. yeah, i know
split = tuple(map(int, tuple(version.split("."))))
self._version = LavalinkVersion(*split)
if not version.endswith("-SNAPSHOT") and (
self._version.major == 3 and self._version.minor < 7
):
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
async def _update_handler(self, data: dict) -> None:
await self._bot.wait_until_ready()
@ -294,7 +316,7 @@ class Node:
uri: str = (
f"{self._rest_uri}/"
f'{f"v{self._version}/" if include_version else ""}'
f'{f"v{self._version.major}/" if include_version else ""}'
f"{path}"
f'{f"/{guild_id}" if guild_id else ""}'
f'{f"?{query}" if query else ""}'
@ -338,22 +360,11 @@ class Node:
ignore_if_available=True,
include_version=False,
)
version = version.replace(".", "")
if not version.endswith("-SNAPSHOT") and int(version) < 370:
self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4
self._version = 4
else:
self._version = int(version[:1])
await self._handle_version_check(version=version)
self._websocket = await self._session.ws_connect(
f"{self._websocket_uri}/v{self._version}/websocket",
f"{self._websocket_uri}/v{self._version.major}/websocket",
headers=self._headers,
heartbeat=self._heartbeat,
)

View File

@ -8,6 +8,7 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import NamedTuple
from typing import Optional
from .enums import RouteIPType
@ -224,3 +225,9 @@ class Ping:
s_runtime = 1000 * (cost_time)
return s_runtime
class LavalinkVersion(NamedTuple):
major: int
minor: int
fix: int