diff --git a/pomice/applemusic/client.py b/pomice/applemusic/client.py index b7304f7..e9522d8 100644 --- a/pomice/applemusic/client.py +++ b/pomice/applemusic/client.py @@ -42,10 +42,10 @@ class Client: self.session: aiohttp.ClientSession = None # type: ignore self._log = logging.getLogger(__name__) - async def request_token(self) -> None: - if not self.session: - self.session = aiohttp.ClientSession() + async def _set_session(self, session: aiohttp.ClientSession) -> None: + self.session = session + async def request_token(self) -> None: # First lets get the raw response from the main page resp = await self.session.get("https://music.apple.com") @@ -187,8 +187,3 @@ class Client: next_page_url = None return Playlist(data, album_tracks) - - async def close(self) -> None: - if self.session: - await self.session.close() - self.session = None # type: ignore diff --git a/pomice/pool.py b/pomice/pool.py index 6a0a9d2..6093065 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -274,6 +274,13 @@ class Node: "Lavalink version 3.7.0 or above is required to use this library.", ) + async def _set_ext_client_session(self, session: aiohttp.ClientSession) -> None: + if self._spotify_client: + await self._spotify_client._set_session(session=session) + + if self._apple_music_client: + await self._apple_music_client._set_session(session=session) + async def _update_handler(self, data: dict) -> None: await self._bot.wait_until_ready() @@ -373,37 +380,37 @@ class Node: f'{f"?{query}" if query else ""}' ) - async with self._session.request( + resp = await self._session.request( method=method, url=uri, headers=self._headers, json=data or {}, - ) as resp: - self._log.debug( - f"Making REST request to Node {self._identifier} with method {method} to {uri}", + ) + self._log.debug( + f"Making REST request to Node {self._identifier} with method {method} to {uri}", + ) + if resp.status >= 300: + resp_data: dict = await resp.json() + raise NodeRestException( + f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}', ) - if resp.status >= 300: - resp_data: dict = await resp.json() - raise NodeRestException( - f'Error from Node {self._identifier} fetching from Lavalink REST api: {resp.status} {resp.reason}: {resp_data["message"]}', - ) - - if method == "DELETE" or resp.status == 204: - self._log.debug( - f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.", - ) - return await resp.json(content_type=None) - - if resp.content_type == "text/plain": - self._log.debug( - f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", - ) - return await resp.text() + if method == "DELETE" or resp.status == 204: self._log.debug( - f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned no data.", ) - return await resp.json() + return await resp.json(content_type=None) + + if resp.content_type == "text/plain": + self._log.debug( + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned text with body {await resp.text()}", + ) + return await resp.text() + + self._log.debug( + f"REST request to Node {self._identifier} with method {method} to {uri} completed sucessfully and returned JSON with body {await resp.json()}", + ) + return await resp.json() def get_player(self, guild_id: int) -> Optional[Player]: """Takes a guild ID as a parameter. Returns a pomice Player object or None.""" @@ -428,6 +435,7 @@ class Node: ) await self._handle_version_check(version=version) + await self._set_ext_client_session(session=self._session) self._log.debug( f"Version check from Node {self._identifier} successful. Returned version {version}", @@ -485,14 +493,6 @@ class Node: await self._session.close() self._log.debug("Websocket and http session closed.") - if self._spotify_client: - await self._spotify_client.close() - self._log.debug("Spotify client session closed.") - - if self._apple_music_client: - await self._apple_music_client.close() - self._log.debug("Apple Music client session closed.") - del self._pool._nodes[self._identifier] self.available = False self._task.cancel() diff --git a/pomice/spotify/client.py b/pomice/spotify/client.py index 25e125c..1660118 100644 --- a/pomice/spotify/client.py +++ b/pomice/spotify/client.py @@ -49,20 +49,21 @@ class Client: self._bearer_headers: Optional[Dict] = None self._log = logging.getLogger(__name__) + async def _set_session(self, session: aiohttp.ClientSession) -> None: + self.session = session + async def _fetch_bearer_token(self) -> None: _data = {"grant_type": "client_credentials"} - if not self.session: - self.session = aiohttp.ClientSession() + resp = await self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) - async with self.session.post(GRANT_URL, data=_data, headers=self._grant_headers) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error fetching bearer token: {resp.status} {resp.reason}", - ) + if resp.status != 200: + raise SpotifyRequestException( + f"Error fetching bearer token: {resp.status} {resp.reason}", + ) - data: dict = await resp.json(loads=json.loads) - self._log.debug(f"Fetched Spotify bearer token successfully") + data: dict = await resp.json(loads=json.loads) + self._log.debug(f"Fetched Spotify bearer token successfully") self._bearer_token = data["access_token"] self._expiry = time.time() + (int(data["expires_in"]) - 10) @@ -83,34 +84,34 @@ class Client: request_url = REQUEST_URL.format(type=spotify_type, id=spotify_id) - async with self.session.get(request_url, headers=self._bearer_headers) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error while fetching results: {resp.status} {resp.reason}", - ) - - data: dict = await resp.json(loads=json.loads) - self._log.debug( - f"Made request to Spotify API with status {resp.status} and response {data}", + resp = await self.session.get(request_url, headers=self._bearer_headers) + if resp.status != 200: + raise SpotifyRequestException( + f"Error while fetching results: {resp.status} {resp.reason}", ) + data: dict = await resp.json(loads=json.loads) + self._log.debug( + f"Made request to Spotify API with status {resp.status} and response {data}", + ) + if spotify_type == "track": return Track(data) elif spotify_type == "album": return Album(data) elif spotify_type == "artist": - async with self.session.get( + resp = await self.session.get( f"{request_url}/top-tracks?market=US", headers=self._bearer_headers, - ) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error while fetching results: {resp.status} {resp.reason}", - ) + ) + if resp.status != 200: + raise SpotifyRequestException( + f"Error while fetching results: {resp.status} {resp.reason}", + ) - track_data: dict = await resp.json(loads=json.loads) - tracks = track_data["tracks"] - return Artist(data, tracks) + track_data: dict = await resp.json(loads=json.loads) + tracks = track_data["tracks"] + return Artist(data, tracks) else: tracks = [ Track(track["track"]) @@ -126,13 +127,13 @@ class Client: next_page_url = data["tracks"]["next"] while next_page_url is not None: - async with self.session.get(next_page_url, headers=self._bearer_headers) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error while fetching results: {resp.status} {resp.reason}", - ) + resp = await self.session.get(next_page_url, headers=self._bearer_headers) + if resp.status != 200: + raise SpotifyRequestException( + f"Error while fetching results: {resp.status} {resp.reason}", + ) - next_data: dict = await resp.json(loads=json.loads) + next_data: dict = await resp.json(loads=json.loads) tracks += [ Track(track["track"]) @@ -164,19 +165,13 @@ class Client: id=f"?seed_tracks={spotify_id}", ) - async with self.session.get(request_url, headers=self._bearer_headers) as resp: - if resp.status != 200: - raise SpotifyRequestException( - f"Error while fetching results: {resp.status} {resp.reason}", - ) - - data: dict = await resp.json(loads=json.loads) + resp = await self.session.get(request_url, headers=self._bearer_headers) + if resp.status != 200: + raise SpotifyRequestException( + f"Error while fetching results: {resp.status} {resp.reason}", + ) + data: dict = await resp.json(loads=json.loads) tracks = [Track(track) for track in data["tracks"]] return tracks - - async def close(self) -> None: - if self.session: - await self.session.close() - self.session = None # type: ignore