Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
fix💥: Make naff always use the correct logger (v2) (#647)
Browse files Browse the repository at this point in the history
* fix💥: make naff always use the correct logger

* fix: default logger missing parentheses

* feat: add access to the logger to more classes

* refactor: add direct access to the logger to more classes and rename `logger()` to `get_logger()`

* fix: wrong name
  • Loading branch information
Kigstn authored Sep 27, 2022
1 parent 36a8609 commit 460be45
Show file tree
Hide file tree
Showing 28 changed files with 199 additions and 162 deletions.
3 changes: 1 addition & 2 deletions naff/api/events/processors/message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import naff.api.events as events

from naff.client.const import logger
from ._template import EventMixinTemplate, Processor
from naff.models import to_snowflake, BaseMessage

Expand Down Expand Up @@ -55,7 +54,7 @@ async def _on_raw_message_delete(self, event: "RawGatewayEvent") -> None:
if not message:
message = BaseMessage.from_dict(event.data, self)
self.cache.delete_message(event.data["channel_id"], event.data["id"])
logger.debug(f"Dispatching Event: {event.resolved_name}")
self.logger.debug(f"Dispatching Event: {event.resolved_name}")
self.dispatch(events.MessageDelete(message))

@Processor.define()
Expand Down
30 changes: 15 additions & 15 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger, MISSING, __api_version__
from naff.client.const import MISSING, __api_version__
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -176,31 +176,31 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
match op:

case OPCODE.HEARTBEAT:
logger.debug("Received heartbeat request from gateway")
self.logger.debug("Received heartbeat request from gateway")
return await self.send_heartbeat()

case OPCODE.HEARTBEAT_ACK:
self.latency.append(time.perf_counter() - self._last_heartbeat)

if self._last_heartbeat != 0 and self.latency[-1] >= 15:
logger.warning(
self.logger.warning(
f"High Latency! shard ID {self.shard[0]} heartbeat took {self.latency[-1]:.1f}s to be acknowledged!"
)
else:
logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds")
self.logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds")

return self._acknowledged.set()

case OPCODE.RECONNECT:
logger.debug("Gateway requested reconnect. Reconnecting...")
self.logger.debug("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
self.logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect()

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
return self.logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")

async def dispatch_event(self, data, seq, event) -> None:
match event:
Expand All @@ -212,12 +212,12 @@ async def dispatch_event(self, data, seq, event) -> None:
self.ws_resume_url = (
f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream"
)
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
self.logger.info(f"Shard {self.shard[0]} has connected to gateway!")
self.logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
return self.state.client.dispatch(events.WebsocketReady(data))

case "RESUMED":
logger.info(f"Successfully resumed connection! Session_ID: {self.session_id}")
self.logger.info(f"Successfully resumed connection! Session_ID: {self.session_id}")
self.state.client.dispatch(events.Resume())
return

Expand All @@ -232,9 +232,9 @@ async def dispatch_event(self, data, seq, event) -> None:
try:
asyncio.create_task(processor(events.RawGatewayEvent(data.copy(), override_name=event_name)))
except Exception as ex:
logger.error(f"Failed to run event processor for {event_name}: {ex}")
self.logger.error(f"Failed to run event processor for {event_name}: {ex}")
else:
logger.debug(f"No processor for `{event_name}`")
self.logger.debug(f"No processor for `{event_name}`")

self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name="raw_gateway_event"))
self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name=f"raw_{event.lower()}"))
Expand Down Expand Up @@ -263,7 +263,7 @@ async def _identify(self) -> None:
serialized = OverriddenJson.dumps(payload)
await self.ws.send_str(serialized)

