Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
095f053
Add all for emulators
sgreenbury Jul 12, 2025
28506b0
Add benchmark script
sgreenbury Jul 12, 2025
6af60fc
Update GP config, remove assignment
sgreenbury Jul 13, 2025
110db49
Update params for GP
sgreenbury Jul 14, 2025
1392bd6
Move benchmark script to experimental
sgreenbury Jul 14, 2025
2448b47
Add simulator name to base class
sgreenbury Jul 14, 2025
7e3fa1c
Add constants to simuilations init
sgreenbury Jul 14, 2025
fbc0af1
Update benchmark to loop over simulators
sgreenbury Jul 14, 2025
001dea6
Merge remote-tracking branch 'origin/fix/607' into 454-benchmark
sgreenbury Jul 14, 2025
e676e29
Remove method since added to base class
sgreenbury Jul 14, 2025
7780f18
Add exception handling
sgreenbury Jul 14, 2025
8625fe7
Fix click args
sgreenbury Jul 14, 2025
527cc30
Fix init
sgreenbury Jul 14, 2025
5c646cb
Fix values
sgreenbury Jul 14, 2025
4fa354d
Fix missing model config to pass to cv
sgreenbury Jul 15, 2025
51afb03
Add notebook to plot benchmark
sgreenbury Jul 15, 2025
db3542b
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 15, 2025
4518507
Add flexibility for output file name
sgreenbury Jul 15, 2025
493ab64
Merge branch 'main' into 454-benchmark
sgreenbury Jul 15, 2025
c31928a
Update plot_benchmarks.ipynb
sgreenbury Jul 15, 2025
64bab78
Merge branch '634-revise-gp-predict' into 454-benchmark
sgreenbury Jul 22, 2025
4770598
Add exception handling in tuner
sgreenbury Jul 22, 2025
9c3654e
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 22, 2025
c50a751
Rename as simulator registry
sgreenbury Jul 22, 2025
3941e93
Fix arg for cross_validate
sgreenbury Jul 22, 2025
28f11ee
Remove fixed seed for EnsembleMLP
sgreenbury Jul 22, 2025
edccdae
Revise MLP learning rate
sgreenbury Jul 22, 2025
65e8fe6
Fix max score calculation
sgreenbury Jul 22, 2025
6953ed0
Revise tuner logging debug msg
sgreenbury Jul 22, 2025
6510d54
Revise to use argmax
sgreenbury Jul 22, 2025
903c5bb
Add max_retries
sgreenbury Jul 23, 2025
ce762f2
Update default bounds for benchmarks
sgreenbury Jul 23, 2025
666245e
Add run_benchmark script
sgreenbury Jul 23, 2025
14bff1b
Print arguments
sgreenbury Jul 23, 2025
c2010cf
Rerun plot_benchmark
sgreenbury Jul 23, 2025
31a59d7
Add retry logic and exception handling for a given model
sgreenbury Jul 23, 2025
1758dd8
Update MLP tune config
sgreenbury Jul 23, 2025
d48463b
Raise error
sgreenbury Jul 23, 2025
6e300a5
Update benchmark to ensure sample supersets
sgreenbury Jul 23, 2025
1b27594
Update benchmark script and plot notebook
sgreenbury Jul 24, 2025
b5691eb
Move benchmark scripts, update config
sgreenbury Jul 24, 2025
bb761fb
Add benchmark README
sgreenbury Jul 24, 2025
9a9f0f2
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 24, 2025
0cadd3d
Refactor tuner, add comments, fix type hints
sgreenbury Jul 24, 2025
8d1126d
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 24, 2025
24150b5
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 24, 2025
43684ea
Update simulator registry
sgreenbury Jul 24, 2025
46725ea
Add FlowProblem to script, fix type lint
sgreenbury Jul 24, 2025
17c5013
Update n_iter in script
sgreenbury Jul 24, 2025
b8f4922
Fix script location
sgreenbury Jul 24, 2025
e2e3cf5
Add defaults for flow_problem
sgreenbury Jul 24, 2025
80caa20
Update results
sgreenbury Jul 25, 2025
7a01e5e
Fix: add seed to simulator
sgreenbury Jul 25, 2025
508b652
Merge remote-tracking branch 'origin/main' into 454-benchmark
sgreenbury Jul 26, 2025
5e1e984
Update plot_benchmark notebook
sgreenbury Jul 28, 2025
410bede
Update plot_benchmark notebook
sgreenbury Jul 28, 2025
83b3ec4
Update plot_benchmark notebook
sgreenbury Jul 28, 2025
2ca6acc
Revise to current main reasults
sgreenbury Jul 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ repos:
hooks:
- id: black
language_version: python3
exclude: "^autoemulate/experimental/|^tests/experimental/"
exclude: "^autoemulate/experimental/|^tests/experimental/|^benchmarks/"
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.12.0
hooks:
- id: reorder-python-imports
exclude: "^autoemulate/experimental/|^tests/experimental/"
exclude: "^autoemulate/experimental/|^tests/experimental/|^benchmarks/"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.4
Expand All @@ -18,13 +18,13 @@ repos:
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
files: ^autoemulate/experimental/|^tests/experimental/
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
files: ^autoemulate/experimental/|^tests/experimental/
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.398
hooks:
- id: pyright
files: ^autoemulate/experimental/|^tests/experimental/
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
5 changes: 3 additions & 2 deletions autoemulate/experimental/simulations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .epidemic import Epidemic
from .flow_problem import FlowProblem
from .projectile import Projectile, ProjectileMultioutput

