Skip to content

Commit

Permalink
Merge pull request #297 from alan-turing-institute/history_matching
Browse files Browse the repository at this point in the history
added history matching and test
  • Loading branch information
marjanfamili authored Feb 21, 2025
2 parents 5957407 + 3237767 commit b61fcf3
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 1,231 deletions.
49 changes: 49 additions & 0 deletions autoemulate/history_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np


def history_matching(obs, predictions, threshold=3.0, discrepancy=0.0, rank=1):
"""
Perform history matching to compute implausibility and identify NROY and RO points.
This implementation performs history matching as a single run, completing the process
in one execution without iterative refinement or staged waves.
The implausibility is calculated as the absolute difference between the observed and
predicted values, normalized by the square root of the sum of the variances of the
observed and predicted values. The implausibility is then compared to a threshold to
classify the points as NROY or RO. The discrepancy value(s) can be provided as a
scalar or an array to account for model discrepancy.
The rank parameter is used to select the number of observations to consider for implausibility calculation.
The default value is 1, which corresponds to the most recent observation.
Parameters:
obs (tuple): Observations as (mean, variance).
predictions (tuple): Predicted (mean, variance).
threshold (float): Implausibility threshold for NROY classification.
discrepancy (float or ndarray): Discrepancy value(s).
rank (int): Rank for implausibility calculation.
Returns:
dict: Contains implausibility (I), NROY indices, and RO indices.
"""
obs_mean, obs_var = np.atleast_1d(obs[0]), np.atleast_1d(obs[1])
pred_mean, pred_var = np.atleast_1d(predictions[0]), np.atleast_1d(predictions[1])
if len(obs_mean) != len(pred_mean[1]):
raise ValueError(
"The number of means in observations and predictions must be equal."
)
if len(obs_var) != len(pred_var[1]):
raise ValueError(
"The number of variances in observations and predictions must be equal."
)
discrepancy = np.atleast_1d(discrepancy)
n_obs = len(obs_mean)
rank = min(max(rank, 0), n_obs - 1)
if discrepancy.size == 1:
discrepancy = np.full(n_obs, discrepancy)

Vs = pred_var + discrepancy + obs_var
I = np.abs(obs_mean - pred_mean) / np.sqrt(Vs)

NROY = np.where(I <= threshold)[0]
RO = np.where(I > threshold)[0]

return {"I": I, "NROY": list(NROY), "RO": list(RO)}
1 change: 1 addition & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ chapters:
sections:
- file: reference/compare
- file: reference/datasets
- file: reference/history_matching
- file: reference/sensitivity_analysis
- file: reference/simulations/index
sections:
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/history_matching.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
====================

.. automodule:: autoemulate.history_matching
:members:
:show-inheritance:
1,312 changes: 81 additions & 1,231 deletions docs/tutorials/01_start.ipynb

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions tests/test_history_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

from autoemulate.history_matching import history_matching


@pytest.fixture
def sample_data_2d():
pred_mean = np.array([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1], [4.0, 4.1], [5.0, 5.1]])
pred_std = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4], [0.5, 0.5]])
pred_var = np.square(pred_std)
predictions = (pred_mean, pred_var)
obs = [(1.5, 0.1), (2.5, 0.2)]
return predictions, obs


@pytest.fixture
def sample_data_1d():
pred_mean = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]])
pred_std = np.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
pred_var = np.square(pred_std)
predictions = (pred_mean, pred_var)
obs = ([1.5], [0.5])
return predictions, obs


def test_history_matching_1d(sample_data_1d):
predictions, obs = sample_data_1d
result = history_matching(predictions=predictions, obs=obs, threshold=1.0)
assert "NROY" in result # Ensure the key exists in the result
assert isinstance(result["NROY"], list) # Validate that NROY is a list
assert len(result["NROY"]) > 0 # Ensure the list is not empty


def test_history_matching_threshold_1d(sample_data_1d):
predictions, obs = sample_data_1d
result = history_matching(predictions=predictions, obs=obs, threshold=0.5)
assert "NROY" in result
assert isinstance(result["NROY"], list)
assert len(result["NROY"]) <= len(predictions[0])


def test_history_matching_2d(sample_data_2d):
predictions, obs = sample_data_2d
result = history_matching(predictions=predictions, obs=obs, threshold=1.0)
assert "NROY" in result # Ensure the key exists in the result
assert isinstance(result["NROY"], list) # Validate that NROY is a list
assert len(result["NROY"]) > 0 # Ensure the list is not empty


def test_history_matching_threshold_2d(sample_data_2d):
predictions, obs = sample_data_2d
result = history_matching(predictions=predictions, obs=obs, threshold=0.5)
assert "NROY" in result
assert isinstance(result["NROY"], list)
assert len(result["NROY"]) <= len(predictions[0])

0 comments on commit b61fcf3

Please sign in to comment.