From 367a215b05c7fb652a7058696bc307ecb16a05c4 Mon Sep 17 00:00:00 2001 From: NiceAesth Date: Mon, 13 Mar 2023 15:15:27 +0200 Subject: [PATCH] feat: allow disabling logging handler; fix: code quality --- docs/hdi/pool.md | 4 ++++ pomice/player.py | 10 +++------- pomice/pool.py | 34 +++++++++++++++++++++++----------- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/docs/hdi/pool.md b/docs/hdi/pool.md index 12f8a4b..790570e 100644 --- a/docs/hdi/pool.md +++ b/docs/hdi/pool.md @@ -70,6 +70,10 @@ After you have initialized your function, we need to fill in the proper paramete - `LogLevel` - The logging level for the node. The default logging level is `LogLevel.INFO`. +* - `log_handler` + - `Optional[logging.Handler]` + - The logging handler for the node. Set to `None` to disable the default logging handler. + ::: diff --git a/pomice/player.py b/pomice/player.py index 93aae48..ecc180b 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -131,12 +131,8 @@ class Player(VoiceProtocol): "_player_endpoint_uri", ) - def __call__(self, client: Client, channel: VoiceChannel): - self.client: Client = client - self.channel: VoiceChannel = channel - self._guild: Guild = channel.guild - - return self + def __call__(self, client: Client, channel: VoiceChannel) -> Player: + return self.__class__(client, channel) def __init__( self, @@ -262,7 +258,7 @@ class Player(VoiceProtocol): """ return self.guild.id not in self._node._players - def _adjust_end_time(self): + def _adjust_end_time(self) -> Optional[str]: if self._node._version >= LavalinkVersion(3, 7, 5): return None diff --git a/pomice/pool.py b/pomice/pool.py index ae97773..6526293 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -17,6 +17,7 @@ from urllib.parse import quote import aiohttp from discord import Client from discord.ext import commands +from discord.utils import MISSING from . import __version__ from . import applemusic @@ -86,6 +87,7 @@ class Node: "_apple_music_client", "_route_planner", "_log", + "_log_handler", "_stats", "available", ) @@ -108,6 +110,7 @@ class Node: apple_music: bool = False, fallback: bool = False, log_level: LogLevel = LogLevel.INFO, + log_handler: Optional[logging.Handler] = MISSING, ): self._bot: commands.Bot = bot self._host: str = host @@ -119,6 +122,7 @@ class Node: self._secure: bool = secure self._fallback: bool = fallback self._log_level: LogLevel = log_level + self._log_handler = log_handler self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" @@ -130,7 +134,7 @@ class Node: self._session_id: Optional[str] = None self._available: bool = False - self._version: LavalinkVersion = None + self._version: LavalinkVersion = LavalinkVersion(0, 0, 0) self._route_planner = RoutePlanner(self) self._log = self._setup_logging(self._log_level) @@ -212,19 +216,25 @@ class Node: def _setup_logging(self, level: LogLevel) -> logging.Logger: logger = logging.getLogger("pomice") - handler = logging.StreamHandler() - dt_fmt = "%Y-%m-%d %H:%M:%S" - formatter = logging.Formatter( - "[{asctime}] [{levelname:<8}] {name}: {message}", - dt_fmt, - style="{", - ) - handler.setFormatter(formatter) logger.setLevel(level) - logger.addHandler(handler) + + if self._log_handler is not None: + handler = self._log_handler + + elif self._log_handler is MISSING: + handler = logging.StreamHandler() + dt_fmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter( + "[{asctime}] [{levelname:<8}] {name}: {message}", + dt_fmt, + style="{", + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger - async def _handle_version_check(self, version: str): + async def _handle_version_check(self, version: str) -> None: if version.endswith("-SNAPSHOT"): # we're just gonna assume all snapshot versions correlate with v4 self._version = LavalinkVersion(major=4, minor=0, fix=0) @@ -877,6 +887,7 @@ class NodePool: apple_music: bool = False, fallback: bool = False, log_level: LogLevel = LogLevel.INFO, + log_handler: Optional[logging.Handler] = None, ) -> Node: """Creates a Node object to be then added into the node pool. For Spotify searching capabilites, pass in valid Spotify API credentials. @@ -902,6 +913,7 @@ class NodePool: apple_music=apple_music, fallback=fallback, log_level=log_level, + log_handler=log_handler, ) await node.connect()