Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix MBM model selection CV bug with uncertainty-based diagnostics #3318

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class BoTorchGenerator(TorchGenerator, Base):
state dict or the state dict of the current BoTorch ``Model`` during
refitting. If False, model parameters will be reoptimized from
scratch on refit. NOTE: This setting is ignored during
``cross_validate`` if ``refit_on_cv`` is False.
``cross_validate`` if ``refit_on_cv`` is False. This is also used in
Surrogate.model_selection.
"""

acquisition_class: type[Acquisition]
Expand Down Expand Up @@ -192,7 +193,9 @@ def fit(
else self.surrogate_spec
)
self._surrogate = Surrogate(
surrogate_spec=surrogate_spec, refit_on_cv=self.refit_on_cv
surrogate_spec=surrogate_spec,
refit_on_cv=self.refit_on_cv,
warm_start_refit=self.warm_start_refit,
)

# Fit the surrogate.
Expand Down
24 changes: 21 additions & 3 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,10 @@ class string names and the values are dictionaries of input transform
allow_batched_models: Set to true to fit the models in a batch if supported.
Set to false to fit individual models to each metric in a loop.
refit_on_cv: Whether to refit the model on the cross-validation folds.
warm_start_refit: Whether to warm-start refitting from the current state_dict
during cross-validation. If refit_on_cv is True, generally one
would set this to be False, so that no information is leaked between or
across folds.
metric_to_best_model_config: Dictionary mapping a metric name to the best
model config. This is only used by BotorchGenerator.cross_validate and for
logging what model was used.
Expand All @@ -698,6 +702,7 @@ def __init__(
likelihood_options: dict[str, Any] | None = None,
allow_batched_models: bool = True,
refit_on_cv: bool = False,
warm_start_refit: bool = True,
metric_to_best_model_config: dict[str, ModelConfig] | None = None,
) -> None:
warnings_raised = _raise_deprecation_warning(
Expand Down Expand Up @@ -764,6 +769,7 @@ def __init__(
self._outcomes: list[str] | None = None
self._model: Model | None = None
self.refit_on_cv = refit_on_cv
self.warm_start_refit = warm_start_refit

def __repr__(self) -> str:
return f"<{self.__class__.__name__}" f" surrogate_spec={self.surrogate_spec}>"
Expand Down Expand Up @@ -855,7 +861,7 @@ def _construct_model(
)
# pyre-ignore [45]
model = botorch_model_class(**formatted_model_inputs)
if state_dict is not None:
if state_dict is not None and (not refit or self.warm_start_refit):
model.load_state_dict(state_dict)
if state_dict is None or refit:
fit_botorch_model(
Expand Down Expand Up @@ -958,7 +964,18 @@ def fit(
model_config = self.metric_to_best_model_config.get(
dataset.outcome_names[0]
)
if len(model_configs) == 1 or (not refit and model_config is not None):
# Model selection is not performed if the best `ModelConfig` has already
# been identified (as specified in `metric_to_best_model_config`).
# The reason for doing this is to support the following flow:
# - Fit model to data and perform model selection, refitting on each fold
# if `refit_on_cv=True`. This will set the best ModelConfig in
# metric_to_best_model_config.
# - Evaluate the choice of model/visualize its performance via
# `Modelbridge.cross_validate``. This also will refit on each fold if
# `refit_on_cv=True`, but we wouldn't want to perform model selection
# on each fold, but rather show the performance of the selecting
# `ModelConfig`` since that is what will be used.
if len(model_configs) == 1 or (model_config is not None):
best_model_config = model_config or model_configs[0]
model = self._construct_model(
dataset=dataset,
Expand Down Expand Up @@ -1172,7 +1189,7 @@ def cross_validate(
return diag_fn(
y_obs=Y.view(-1).cpu().numpy(),
y_pred=pred_Y,
se_pred=pred_Yvar,
se_pred=np.sqrt(pred_Yvar),
)

def _discard_cached_model_and_data_if_search_space_digest_changed(
Expand Down Expand Up @@ -1329,6 +1346,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
return {
"surrogate_spec": self.surrogate_spec,
"refit_on_cv": self.refit_on_cv,
"warm_start_refit": self.warm_start_refit,
"metric_to_best_model_config": self.metric_to_best_model_config,
}

Expand Down
8 changes: 6 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,9 @@ def test_surrogate_model_options_propagation(self) -> None:
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
mock_init.assert_called_with(
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
)

@mock_botorch_optimize
def test_surrogate_options_propagation(self) -> None:
Expand All @@ -797,7 +799,9 @@ def test_surrogate_options_propagation(self) -> None:
search_space_digest=self.mf_search_space_digest,
candidate_metadata=self.candidate_metadata,
)
mock_init.assert_called_with(surrogate_spec=surrogate_spec, refit_on_cv=False)
mock_init.assert_called_with(
surrogate_spec=surrogate_spec, refit_on_cv=False, warm_start_refit=True
)

@mock_botorch_optimize
def test_model_list_choice(self) -> None:
Expand Down
22 changes: 22 additions & 0 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,27 @@ def test_construct_model(self) -> None:
)
mock_fit.assert_not_called()

def test_construct_model_warm_start(self) -> None:
for warm_start_refit in (False, True):
surrogate = Surrogate(warm_start_refit=warm_start_refit)
with patch.object(
SingleTaskGP, "__init__", return_value=None, autospec=True
), patch(f"{SURROGATE_PATH}.fit_botorch_model"), patch.object(
SingleTaskGP, "load_state_dict"
) as mock_load_state_dict:
surrogate._construct_model(
dataset=self.training_data[0],
search_space_digest=self.search_space_digest,
model_config=surrogate.surrogate_spec.model_configs[0],
default_botorch_model_class=SingleTaskGP,
state_dict={}, # pyre-ignore [6]
refit=True,
)
if warm_start_refit:
mock_load_state_dict.assert_called_once()
else:
mock_load_state_dict.assert_not_called()

@mock_botorch_optimize
def test_construct_custom_model(self, use_model_config: bool = False) -> None:
# Test error for unsupported covar_module and likelihood.
Expand Down Expand Up @@ -1322,6 +1343,7 @@ def test_serialize_attributes_as_kwargs(self) -> None:
expected = {
"surrogate_spec": surrogate.surrogate_spec,
"refit_on_cv": surrogate.refit_on_cv,
"warm_start_refit": surrogate.warm_start_refit,
"metric_to_best_model_config": surrogate.metric_to_best_model_config,
}
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)
Expand Down