From 4170f581af6ede100d5a44d09af93d6d929d3fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Fri, 26 Sep 2025 11:58:33 -0700 Subject: [PATCH 1/2] Enable loading checkpoints from automl/PFNs (#3017) Summary: X-link: https://github.com/pytorch/botorch/pull/3017 Enable our impelementation of PFNs as surrogate to load checkpoints from trainings done with the automl/PFNs repository. Reviewed By: Balandat Differential Revision: D82313030 --- .../models/prior_fitted_network.py | 15 ++++++- pyproject.toml | 1 + .../models/test_prior_fitted_network.py | 40 ++++++++++++++++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/botorch_community/models/prior_fitted_network.py b/botorch_community/models/prior_fitted_network.py index e7e705eacd..bd2a11e8e8 100644 --- a/botorch_community/models/prior_fitted_network.py +++ b/botorch_community/models/prior_fitted_network.py @@ -27,6 +27,7 @@ ModelPaths, ) from botorch_community.posteriors.riemann import BoundedRiemannPosterior +from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs from torch import Tensor from torch.nn import Module @@ -44,6 +45,7 @@ def __init__( batch_first: bool = False, constant_model_kwargs: dict[str, Any] | None = None, input_transform: InputTransform | None = None, + load_training_checkpoint: bool = False, ) -> None: """Initialize a PFNModel. @@ -71,6 +73,8 @@ def __init__( constant_model_kwargs: A dictionary of model kwargs that will be passed to the model in each forward pass. input_transform: A Botorch input transform. + load_training_checkpoint: Whether to load a training checkpoint as + produced by the PFNs training code, see github.com/automl/PFNs. """ super().__init__() @@ -79,6 +83,15 @@ def __init__( model_path=checkpoint_url, ) + if load_training_checkpoint: + # the model is not an actual model, but a training checkpoint + # make a model out of it + checkpoint = model + config = MainConfig.from_dict(checkpoint["config"]) + model = config.model.create_model() + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + if train_Yvar is not None: logger.debug("train_Yvar provided but ignored for PFNModel.") @@ -113,7 +126,7 @@ def __init__( self.train_X = train_X # shape: `b x n x d` self.train_Y = train_Y # shape: `b x n` - self.pfn = model + self.pfn = model.to(train_X.device) self.batch_first = batch_first self.constant_model_kwargs = constant_model_kwargs if input_transform is not None: diff --git a/pyproject.toml b/pyproject.toml index fc636e4658..80db3d85bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ "pytest-cov", "requests", "pymoo", + "pfns" ] dev = [ diff --git a/test_community/models/test_prior_fitted_network.py b/test_community/models/test_prior_fitted_network.py index 74df2ed8af..0c32a7d49d 100644 --- a/test_community/models/test_prior_fitted_network.py +++ b/test_community/models/test_prior_fitted_network.py @@ -19,6 +19,8 @@ download_model, ModelPaths, ) +from pfns.model.transformer_config import CrossEntropyConfig, TransformerConfig +from pfns.train import MainConfig, OptimizerConfig from torch import nn, Tensor @@ -162,6 +164,42 @@ def test_input_transform(self): self.assertIsInstance(model.input_transform, Normalize) self.assertEqual(model.input_transform.bounds.shape, torch.Size([2, 3])) + def test_unpack_checkpoint(self): + config = MainConfig( + priors=[], + optimizer=OptimizerConfig( + optimizer="adam", + lr=0.001, + ), + model=TransformerConfig( + criterion=CrossEntropyConfig(num_classes=3), + ), + batch_shape_sampler=None, + ) + + model = config.model.create_model() + + checkpoint = { + "config": config.to_dict(), + "model_state_dict": model.state_dict(), + } + + loaded_model = PFNModel( + train_X=torch.rand(10, 3), + train_Y=torch.rand(10, 1), + input_transform=Normalize(d=3), + model=checkpoint, + load_training_checkpoint=True, + ) + + loaded_state_dict = loaded_model.pfn.state_dict() + self.assertEqual( + sorted(loaded_state_dict.keys()), + sorted(model.state_dict().keys()), + ) + for k in loaded_state_dict.keys(): + self.assertTrue(torch.equal(loaded_state_dict[k], model.state_dict()[k])) + class TestPriorFittedNetworkUtils(BotorchTestCase): @patch("botorch_community.models.utils.prior_fitted_network.requests.get") @@ -215,7 +253,7 @@ def test_download_model_cache_miss( train_X=torch.rand(10, 3), train_Y=torch.rand(10, 1), ) - self.assertEqual(model.pfn, fake_model) + self.assertEqual(model.pfn, fake_model.to("cpu")) @patch("botorch_community.models.utils.prior_fitted_network.torch.load") @patch("botorch_community.models.utils.prior_fitted_network.os.path.exists") From 08bfa919801ba1b6df9f46b12f30cd4d7435643b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Fri, 26 Sep 2025 11:58:33 -0700 Subject: [PATCH 2/2] Enable evaluating PFNs trained with pytorch/PFNs to be evaluated in Ax Summary: This PR enables PFNs generally to work in our MAST evaluation, as the registry didn't quite work before, and it additionally allows to use training checkpoints from pytorch/PFNs to be used to do evaluations. Reviewed By: Balandat Differential Revision: D80944578 --- test_community/models/test_prior_fitted_network.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test_community/models/test_prior_fitted_network.py b/test_community/models/test_prior_fitted_network.py index 0c32a7d49d..74c6229af5 100644 --- a/test_community/models/test_prior_fitted_network.py +++ b/test_community/models/test_prior_fitted_network.py @@ -179,9 +179,10 @@ def test_unpack_checkpoint(self): model = config.model.create_model() + state_dict = model.state_dict() checkpoint = { "config": config.to_dict(), - "model_state_dict": model.state_dict(), + "model_state_dict": state_dict, } loaded_model = PFNModel( @@ -195,10 +196,10 @@ def test_unpack_checkpoint(self): loaded_state_dict = loaded_model.pfn.state_dict() self.assertEqual( sorted(loaded_state_dict.keys()), - sorted(model.state_dict().keys()), + sorted(state_dict.keys()), ) for k in loaded_state_dict.keys(): - self.assertTrue(torch.equal(loaded_state_dict[k], model.state_dict()[k])) + self.assertTrue(torch.equal(loaded_state_dict[k], state_dict[k])) class TestPriorFittedNetworkUtils(BotorchTestCase):