@@ -673,6 +673,10 @@ class string names and the values are dictionaries of input transform
673
673
allow_batched_models: Set to true to fit the models in a batch if supported.
674
674
Set to false to fit individual models to each metric in a loop.
675
675
refit_on_cv: Whether to refit the model on the cross-validation folds.
676
+ warm_start_refit: Whether to warm-start refitting from the current state_dict
677
+ during cross-validation. If refit_on_cv is True, generally one
678
+ would set this to be False, so that no information is leaked between or
679
+ across folds.
676
680
metric_to_best_model_config: Dictionary mapping a metric name to the best
677
681
model config. This is only used by BotorchGenerator.cross_validate and for
678
682
logging what model was used.
@@ -698,6 +702,7 @@ def __init__(
698
702
likelihood_options : dict [str , Any ] | None = None ,
699
703
allow_batched_models : bool = True ,
700
704
refit_on_cv : bool = False ,
705
+ warm_start_refit : bool = True ,
701
706
metric_to_best_model_config : dict [str , ModelConfig ] | None = None ,
702
707
) -> None :
703
708
warnings_raised = _raise_deprecation_warning (
@@ -764,6 +769,7 @@ def __init__(
764
769
self ._outcomes : list [str ] | None = None
765
770
self ._model : Model | None = None
766
771
self .refit_on_cv = refit_on_cv
772
+ self .warm_start_refit = warm_start_refit
767
773
768
774
def __repr__ (self ) -> str :
769
775
return f"<{ self .__class__ .__name__ } " f" surrogate_spec={ self .surrogate_spec } >"
@@ -855,7 +861,7 @@ def _construct_model(
855
861
)
856
862
# pyre-ignore [45]
857
863
model = botorch_model_class (** formatted_model_inputs )
858
- if state_dict is not None :
864
+ if state_dict is not None and ( not refit or self . warm_start_refit ) :
859
865
model .load_state_dict (state_dict )
860
866
if state_dict is None or refit :
861
867
fit_botorch_model (
@@ -958,7 +964,18 @@ def fit(
958
964
model_config = self .metric_to_best_model_config .get (
959
965
dataset .outcome_names [0 ]
960
966
)
961
- if len (model_configs ) == 1 or (not refit and model_config is not None ):
967
+ # Model selection is not performed if the best `ModelConfig` has already
968
+ # been identified (as specified in `metric_to_best_model_config`).
969
+ # The reason for doing this is to support the following flow:
970
+ # - Fit model to data and perform model selection, refitting on each fold
971
+ # if `refit_on_cv=True`. This will set the best ModelConfig in
972
+ # metric_to_best_model_config.
973
+ # - Evaluate the choice of model/visualize its performance via
974
+ # `Modelbridge.cross_validate``. This also will refit on each fold if
975
+ # `refit_on_cv=True`, but we wouldn't want to perform model selection
976
+ # on each fold, but rather show the performance of the selecting
977
+ # `ModelConfig`` since that is what will be used.
978
+ if len (model_configs ) == 1 or (model_config is not None ):
962
979
best_model_config = model_config or model_configs [0 ]
963
980
model = self ._construct_model (
964
981
dataset = dataset ,
@@ -1329,6 +1346,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:
1329
1346
return {
1330
1347
"surrogate_spec" : self .surrogate_spec ,
1331
1348
"refit_on_cv" : self .refit_on_cv ,
1349
+ "warm_start_refit" : self .warm_start_refit ,
1332
1350
"metric_to_best_model_config" : self .metric_to_best_model_config ,
1333
1351
}
1334
1352
0 commit comments