From b50d1f9a1c305f0364a6831b86a79b2a76671f8f Mon Sep 17 00:00:00 2001 From: radka-j Date: Fri, 15 Aug 2025 19:22:41 +0100 Subject: [PATCH 1/2] re-initialise emulator before refitting in HMW --- autoemulate/calibration/history_matching.py | 64 +++++++++++-------- .../simulator/03_history_matching.ipynb | 6 +- tests/test_history_matching.py | 26 ++++---- 3 files changed, 53 insertions(+), 43 deletions(-) diff --git a/autoemulate/calibration/history_matching.py b/autoemulate/calibration/history_matching.py index d6d85126c..f528f82e3 100644 --- a/autoemulate/calibration/history_matching.py +++ b/autoemulate/calibration/history_matching.py @@ -4,9 +4,10 @@ import torch from autoemulate.core.device import TorchDeviceMixin +from autoemulate.core.results import Result from autoemulate.core.types import DeviceLike, TensorLike from autoemulate.data.utils import set_random_seed -from autoemulate.emulators.base import ProbabilisticEmulator +from autoemulate.emulators import TransformedEmulator, get_emulator_class from autoemulate.simulations.base import Simulator logging.basicConfig(level=logging.INFO) @@ -273,7 +274,7 @@ class HistoryMatchingWorkflow(HistoryMatching): def __init__( # noqa: PLR0913 allow too many arguments since all currently required self, simulator: Simulator, - emulator: ProbabilisticEmulator, + result: Result, observations: dict[str, tuple[float, float]] | dict[str, float], threshold: float = 3.0, model_discrepancy: float = 0.0, @@ -290,8 +291,8 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ ---------- simulator: Simulator A simulator. - emulator: ProbabilisticEmulator - A ProbabilisticEmulator pre-trained on `simulator` data. + result: Result + A Result object containing the pre-trained emulator and its parameters. observations: dict[str, tuple[float, float] | dict[str, float] For each output variable, specifies observed [value, noise]. In case of no uncertainty in observations, provides just the observed value. @@ -317,7 +318,8 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ self.simulator = simulator if random_seed is not None: set_random_seed(seed=random_seed) - self.emulator = emulator + self.result = result + self.emulator = result.model self.emulator.device = self.device # These get updated when run() is called and used to refit the emulator @@ -349,8 +351,8 @@ def generate_samples(self, n: int) -> tuple[TensorLike, TensorLike]: test_x = self.simulator.sample_inputs(n).to(self.device) # Rule out implausible parameters from samples using an emulator - pred = self.emulator.predict(test_x) - impl_scores = self.calculate_implausibility(pred.mean, pred.variance) + mean, variance = self.emulator.predict_mean_and_variance(test_x) + impl_scores = self.calculate_implausibility(mean, variance) return test_x, impl_scores @@ -445,6 +447,21 @@ def simulate(self, x: TensorLike) -> tuple[TensorLike, TensorLike]: return x, y + def refit_emulator(self): + """Refit the emulator using all available training data.""" + # Create a fresh model with the same configuration + self.emulator = TransformedEmulator( + self.train_x, + self.train_y, + model=get_emulator_class(self.result.model_name), + x_transforms=self.result.x_transforms, + y_transforms=self.result.y_transforms, + device=self.device, + **self.result.params, + ) + # Fit the fresh model on the new data + self.emulator.fit(self.train_x, self.train_y) + def run( self, n_simulations: int = 100, @@ -529,21 +546,20 @@ def run( _, _ = self.simulate(nroy_simulation_samples) # Refit emulator using all available data - assert self.emulator is not None - self.emulator.fit(self.train_x, self.train_y) + self.refit_emulator() - prediction = self.emulator.predict(self.train_x) + # prediction = self.emulator.predict(self.train_x) - print("mean", ((prediction.mean - (self.train_y)) / self.train_y).mean()) - print("std", ((prediction.mean - self.train_y) / self.train_y).std()) + # print("mean", ((prediction.mean - (self.train_y)) / self.train_y).mean()) + # print("std", ((prediction.mean - self.train_y) / self.train_y).std()) - print("ratio", (prediction.variance / self.train_y).mean()) - print("ratio", (prediction.variance / self.train_y).std()) + # print("ratio", (prediction.variance / self.train_y).mean()) + # print("ratio", (prediction.variance / self.train_y).std()) - print( - "prediction variance mean", (prediction.variance / prediction.mean).mean() - ) - print("prediction variance std", (prediction.variance / prediction.mean).std()) + # print( + # "pred variance mean", (prediction.variance / prediction.mean).mean() + # ) + # print("pred variance std", (prediction.variance / prediction.mean).std()) # Return test parameters and impl scores for this run/wave return torch.cat(test_parameters_list, 0), torch.cat(impl_scores_list, 0) @@ -603,14 +619,12 @@ def run_waves( # noqa: PLR0913 print("Wave ", i, self.simulator.param_bounds) if len(test_x) < n_simulations or len(impl_scores) < n_simulations: - logger.warning( - " Not enough parameters or impl scores generated in wave %d/%d", - i + 1, - n_waves, - "Stopping history matching workflow. Results are stored until wave %d/%d.", - i, - n_waves, + msg = ( + f"Not enough parameters or impl scores generated in wave {i + 1}", + f"/{n_waves}. Stopping history matching workflow. Results are ", + f"stored until wave {i}/{n_waves}.", ) + logger.warning(msg) break wave_results.append((test_x, impl_scores)) diff --git a/docs/tutorials/simulator/03_history_matching.ipynb b/docs/tutorials/simulator/03_history_matching.ipynb index a07a5e42d..31ce0f55b 100644 --- a/docs/tutorials/simulator/03_history_matching.ipynb +++ b/docs/tutorials/simulator/03_history_matching.ipynb @@ -108,7 +108,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = ae.best_result().model" + "best_result = ae.best_result()" ] }, { @@ -144,7 +144,7 @@ "source": [ "hmw = HistoryMatchingWorkflow(\n", " simulator=simulator,\n", - " emulator=model,\n", + " result=best_result,\n", " observations=observations,\n", " threshold=3.0,\n", " train_x=x,\n", @@ -224,7 +224,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "autoemulate-py3.12", "language": "python", "name": "python3" }, diff --git a/tests/test_history_matching.py b/tests/test_history_matching.py index 5d1ce0a0c..a97189e00 100644 --- a/tests/test_history_matching.py +++ b/tests/test_history_matching.py @@ -4,14 +4,10 @@ HistoryMatching, HistoryMatchingWorkflow, ) -from autoemulate.core.device import ( - SUPPORTED_DEVICES, - check_torch_device_is_available, -) +from autoemulate.core.compare import AutoEmulate +from autoemulate.core.device import SUPPORTED_DEVICES, check_torch_device_is_available from autoemulate.core.types import TensorLike -from autoemulate.emulators.gaussian_process.exact import ( - GaussianProcess, -) +from autoemulate.emulators import TransformedEmulator from autoemulate.simulations.epidemic import Epidemic from .test_base_simulator import MockSimulator @@ -109,14 +105,14 @@ def test_run(device): assert isinstance(y, TensorLike) # Run history matching - gp = GaussianProcess(x, y, device=device) - gp.fit(x, y) + ae = AutoEmulate(x, y, models=["GaussianProcess"], model_tuning=False) + res = ae.best_result() observations = {"infection_rate": (0.3, 0.05)} hm = HistoryMatchingWorkflow( simulator=simulator, - emulator=gp, + result=res, observations=observations, threshold=3.0, model_discrepancy=0.1, @@ -129,7 +125,7 @@ def test_run(device): # Check basic structure of results assert isinstance(hm.train_x, TensorLike) - assert isinstance(hm.emulator, GaussianProcess) + assert isinstance(hm.emulator, TransformedEmulator) assert len(hm.train_x) == 5 @@ -148,20 +144,20 @@ def test_run_max_tries(): assert isinstance(y, TensorLike) # Run history matching - gp = GaussianProcess(x, y) - gp.fit(x, y) + ae = AutoEmulate(x, y, models=["GaussianProcess"], model_tuning=False) + res = ae.best_result() # Extreme values outside the range of what the simulator returns observations = {"infection_rate": (100.0, 1.0)} hm = HistoryMatchingWorkflow( simulator=simulator, - emulator=gp, + result=res, observations=observations, threshold=3.0, model_discrepancy=0.1, rank=1, ) - with pytest.raises(RuntimeError): + with pytest.raises(Warning): hm.run(n_simulations=5) From 57a447781598c4a24a44088b4c5be24b04aa6991 Mon Sep 17 00:00:00 2001 From: radka-j Date: Fri, 15 Aug 2025 19:27:18 +0100 Subject: [PATCH 2/2] pass result to HMW object in ModularCirc nb --- case_studies/blood_pressure/ModularCirc.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/case_studies/blood_pressure/ModularCirc.ipynb b/case_studies/blood_pressure/ModularCirc.ipynb index 10d66105b..d097306bd 100644 --- a/case_studies/blood_pressure/ModularCirc.ipynb +++ b/case_studies/blood_pressure/ModularCirc.ipynb @@ -544,7 +544,7 @@ "\n", "hmw = HistoryMatchingWorkflow(\n", " simulator=simulator,\n", - " emulator=gp_matern,\n", + " result=ae_hm.best_result(),\n", " observations=observations,\n", " threshold=3.0,\n", " train_x=x.float(),\n",