Skip to content

Commit

Permalink
support specifying warm start refit in per-metric model selection (#3317
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #3317

see title. This enables not warm-starting, which can be desirable for evaluating over-parameterized models via CV.

Reviewed By: saitcakmak

Differential Revision: D68114572

fbshipit-source-id: 9ae5ce038af5f4fbf83d702af984abf1118ff7eb
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 7, 2025
1 parent 29ab9e0 commit 4d03467
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
7 changes: 5 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class BoTorchGenerator(TorchGenerator, Base):
state dict or the state dict of the current BoTorch ``Model`` during
refitting. If False, model parameters will be reoptimized from
scratch on refit. NOTE: This setting is ignored during
``cross_validate`` if ``refit_on_cv`` is False.
``cross_validate`` if ``refit_on_cv`` is False. This is also used in
Surrogate.model_selection.
"""

acquisition_class: type[Acquisition]
Expand Down Expand Up @@ -192,7 +193,9 @@ def fit(
else self.surrogate_spec
)
self._surrogate = Surrogate(
surrogate_spec=surrogate_spec, refit_on_cv=self.refit_on_cv
surrogate_spec=surrogate_spec,
refit_on_cv=self.refit_on_cv,
warm_start_refit=self.warm_start_refit,
)

# Fit the surrogate.
Expand Down
22 changes: 20 additions & 2 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,10 @@ class string names and the values are dictionaries of input transform
allow_batched_models: Set to true to fit the models in a batch if supported.
Set to false to fit individual models to each metric in a loop.
refit_on_cv: Whether to refit the model on the cross-validation folds.
warm_start_refit: Whether to warm-start refitting from the current state_dict
during cross-validation. If refit_on_cv is True, generally one
would set this to be False, so that no information is leaked between or
across folds.
metric_to_best_model_config: Dictionary mapping a metric name to the best
model config. This is only used by BotorchGenerator.cross_validate and for
logging what model was used.
Expand All @@ -698,6 +702,7 @@ def __init__(
likelihood_options: dict[str, Any] | None = None,
allow_batched_models: bool = True,
refit_on_cv: bool = False,
warm_start_refit: bool = True,
metric_to_best_model_config: dict[str, ModelConfig] | None = None,
) -> None:
warnings_raised = _raise_deprecation_warning(
Expand Down Expand Up @@ -764,6 +769,7 @@ def __init__(
self._outcomes: list[str] | None = None
self._model: Model | None = None
self.refit_on_cv = refit_on_cv
self.warm_start_refit = warm_start_refit

def __repr__(self) -> str:
return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>"
Expand Down Expand Up @@ -855,7 +861,7 @@ def _construct_model(
)
# pyre-ignore [45]
model = botorch_model_class(**formatted_model_inputs)
if state_dict is not None:
if state_dict is not None and (not refit or self.warm_start_refit):
model.load_state_dict(state_dict)
if state_dict is None or refit:
fit_botorch_model(
Expand Down Expand Up @@ -958,7 +964,18 @@ def fit(
model_config = self.metric_to_best_model_config.get(
dataset.outcome_names[0]
)
if len(model_configs) == 1 or (not refit and model_config is not None):
# Model selection is not performed if the best `ModelConfig` has already
# been identified (as specified in `metric_to_best_model_config`).
# The reason for doing this is to support the following flow:
# - Fit model to data and perform model selection, refitting on each fold
# if `refit_on_cv=True`. This will set the best ModelConfig in
# metric_to_best_model_config.
# - Evaluate the choice of model/visualize its performance via
# `Modelbridge.cross_validate``. This also will refit on each fold if
# `refit_on_cv=True`, but we wouldn't want to perform model selection
# on each fold, but rather show the performance of the selecting
# `ModelConfig`` since that is what will be used.
if len(model_configs) == 1 or (model_config is not None):
best_model_config = model_config or model_configs[0]
model = self._construct_model(
dataset=dataset,
Expand Down Expand Up @@ -1329,6 +1346,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
return {
"surrogate_spec": self.surrogate_spec,
"refit_on_cv": self.refit_on_cv,
"warm_start_refit": self.warm_start_refit,
"metric_to_best_model_config": self.metric_to_best_model_config,
}

Expand Down
8 changes: 6 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,9 @@ def test_surrogate_model_options_propagation(self) -> None:
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
mock_init.assert_called_with(
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
)

@mock_botorch_optimize
def test_surrogate_options_propagation(self) -> None:
Expand All @@ -797,7 +799,9 @@ def test_surrogate_options_propagation(self) -> None:
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
mock_init.assert_called_with(
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
)

@mock_botorch_optimize
def test_model_list_choice(self) -> None:
Expand Down
22 changes: 22 additions & 0 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,27 @@ def test_construct_model(self) -> None:
)
mock_fit.assert_not_called()

def test_construct_model_warm_start(self) -> None:
for warm_start_refit in (False, True):
surrogate = Surrogate(warm_start_refit=warm_start_refit)
with patch.object(
SingleTaskGP, "__init__", return_value=None, autospec=True
), patch(f"{SURROGATE_PATH}.fit_botorch_model"), patch.object(
SingleTaskGP, "load_state_dict"
) as mock_load_state_dict:
surrogate._construct_model(
dataset=self.training_data[0],
search_space_digest=self.search_space_digest,
model_config=surrogate.surrogate_spec.model_configs[0],
default_botorch_model_class=SingleTaskGP,
state_dict={}, # pyre-ignore [6]
refit=True,
)
if warm_start_refit:
mock_load_state_dict.assert_called_once()
else:
mock_load_state_dict.assert_not_called()

@mock_botorch_optimize
def test_construct_custom_model(self, use_model_config: bool = False) -> None:
# Test error for unsupported covar_module and likelihood.
Expand Down Expand Up @@ -1322,6 +1343,7 @@ def test_serialize_attributes_as_kwargs(self) -> None:
expected = {
"surrogate_spec": surrogate.surrogate_spec,
"refit_on_cv": surrogate.refit_on_cv,
"warm_start_refit": surrogate.warm_start_refit,
"metric_to_best_model_config": surrogate.metric_to_best_model_config,
}
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)
Expand Down

0 comments on commit 4d03467

Please sign in to comment.