diff --git a/pomice/pool.py b/pomice/pool.py index 2d4998d..c8644ee 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -import json +import orjson import random import re from typing import Dict, Optional, TYPE_CHECKING @@ -46,7 +46,8 @@ URL_REGEX = re.compile( r"https?://(?:www\.)?.+" ) - +JSON_ENCODER = orjson.dumps +JSON_DECODER = orjson.loads class Node: """The base class for a node. @@ -85,7 +86,7 @@ class Node: self._websocket_uri = f"{'wss' if self._secure else 'ws'}://{self._host}:{self._port}" self._rest_uri = f"{'https' if self._secure else 'http'}://{self._host}:{self._port}" - self._session = session or aiohttp.ClientSession() + self._session = session or aiohttp.ClientSession(json_serialize=JSON_ENCODER) self._websocket: aiohttp.ClientWebSocketResponse = None self._task: asyncio.Task = None @@ -194,7 +195,7 @@ class Node: if not self.is_connected: self._bot.loop.create_task(self.connect()) else: - self._bot.loop.create_task(self._handle_payload(msg.json())) + self._bot.loop.create_task(self._handle_payload(msg.json(loads=JSON_DECODER))) async def _handle_payload(self, data: dict): op = data.get("op", None) @@ -219,7 +220,7 @@ class Node: f"The node '{self._identifier}' is unavailable." ) - await self._websocket.send_str(json.dumps(data)) + await self._websocket.send_str(orjson.dumps(data)) def get_player(self, guild_id: int): """Takes a guild ID as a parameter. Returns a pomice Player object.""" @@ -284,7 +285,7 @@ class Node: f"Failed to build track. Check if the identifier is correct and try again." ) - data: dict = await resp.json() + data: dict = await resp.json(loads=JSON_DECODER) return Track(track_id=identifier, ctx=ctx, info=data) async def get_tracks( @@ -372,7 +373,7 @@ class Node: url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: - data: dict = await response.json() + data: dict = await response.json(loads=JSON_DECODER) track: dict = data["tracks"][0] info: dict = track.get("info") @@ -397,7 +398,7 @@ class Node: url=f"{self._rest_uri}/loadtracks?identifier={quote(query)}", headers={"Authorization": self._password} ) as response: - data = await response.json() + data = await response.json(loads=JSON_DECODER) load_type = data.get("loadType")