Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions autoemulate/calibration/history_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch

from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.results import Result
from autoemulate.core.types import DeviceLike, TensorLike
from autoemulate.data.utils import set_random_seed
from autoemulate.emulators.base import ProbabilisticEmulator
from autoemulate.emulators import TransformedEmulator, get_emulator_class
from autoemulate.simulations.base import Simulator

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -273,7 +274,7 @@ class HistoryMatchingWorkflow(HistoryMatching):
def __init__( # noqa: PLR0913 allow too many arguments since all currently required
self,
simulator: Simulator,
emulator: ProbabilisticEmulator,
result: Result,
observations: dict[str, tuple[float, float]] | dict[str, float],
threshold: float = 3.0,
model_discrepancy: float = 0.0,
Expand All @@ -290,8 +291,8 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
----------
simulator: Simulator
A simulator.
emulator: ProbabilisticEmulator
A ProbabilisticEmulator pre-trained on `simulator` data.
result: Result
A Result object containing the pre-trained emulator and its parameters.
observations: dict[str, tuple[float, float] | dict[str, float]
For each output variable, specifies observed [value, noise]. In case
of no uncertainty in observations, provides just the observed value.
Expand All @@ -317,7 +318,8 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
self.simulator = simulator
if random_seed is not None:
set_random_seed(seed=random_seed)
self.emulator = emulator
self.result = result
self.emulator = result.model
self.emulator.device = self.device

# These get updated when run() is called and used to refit the emulator
Expand Down Expand Up @@ -349,8 +351,8 @@ def generate_samples(self, n: int) -> tuple[TensorLike, TensorLike]:
test_x = self.simulator.sample_inputs(n).to(self.device)

# Rule out implausible parameters from samples using an emulator
pred = self.emulator.predict(test_x)
impl_scores = self.calculate_implausibility(pred.mean, pred.variance)
mean, variance = self.emulator.predict_mean_and_variance(test_x)
impl_scores = self.calculate_implausibility(mean, variance)

return test_x, impl_scores

Expand Down Expand Up @@ -445,6 +447,21 @@ def simulate(self, x: TensorLike) -> tuple[TensorLike, TensorLike]:

return x, y

def refit_emulator(self):
"""Refit the emulator using all available training data."""
# Create a fresh model with the same configuration
self.emulator = TransformedEmulator(
self.train_x,
self.train_y,
model=get_emulator_class(self.result.model_name),
x_transforms=self.result.x_transforms,
y_transforms=self.result.y_transforms,
device=self.device,
**self.result.params,
)
# Fit the fresh model on the new data
self.emulator.fit(self.train_x, self.train_y)

def run(
self,
n_simulations: int = 100,
Expand Down Expand Up @@ -529,21 +546,20 @@ def run(
_, _ = self.simulate(nroy_simulation_samples)

# Refit emulator using all available data
assert self.emulator is not None
self.emulator.fit(self.train_x, self.train_y)
self.refit_emulator()

prediction = self.emulator.predict(self.train_x)
# prediction = self.emulator.predict(self.train_x)

print("mean", ((prediction.mean - (self.train_y)) / self.train_y).mean())
print("std", ((prediction.mean - self.train_y) / self.train_y).std())
# print("mean", ((prediction.mean - (self.train_y)) / self.train_y).mean())
# print("std", ((prediction.mean - self.train_y) / self.train_y).std())

print("ratio", (prediction.variance / self.train_y).mean())
print("ratio", (prediction.variance / self.train_y).std())
# print("ratio", (prediction.variance / self.train_y).mean())
# print("ratio", (prediction.variance / self.train_y).std())

print(
"prediction variance mean", (prediction.variance / prediction.mean).mean()
)
print("prediction variance std", (prediction.variance / prediction.mean).std())
# print(
# "pred variance mean", (prediction.variance / prediction.mean).mean()
# )
# print("pred variance std", (prediction.variance / prediction.mean).std())

# Return test parameters and impl scores for this run/wave
return torch.cat(test_parameters_list, 0), torch.cat(impl_scores_list, 0)
Expand Down Expand Up @@ -603,14 +619,12 @@ def run_waves( # noqa: PLR0913
print("Wave ", i, self.simulator.param_bounds)

if len(test_x) < n_simulations or len(impl_scores) < n_simulations:
logger.warning(
" Not enough parameters or impl scores generated in wave %d/%d",
i + 1,
n_waves,
"Stopping history matching workflow. Results are stored until wave %d/%d.",
i,
n_waves,
msg = (
f"Not enough parameters or impl scores generated in wave {i + 1}",
f"/{n_waves}. Stopping history matching workflow. Results are ",
f"stored until wave {i}/{n_waves}.",
)
logger.warning(msg)
break

wave_results.append((test_x, impl_scores))
Expand Down
2 changes: 1 addition & 1 deletion case_studies/blood_pressure/ModularCirc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@
"\n",
"hmw = HistoryMatchingWorkflow(\n",
" simulator=simulator,\n",
" emulator=gp_matern,\n",
" result=ae_hm.best_result(),\n",
" observations=observations,\n",
" threshold=3.0,\n",
" train_x=x.float(),\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/simulator/03_history_matching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = ae.best_result().model"
"best_result = ae.best_result()"
]
},
{
Expand Down Expand Up @@ -144,7 +144,7 @@
"source": [
"hmw = HistoryMatchingWorkflow(\n",
" simulator=simulator,\n",
" emulator=model,\n",
" result=best_result,\n",
" observations=observations,\n",
" threshold=3.0,\n",
" train_x=x,\n",
Expand Down Expand Up @@ -224,7 +224,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "autoemulate-py3.12",
"language": "python",
"name": "python3"
},
Expand Down
26 changes: 11 additions & 15 deletions tests/test_history_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
HistoryMatching,
HistoryMatchingWorkflow,
)
from autoemulate.core.device import (
SUPPORTED_DEVICES,
check_torch_device_is_available,
)
from autoemulate.core.compare import AutoEmulate
from autoemulate.core.device import SUPPORTED_DEVICES, check_torch_device_is_available
from autoemulate.core.types import TensorLike
from autoemulate.emulators.gaussian_process.exact import (
GaussianProcess,
)
from autoemulate.emulators import TransformedEmulator
from autoemulate.simulations.epidemic import Epidemic

from .test_base_simulator import MockSimulator
Expand Down Expand Up @@ -109,14 +105,14 @@ def test_run(device):
assert isinstance(y, TensorLike)

# Run history matching
gp = GaussianProcess(x, y, device=device)
gp.fit(x, y)
ae = AutoEmulate(x, y, models=["GaussianProcess"], model_tuning=False)
res = ae.best_result()

observations = {"infection_rate": (0.3, 0.05)}

hm = HistoryMatchingWorkflow(
simulator=simulator,
emulator=gp,
result=res,
observations=observations,
threshold=3.0,
model_discrepancy=0.1,
Expand All @@ -129,7 +125,7 @@ def test_run(device):

# Check basic structure of results
assert isinstance(hm.train_x, TensorLike)
assert isinstance(hm.emulator, GaussianProcess)
assert isinstance(hm.emulator, TransformedEmulator)

assert len(hm.train_x) == 5

Expand All @@ -148,20 +144,20 @@ def test_run_max_tries():
assert isinstance(y, TensorLike)

# Run history matching
gp = GaussianProcess(x, y)
gp.fit(x, y)
ae = AutoEmulate(x, y, models=["GaussianProcess"], model_tuning=False)
res = ae.best_result()

# Extreme values outside the range of what the simulator returns
observations = {"infection_rate": (100.0, 1.0)}

hm = HistoryMatchingWorkflow(
simulator=simulator,
emulator=gp,
result=res,
observations=observations,
threshold=3.0,
model_discrepancy=0.1,
rank=1,
)

with pytest.raises(RuntimeError):
with pytest.raises(Warning):
hm.run(n_simulations=5)