Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1131 A class with async "__call__" method fails to work as a middleware #1132

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions slack_bolt/app/async_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
AsyncMessageListenerMatches,
)
from slack_bolt.oauth.async_internals import select_consistent_installation_store
from slack_bolt.util.utils import get_name_for_callable
from slack_bolt.util.utils import get_name_for_callable, is_coroutine_function
from slack_bolt.workflows.step.async_step import (
AsyncWorkflowStep,
AsyncWorkflowStepBuilder,
Expand Down Expand Up @@ -778,7 +778,7 @@ async def custom_error_handler(error, body, logger):
func: The function that is supposed to be executed
when getting an unhandled error in Bolt app.
"""
if not inspect.iscoroutinefunction(func):
if not is_coroutine_function(func):
name = get_name_for_callable(func)
raise BoltError(error_listener_function_must_be_coro_func(name))
self._async_listener_runner.listener_error_handler = AsyncCustomListenerErrorHandler(
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def _register_listener(
value_to_return = functions[0]

for func in functions:
if not inspect.iscoroutinefunction(func):
if not is_coroutine_function(func):
name = get_name_for_callable(func)
raise BoltError(error_listener_function_must_be_coro_func(name))

Expand All @@ -1422,7 +1422,7 @@ def _register_listener(
for m in middleware or []:
if isinstance(m, AsyncMiddleware):
listener_middleware.append(m)
elif isinstance(m, Callable) and inspect.iscoroutinefunction(m):
elif isinstance(m, Callable) and is_coroutine_function(m):
listener_middleware.append(AsyncCustomMiddleware(app_name=self.name, func=m, base_logger=self._base_logger))
else:
raise ValueError(error_unexpected_listener_middleware(type(m)))
Expand Down
5 changes: 2 additions & 3 deletions slack_bolt/middleware/async_custom_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from logging import Logger
from typing import Callable, Awaitable, Any, Sequence, Optional

Expand All @@ -7,7 +6,7 @@
from slack_bolt.request.async_request import AsyncBoltRequest
from slack_bolt.response import BoltResponse
from .async_middleware import AsyncMiddleware
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable, is_coroutine_function


class AsyncCustomMiddleware(AsyncMiddleware):
Expand All @@ -24,7 +23,7 @@ def __init__(
base_logger: Optional[Logger] = None,
):
self.app_name = app_name
if inspect.iscoroutinefunction(func):
if is_coroutine_function(func):
self.func = func
else:
raise ValueError("Async middleware function must be an async function")
Expand Down
6 changes: 6 additions & 0 deletions slack_bolt/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ def get_name_for_callable(func: Callable) -> str:

def get_arg_names_of_callable(func: Callable) -> List[str]:
return inspect.getfullargspec(inspect.unwrap(func)).args


def is_coroutine_function(func: Optional[Any]) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I may rename is_coroutine_function to is_callable_coroutine in a follow up PR

return func is not None and (
inspect.iscoroutinefunction(func) or (hasattr(func, "__call__") and inspect.iscoroutinefunction(func.__call__))
)
20 changes: 20 additions & 0 deletions tests/scenario_tests_async/test_app_using_methods_in_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ async def test_instance_methods(self):
app.shortcut("test-shortcut")(awesome.instance_method)
await self.run_app_and_verify(app)

@pytest.mark.asyncio
async def test_callable_class(self):
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
instance = CallableClass("Slackbot")
app.use(instance)
app.shortcut("test-shortcut")(instance.event_handler)
await self.run_app_and_verify(app)

@pytest.mark.asyncio
async def test_instance_methods_uncommon_name_1(self):
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
Expand Down Expand Up @@ -225,6 +233,18 @@ async def static_method(context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck)
await say(f"Hello <@{context.user_id}>!")


class CallableClass:
def __init__(self, name: str):
self.name = name

async def __call__(self, next: Callable):
await next()

async def event_handler(self, context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck):
await ack()
await say(f"Hello <@{context.user_id}>! My name is {self.name}")


async def top_level_function(invalid_arg, ack, say):
assert invalid_arg is None
await ack()
Expand Down