Skip to content

Commit

Permalink
fix(tracer): mypy generic to preserve decorated method signature (#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
heitorlessa authored Jul 17, 2021
1 parent 89337a2 commit d5c3431
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions aws_lambda_powertools/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import numbers
import os
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload

from ..shared import constants
from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice
Expand All @@ -18,6 +18,9 @@
aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE)
aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE)

AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable)


class Tracer:
"""Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions
Expand Down Expand Up @@ -329,12 +332,26 @@ def decorate(event, context, **kwargs):

return decorate

# see #465
@overload
def capture_method(self, method: "AnyCallableT") -> "AnyCallableT":
...

@overload
def capture_method(
self,
method: Optional[Callable] = None,
method: None = None,
capture_response: Optional[bool] = None,
capture_error: Optional[bool] = None,
):
) -> Callable[["AnyCallableT"], "AnyCallableT"]:
...

def capture_method(
self,
method: Optional[AnyCallableT] = None,
capture_response: Optional[bool] = None,
capture_error: Optional[bool] = None,
) -> AnyCallableT:
"""Decorator to create subsegment for arbitrary functions
It also captures both response and exceptions as metadata
Expand Down Expand Up @@ -487,8 +504,9 @@ async def async_tasks():
# Return a partial function with args filled
if method is None:
logger.debug("Decorator called with parameters")
return functools.partial(
self.capture_method, capture_response=capture_response, capture_error=capture_error
return cast(
AnyCallableT,
functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error),
)

method_name = f"{method.__name__}"
Expand All @@ -509,7 +527,7 @@ async def async_tasks():
return self._decorate_generator_function(
method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
)
elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__):
elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): # type: ignore
return self._decorate_generator_function_with_context_manager(
method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
)
Expand Down Expand Up @@ -602,11 +620,11 @@ def decorate(*args, **kwargs):

def _decorate_sync_function(
self,
method: Callable,
method: AnyCallableT,
capture_response: Optional[Union[bool, str]] = None,
capture_error: Optional[Union[bool, str]] = None,
method_name: Optional[str] = None,
):
) -> AnyCallableT:
@functools.wraps(method)
def decorate(*args, **kwargs):
with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
Expand All @@ -628,7 +646,7 @@ def decorate(*args, **kwargs):

return response

return decorate
return cast(AnyCallableT, decorate)

def _add_response_as_metadata(
self,
Expand Down

0 comments on commit d5c3431

Please sign in to comment.