From 3c68bd35a3b755fe8e3a9b8b869b79dfa992eea4 Mon Sep 17 00:00:00 2001 From: Eric Onofrey Date: Thu, 16 Jan 2025 16:08:45 -0800 Subject: [PATCH] Delete `checked_cast` and replace `checked_cast_(list|dict|to_tuple|optional) (#3230) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3230 Make the below replacements: `checked_cast_list` -> `assert_is_instance_list` `checked_cast_dict` -> `assert_is_instance_dict` `checked_cast_to_tuple` -> `assert_is_instance_of_tuple` `checked_cast_optional` -> `assert_is_instance_optional` `_argparse_type_encoder` untouched Reviewed By: danielcohenlive Differential Revision: D67993468 fbshipit-source-id: b5956a6fc9a81a6516d24a762c6e1257c3cb53f4 --- ax/modelbridge/generation_strategy.py | 4 +- ax/modelbridge/modelbridge_utils.py | 15 ++-- .../transforms/power_transform_y.py | 4 +- ax/models/random/base.py | 4 +- ax/models/torch/botorch_modular/surrogate.py | 9 +- ax/plot/scatter.py | 4 +- ax/service/tests/test_instantiation_utils.py | 5 +- ax/service/utils/instantiation.py | 21 +++-- ax/utils/common/tests/test_typeutils.py | 57 ++++++------ ax/utils/common/typeutils.py | 87 +++++++++---------- tutorials/external_generation_node.ipynb | 4 +- tutorials/multi_task.ipynb | 6 +- tutorials/sebo.ipynb | 6 +- 13 files changed, 114 insertions(+), 112 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 96716891cdc..067ac1ffadc 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -34,7 +34,7 @@ from ax.modelbridge.model_spec import FactoryFunctionModelSpec from ax.modelbridge.transition_criterion import TrialBasedCriterion from ax.utils.common.logger import _round_floats_for_logging, get_logger -from ax.utils.common.typeutils import checked_cast_list +from ax.utils.common.typeutils import assert_is_instance_list from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -626,7 +626,7 @@ def clone_reset(self) -> GenerationStrategy: return GenerationStrategy(name=self.name, nodes=cloned_nodes) return GenerationStrategy( - name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes) + name=self.name, steps=assert_is_instance_list(cloned_nodes, GenerationStep) ) def _unset_non_persistent_state_fields(self) -> None: diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index 4bc3d0191da..a3715c077ce 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -56,7 +56,10 @@ pareto_frontier_evaluator, ) from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple +from ax.utils.common.typeutils import ( + assert_is_instance_of_tuple, + assert_is_instance_optional, +) from botorch.acquisition.multi_objective.multi_output_risk_measures import ( IndependentCVaR, IndependentVaR, @@ -218,7 +221,9 @@ def extract_search_space_digest( if isinstance(p, ChoiceParameter): if p.is_task: task_features.append(i) - target_values[i] = checked_cast_to_tuple((int, float), p.target_value) + target_values[i] = assert_is_instance_of_tuple( + p.target_value, (int, float) + ) elif p.is_ordered: ordinal_features.append(i) else: @@ -243,7 +248,7 @@ def extract_search_space_digest( raise ValueError(f"Unknown parameter type {type(p)}") if p.is_fidelity: fidelity_features.append(i) - target_values[i] = checked_cast_to_tuple((int, float), p.target_value) + target_values[i] = assert_is_instance_of_tuple(p.target_value, (int, float)) return SearchSpaceDigest( feature_names=param_names, @@ -1054,8 +1059,8 @@ def _get_multiobjective_optimization_config( objective_thresholds: TRefPoint | None = None, ) -> MultiObjectiveOptimizationConfig: # Optimization_config - mooc = optimization_config or checked_cast_optional( - MultiObjectiveOptimizationConfig, modelbridge._optimization_config + mooc = optimization_config or assert_is_instance_optional( + modelbridge._optimization_config, MultiObjectiveOptimizationConfig ) if not mooc: raise ValueError( diff --git a/ax/modelbridge/transforms/power_transform_y.py b/ax/modelbridge/transforms/power_transform_y.py index a8e79496b3f..3049092a48e 100644 --- a/ax/modelbridge/transforms/power_transform_y.py +++ b/ax/modelbridge/transforms/power_transform_y.py @@ -22,7 +22,7 @@ from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated from ax.models.types import TConfig from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast_list +from ax.utils.common.typeutils import assert_is_instance_list from pyre_extensions import assert_is_instance from sklearn.preprocessing import PowerTransformer @@ -216,5 +216,5 @@ def _compute_inverse_bounds( bounds[1] = (-1.0 / lambda_ - mu) / sigma elif lambda_ > 2.0 + tol: bounds[0] = (1.0 / (2.0 - lambda_) - mu) / sigma - inv_bounds[k] = tuple(checked_cast_list(float, bounds)) + inv_bounds[k] = tuple(assert_is_instance_list(bounds, float)) return inv_bounds diff --git a/ax/models/random/base.py b/ax/models/random/base.py index 042b55b023b..06c5bc53be0 100644 --- a/ax/models/random/base.py +++ b/ax/models/random/base.py @@ -24,7 +24,7 @@ from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast_to_tuple +from ax.utils.common.typeutils import assert_is_instance_of_tuple from botorch.utils.sampling import HitAndRunPolytopeSampler from pyre_extensions import assert_is_instance from torch import Tensor @@ -129,7 +129,7 @@ def gen( if model_gen_options: max_draws = model_gen_options.get("max_rs_draws") if max_draws is not None: - max_draws = int(checked_cast_to_tuple((int, float), max_draws)) + max_draws = int(assert_is_instance_of_tuple(max_draws, (int, float))) try: # Always rejection sample, but this only rejects if there are # constraints or actual duplicates and deduplicate is specified. diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index c2d10345697..3f35bd93913 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -51,7 +51,10 @@ from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import _argparse_type_encoder, checked_cast_optional +from ax.utils.common.typeutils import ( + _argparse_type_encoder, + assert_is_instance_optional, +) from ax.utils.stats.model_fit_stats import ( DIAGNOSTIC_FN_DIRECTIONS, DIAGNOSTIC_FNS, @@ -1277,7 +1280,9 @@ def best_out_of_sample_point( options = options or {} acqf_class, acqf_options = pick_best_out_of_sample_point_acqf_class( outcome_constraints=torch_opt_config.outcome_constraints, - seed_inner=checked_cast_optional(int, options.get(Keys.SEED_INNER, None)), + seed_inner=assert_is_instance_optional( + options.get(Keys.SEED_INNER, None), int + ), qmc=assert_is_instance( options.get(Keys.QMC, True), bool, diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index 9b5d65c7e6a..f9b9db4c44c 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -44,7 +44,7 @@ TNullableGeneratorRunsDict, ) from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast_optional +from ax.utils.common.typeutils import assert_is_instance_optional from ax.utils.stats.statstools import relativize from plotly import subplots @@ -419,7 +419,7 @@ def plot_multiple_metrics( layout_offset_x = 0.15 else: layout_offset_x = 0 - rel = checked_cast_optional(bool, kwargs.get("rel")) + rel = assert_is_instance_optional(kwargs.get("rel"), bool) if rel is not None: warnings.warn( "Use `rel_x` and `rel_y` instead of `rel`.", diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index 137e77a4c1a..d16f9bf27c7 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -385,7 +385,10 @@ def test_choice_with_is_sorted(self) -> None: else: self.assertEqual(output.sort_values, sort_values) - with self.assertRaisesRegex(ValueError, "Value was not of type "): + with self.assertRaisesRegex( + TypeError, + r"obj is not an instance of cls: obj=\['Foo'\] cls=", + ): representation: dict[str, Any] = { "name": "foo_or_bar", "type": "choice", diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 7f1bd6d07d9..99fb39b142a 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -47,7 +47,10 @@ from ax.exceptions.core import UnsupportedError from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger -from ax.utils.common.typeutils import checked_cast_optional, checked_cast_to_tuple +from ax.utils.common.typeutils import ( + assert_is_instance_of_tuple, + assert_is_instance_optional, +) from pyre_extensions import assert_is_instance, none_throws DEFAULT_OBJECTIVE_NAME = "objective" @@ -227,8 +230,8 @@ def _make_range_param( parameter_type=cls._to_parameter_type( bounds, parameter_type, name, "bounds" ), - lower=checked_cast_to_tuple((float, int), bounds[0]), - upper=checked_cast_to_tuple((float, int), bounds[1]), + lower=assert_is_instance_of_tuple(bounds[0], (float, int)), + upper=assert_is_instance_of_tuple(bounds[1], (float, int)), log_scale=assert_is_instance(representation.get("log_scale", False), bool), digits=representation.get("digits", None), # pyre-ignore[6] is_fidelity=assert_is_instance( @@ -258,17 +261,19 @@ def _make_choice_param( values, parameter_type, name, "values" ), values=values, - is_ordered=checked_cast_optional(bool, representation.get("is_ordered")), + is_ordered=assert_is_instance_optional( + representation.get("is_ordered"), bool + ), is_fidelity=assert_is_instance( representation.get("is_fidelity", False), bool ), is_task=assert_is_instance(representation.get("is_task", False), bool), target_value=representation.get("target_value", None), # pyre-ignore[6] - sort_values=checked_cast_optional( - bool, representation.get("sort_values", None) + sort_values=assert_is_instance_optional( + representation.get("sort_values", None), bool ), - dependents=checked_cast_optional( - dict, representation.get("dependents", None) + dependents=assert_is_instance_optional( + representation.get("dependents", None), dict ), ) diff --git a/ax/utils/common/tests/test_typeutils.py b/ax/utils/common/tests/test_typeutils.py index 506c34f5e08..a219f1a7ae9 100644 --- a/ax/utils/common/tests/test_typeutils.py +++ b/ax/utils/common/tests/test_typeutils.py @@ -10,43 +10,36 @@ import numpy as np from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import ( - checked_cast, - checked_cast_dict, - checked_cast_list, - checked_cast_optional, + assert_is_instance_dict, + assert_is_instance_list, + assert_is_instance_optional, ) from ax.utils.common.typeutils_nonnative import numpy_type_to_python_type +from pyre_extensions import assert_is_instance class TestTypeUtils(TestCase): - def test_checked_cast(self) -> None: - self.assertEqual(checked_cast(float, 2.0), 2.0) - with self.assertRaises(ValueError): - checked_cast(float, 2) - - def test_checked_cast_with_error_override(self) -> None: - self.assertEqual(checked_cast(float, 2.0), 2.0) - with self.assertRaises(NotImplementedError): - checked_cast( - float, 2, exception=NotImplementedError("foo() doesn't support ints") - ) - - def test_checked_cast_list(self) -> None: - self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0]) - with self.assertRaises(ValueError): - checked_cast_list(float, [1.0, 2]) - - def test_checked_cast_optional(self) -> None: - self.assertEqual(checked_cast_optional(float, None), None) - with self.assertRaises(ValueError): - checked_cast_optional(float, 2) - - def test_checked_cast_dict(self) -> None: - self.assertEqual(checked_cast_dict(str, int, {"some": 1}), {"some": 1}) - with self.assertRaises(ValueError): - checked_cast_dict(str, int, {"some": 1.0}) - with self.assertRaises(ValueError): - checked_cast_dict(str, int, {1: 1}) + def test_assert_is_instance(self) -> None: + self.assertEqual(assert_is_instance(2.0, float), 2.0) + with self.assertRaises(TypeError): + assert_is_instance(2, float) + + def test_assert_is_instance_list(self) -> None: + self.assertEqual(assert_is_instance_list([1.0, 2.0], float), [1.0, 2.0]) + with self.assertRaises(TypeError): + assert_is_instance_list([1.0, 2], float) + + def test_assert_is_instance_optional(self) -> None: + self.assertEqual(assert_is_instance_optional(None, float), None) + with self.assertRaises(TypeError): + assert_is_instance_optional(2, float) + + def test_assert_is_instance_dict(self) -> None: + self.assertEqual(assert_is_instance_dict({"some": 1}, str, int), {"some": 1}) + with self.assertRaises(TypeError): + assert_is_instance_dict({"some": 1.0}, str, int) + with self.assertRaises(TypeError): + assert_is_instance_dict({1: 1}, str, int) def test_numpy_type_to_python_type(self) -> None: self.assertEqual(type(numpy_type_to_python_type(np.int64(2))), int) diff --git a/ax/utils/common/typeutils.py b/ax/utils/common/typeutils.py index 11f36818326..5c1f45a70bb 100644 --- a/ax/utils/common/typeutils.py +++ b/ax/utils/common/typeutils.py @@ -7,6 +7,7 @@ from typing import Any, TypeVar +from pyre_extensions import assert_is_instance T = TypeVar("T") V = TypeVar("V") @@ -15,79 +16,69 @@ Y = TypeVar("Y") -def checked_cast(typ: type[T], val: V, exception: Exception | None = None) -> T: +def assert_is_instance_optional(val: V | None, typ: type[T]) -> T | None: """ - Cast a value to a type (with a runtime safety check). - - Returns the value unchanged and checks its type at runtime. This signals to the - typechecker that the value has the designated type. - - Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument, - but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is - not of the expected type. The type passed as an argument should be a python class. + Asserts that the value is an instance of the given type if it is not None. Args: - typ: the type to cast to - val: the value that we are casting - exception: override exception to raise if typecheck fails + val: the value to check + typ: the type to check against Returns: - the ``val`` argument, unchanged - - .. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast + the `val` argument, unchanged """ - if not isinstance(val, typ): - raise ( - exception - if exception is not None - else ValueError(f"Value was not of type {typ}:\n{val}") - ) - return val - - -def checked_cast_optional(typ: type[T], val: V | None) -> T | None: - """Calls checked_cast only if value is not None.""" if val is None: return val - return checked_cast(typ, val) + return assert_is_instance(val, typ) -def checked_cast_list(typ: type[T], old_l: list[V]) -> list[T]: - """Calls checked_cast on all items in a list.""" - new_l = [] - for val in old_l: - val = checked_cast(typ, val) - new_l.append(val) - return new_l +def assert_is_instance_list(old_l: list[V], typ: type[T]) -> list[T]: + """ + Asserts that all items in a list are instances of the given type. + Args: + old_l: the list to check + typ: the type to check against + Returns: + the `old_l` argument, unchanged + """ + return [assert_is_instance(val, typ) for val in old_l] -def checked_cast_dict( - key_typ: type[K], value_typ: type[V], d: dict[X, Y] + +def assert_is_instance_dict( + d: dict[X, Y], key_type: type[K], val_type: type[V] ) -> dict[K, V]: - """Calls checked_cast on all keys and values in the dictionary.""" + """ + Asserts that all keys and values in the dictionary are instances + of the given classes. + + Args: + d: the dictionary to check + key_type: the type to check against for keys + val_type: the type to check against for values + Returns: + the `d` argument, unchanged + """ new_dict = {} for key, val in d.items(): - val = checked_cast(value_typ, val) - key = checked_cast(key_typ, key) + key = assert_is_instance(key, key_type) + val = assert_is_instance(val, val_type) new_dict[key] = val return new_dict # pyre-fixme[34]: `T` isn't present in the function's parameters. -def checked_cast_to_tuple(typ: tuple[type[V], ...], val: V) -> T: +def assert_is_instance_of_tuple(val: V, typ: tuple[type[V], ...]) -> T: """ - Cast a value to a union of multiple types (with a runtime safety check). - This function is similar to `checked_cast`, but allows for the type to be - defined as a tuple of types, in which case the value is cast as a union of - the types in the tuple. + Asserts that a value is an instance of any type in a tuple of types. Args: - typ: the tuple of types to cast to - val: the value that we are casting + typ: the tuple of types to check against + val: the value that we are checking Returns: - the ``val`` argument, unchanged + the `val` argument, unchanged """ if not isinstance(val, typ): - raise ValueError(f"Value was not of type {type!r}:\n{val!r}") + raise TypeError(f"Value was not of any type {typ!r}:\n{val!r}") # pyre-fixme[7]: Expected `T` but got `V`. return val diff --git a/tutorials/external_generation_node.ipynb b/tutorials/external_generation_node.ipynb index c1aac0b259d..5817bfa5bf8 100644 --- a/tutorials/external_generation_node.ipynb +++ b/tutorials/external_generation_node.ipynb @@ -56,9 +56,9 @@ "from ax.plot.trace import plot_objective_value_vs_trial_index\n", "from ax.service.ax_client import AxClient, ObjectiveProperties\n", "from ax.service.utils.report_utils import exp_to_df\n", - "from ax.utils.common.typeutils import checked_cast\n", "from ax.utils.measurement.synthetic_functions import hartmann6\n", "from sklearn.ensemble import RandomForestRegressor\n", + "from pyre_extensions import assert_is_instance\n", "\n", "\n", "class RandomForestGenerationNode(ExternalGenerationNode):\n", @@ -285,7 +285,7 @@ "\n", "def evaluate(parameterization: TParameterization) -> Dict[str, Tuple[float, float]]:\n", " x = np.array([parameterization.get(f\"x{i+1}\") for i in range(6)])\n", - " return {\"hartmann6\": (checked_cast(float, hartmann6(x)), 0.0)}" + " return {\"hartmann6\": (assert_is_instance(hartmann6(x), float), 0.0)}" ] }, { diff --git a/tutorials/multi_task.ipynb b/tutorials/multi_task.ipynb index 053b7aa4c4a..4e1559430f6 100644 --- a/tutorials/multi_task.ipynb +++ b/tutorials/multi_task.ipynb @@ -66,8 +66,8 @@ "from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice\n", "from ax.plot.diagnostic import interact_batch_comparison\n", "from ax.runners.synthetic import SyntheticRunner\n", - "from ax.utils.common.typeutils import checked_cast\n", "from ax.utils.notebook.plotting import init_notebook_plotting, render\n", + "from pyre_extensions import assert_is_instance\n", "\n", "init_notebook_plotting()\n", "\n", @@ -423,8 +423,7 @@ " )\n", "\n", " \n", - " return checked_cast(\n", - " TorchModelBridge,\n", + " return assert_is_instance(\n", " Models.ST_MTGP(\n", " experiment=experiment,\n", " search_space=search_space or experiment.search_space,\n", @@ -435,6 +434,7 @@ " torch_device=device,\n", " status_quo_features=status_quo_features,\n", " ),\n", + " TorchModelBridge,\n", " )" ] }, diff --git a/tutorials/sebo.ipynb b/tutorials/sebo.ipynb index 059b165b504..730ead6c532 100644 --- a/tutorials/sebo.ipynb +++ b/tutorials/sebo.ipynb @@ -59,9 +59,9 @@ "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", "from ax.runners.synthetic import SyntheticRunner\n", "from ax.service.ax_client import AxClient, ObjectiveProperties\n", - "from ax.utils.common.typeutils import checked_cast\n", "from botorch.acquisition.multi_objective import qNoisyExpectedHypervolumeImprovement\n", - "from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP" + "from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP\n", + "from pyre_extensions import assert_is_instance" ] }, { @@ -158,7 +158,7 @@ "source": [ "class AugBraninMetric(NoisyFunctionMetric):\n", " def f(self, x: np.ndarray) -> float:\n", - " return checked_cast(float, branin_augment(x_vec=x, augment_dim=aug_dim))\n", + " return assert_is_instance(branin_augment(x_vec=x, augment_dim=aug_dim), float)\n", "\n", "\n", "# Create search space in Ax \n",