From 6b0ad553f3b6bd6ff4c4d0137b28b09d22a8e91a Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 9 Jul 2024 16:20:19 -0700 Subject: [PATCH] Bug fix: Make MBM acquisition pass the correct args to each BoTorch optimizer; cleanup (#2571) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2571 Motivation: MBM can dispatch to four different BoTorch optimizers depending on the search space. Currently, `optimize_acqf_discrete_local_search` is failing because it is passed the inappropriate argument `sequential`. This diff * Makes the logic more clear: First, Acquisition.optimize determines which of the four optimizers is appropriate. then it constructs arguments based on that optimizer. Then it constructs arguments for the optimizers using `optimizer_argparse`. Then it does any optimizer_specific logic and calls the optimizer. * Fixes optimizer_argparse so that inappropriate arguments such as `sequential` are not passed to optimizers they don't apply to. * Removes special-casing for qNEHVI and qMES in optimizer_argparse that wasn't actually doing anything * Extends unit tests for `optimizer_argparse` to check all optimizers * Reduces the usage and scope of mocks in test_acquisition so that `optimize_acqf` and its variants are actually run as much as possible. Differential Revision: D59354709 --- .../torch/botorch_modular/acquisition.py | 67 ++-- .../botorch_modular/optimizer_argparse.py | 59 +++- ax/models/torch/tests/test_acquisition.py | 292 ++++++++++++------ .../torch/tests/test_optimizer_argparse.py | 118 +++++-- 4 files changed, 372 insertions(+), 164 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index b97903ef43f..a190ac91b9b 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -425,6 +425,35 @@ def optimize( _tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device) ssd = search_space_digest bounds = _tensorize(ssd.bounds).t() + discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features) + discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features) + if ( + optimizer_options is not None + and "force_use_optimize_acqf" in optimizer_options + ): + force_use_optimize_acqf = optimizer_options.pop("force_use_optimize_acqf") + else: + force_use_optimize_acqf = False + + if (len(discrete_features) == 0) or force_use_optimize_acqf: + optimizer = "optimize_acqf" + else: + fully_discrete = len(discrete_choices) == len(ssd.feature_names) + if fully_discrete: + total_discrete_choices = reduce( + operator.mul, [float(len(c)) for c in discrete_choices.values()] + ) + if total_discrete_choices > MAX_CHOICES_ENUMERATE: + optimizer = "optimize_acqf_discrete_local_search" + else: + optimizer = "optimize_acqf_discrete" + # `raw_samples` is not supported by `optimize_acqf_discrete`. + # TODO[santorella]: Rather than manually removing it, we should + # ensure that it is never passed. + if optimizer_options is not None: + optimizer_options.pop("raw_samples") + else: + optimizer = "optimize_acqf_mixed" # Prepare arguments for optimizer optimizer_options_with_defaults = optimizer_argparse( @@ -432,12 +461,12 @@ def optimize( bounds=bounds, q=n, optimizer_options=optimizer_options, + optimizer=optimizer, ) post_processing_func = get_post_processing_func( rounding_func=rounding_func, optimizer_options=optimizer_options_with_defaults, ) - discrete_features = sorted(ssd.ordinal_features + ssd.categorical_features) if fixed_features is not None: for i in fixed_features: if not 0 <= i < len(ssd.feature_names): @@ -446,10 +475,7 @@ def optimize( # customized in subclasses if necessary. arm_weights = torch.ones(n, dtype=self.dtype) # 1. Handle the fully continuous search space. - if ( - optimizer_options_with_defaults.pop("force_use_optimize_acqf", False) - or not discrete_features - ): + if optimizer == "optimize_acqf": candidates, acqf_values = optimize_acqf( acq_function=self.acqf, bounds=bounds, @@ -462,10 +488,11 @@ def optimize( return candidates, acqf_values, arm_weights # 2. Handle search spaces with discrete features. - discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features) - # 2a. Handle the fully discrete search space. - if len(discrete_choices) == len(ssd.feature_names): + if optimizer in ( + "optimize_acqf_discrete", + "optimize_acqf_discrete_local_search", + ): X_observed = self.X_observed if self.X_pending is not None: if X_observed is None: @@ -474,10 +501,7 @@ def optimize( X_observed = torch.cat([X_observed, self.X_pending], dim=0) # Special handling for search spaces with a large number of choices - total_choices = reduce( - operator.mul, [float(len(c)) for c in discrete_choices.values()] - ) - if total_choices > MAX_CHOICES_ENUMERATE: + if optimizer == "optimize_acqf_discrete_local_search": discrete_choices = [ torch.tensor(c, device=self.device, dtype=self.dtype) for c in discrete_choices.values() @@ -492,6 +516,7 @@ def optimize( ) return candidates, acqf_values, arm_weights + # Else, optimizer is `optimize_acqf_discrete` # Enumerate all possible choices all_choices = (discrete_choices[i] for i in range(len(discrete_choices))) all_choices = _tensorize(tuple(product(*all_choices))) @@ -530,26 +555,16 @@ def optimize( ) n = num_choices - # `raw_samples` is not supported by `optimize_acqf_discrete`. - # TODO[santorella]: Rather than manually removing it, we should - # ensure that it is never passed. - if optimizer_options is not None: - optimizer_options.pop("raw_samples") - discrete_opt_options = optimizer_argparse( - self.acqf, - bounds=bounds, - q=n, - optimizer_options=optimizer_options, - optimizer_is_discrete=True, - ) candidates, acqf_values = optimize_acqf_discrete( - acq_function=self.acqf, q=n, choices=all_choices, **discrete_opt_options + acq_function=self.acqf, + q=n, + choices=all_choices, + **optimizer_options_with_defaults, ) return candidates, acqf_values, arm_weights # 2b. Handle mixed search spaces that have discrete and continuous features. # Only sequential optimization is supported for `optimize_acqf_mixed`. - optimizer_options_with_defaults.pop("sequential") candidates, acqf_values = optimize_acqf_mixed( acq_function=self.acqf, bounds=bounds, diff --git a/ax/models/torch/botorch_modular/optimizer_argparse.py b/ax/models/torch/botorch_modular/optimizer_argparse.py index ccc41b062ad..d39a83a4fdf 100644 --- a/ax/models/torch/botorch_modular/optimizer_argparse.py +++ b/ax/models/torch/botorch_modular/optimizer_argparse.py @@ -42,16 +42,23 @@ def _argparse_base( init_batch_limit: int = INIT_BATCH_LIMIT, batch_limit: int = BATCH_LIMIT, optimizer_options: Optional[Dict[str, Any]] = None, - optimizer_is_discrete: bool = False, + *, + optimizer: str = "optimize_acqf", **ignore: Any, ) -> Dict[str, Any]: - """Extract the base optimizer kwargs form the given arguments. + """Extract the kwargs to be passed to a BoTorch optimizer. NOTE: Since `optimizer_options` is how the user would typically pass in these options, it takes precedence over other arguments. E.g., if both `num_restarts` and `optimizer_options["num_restarts"]` are provided, this will use `num_restarts` from `optimizer_options`. + NOTE: Arguments in `**ignore` are ignored; in addition, any arguments + specified in `optimizer_options` that are supported for some optimizers and + not others can be silently ignored. For example, `sequential` will not be + passed when the `optimizer` is `optimize_acqf_discrete`, since + `optimize_acqf_discrete` does not support it. + Args: acqf: The acquisition function being optimized. sequential: Whether we choose one candidate at a time in a sequential manner. @@ -75,23 +82,46 @@ def _argparse_base( >>> }, >>> "retry_on_optimization_warning": False, >>> } - optimizer_is_discrete: True if the optimizer is `optimizer_acqf_discrete`, - which supports a limited set of arguments. + optimizer: one of "optimize_acqf", + "optimize_acqf_discrete_local_search", "optimize_acqf_discrete", or + "optimize_acqf_mixed". This is generally chosen by + `Acquisition.optimize`. """ - optimizer_options = optimizer_options or {} - if optimizer_is_discrete: - return optimizer_options - return { - "sequential": sequential, + supported_optimizers = [ + "optimize_acqf", + "optimize_acqf_discrete_local_search", + "optimize_acqf_discrete", + "optimize_acqf_mixed", + ] + if optimizer not in supported_optimizers: + raise ValueError( + f"optimizer=`{optimizer}` is not supported. Accepted options are " + f"{supported_optimizers}" + ) + provided_options = optimizer_options if optimizer_options is not None else {} + + # construct arguments from options that are not `provided_options` + options = { "num_restarts": num_restarts, "raw_samples": raw_samples, - "options": { + } + # if not, 'options' will be silently ignored + if optimizer in ["optimize_acqf", "optimize_acqf_mixed"]: + options["options"] = { "init_batch_limit": init_batch_limit, "batch_limit": batch_limit, - **optimizer_options.get("options", {}), - }, - **{k: v for k, v in optimizer_options.items() if k != "options"}, - } + **provided_options.get("options", {}), + } + + if optimizer == "optimize_acqf": + options["sequential"] = sequential + + # optimize_acqf_discrete only accepts 'choices', 'max_batch_size', 'unique' + if optimizer == "optimize_acqf_discrete": + return provided_options + + options.update(**{k: v for k, v in provided_options.items() if k != "options"}) + return options @optimizer_argparse.register(qKnowledgeGradient) @@ -117,6 +147,7 @@ def _argparse_kg( num_restarts=num_restarts, raw_samples=raw_samples, optimizer_options=optimizer_options, + optimizer="optimize_acqf", **kwargs, ) diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 3fed96edb87..49e70aec0f1 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -21,6 +21,7 @@ from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import AxWarning, SearchSpaceExhausted from ax.models.torch.botorch_modular.acquisition import Acquisition +from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch.utils import ( _get_X_pending_and_observed, @@ -30,7 +31,10 @@ from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase -from ax.utils.testing.mock import fast_botorch_optimize +from ax.utils.testing.mock import ( + fast_botorch_optimize, + fast_botorch_optimize_context_manager, +) from ax.utils.testing.utils import generic_equals from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.input_constructors import ( @@ -45,6 +49,11 @@ ) from botorch.acquisition.objective import LinearMCObjective from botorch.models.gp_regression import SingleTaskGP +from botorch.optim.optimize import ( + optimize_acqf, + optimize_acqf_discrete, + optimize_acqf_mixed, +) from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import MockPosterior @@ -66,18 +75,14 @@ def __init__(self, **kwargs: Any) -> None: # pyre-fixme[6]: For 1st param expected `Model` but got `None`. AcquisitionFunction.__init__(self, model=None) - def __call__(self, X: Tensor, **kwargs: Any) -> Tensor: - return X.sum(dim=-1) - - # pyre-fixme[14]: `set_X_pending` overrides method defined in - # `AcquisitionFunction` inconsistently. - def set_X_pending(self, X: Tensor, **kwargs: Any) -> None: - self.X_pending = X - # pyre-fixme[15]: `forward` overrides method defined in `AcquisitionFunction` # inconsistently. def forward(self, X: torch.Tensor) -> None: - pass + # take the norm and sum over the q-batch dim + if len(X.shape) > 2: + return torch.linalg.norm(X, dim=-1).sum(-1) + else: + return torch.linalg.norm(X, dim=-1).squeeze(-1) class DummyOneShotAcquisitionFunction(DummyAcquisitionFunction, qKnowledgeGradient): @@ -86,7 +91,6 @@ def evaluate(self, X: Tensor, **kwargs: Any) -> Tensor: class AcquisitionTest(TestCase): - @fast_botorch_optimize def setUp(self) -> None: super().setUp() qNEI_input_constructor = get_acqf_input_constructor(qNoisyExpectedImprovement) @@ -124,14 +128,15 @@ def setUp(self) -> None: bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)], target_values={2: 1.0}, ) - self.surrogate.fit( - datasets=self.training_data, - search_space_digest=SearchSpaceDigest( - feature_names=self.search_space_digest.feature_names[:1], - bounds=self.search_space_digest.bounds, - target_values=self.search_space_digest.target_values, - ), - ) + with fast_botorch_optimize_context_manager(): + self.surrogate.fit( + datasets=self.training_data, + search_space_digest=SearchSpaceDigest( + feature_names=self.search_space_digest.feature_names[:1], + bounds=self.search_space_digest.bounds, + target_values=self.search_space_digest.target_values, + ), + ) self.botorch_acqf_class = DummyAcquisitionFunction self.objective_weights = torch.tensor([1.0]) @@ -148,7 +153,11 @@ def setUp(self) -> None: self.fixed_features = {1: 2.0} self.options = {"cache_root": False, "prune_baseline": False} self.inequality_constraints = [ - (torch.tensor([0, 1], **tkwargs), torch.tensor([-1.0, 1.0], **tkwargs), 1) + ( + torch.tensor([0, 1], dtype=torch.int), + torch.tensor([-1.0, 1.0], **tkwargs), + 1, + ) ] self.rounding_func = lambda x: x self.optimizer_options = {Keys.NUM_RESTARTS: 20, Keys.RAW_SAMPLES: 1024} @@ -245,13 +254,8 @@ def test_init( f"{CURRENT_PATH}.Acquisition.compute_model_dependencies", return_value={"eta": 0.1}, ) - @mock.patch( - f"{DummyAcquisitionFunction.__module__}.DummyAcquisitionFunction.__init__", - return_value=None, - ) def test_init_with_subset_model_false( self, - mock_botorch_acqf_class: Mock, mock_compute_model_deps: Mock, mock_get_objective_and_transform: Mock, mock_subset_model: Mock, @@ -272,50 +276,59 @@ def test_init_with_subset_model_false( botorch_acqf_class=self.botorch_acqf_class, options=self.options, ) - mock_subset_model.assert_not_called() - # Check `get_botorch_objective_and_transform` kwargs - mock_get_objective_and_transform.assert_called_once() - _, ckwargs = mock_get_objective_and_transform.call_args - self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) - self.assertIs(ckwargs["objective_weights"], self.objective_weights) - self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints) - self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1])) - # Check final `acqf` creation - model_deps = {"eta": 0.1} - self.mock_input_constructor.assert_called_once() - mock_botorch_acqf_class.assert_called_once() - _, ckwargs = self.mock_input_constructor.call_args - self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) - self.assertIs(ckwargs["objective"], botorch_objective) - self.assertTrue( - torch.equal(ckwargs["X_pending"], self.pending_observations[0]) - ) - for k, v in chain(self.options.items(), model_deps.items()): - self.assertEqual(ckwargs[k], v) - self.assertIs( - ckwargs["constraints"], - self.constraints, - ) - mock_get_outcome_constraint_transforms.assert_called_once_with( - outcome_constraints=self.outcome_constraints - ) + mock_subset_model.assert_not_called() + # Check `get_botorch_objective_and_transform` kwargs + mock_get_objective_and_transform.assert_called_once() + _, ckwargs = mock_get_objective_and_transform.call_args + self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) + self.assertIs(ckwargs["objective_weights"], self.objective_weights) + self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints) + self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1])) + # Check final `acqf` creation + model_deps = {"eta": 0.1} + self.mock_input_constructor.assert_called_once() + _, ckwargs = self.mock_input_constructor.call_args + self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) + self.assertIs(ckwargs["objective"], botorch_objective) + self.assertTrue(torch.equal(ckwargs["X_pending"], self.pending_observations[0])) + for k, v in chain(self.options.items(), model_deps.items()): + self.assertEqual(ckwargs[k], v) + self.assertIs( + ckwargs["constraints"], + self.constraints, + ) + mock_get_outcome_constraint_transforms.assert_called_once_with( + outcome_constraints=self.outcome_constraints + ) - @mock.patch(f"{ACQUISITION_PATH}.optimize_acqf", return_value=(Mock(), Mock())) + @fast_botorch_optimize + @mock.patch(f"{ACQUISITION_PATH}.optimize_acqf", wraps=optimize_acqf) def test_optimize(self, mock_optimize_acqf: Mock) -> None: acquisition = self.get_acquisition_function(fixed_features=self.fixed_features) - acquisition.optimize( - n=3, - search_space_digest=self.search_space_digest, - inequality_constraints=self.inequality_constraints, - fixed_features=self.fixed_features, - rounding_func=self.rounding_func, + n = 5 + with mock.patch( + f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse + ) as mock_optimizer_argparse: + acquisition.optimize( + n=n, + search_space_digest=self.search_space_digest, + inequality_constraints=self.inequality_constraints, + fixed_features=self.fixed_features, + rounding_func=self.rounding_func, + optimizer_options=self.optimizer_options, + ) + mock_optimizer_argparse.assert_called_once_with( + acquisition.acqf, + bounds=mock.ANY, + q=n, optimizer_options=self.optimizer_options, + optimizer="optimize_acqf", ) mock_optimize_acqf.assert_called_with( acq_function=acquisition.acqf, sequential=True, bounds=mock.ANY, - q=3, + q=n, options={"init_batch_limit": 32, "batch_limit": 5}, inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, @@ -340,7 +353,7 @@ def test_optimize(self, mock_optimize_acqf: Mock) -> None: optimizer_options["force_use_optimize_acqf"] = True mock_optimize_acqf.reset_mock() acquisition.optimize( - n=3, + n=n, search_space_digest=self.search_space_digest, inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, @@ -351,7 +364,7 @@ def test_optimize(self, mock_optimize_acqf: Mock) -> None: acq_function=acquisition.acqf, sequential=True, bounds=mock.ANY, - q=3, + q=n, options={"init_batch_limit": 32, "batch_limit": 5}, inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, @@ -361,8 +374,9 @@ def test_optimize(self, mock_optimize_acqf: Mock) -> None: # Now using both rounding func and post processing func. mock_optimize_acqf.reset_mock() rounding_func = Mock(side_effect=lambda x: x // 4) + post_processing_func.reset_mock() acquisition.optimize( - n=3, + n=n, search_space_digest=self.search_space_digest, inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, @@ -373,15 +387,13 @@ def test_optimize(self, mock_optimize_acqf: Mock) -> None: # Call it with a known input to check that the functions are called in # the correct order. self.assertEqual(actual_func(3.0), 2) # (3 ** 2) // 4 = 2 - post_processing_func.assert_called_once_with(3.0) - rounding_func.assert_called_once_with(9.0) + post_processing_func.assert_called_with(3.0) + rounding_func.assert_called_with(9.0) def test_optimize_discrete(self) -> None: ssd1 = SearchSpaceDigest( feature_names=["a", "b", "c"], - # pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float, int], - # Union[float, int]]]` but got `List[Tuple[int, int, int]]`. - bounds=[(1, 2, 3), (2, 3, 4)], + bounds=[(1, 2), (2, 3), (3, 4)], categorical_features=[0, 1, 2], discrete_choices={0: [1, 2], 1: [2, 3], 2: [3, 4]}, ) @@ -422,11 +434,46 @@ def test_optimize_discrete(self) -> None: # 2 candidates have acqf value 8, but [1, 3, 4] is pending and thus should # not be selected. [2, 3, 4] is the best point, but has already been picked acquisition = self.get_acquisition_function() - X_selected, _, weights = acquisition.optimize( - n=2, - search_space_digest=ssd1, - rounding_func=self.rounding_func, + n = 2 + with mock.patch( + f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse + ) as mock_optimizer_argparse, mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_discrete", wraps=optimize_acqf_discrete + ) as mock_optimize_acqf_discrete: + X_selected, _, weights = acquisition.optimize( + n=n, + search_space_digest=ssd1, + rounding_func=self.rounding_func, + ) + mock_optimizer_argparse.assert_called_once_with( + acquisition.acqf, + bounds=mock.ANY, + q=n, + optimizer_options=None, + optimizer="optimize_acqf_discrete", + ) + + mock_optimize_acqf_discrete.assert_called_once_with( + acq_function=acquisition.acqf, + q=n, + choices=mock.ANY, + ) + expected_choices = torch.tensor( + [ + elt + for elt in all_possible_choices + # not a pending observation + if not (self.pending_observations[0] == torch.tensor(elt)).all() + # not in training data + and not (self.X == torch.tensor(elt)).all(1).any() + ], + ) + self.assertTrue( + torch.equal( + expected_choices, mock_optimize_acqf_discrete.call_args[1]["choices"] + ) ) + expected = torch.tensor([[2, 2, 4], [2, 3, 3]]).to(self.X) self.assertTrue(X_selected.shape == (2, 3)) self.assertTrue( @@ -438,20 +485,36 @@ def test_optimize_discrete(self) -> None: # [4, 2, 4], [3, 2, 4], [4, 2, 3] ssd2 = SearchSpaceDigest( feature_names=["a", "b", "c"], - # pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float, int], - # Union[float, int]]]` but got `List[Tuple[int, int, int]]`. - bounds=[(0, 0, 0), (4, 4, 4)], + bounds=[(0, 4) for _ in range(3)], categorical_features=[0, 1, 2], # pyre-fixme[6]: For 4th param expected `Dict[int, List[Union[float, # int]]]` but got `Dict[int, List[int]]`. discrete_choices={k: [0, 1, 2, 3, 4] for k in range(3)}, ) - X_selected, _, weights = acquisition.optimize( - n=3, - search_space_digest=ssd2, - fixed_features=self.fixed_features, - rounding_func=self.rounding_func, + with mock.patch( + f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse + ) as mock_optimizer_argparse, mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_discrete", wraps=optimize_acqf_discrete + ) as mock_optimize_acqf_discrete: + X_selected, _, weights = acquisition.optimize( + n=3, + search_space_digest=ssd2, + fixed_features=self.fixed_features, + rounding_func=self.rounding_func, + ) + mock_optimizer_argparse.assert_called_once_with( + acquisition.acqf, + bounds=mock.ANY, + q=3, + optimizer_options=None, + optimizer="optimize_acqf_discrete", + ) + mock_optimize_acqf_discrete.assert_called_once_with( + acq_function=acquisition.acqf, + q=3, + choices=mock.ANY, ) + expected = torch.tensor([[4, 2, 4], [3, 2, 4], [4, 2, 3]]).to(self.X) self.assertTrue(X_selected.shape == (3, 3)) self.assertTrue( @@ -499,6 +562,8 @@ def test_optimize_discrete(self) -> None: all((x.unsqueeze(0) == expected).all(dim=-1).any() for x in X_selected) ) + # mock `optimize_acqf_discrete_local_search` because it isn't handled by + # `fast_botorch_optimize` @mock.patch( f"{ACQUISITION_PATH}.optimize_acqf_discrete_local_search", return_value=(Mock(), Mock()), @@ -517,16 +582,39 @@ def test_optimize_acqf_discrete_local_search( }, ) acquisition = self.get_acquisition_function() - acquisition.optimize( - n=3, - search_space_digest=ssd, - inequality_constraints=self.inequality_constraints, - fixed_features=None, - rounding_func=self.rounding_func, + with mock.patch( + f"{ACQUISITION_PATH}.optimizer_argparse", wraps=optimizer_argparse + ) as mock_optimizer_argparse: + acquisition.optimize( + n=3, + search_space_digest=ssd, + inequality_constraints=self.inequality_constraints, + fixed_features=None, + rounding_func=self.rounding_func, + optimizer_options=self.optimizer_options, + ) + mock_optimizer_argparse.assert_called_once_with( + acquisition.acqf, + bounds=mock.ANY, + q=3, optimizer_options=self.optimizer_options, + optimizer="optimize_acqf_discrete_local_search", ) mock_optimize_acqf_discrete_local_search.assert_called_once() args, kwargs = mock_optimize_acqf_discrete_local_search.call_args + self.assertEqual(len(args), 0) + self.assertSetEqual( + { + "acq_function", + "discrete_choices", + "q", + "num_restarts", + "raw_samples", + "inequality_constraints", + "X_avoid", + }, + set(kwargs.keys()), + ) self.assertEqual(kwargs["acq_function"], acquisition.acqf) self.assertEqual(kwargs["q"], 3) self.assertEqual(kwargs["inequality_constraints"], self.inequality_constraints) @@ -544,10 +632,8 @@ def test_optimize_acqf_discrete_local_search( all((X_avoid_true == x).all(dim=-1).any().item() for x in kwargs["X_avoid"]) ) - @mock.patch( - f"{ACQUISITION_PATH}.optimize_acqf_mixed", return_value=(Mock(), Mock()) - ) - def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None: + @fast_botorch_optimize + def test_optimize_mixed(self) -> None: tkwargs = {"dtype": self.X.dtype, "device": self.X.device} ssd = SearchSpaceDigest( feature_names=["a", "b"], @@ -556,14 +642,17 @@ def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None: discrete_choices={1: [0, 1, 2]}, ) acquisition = self.get_acquisition_function() - acquisition.optimize( - n=3, - search_space_digest=ssd, - inequality_constraints=self.inequality_constraints, - fixed_features=None, - rounding_func=self.rounding_func, - optimizer_options=self.optimizer_options, - ) + with mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_mixed", wraps=optimize_acqf_mixed + ) as mock_optimize_acqf_mixed: + acquisition.optimize( + n=3, + search_space_digest=ssd, + inequality_constraints=self.inequality_constraints, + fixed_features=None, + rounding_func=self.rounding_func, + optimizer_options=self.optimizer_options, + ) mock_optimize_acqf_mixed.assert_called_with( acq_function=acquisition.acqf, bounds=mock.ANY, @@ -582,12 +671,13 @@ def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None: ) ) # Check that we don't use mixed optimizer if force_use_optimize_acqf is True. - mock_optimize_acqf_mixed.reset_mock() optimizer_options = self.optimizer_options.copy() optimizer_options["force_use_optimize_acqf"] = True with mock.patch( - f"{ACQUISITION_PATH}.optimize_acqf", return_value=(Mock(), Mock()) - ) as mock_optimize_acqf: + f"{ACQUISITION_PATH}.optimize_acqf", wraps=optimize_acqf + ) as mock_optimize_acqf, mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_mixed", wraps=optimize_acqf_mixed + ) as mock_optimize_acqf_mixed: acquisition.optimize( n=3, search_space_digest=ssd, diff --git a/ax/models/torch/tests/test_optimizer_argparse.py b/ax/models/torch/tests/test_optimizer_argparse.py index 2602d6e01d1..3161cedcf37 100644 --- a/ax/models/torch/tests/test_optimizer_argparse.py +++ b/ax/models/torch/tests/test_optimizer_argparse.py @@ -9,18 +9,23 @@ from __future__ import annotations from importlib import reload +from itertools import product from unittest.mock import patch from ax.models.torch.botorch_modular import optimizer_argparse as Argparse from ax.models.torch.botorch_modular.optimizer_argparse import ( _argparse_base, + BATCH_LIMIT, INIT_BATCH_LIMIT, MaybeType, + NUM_RESTARTS, optimizer_argparse, + RAW_SAMPLES, ) from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.analytic import LogExpectedImprovement from botorch.acquisition.knowledge_gradient import ( qKnowledgeGradient, qMultiFidelityKnowledgeGradient, @@ -32,17 +37,50 @@ class DummyAcquisitionFunction(AcquisitionFunction): class OptimizerArgparseTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.default_expected_options = { + "optimize_acqf": { + "num_restarts": NUM_RESTARTS, + "raw_samples": RAW_SAMPLES, + "options": { + "init_batch_limit": INIT_BATCH_LIMIT, + "batch_limit": BATCH_LIMIT, + }, + "sequential": True, + }, + "optimize_acqf_discrete_local_search": { + "num_restarts": NUM_RESTARTS, + "raw_samples": RAW_SAMPLES, + }, + "optimize_acqf_discrete": {}, + "optimize_acqf_mixed": { + "num_restarts": NUM_RESTARTS, + "raw_samples": RAW_SAMPLES, + "options": { + "init_batch_limit": INIT_BATCH_LIMIT, + "batch_limit": BATCH_LIMIT, + }, + }, + } + def test_notImplemented(self) -> None: - with self.assertRaises(NotImplementedError) as e: + with self.assertRaisesRegex( + NotImplementedError, "Could not find signature for" + ): optimizer_argparse[type(None)] # passing `None` produces a different error - self.assertTrue("Could not find signature for" in str(e)) + + def test_unsupported_optimizer(self) -> None: + with self.assertRaisesRegex( + ValueError, "optimizer=`wishful thinking` is not supported" + ): + optimizer_argparse(LogExpectedImprovement, optimizer="wishful thinking") def test_register(self) -> None: with patch.dict(optimizer_argparse.funcs, {}): @optimizer_argparse.register(DummyAcquisitionFunction) - # pyre-fixme[3]: Return type must be annotated. - def _argparse(acqf: MaybeType[DummyAcquisitionFunction]): + def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None: pass self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse) @@ -51,33 +89,67 @@ def test_fallback(self) -> None: with patch.dict(optimizer_argparse.funcs, {}): @optimizer_argparse.register(AcquisitionFunction) - # pyre-fixme[3]: Return type must be annotated. - def _argparse(acqf: MaybeType[DummyAcquisitionFunction]): + def _argparse(acqf: MaybeType[DummyAcquisitionFunction]) -> None: pass self.assertEqual(optimizer_argparse[DummyAcquisitionFunction], _argparse) def test_optimizer_options(self) -> None: - # This has a bespoke test - skipped_func = optimizer_argparse[qKnowledgeGradient] + # qKG should have a bespoke test + # currently there is only one function in fns_to_test + fns_to_test = [ + elt + for elt in optimizer_argparse.funcs.values() + if elt is not optimizer_argparse[qKnowledgeGradient] + ] user_options = {"foo": "bar", "num_restarts": 13} - for func in optimizer_argparse.funcs.values(): - if func is skipped_func: - continue - - parsed_options = func(None, optimizer_options=user_options) - for key, val in user_options.items(): - self.assertEqual(val, parsed_options.get(key)) + for func, optimizer in product( + fns_to_test, + [ + "optimize_acqf", + "optimize_acqf_discrete", + "optimize_acqf_mixed", + "optimize_acqf_discrete_local_search", + ], + ): + with self.subTest(func=func, optimizer=optimizer): + parsed_options = func( + None, optimizer_options=user_options, optimizer=optimizer + ) + self.assertDictEqual( + {**self.default_expected_options[optimizer], **user_options}, + parsed_options, + ) # Also test sub-options. - func = _argparse_base - parsed_options = func( - None, optimizer_options={"options": {"batch_limit": 10, "maxiter": 20}} - ) - self.assertEqual( - parsed_options["options"], - {"batch_limit": 10, "init_batch_limit": INIT_BATCH_LIMIT, "maxiter": 20}, - ) + inner_options = {"batch_limit": 10, "maxiter": 20} + options = {"options": inner_options} + for func in fns_to_test: + for optimizer in [ + "optimize_acqf", + "optimize_acqf_mixed", + "optimize_acqf_discrete", + ]: + default = self.default_expected_options[optimizer] + parsed_options = func( + None, optimizer_options=options, optimizer=optimizer + ) + expected_options = {k: v for k, v in default.items() if k != "options"} + if "options" in default: + expected_options["options"] = { + **default["options"], + **inner_options, + } + else: + expected_options["options"] = inner_options + self.assertDictEqual(expected_options, parsed_options) + + parsed_options = func( + None, + optimizer_options={"options": {"batch_limit": 10, "maxiter": 20}}, + optimizer="optimize_acqf_discrete_local_search", + ) + self.assertNotIn("options", parsed_options) def test_kg(self) -> None: with patch(