Merge pull request #39 from NiceAesth/fix
feat: allow custom logging handler; fix: missing typings
This commit is contained in:
commit
6d96a9e53d
2
Pipfile
2
Pipfile
|
|
@ -13,6 +13,8 @@ pre-commit = "*"
|
||||||
furo = "*"
|
furo = "*"
|
||||||
sphinx = "*"
|
sphinx = "*"
|
||||||
myst-parser = "*"
|
myst-parser = "*"
|
||||||
|
black = "*"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.8"
|
python_version = "3.8"
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,10 @@ After you have initialized your function, we need to fill in the proper paramete
|
||||||
- `LogLevel`
|
- `LogLevel`
|
||||||
- The logging level for the node. The default logging level is `LogLevel.INFO`.
|
- The logging level for the node. The default logging level is `LogLevel.INFO`.
|
||||||
|
|
||||||
|
* - `log_handler`
|
||||||
|
- `Optional[logging.Handler]`
|
||||||
|
- The logging handler for the node. Set to `None` to disable the default logging handler.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -131,10 +131,10 @@ class Player(VoiceProtocol):
|
||||||
"_player_endpoint_uri",
|
"_player_endpoint_uri",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, client: Client, channel: VoiceChannel):
|
def __call__(self, client: Client, channel: VoiceChannel) -> Player:
|
||||||
self.client: Client = client
|
self.client = client
|
||||||
self.channel: VoiceChannel = channel
|
self.channel = channel
|
||||||
self._guild: Guild = channel.guild
|
self._guild = channel.guild
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -262,7 +262,7 @@ class Player(VoiceProtocol):
|
||||||
"""
|
"""
|
||||||
return self.guild.id not in self._node._players
|
return self.guild.id not in self._node._players
|
||||||
|
|
||||||
def _adjust_end_time(self):
|
def _adjust_end_time(self) -> Optional[str]:
|
||||||
if self._node._version >= LavalinkVersion(3, 7, 5):
|
if self._node._version >= LavalinkVersion(3, 7, 5):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from urllib.parse import quote
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from discord import Client
|
from discord import Client
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from discord.utils import MISSING
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from . import applemusic
|
from . import applemusic
|
||||||
|
|
@ -49,6 +50,8 @@ __all__ = (
|
||||||
"NodePool",
|
"NodePool",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?")
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node:
|
||||||
"""The base class for a node.
|
"""The base class for a node.
|
||||||
|
|
@ -86,6 +89,7 @@ class Node:
|
||||||
"_apple_music_client",
|
"_apple_music_client",
|
||||||
"_route_planner",
|
"_route_planner",
|
||||||
"_log",
|
"_log",
|
||||||
|
"_log_handler",
|
||||||
"_stats",
|
"_stats",
|
||||||
"available",
|
"available",
|
||||||
)
|
)
|
||||||
|
|
@ -108,6 +112,7 @@ class Node:
|
||||||
apple_music: bool = False,
|
apple_music: bool = False,
|
||||||
fallback: bool = False,
|
fallback: bool = False,
|
||||||
log_level: LogLevel = LogLevel.INFO,
|
log_level: LogLevel = LogLevel.INFO,
|
||||||
|
log_handler: Optional[logging.Handler] = MISSING,
|
||||||
):
|
):
|
||||||
self._bot: commands.Bot = bot
|
self._bot: commands.Bot = bot
|
||||||
self._host: str = host
|
self._host: str = host
|
||||||
|
|
@ -119,6 +124,7 @@ class Node:
|
||||||
self._secure: bool = secure
|
self._secure: bool = secure
|
||||||
self._fallback: bool = fallback
|
self._fallback: bool = fallback
|
||||||
self._log_level: LogLevel = log_level
|
self._log_level: LogLevel = log_level
|
||||||
|
self._log_handler = log_handler
|
||||||
|
|
||||||
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
|
||||||
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
|
||||||
|
|
@ -130,7 +136,7 @@ class Node:
|
||||||
|
|
||||||
self._session_id: Optional[str] = None
|
self._session_id: Optional[str] = None
|
||||||
self._available: bool = False
|
self._available: bool = False
|
||||||
self._version: LavalinkVersion = None
|
self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
|
||||||
|
|
||||||
self._route_planner = RoutePlanner(self)
|
self._route_planner = RoutePlanner(self)
|
||||||
self._log = self._setup_logging(self._log_level)
|
self._log = self._setup_logging(self._log_level)
|
||||||
|
|
@ -212,6 +218,13 @@ class Node:
|
||||||
|
|
||||||
def _setup_logging(self, level: LogLevel) -> logging.Logger:
|
def _setup_logging(self, level: LogLevel) -> logging.Logger:
|
||||||
logger = logging.getLogger("pomice")
|
logger = logging.getLogger("pomice")
|
||||||
|
logger.setLevel(level)
|
||||||
|
handler = None
|
||||||
|
|
||||||
|
if self._log_handler is not None:
|
||||||
|
handler = self._log_handler
|
||||||
|
|
||||||
|
elif self._log_handler is MISSING:
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
dt_fmt = "%Y-%m-%d %H:%M:%S"
|
dt_fmt = "%Y-%m-%d %H:%M:%S"
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
|
|
@ -220,25 +233,37 @@ class Node:
|
||||||
style="{",
|
style="{",
|
||||||
)
|
)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.setLevel(level)
|
|
||||||
|
if handler:
|
||||||
|
logger.handlers.clear()
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
async def _handle_version_check(self, version: str):
|
async def _handle_version_check(self, version: str) -> None:
|
||||||
if version.endswith("-SNAPSHOT"):
|
if version.endswith("-SNAPSHOT"):
|
||||||
# we're just gonna assume all snapshot versions correlate with v4
|
# we're just gonna assume all snapshot versions correlate with v4
|
||||||
self._version = LavalinkVersion(major=4, minor=0, fix=0)
|
self._version = LavalinkVersion(major=4, minor=0, fix=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
# this crazy ass line maps the split version string into
|
_version_rx = VERSION_REGEX.match(version)
|
||||||
# an iterable with ints instead of strings and then
|
if not _version_rx:
|
||||||
# turns that iterable into a tuple. yeah, i know
|
self._available = False
|
||||||
|
raise LavalinkVersionIncompatible(
|
||||||
|
"The Lavalink version you're using is incompatible. "
|
||||||
|
"Lavalink version 3.7.0 or above is required to use this library.",
|
||||||
|
)
|
||||||
|
|
||||||
split = tuple(map(int, tuple(version.split("."))))
|
_version_groups = _version_rx.groups()
|
||||||
self._version = LavalinkVersion(*split)
|
major, minor, fix = (
|
||||||
if not version.endswith("-SNAPSHOT") and (
|
int(_version_groups[0] or 0),
|
||||||
self._version.major == 3 and self._version.minor < 7
|
int(_version_groups[1] or 0),
|
||||||
):
|
int(_version_groups[2] or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
|
||||||
|
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
|
||||||
|
if self._version < LavalinkVersion(3, 7, 0):
|
||||||
self._available = False
|
self._available = False
|
||||||
raise LavalinkVersionIncompatible(
|
raise LavalinkVersionIncompatible(
|
||||||
"The Lavalink version you're using is incompatible. "
|
"The Lavalink version you're using is incompatible. "
|
||||||
|
|
@ -877,6 +902,7 @@ class NodePool:
|
||||||
apple_music: bool = False,
|
apple_music: bool = False,
|
||||||
fallback: bool = False,
|
fallback: bool = False,
|
||||||
log_level: LogLevel = LogLevel.INFO,
|
log_level: LogLevel = LogLevel.INFO,
|
||||||
|
log_handler: Optional[logging.Handler] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Creates a Node object to be then added into the node pool.
|
"""Creates a Node object to be then added into the node pool.
|
||||||
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
For Spotify searching capabilites, pass in valid Spotify API credentials.
|
||||||
|
|
@ -902,6 +928,7 @@ class NodePool:
|
||||||
apple_music=apple_music,
|
apple_music=apple_music,
|
||||||
fallback=fallback,
|
fallback=fallback,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
|
log_handler=log_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
await node.connect()
|
await node.connect()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue