From f3125bd776185fdc6f54304d731321d53d5585da Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Tue, 5 Dec 2023 14:26:57 -0800 Subject: [PATCH] Use inspect to support typeguard version being overridden (#2044) Summary: Attempt to solve/mitigate https://github.com/facebook/Ax/issues/2043 We stopped pinning typeguard to 2.13.3, but gracefully failing if it doesn't have that version manually changes the version. This is important so ax can install in a standard way without conflicting with user's other dependencies. Reviewed By: mpolson64 Differential Revision: D51853729 --- ax/utils/common/kwargs.py | 4 ++-- ax/utils/common/typeutils.py | 17 ++++++++++++++++- setup.py | 2 +- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ax/utils/common/kwargs.py b/ax/utils/common/kwargs.py index 40280066778..4083e9f4f04 100644 --- a/ax/utils/common/kwargs.py +++ b/ax/utils/common/kwargs.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional from ax.utils.common.logger import get_logger -from typeguard import check_type +from ax.utils.common.typeutils import version_safe_check_type logger: Logger = get_logger(__name__) @@ -82,7 +82,7 @@ def validate_kwarg_typing(typed_callables: List[Callable], **kwargs: Any) -> Non # if the keyword is a callable, we only do shallow checks if not (callable(kw_val) and callable(param.annotation)): try: - check_type(kw, kw_val, param.annotation) + version_safe_check_type(kw, kw_val, param.annotation) except TypeError: message = ( f"`{typed_callable}` expected argument `{kw}` to be of" diff --git a/ax/utils/common/typeutils.py b/ax/utils/common/typeutils.py index a7b1009b50e..a36781eb9bd 100644 --- a/ax/utils/common/typeutils.py +++ b/ax/utils/common/typeutils.py @@ -4,10 +4,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from inspect import signature from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar import numpy as np - +from typeguard import check_type T = TypeVar("T") V = TypeVar("V") @@ -108,6 +109,20 @@ def checked_cast_to_tuple(typ: Tuple[Type[V], ...], val: V) -> T: return val +def version_safe_check_type(argname: str, value: T, expected_type: Type[T]) -> None: + """Excecute the check_type function if it has the expected signature, otherwise + warn. This is done to support newer versions of typeguard with minimal loss + of functionality for users that have dependency conflicts""" + # Get the signature of the check_type function + sig = signature(check_type) + # Get the parameters of the check_type function + params = sig.parameters + # Check if the check_type function has the expected signature + params = set(params.keys()) + if all(arg in params for arg in ["argname", "value", "expected_type"]): + check_type(argname, value, expected_type) + + # pyre-fixme[3]: Return annotation cannot be `Any`. # pyre-fixme[2]: Parameter annotation cannot be `Any`. def numpy_type_to_python_type(value: Any) -> Any: diff --git a/setup.py b/setup.py index c0ae7ec6c4b..f8101501dfb 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "ipywidgets", # Needed for compatibility with ipywidgets >= 8.0.0 "plotly>=5.12.0", - "typeguard==2.13.3", + "typeguard", "pyre-extensions", ]