logger.debug(
self.logger.debug(
f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!"
)

Expand All @@ -285,11 +285,11 @@ async def _resume_connection(self) -> None:
serialized = OverriddenJson.dumps(payload)
await self.ws.send_str(serialized)

logger.debug(f"{self.shard[0]} is attempting to resume a connection")
self.logger.debug(f"{self.shard[0]} is attempting to resume a connection")

async def send_heartbeat(self) -> None:
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")
self.logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
"""Update the bot's presence status."""
Expand Down
21 changes: 13 additions & 8 deletions naff/api/gateway/state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import traceback
from datetime import datetime
from logging import Logger
from typing import TYPE_CHECKING, Optional, Union

import naff
from naff.api import events
from naff.client.const import logger, MISSING, Absent
from naff.client.const import Absent, MISSING, get_logger
from naff.client.errors import NaffException, WebSocketClosed
from naff.client.utils.attr_utils import define, field
from naff.models.discord.activity import Activity
Expand Down Expand Up @@ -43,6 +44,8 @@ class ConnectionState:

_shard_task: asyncio.Task | None = None

logger: Logger = field(init=False, factory=get_logger)

def __attrs_post_init__(self, *args, **kwargs) -> None:
self._shard_ready = asyncio.Event()

Expand All @@ -68,7 +71,7 @@ async def start(self) -> None:
"""Connect to the Discord Gateway."""
self.gateway_url = await self.client.http.get_gateway()

logger.debug(f"Starting Shard ID {self.shard_id}")
self.logger.debug(f"Starting Shard ID {self.shard_id}")
self.start_time = datetime.now()
self._shard_task = asyncio.create_task(self._ws_connect())

Expand All @@ -80,7 +83,7 @@ async def start(self) -> None:

async def stop(self) -> None:
"""Disconnect from the Discord Gateway."""
logger.debug(f"Shutting down shard ID {self.shard_id}")
self.logger.debug(f"Shutting down shard ID {self.shard_id}")
if self.gateway is not None:
self.gateway.close()
self.gateway = None
Expand All @@ -98,7 +101,7 @@ def clear_ready(self) -> None:

async def _ws_connect(self) -> None:
"""Connect to the Discord Gateway."""
logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...")
self.logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...")
try:
async with GatewayClient(self, (self.shard_id, self.client.total_shards)) as self.gateway:
try:
Expand All @@ -123,7 +126,7 @@ async def _ws_connect(self) -> None:

except Exception as e:
self.client.dispatch(events.Disconnect())
logger.error("".join(traceback.format_exception(type(e), e, e.__traceback__)))
self.logger.error("".join(traceback.format_exception(type(e), e, e.__traceback__)))

async def change_presence(
self, status: Optional[Union[str, Status]] = Status.ONLINE, activity: Absent[Union[Activity, str]] = MISSING
Expand All @@ -149,15 +152,17 @@ async def change_presence(

if activity.type == ActivityType.STREAMING:
if not activity.url:
logger.warning("Streaming activity cannot be set without a valid URL attribute")
self.logger.warning("Streaming activity cannot be set without a valid URL attribute")
elif activity.type not in [
ActivityType.GAME,
ActivityType.STREAMING,
ActivityType.LISTENING,
ActivityType.WATCHING,
ActivityType.COMPETING,
]:
logger.warning(f"Activity type `{ActivityType(activity.type).name}` may not be enabled for bots")
self.logger.warning(
f"Activity type `{ActivityType(activity.type).name}` may not be enabled for bots"
)
else:
activity = self.client.activity

Expand All @@ -172,7 +177,7 @@ async def change_presence(
if self.client.status:
status = self.client.status
else:
logger.warning("Status must be set to a valid status type, defaulting to online")
self.logger.warning("Status must be set to a valid status type, defaulting to online")
status = Status.ONLINE

self.client._status = status
Expand Down
16 changes: 8 additions & 8 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aiohttp import WSMsgType
from typing import TypeVar, TYPE_CHECKING

from naff.client.const import logger
from naff.client.errors import WebSocketClosed
from naff.models.naff.cooldowns import CooldownSystem

Expand Down Expand Up @@ -41,6 +40,7 @@ async def rate_limit(self) -> None:
class WebsocketClient:
def __init__(self, state: "ConnectionState") -> None:
self.state = state
self.logger = state.client.logger
self.ws = None
self.ws_url = None

Expand Down Expand Up @@ -134,11 +134,11 @@ async def send(self, data: str, bypass=False) -> None:
bypass: Should the rate limit be ignored for this send (used for heartbeats)
"""
logger.debug(f"Sending data to websocket: {data}")
self.logger.debug(f"Sending data to websocket: {data}")

async with self._race_lock:
if self.ws is None:
return logger.warning("Attempted to send data while websocket is not connected!")
return self.logger.warning("Attempted to send data while websocket is not connected!")
if not bypass:
await self.rl_manager.rate_limit()

Expand Down Expand Up @@ -177,7 +177,7 @@ async def receive(self, force: bool = False) -> str:
resp = await self.ws.receive()

if resp.type == WSMsgType.CLOSE:
logger.debug(f"Disconnecting from gateway! Reason: {resp.data}::{resp.extra}")
self.logger.debug(f"Disconnecting from gateway! Reason: {resp.data}::{resp.extra}")
if resp.data >= 4000:
# This should propagate to __aexit__() which will forcefully shut down everything
# and cleanup correctly.
Expand Down Expand Up @@ -232,7 +232,7 @@ async def receive(self, force: bool = False) -> str:
try:
msg = OverriddenJson.loads(msg)
except Exception as e:
logger.error(e)
self.logger.error(e)
continue

return msg
Expand Down Expand Up @@ -270,7 +270,7 @@ async def run_bee_gees(self) -> None:
await self._start_bee_gees()
except Exception:
self.close()
logger.error("The heartbeater raised an exception!", exc_info=True)
self.logger.error("The heartbeater raised an exception!", exc_info=True)

async def _start_bee_gees(self) -> None:
if self.heartbeat_interval is None:
Expand All @@ -283,10 +283,10 @@ async def _start_bee_gees(self) -> None:
else:
return

logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds")
self.logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds")
while not self._kill_bee_gees.is_set():
if not self._acknowledged.is_set():
logger.warning(
self.logger.warning(
f"Heartbeat has not been acknowledged for {self.heartbeat_interval} seconds,"
" likely zombied connection. Reconnect!"
)
Expand Down
28 changes: 17 additions & 11 deletions naff/api/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This file handles the interaction with discords http endpoints."""
import asyncio
from logging import Logger
from typing import Any, cast
from urllib.parse import quote as _uriquote
from weakref import WeakValueDictionary
Expand All @@ -26,12 +27,13 @@
ScheduledEventsRequests,
)
from naff.client.const import (
MISSING,
__py_version__,
__repo_url__,
__version__,
logger,
__api_version__,
)
import naff.client.const as constants
from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError
from naff.client.mixins.serialization import DictSerializationMixin
from naff.client.utils.input_utils import response_decode, OverriddenJson
Expand Down Expand Up @@ -144,7 +146,7 @@ class HTTPClient(
):
"""A http client for sending requests to the Discord API."""

def __init__(self, connector: BaseConnector | None = None) -> None:
def __init__(self, connector: BaseConnector | None = None, logger: Logger = MISSING) -> None:
self.connector: BaseConnector | None = connector
self.__session: ClientSession | None = None
self.token: str | None = None
Expand All @@ -158,6 +160,10 @@ def __init__(self, connector: BaseConnector | None = None) -> None:
f"DiscordBot ({__repo_url__} {__version__} Python/{__py_version__}) aiohttp/{aiohttp.__version__}"
)

if logger is MISSING:
logger = constants.get_logger()
self.logger = logger

def get_ratelimit(self, route: Route) -> BucketLock:
"""
Get a route's rate limit bucket.
Expand Down Expand Up @@ -191,7 +197,7 @@ def ingest_ratelimit(self, route: Route, header: CIMultiDictProxy, bucket_lock:

if bucket_lock.bucket_hash:
# We only ever try and cache the bucket if the bucket hash has been set (ignores unlimited endpoints)
logger.debug(f"Caching ingested rate limit data for: {bucket_lock.bucket_hash}")
self.logger.debug(f"Caching ingested rate limit data for: {bucket_lock.bucket_hash}")
self._endpoints[route.rl_bucket] = bucket_lock.bucket_hash
self.ratelimit_locks[bucket_lock.bucket_hash] = bucket_lock

Expand Down Expand Up @@ -304,14 +310,14 @@ async def request(
if result.get("global", False):
# global ratelimit is reached
# if we get a global, that's pretty bad, this would usually happen if the user is hitting the api from 2 clients sharing a token
logger.error(
self.logger.error(
f"Bot has exceeded global ratelimit, locking REST API for {result['retry_after']} seconds"
)
await self.global_lock.lock(float(result["retry_after"]))
continue
elif result.get("message") == "The resource is being rate limited.":
# resource ratelimit is reached
logger.warning(
self.logger.warning(
f"{route.endpoint} The resource is being rate limited! "
f"Reset in {result.get('retry_after')} seconds"
)
Expand All @@ -322,21 +328,21 @@ async def request(
# endpoint ratelimit is reached
# 429's are unfortunately unavoidable, but we can attempt to avoid them
# so long as these are infrequent we're doing well
logger.warning(
self.logger.warning(
f"{route.endpoint} Has exceeded it's ratelimit ({lock.limit})! Reset in {lock.delta} seconds"
)
await lock.defer_unlock() # lock this route and wait for unlock
continue
elif lock.remaining == 0:
# Last call available in the bucket, lock until reset
logger.debug(
self.logger.debug(
f"{route.endpoint} Has exhausted its ratelimit ({lock.limit})! Locking route for {lock.delta} seconds"
)
await lock.blind_defer_unlock() # lock this route, but continue processing the current response

elif response.status in {500, 502, 504}:
# Server issues, retry
logger.warning(
self.logger.warning(
f"{route.endpoint} Received {response.status}... retrying in {1 + attempt * 2} seconds"
)
await asyncio.sleep(1 + attempt * 2)
Expand All @@ -345,7 +351,7 @@ async def request(
if not 300 > response.status >= 200:
await self._raise_exception(response, route, result)

logger.debug(
self.logger.debug(
f"{route.endpoint} Received {response.status} :: [{lock.remaining}/{lock.limit} calls remaining]"
)
return result
Expand All @@ -356,7 +362,7 @@ async def request(
raise

async def _raise_exception(self, response, route, result) -> None:
logger.error(f"{route.method}::{route.url}: {response.status}")
self.logger.error(f"{route.method}::{route.url}: {response.status}")

if response.status == 403:
raise Forbidden(response, response_data=result, route=route)
Expand All @@ -368,7 +374,7 @@ async def _raise_exception(self, response, route, result) -> None:
raise HTTPException(response, response_data=result, route=route)

async def request_cdn(self, url, asset) -> bytes: # pyright: ignore [reportGeneralTypeIssues]
logger.debug(f"{asset} requests {url} from CDN")
self.logger.debug(f"{asset} requests {url} from CDN")
async with cast(ClientSession, self.__session).get(url) as response:
if response.status == 200:
return await response.read()
Expand Down
Loading

0 comments on commit 460be45

Please sign in to comment.