Skip to content

Commit 3f828cd

Browse files
sdaultonfacebook-github-bot
authored andcommitted
support specifying warm start refit in per-metric model selection (facebook#3317)
Summary: see title. This enables not warm-starting, which can be desirable for evaluating over-parameterized models via CV. Reviewed By: saitcakmak Differential Revision: D68114572
1 parent 01c6a73 commit 3f828cd

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

Diff for: ax/models/torch/botorch_modular/model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class BoTorchGenerator(TorchGenerator, Base):
7171
state dict or the state dict of the current BoTorch ``Model`` during
7272
refitting. If False, model parameters will be reoptimized from
7373
scratch on refit. NOTE: This setting is ignored during
74-
``cross_validate`` if ``refit_on_cv`` is False.
74+
``cross_validate`` if ``refit_on_cv`` is False. This is also used in Surrogate.model_selection.
7575
"""
7676

7777
acquisition_class: type[Acquisition]
@@ -192,7 +192,9 @@ def fit(
192192
else self.surrogate_spec
193193
)
194194
self._surrogate = Surrogate(
195-
surrogate_spec=surrogate_spec, refit_on_cv=self.refit_on_cv
195+
surrogate_spec=surrogate_spec,
196+
refit_on_cv=self.refit_on_cv,
197+
warm_start_refit=self.warm_start_refit,
196198
)
197199

198200
# Fit the surrogate.

Diff for: ax/models/torch/botorch_modular/surrogate.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,10 @@ class string names and the values are dictionaries of input transform
673673
allow_batched_models: Set to true to fit the models in a batch if supported.
674674
Set to false to fit individual models to each metric in a loop.
675675
refit_on_cv: Whether to refit the model on the cross-validation folds.
676+
warm_start_refit: Whether to warm-start refitting from the current state_dict
677+
during cross-validation. If refit_on_cv is True, generally one
678+
would set this to be False, so that no information is leaked between or
679+
across folds.
676680
metric_to_best_model_config: Dictionary mapping a metric name to the best
677681
model config. This is only used by BotorchGenerator.cross_validate and for
678682
logging what model was used.
@@ -698,6 +702,7 @@ def __init__(
698702
likelihood_options: dict[str, Any] | None = None,
699703
allow_batched_models: bool = True,
700704
refit_on_cv: bool = False,
705+
warm_start_refit: bool = True,
701706
metric_to_best_model_config: dict[str, ModelConfig] | None = None,
702707
) -> None:
703708
warnings_raised = _raise_deprecation_warning(
@@ -764,6 +769,7 @@ def __init__(
764769
self._outcomes: list[str] | None = None
765770
self._model: Model | None = None
766771
self.refit_on_cv = refit_on_cv
772+
self.warm_start_refit = warm_start_refit
767773

768774
def __repr__(self) -> str:
769775
return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>"
@@ -855,7 +861,7 @@ def _construct_model(
855861
)
856862
# pyre-ignore [45]
857863
model = botorch_model_class(**formatted_model_inputs)
858-
if state_dict is not None:
864+
if state_dict is not None and (not refit or self.warm_start_refit):
859865
model.load_state_dict(state_dict)
860866
if state_dict is None or refit:
861867
fit_botorch_model(
@@ -958,7 +964,18 @@ def fit(
958964
model_config = self.metric_to_best_model_config.get(
959965
dataset.outcome_names[0]
960966
)
961-
if len(model_configs) == 1 or (not refit and model_config is not None):
967+
# Model selection is not performed if the best `ModelConfig` has already
968+
# been identified (as specified in `metric_to_best_model_config`).
969+
# The reason for doing this is to support the following flow:
970+
# - Fit model to data and perform model selection, refitting on each fold
971+
# if `refit_on_cv=True`. This will set the best ModelConfig in
972+
# metric_to_best_model_config.
973+
# - Evaluate the choice of model/visualize its performance via
974+
# `Modelbridge.cross_validate``. This also will refit on each fold if
975+
# `refit_on_cv=True`, but we wouldn't want to perform model selection
976+
# on each fold, but rather show the performance of the selecting
977+
# `ModelConfig`` since that is what will be used.
978+
if len(model_configs) == 1 or (model_config is not None):
962979
best_model_config = model_config or model_configs[0]
963980
model = self._construct_model(
964981
dataset=dataset,
@@ -1329,6 +1346,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
13291346
return {
13301347
"surrogate_spec": self.surrogate_spec,
13311348
"refit_on_cv": self.refit_on_cv,
1349+
"warm_start_refit": self.warm_start_refit,
13321350
"metric_to_best_model_config": self.metric_to_best_model_config,
13331351
}
13341352

Diff for: ax/models/torch/tests/test_model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,9 @@ def test_surrogate_model_options_propagation(self) -> None:
785785
search_space_digest=self.mf_search_space_digest,
786786
candidate_metadata=self.candidate_metadata,
787787
)
788-
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
788+
mock_init.assert_called_with(
789+
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
790+
)
789791

790792
@mock_botorch_optimize
791793
def test_surrogate_options_propagation(self) -> None:
@@ -797,7 +799,9 @@ def test_surrogate_options_propagation(self) -> None:
797799
search_space_digest=self.mf_search_space_digest,
798800
candidate_metadata=self.candidate_metadata,
799801
)
800-
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
802+
mock_init.assert_called_with(
803+
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
804+
)
801805

802806
@mock_botorch_optimize
803807
def test_model_list_choice(self) -> None:

Diff for: ax/models/torch/tests/test_surrogate.py

+22
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,27 @@ def test_construct_model(self) -> None:
686686
)
687687
mock_fit.assert_not_called()
688688

689+
def test_construct_model_warm_start(self) -> None:
690+
for warm_start_refit in (False, True):
691+
surrogate = Surrogate(warm_start_refit=warm_start_refit)
692+
with patch.object(
693+
SingleTaskGP, "__init__", return_value=None, autospec=True
694+
), patch(f"{SURROGATE_PATH}.fit_botorch_model"), patch.object(
695+
SingleTaskGP, "load_state_dict"
696+
) as mock_load_state_dict:
697+
surrogate._construct_model(
698+
dataset=self.training_data[0],
699+
search_space_digest=self.search_space_digest,
700+
model_config=surrogate.surrogate_spec.model_configs[0],
701+
default_botorch_model_class=SingleTaskGP,
702+
state_dict={}, # pyre-ignore [6]
703+
refit=True,
704+
)
705+
if warm_start_refit:
706+
mock_load_state_dict.assert_called_once()
707+
else:
708+
mock_load_state_dict.assert_not_called()
709+
689710
@mock_botorch_optimize
690711
def test_construct_custom_model(self, use_model_config: bool = False) -> None:
691712
# Test error for unsupported covar_module and likelihood.
@@ -1322,6 +1343,7 @@ def test_serialize_attributes_as_kwargs(self) -> None:
13221343
expected = {
13231344
"surrogate_spec": surrogate.surrogate_spec,
13241345
"refit_on_cv": surrogate.refit_on_cv,
1346+
"warm_start_refit": surrogate.warm_start_refit,
13251347
"metric_to_best_model_config": surrogate.metric_to_best_model_config,
13261348
}
13271349
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)

0 commit comments

Comments
 (0)