Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

feat: Convert client.on_error into a proper event #573

Merged
merged 3 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions naff/api/events/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def on_guild_join(event):

"""
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

from naff.client.const import MISSING
from naff.models.discord.snowflake import to_snowflake
Expand All @@ -33,6 +33,7 @@ def on_guild_join(event):
"Component",
"Connect",
"Disconnect",
"Error",
"ShardConnect",
"ShardDisconnect",
"GuildEvent",
Expand All @@ -47,7 +48,7 @@ def on_guild_join(event):

if TYPE_CHECKING:
from naff import Client
from naff.models.naff.context import ComponentContext
from naff.models.naff.context import ComponentContext, Context
from naff.models.discord.snowflake import Snowflake_Type
from naff.models.discord.guild import Guild

Expand Down Expand Up @@ -161,3 +162,14 @@ class Button(Component):
@define(kw_only=False)
class Select(Component):
"""Dispatched when a user uses a Select."""


@define(kw_only=False)
class Error(BaseEvent):
"""Dispatched when the library encounters an error."""

source: str = field(metadata=docs("The source of the error"))
error: Exception = field(metadata=docs("The error that was encountered"))
args: tuple[Any] = field(factory=tuple)
kwargs: dict[str, Any] = field(factory=dict)
ctx: Optional["Context"] = field(default=None, metadata=docs("The Context, if one was active"))
2 changes: 1 addition & 1 deletion naff/client/auto_shard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
await self.on_error("async-extension-loader", e)
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
Expand Down
34 changes: 23 additions & 11 deletions naff/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,11 @@ async def _async_wrap(_coro: Listener, _event: BaseEvent, *_args, **_kwargs) ->
except asyncio.CancelledError:
pass
except Exception as e:
await self.on_error(event, e)
if isinstance(event, events.Error):
# No infinite loops please
self.default_error_handler(repr(event), e)
else:
self.dispatch(events.Error(repr(event), e))

wrapped = _async_wrap(coro, event, *args, **kwargs)

Expand Down Expand Up @@ -490,6 +494,10 @@ def default_error_handler(source: str, error: BaseException) -> None:
"Ignoring exception in {}:{}{}".format(source, "\n" if len(out) > 1 else " ", "".join(out)),
)

@Listener.create()
async def _on_error(self, event: events.Error) -> None:
self.on_error(event.source, event.error, *event.args, **event.kwargs)

async def on_error(self, source: str, error: Exception, *args, **kwargs) -> None:
"""
Catches all errors dispatched by the library.
Expand All @@ -510,7 +518,7 @@ async def on_command_error(self, ctx: Context, error: Exception, *args, **kwargs
Override this to change error handling behavior

"""
await self.on_error(f"cmd /`{ctx.invoke_target}`", error, *args, **kwargs)
self.dispatch(events.Error(f"cmd /`{ctx.invoke_target}`", error, args, kwargs, ctx))
try:
if isinstance(error, errors.CommandOnCooldown):
await ctx.send(
Expand All @@ -537,7 +545,8 @@ async def on_command_error(self, ctx: Context, error: Exception, *args, **kwargs
)
elif self.send_command_tracebacks:
out = "".join(traceback.format_exception(error))
out = out.replace(self.http.token, "[REDACTED TOKEN]")
if self.http.token is not None:
out = out.replace(self.http.token, "[REDACTED TOKEN]")
await ctx.send(
embeds=Embed(
title=f"Error: {type(error).__name__}",
Expand Down Expand Up @@ -575,7 +584,7 @@ async def on_component_error(self, ctx: ComponentContext, error: Exception, *arg
Override this to change error handling behavior

"""
return await self.on_error(f"Component Callback for {ctx.custom_id}", error, *args, **kwargs)
return self.dispatch(events.Error(f"Component Callback for {ctx.custom_id}", error, args, kwargs, ctx))

async def on_component(self, ctx: ComponentContext) -> None:
"""
Expand All @@ -599,11 +608,14 @@ async def on_autocomplete_error(self, ctx: AutocompleteContext, error: Exception
Override this to change error handling behavior

"""
return await self.on_error(
f"Autocomplete Callback for /{ctx.invoke_target} - Option: {ctx.focussed_option}",
error,
*args,
**kwargs,
return self.dispatch(
events.Error(
f"Autocomplete Callback for /{ctx.invoke_target} - Option: {ctx.focussed_option}",
error,
args,
kwargs,
ctx,
)
)

async def on_autocomplete(self, ctx: AutocompleteContext) -> None:
Expand Down Expand Up @@ -667,7 +679,7 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
await self.on_error("async-extension-loader", e)
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
Expand Down Expand Up @@ -1081,7 +1093,7 @@ async def _init_interactions(self) -> None:
else:
await self._cache_interactions(warn_missing=False)
except Exception as e:
await self.on_error("Interaction Syncing", e)
self.dispatch(events.Error("Interaction Syncing", e))

async def _cache_interactions(self, warn_missing: bool = False) -> None:
"""Get all interactions used by this bot and cache them."""
Expand Down