ALL_SIMULATORS = [Epidemic, Projectile, ProjectileMultioutput]
ALL_SIMULATORS = [Epidemic, FlowProblem, Projectile, ProjectileMultioutput]

__all__ = ["Epidemic", "Projectile", "ProjectileMultioutput"]
__all__ = ["Epidemic", "FlowProblem", "Projectile", "ProjectileMultioutput"]

SIMULATOR_REGISTRY = dict(zip(__all__, ALL_SIMULATORS, strict=False))
27 changes: 25 additions & 2 deletions autoemulate/experimental/simulations/flow_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class FlowProblem(Simulator):

def __init__(
self,
parameters_range: dict[str, tuple[float, float]],
output_names: list[str],
parameters_range: dict[str, tuple[float, float]] | None = None,
output_names: list[str] | None = None,
log_level: str = "progress_bar",
ncycles: int = 10,
ncomp: int = 10,
Expand All @@ -43,6 +43,29 @@ def __init__(
ncomp: int
Number of compartments in the tube.
"""
if parameters_range is None:
parameters_range = {
# Cardiac cycle period (s)
"T": (0.5, 2.0),
# Pulse duration (s)
"td": (0.1, 0.5),
# Amplitude (e.g., pressure or flow rate)
"amp": (100.0, 1000.0),
# Time step (s)
"dt": (0.0001, 0.01),
# Compliance (unit varies based on context)
"C": (20.0, 60.0),
# Resistance (unit varies based on context)
"R": (0.01, 0.1),
# Inductance (unit varies based on context)
"L": (0.001, 0.005),
# Outflow resistance (unit varies based on context)
"R_o": (0.01, 0.05),
# Initial pressure (unit varies based on context)
"p_o": (5.0, 15.0),
}
if output_names is None:
output_names = ["pressure"]
super().__init__(parameters_range, output_names, log_level)
self.ncycles = ncycles
self.ncomp = ncomp
Expand Down
14 changes: 14 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Bechmarks

- [benchmark.py](./benchmark.py): a script with CLI for running batches of simulations with AutoEmulate for different numbers of tuningiterations
- [run_benchmark.sh](./run_benchmark.sh): runs batches of simulations enabling some parallelisation
- [plot_benchmark.ipynb](./plot_benchmark.ipynb): notebook for plotting results

## Quickstart
- Install [pueue](https://github.com/Nukesor/pueue): is included in [run_benchmark.sh](./run_benchmark.sh) and simplifies running multiple python scripts
- Run:
```bash
./run_benchmark.sh
```


118 changes: 118 additions & 0 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import itertools
from typing import cast

import click
import numpy as np
import pandas as pd
import torch
from autoemulate.experimental.compare import AutoEmulate
from autoemulate.experimental.emulators import ALL_EMULATORS
from autoemulate.experimental.emulators.base import Emulator
from autoemulate.experimental.simulations import SIMULATOR_REGISTRY
from autoemulate.experimental.simulations.base import Simulator
from tqdm import tqdm


def run_benchmark(
x: torch.Tensor, y: torch.Tensor, n_iter: int, n_splits: int, log_level: str
) -> pd.DataFrame:
ae = AutoEmulate(
x,
y,
models=cast(list[type[Emulator] | str], ALL_EMULATORS),
n_iter=n_iter,
n_splits=n_splits,
log_level=log_level,
)
return ae.summarise()


@click.command()
@click.option(
"--simulators",
type=str,
multiple=True,
default=["ProjectileMultioutput"],
help="Number of samples to generate",
)
@click.option(
"--n_samples_list",
type=int,
multiple=True,
default=[20, 50, 100, 200, 500],
help="Number of samples to generate",
)
@click.option(
"--n_iter_list",
type=int,
multiple=True,
default=[10, 50, 100],
help="Number of iterations to run",
)
@click.option(
"--n_splits_list",
type=int,
multiple=True,
default=[2, 5],
help="Number of splits for cross-validation",
)
@click.option(
"--seed",
type=int,
default=42,
help="Seed for the permutations over params",
)
@click.option(
"--output_file",
type=str,
default="benchmark_results.csv",
help="File name for output",
)
@click.option("--log_level", default="progress_bar", help="Logging level")
def main( # noqa: PLR0913
simulators, n_samples_list, n_iter_list, n_splits_list, seed, output_file, log_level
):
print(f"Running benchmark with simulators: {simulators}")
print(f"Number of samples: {n_samples_list}")
print(f"Number of iterations: {n_iter_list}")
print(f"Number of splits: {n_splits_list}")
print(f"Seed: {seed}")
print(f"Output file: {output_file}")
print(f"Log level: {log_level}")
print("-" * 50)

dfs = []
for simulator_str in simulators:
# Generate samples
simulator: Simulator = SIMULATOR_REGISTRY[simulator_str]()
max_samples = max(n_samples_list)
x_all = simulator.sample_inputs(max_samples, random_seed=seed).to(torch.float32)
y_all = simulator.forward_batch(x_all).to(torch.float32)

params = list(itertools.product(n_samples_list, n_iter_list, n_splits_list))
np.random.seed(seed)
params = np.random.permutation(params)
for n_samples, n_iter, n_splits in tqdm(params):
print(
f"Running benchmark for {simulator_str} with {n_samples} samples, "
f"{n_iter} iterations, and {n_splits} splits"
)
try:
x = x_all[:n_samples]
y = y_all[:n_samples]
df = run_benchmark(x, y, n_iter, n_splits, log_level)
df["simulator"] = simulator_str
df["n_samples"] = n_samples
df["n_iter"] = n_iter
df["n_splits"] = n_splits
dfs.append(df)
final_df = pd.concat(dfs, ignore_index=True)
final_df.sort_values("r2_test", ascending=False).to_csv(
output_file, index=False
)
except Exception as e:
print(f"Error raised while testing :\n{e}")


if __name__ == "__main__":
main()
127 changes: 127 additions & 0 deletions benchmarks/plot_benchmark.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "702bb87d",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"https://github.com/user-attachments/files/21469860/benchmark_results.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0dd4297",
"metadata": {},
"outputs": [],
"source": [
"N_BOOTSTRAPS=100\n",
"\n",
"def generate_plots(df, metric=\"r2_train\", exclude = [\"SupportVectorMachine\", \"LightGBM\"], fontsize=\"small\"):\n",
" simulator_list = sorted(df[\"simulator\"].unique().tolist())\n",
" n_iter_list = sorted(df[\"n_iter\"].unique().tolist())\n",
" n_splits_list = sorted(df[\"n_splits\"].unique().tolist())\n",
" color = {name:f\"C{idx}\" for idx, name in enumerate(sorted(df[\"model_name\"].unique().tolist()))}\n",
" for plot_idx, simulator in enumerate(simulator_list):\n",
" fig, axs = plt.subplots(len(n_splits_list), len(n_iter_list), figsize=(12, 6), squeeze=False)\n",
" handles = []\n",
" labels = []\n",
" for row_idx, n_splits in enumerate(n_splits_list):\n",
" for col_idx, n_iter in enumerate(n_iter_list):\n",
" subset = df[df[\"simulator\"].eq(simulator) & df[\"n_splits\"].eq(n_splits) & df[\"n_iter\"].eq(n_iter)]\n",
" ax = axs[row_idx][col_idx]\n",
" for idx, ((name,), group) in enumerate(subset.groupby([\"model_name\"], sort=True)): \n",
" if name in exclude:\n",
" continue\n",
" group_sorted = group.sort_values(\"n_samples\")\n",
" line = ax.plot(group_sorted[\"n_samples\"], group_sorted[metric], label=name, c=color[name])\n",
"\n",
" if row_idx == 0 and col_idx == 0:\n",
" handles.append(line[0])\n",
" labels.append(name)\n",
" \n",
" mean = group_sorted[metric]\n",
" ste = group_sorted[f\"{metric}_std\"] / np.sqrt(N_BOOTSTRAPS)\n",
" ax.fill_between(group_sorted[\"n_samples\"], mean - ste, mean + ste, alpha=0.2, lw=0, color=color[name])\n",
" ax.set_ylim(-0.1, 1.05)\n",
" # ax.set_xlim(df[\"n_samples\"].min(), df[\"n_samples\"].max())\n",
" ax.set_xlim(10, df[\"n_samples\"].max())\n",
" ax.axhline(0., lw=0.5, ls=\"--\", c=\"grey\", alpha=0.5, zorder=-1)\n",
" \n",
" ax.set_xscale(\"log\")\n",
" # ax.set_yscale(\"log\")\n",
" if col_idx == 0:\n",
" ax.set_ylabel(metric, size=fontsize)\n",
" if row_idx == len(n_splits_list)-1:\n",
" ax.set_xlabel(\"n_samples\", size=fontsize)\n",
" ax.tick_params(labelsize=fontsize)\n",
" ax.set_title(f\"{simulator} (n_iter={n_iter}, n_splits={n_splits})\", size=fontsize)\n",
" ax.grid(True, which='both', linestyle=':', linewidth=0.5, alpha=0.7)\n",
" fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=df[\"model_name\"].nunique()-len(exclude), fontsize=fontsize)\n",
" \n",
" # Adjust layout to make room for legend\n",
" plt.tight_layout()\n",
" plt.subplots_adjust(top=0.88)\n",
" \n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be99a004",
"metadata": {},
"outputs": [],
"source": [
"# All models\n",
"generate_plots(df, metric=\"r2_test\", exclude=[])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffa939d7",
"metadata": {},
"outputs": [],
"source": [
"# GPs, ensembles and MLPs only\n",
"generate_plots(df, metric=\"r2_test\", exclude=[\"RandomForest\", \"LightGBM\", \"SupportVectorMachine\", \"RadialBasisFunctions\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a313ab5c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
22 changes: 22 additions & 0 deletions benchmarks/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
set -e
source .venv/bin/activate

# Run the benchmark script with the specified parameters
date_time=$(date +"%Y-%m-%d_%H%M%S")
outpath="./benchmarks/data/${date_time}/"
mkdir -p "$outpath"
for simulator in Epidemic FlowProblem Projectile ProjectileMultioutput; do
for n_iter_pair in "10 100" "150 50" "200 20"; do
for n_splits in 5 2; do
n_iter_array=($n_iter_pair)
n_iter1=${n_iter_array[0]}
n_iter2=${n_iter_array[1]}
echo "Running benchmark for simulator: $simulator, n_splits: $n_splits, n_iter: $n_iter1 $n_iter2"
pueue add "python benchmarks/benchmark.py --simulators \"$simulator\" --n_splits_list \"$n_splits\" --n_iter_list \"$n_iter1\" --n_iter_list \"$n_iter2\" --log_level info --output_file \"${outpath}/benchmark_results_${simulator}_n_splits_${n_splits}_n_iter_${n_iter1}_${n_iter2}.csv\""
done
done
done

# Combine outputs with:
# xsv cat rows benchmarks/data/${date_time}/benchmark_*.csv > benchmark_results.csv
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ source = [".", "/tmp"]
[tool.pyright]
venvPath = "."
venv = ".venv"
include = ["autoemulate/experimental/*", "tests/experimental/*"]
include = ["autoemulate/experimental/*", "tests/experimental/*", "benchmarks/*"]

[tool.ruff]
src = ["autoemulate/"]
line-length = 88
include = ["autoemulate/experimental/**/*.py", "tests/experimental/**/*.py"]
include = ["autoemulate/experimental/**/*.py", "tests/experimental/**/*.py", "benchmarks/**/*.py"]
target-version = "py310"

[tool.ruff.format]
Expand Down
Loading