Merge pull request #39 from NiceAesth/fix

feat: allow custom logging handler; fix: missing typings
This commit is contained in:
Clxud 2023-03-13 18:53:28 -04:00 committed by GitHub
commit 6d96a9e53d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 24 deletions

View File

@ -13,6 +13,8 @@ pre-commit = "*"
furo = "*" furo = "*"
sphinx = "*" sphinx = "*"
myst-parser = "*" myst-parser = "*"
black = "*"
typing-extensions = "*"
[requires] [requires]
python_version = "3.8" python_version = "3.8"

View File

@ -70,6 +70,10 @@ After you have initialized your function, we need to fill in the proper paramete
- `LogLevel` - `LogLevel`
- The logging level for the node. The default logging level is `LogLevel.INFO`. - The logging level for the node. The default logging level is `LogLevel.INFO`.
* - `log_handler`
- `Optional[logging.Handler]`
- The logging handler for the node. Set to `None` to disable the default logging handler.
::: :::

View File

@ -131,10 +131,10 @@ class Player(VoiceProtocol):
"_player_endpoint_uri", "_player_endpoint_uri",
) )
def __call__(self, client: Client, channel: VoiceChannel): def __call__(self, client: Client, channel: VoiceChannel) -> Player:
self.client: Client = client self.client = client
self.channel: VoiceChannel = channel self.channel = channel
self._guild: Guild = channel.guild self._guild = channel.guild
return self return self
@ -262,7 +262,7 @@ class Player(VoiceProtocol):
""" """
return self.guild.id not in self._node._players return self.guild.id not in self._node._players
def _adjust_end_time(self): def _adjust_end_time(self) -> Optional[str]:
if self._node._version >= LavalinkVersion(3, 7, 5): if self._node._version >= LavalinkVersion(3, 7, 5):
return None return None

View File

@ -17,6 +17,7 @@ from urllib.parse import quote
import aiohttp import aiohttp
from discord import Client from discord import Client
from discord.ext import commands from discord.ext import commands
from discord.utils import MISSING
from . import __version__ from . import __version__
from . import applemusic from . import applemusic
@ -49,6 +50,8 @@ __all__ = (
"NodePool", "NodePool",
) )
VERSION_REGEX = re.compile(r"(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:[a-zA-Z0-9_-]+)?")
class Node: class Node:
"""The base class for a node. """The base class for a node.
@ -86,6 +89,7 @@ class Node:
"_apple_music_client", "_apple_music_client",
"_route_planner", "_route_planner",
"_log", "_log",
"_log_handler",
"_stats", "_stats",
"available", "available",
) )
@ -108,6 +112,7 @@ class Node:
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False, fallback: bool = False,
log_level: LogLevel = LogLevel.INFO, log_level: LogLevel = LogLevel.INFO,
log_handler: Optional[logging.Handler] = MISSING,
): ):
self._bot: commands.Bot = bot self._bot: commands.Bot = bot
self._host: str = host self._host: str = host
@ -119,6 +124,7 @@ class Node:
self._secure: bool = secure self._secure: bool = secure
self._fallback: bool = fallback self._fallback: bool = fallback
self._log_level: LogLevel = log_level self._log_level: LogLevel = log_level
self._log_handler = log_handler
self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._websocket_uri: str = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}"
self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" self._rest_uri: str = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}"
@ -130,7 +136,7 @@ class Node:
self._session_id: Optional[str] = None self._session_id: Optional[str] = None
self._available: bool = False self._available: bool = False
self._version: LavalinkVersion = None self._version: LavalinkVersion = LavalinkVersion(0, 0, 0)
self._route_planner = RoutePlanner(self) self._route_planner = RoutePlanner(self)
self._log = self._setup_logging(self._log_level) self._log = self._setup_logging(self._log_level)
@ -212,6 +218,13 @@ class Node:
def _setup_logging(self, level: LogLevel) -> logging.Logger: def _setup_logging(self, level: LogLevel) -> logging.Logger:
logger = logging.getLogger("pomice") logger = logging.getLogger("pomice")
logger.setLevel(level)
handler = None
if self._log_handler is not None:
handler = self._log_handler
elif self._log_handler is MISSING:
handler = logging.StreamHandler() handler = logging.StreamHandler()
dt_fmt = "%Y-%m-%d %H:%M:%S" dt_fmt = "%Y-%m-%d %H:%M:%S"
formatter = logging.Formatter( formatter = logging.Formatter(
@ -220,25 +233,37 @@ class Node:
style="{", style="{",
) )
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.setLevel(level)
if handler:
logger.handlers.clear()
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
async def _handle_version_check(self, version: str): async def _handle_version_check(self, version: str) -> None:
if version.endswith("-SNAPSHOT"): if version.endswith("-SNAPSHOT"):
# we're just gonna assume all snapshot versions correlate with v4 # we're just gonna assume all snapshot versions correlate with v4
self._version = LavalinkVersion(major=4, minor=0, fix=0) self._version = LavalinkVersion(major=4, minor=0, fix=0)
return return
# this crazy ass line maps the split version string into _version_rx = VERSION_REGEX.match(version)
# an iterable with ints instead of strings and then if not _version_rx:
# turns that iterable into a tuple. yeah, i know self._available = False
raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. "
"Lavalink version 3.7.0 or above is required to use this library.",
)
split = tuple(map(int, tuple(version.split(".")))) _version_groups = _version_rx.groups()
self._version = LavalinkVersion(*split) major, minor, fix = (
if not version.endswith("-SNAPSHOT") and ( int(_version_groups[0] or 0),
self._version.major == 3 and self._version.minor < 7 int(_version_groups[1] or 0),
): int(_version_groups[2] or 0),
)
self._log.debug(f"Parsed Lavalink version: {major}.{minor}.{fix}")
self._version = LavalinkVersion(major=major, minor=minor, fix=fix)
if self._version < LavalinkVersion(3, 7, 0):
self._available = False self._available = False
raise LavalinkVersionIncompatible( raise LavalinkVersionIncompatible(
"The Lavalink version you're using is incompatible. " "The Lavalink version you're using is incompatible. "
@ -877,6 +902,7 @@ class NodePool:
apple_music: bool = False, apple_music: bool = False,
fallback: bool = False, fallback: bool = False,
log_level: LogLevel = LogLevel.INFO, log_level: LogLevel = LogLevel.INFO,
log_handler: Optional[logging.Handler] = None,
) -> Node: ) -> Node:
"""Creates a Node object to be then added into the node pool. """Creates a Node object to be then added into the node pool.
For Spotify searching capabilites, pass in valid Spotify API credentials. For Spotify searching capabilites, pass in valid Spotify API credentials.
@ -902,6 +928,7 @@ class NodePool:
apple_music=apple_music, apple_music=apple_music,
fallback=fallback, fallback=fallback,
log_level=log_level, log_level=log_level,
log_handler=log_handler,
) )
await node.connect() await node.connect()