From 4d0346784101a2e9bc2fabd966f79f5ec0cda1a9 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 7 Feb 2025 12:09:09 -0800 Subject: [PATCH] support specifying warm start refit in per-metric model selection (#3317) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3317 see title. This enables not warm-starting, which can be desirable for evaluating over-parameterized models via CV. Reviewed By: saitcakmak Differential Revision: D68114572 fbshipit-source-id: 9ae5ce038af5f4fbf83d702af984abf1118ff7eb --- ax/models/torch/botorch_modular/model.py | 7 +++++-- ax/models/torch/botorch_modular/surrogate.py | 22 ++++++++++++++++++-- ax/models/torch/tests/test_model.py | 8 +++++-- ax/models/torch/tests/test_surrogate.py | 22 ++++++++++++++++++++ 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 0166ac58888..d7a8eb338b5 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -71,7 +71,8 @@ class BoTorchGenerator(TorchGenerator, Base): state dict or the state dict of the current BoTorch ``Model`` during refitting. If False, model parameters will be reoptimized from scratch on refit. NOTE: This setting is ignored during - ``cross_validate`` if ``refit_on_cv`` is False. + ``cross_validate`` if ``refit_on_cv`` is False. This is also used in + Surrogate.model_selection. """ acquisition_class: type[Acquisition] @@ -192,7 +193,9 @@ def fit( else self.surrogate_spec ) self._surrogate = Surrogate( - surrogate_spec=surrogate_spec, refit_on_cv=self.refit_on_cv + surrogate_spec=surrogate_spec, + refit_on_cv=self.refit_on_cv, + warm_start_refit=self.warm_start_refit, ) # Fit the surrogate. diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 30f4ef30200..ef342bb9d76 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -673,6 +673,10 @@ class string names and the values are dictionaries of input transform allow_batched_models: Set to true to fit the models in a batch if supported. Set to false to fit individual models to each metric in a loop. refit_on_cv: Whether to refit the model on the cross-validation folds. + warm_start_refit: Whether to warm-start refitting from the current state_dict + during cross-validation. If refit_on_cv is True, generally one + would set this to be False, so that no information is leaked between or + across folds. metric_to_best_model_config: Dictionary mapping a metric name to the best model config. This is only used by BotorchGenerator.cross_validate and for logging what model was used. @@ -698,6 +702,7 @@ def __init__( likelihood_options: dict[str, Any] | None = None, allow_batched_models: bool = True, refit_on_cv: bool = False, + warm_start_refit: bool = True, metric_to_best_model_config: dict[str, ModelConfig] | None = None, ) -> None: warnings_raised = _raise_deprecation_warning( @@ -764,6 +769,7 @@ def __init__( self._outcomes: list[str] | None = None self._model: Model | None = None self.refit_on_cv = refit_on_cv + self.warm_start_refit = warm_start_refit def __repr__(self) -> str: return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>" @@ -855,7 +861,7 @@ def _construct_model( ) # pyre-ignore [45] model = botorch_model_class(**formatted_model_inputs) - if state_dict is not None: + if state_dict is not None and (not refit or self.warm_start_refit): model.load_state_dict(state_dict) if state_dict is None or refit: fit_botorch_model( @@ -958,7 +964,18 @@ def fit( model_config = self.metric_to_best_model_config.get( dataset.outcome_names[0] ) - if len(model_configs) == 1 or (not refit and model_config is not None): + # Model selection is not performed if the best `ModelConfig` has already + # been identified (as specified in `metric_to_best_model_config`). + # The reason for doing this is to support the following flow: + # - Fit model to data and perform model selection, refitting on each fold + # if `refit_on_cv=True`. This will set the best ModelConfig in + # metric_to_best_model_config. + # - Evaluate the choice of model/visualize its performance via + # `Modelbridge.cross_validate``. This also will refit on each fold if + # `refit_on_cv=True`, but we wouldn't want to perform model selection + # on each fold, but rather show the performance of the selecting + # `ModelConfig`` since that is what will be used. + if len(model_configs) == 1 or (model_config is not None): best_model_config = model_config or model_configs[0] model = self._construct_model( dataset=dataset, @@ -1329,6 +1346,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]: return { "surrogate_spec": self.surrogate_spec, "refit_on_cv": self.refit_on_cv, + "warm_start_refit": self.warm_start_refit, "metric_to_best_model_config": self.metric_to_best_model_config, } diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index f65cb96ba84..1fb1b576e23 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -785,7 +785,9 @@ def test_surrogate_model_options_propagation(self) -> None: search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) - mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False) + mock_init.assert_called_with( + surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True + ) @mock_botorch_optimize def test_surrogate_options_propagation(self) -> None: @@ -797,7 +799,9 @@ def test_surrogate_options_propagation(self) -> None: search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) - mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False) + mock_init.assert_called_with( + surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True + ) @mock_botorch_optimize def test_model_list_choice(self) -> None: diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index d683c37c11e..85ac82b6855 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -686,6 +686,27 @@ def test_construct_model(self) -> None: ) mock_fit.assert_not_called() + def test_construct_model_warm_start(self) -> None: + for warm_start_refit in (False, True): + surrogate = Surrogate(warm_start_refit=warm_start_refit) + with patch.object( + SingleTaskGP, "__init__", return_value=None, autospec=True + ), patch(f"{SURROGATE_PATH}.fit_botorch_model"), patch.object( + SingleTaskGP, "load_state_dict" + ) as mock_load_state_dict: + surrogate._construct_model( + dataset=self.training_data[0], + search_space_digest=self.search_space_digest, + model_config=surrogate.surrogate_spec.model_configs[0], + default_botorch_model_class=SingleTaskGP, + state_dict={}, # pyre-ignore [6] + refit=True, + ) + if warm_start_refit: + mock_load_state_dict.assert_called_once() + else: + mock_load_state_dict.assert_not_called() + @mock_botorch_optimize def test_construct_custom_model(self, use_model_config: bool = False) -> None: # Test error for unsupported covar_module and likelihood. @@ -1322,6 +1343,7 @@ def test_serialize_attributes_as_kwargs(self) -> None: expected = { "surrogate_spec": surrogate.surrogate_spec, "refit_on_cv": surrogate.refit_on_cv, + "warm_start_refit": surrogate.warm_start_refit, "metric_to_best_model_config": surrogate.metric_to_best_model_config, } self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)