Skip to content

Commit 395ee0b

Browse files
committed
Add method for forward batch where simulator can fail
1 parent bfa6bbf commit 395ee0b

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

autoemulate/experimental/calibration/history_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def simulate(self, x: TensorLike) -> tuple[TensorLike, TensorLike]:
408408
Tensors of succesfully simulated input parameters and predictions.
409409
"""
410410
# if simulation fails, returned y and x have fewer rows than input x
411-
y, x = self.simulator.forward_batch(x, return_x=True)
411+
y, x = self.simulator.forward_batch_skip_failures(x)
412412
y = y.to(self.device)
413413
x = x.to(self.device)
414414

autoemulate/experimental/simulations/base.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,26 +144,59 @@ def forward(self, x: TensorLike) -> TensorLike | None:
144144
return y
145145
return None
146146

147-
def forward_batch(
148-
self, x: TensorLike, return_x: bool = False
149-
) -> TensorLike | tuple[TensorLike, TensorLike]:
150-
"""
151-
Run multiple simulations with different parameters.
147+
def forward_batch(self, x: TensorLike) -> TensorLike:
148+
"""Run multiple simulations with different parameters.
149+
150+
For infallible simulators that always succeed.
151+
If your simulator might fail, use `forward_batch_skip_failures()` instead.
152152
153153
Parameters
154154
----------
155155
x: TensorLike
156156
Tensor of input parameters to make predictions for.
157-
return_x: bool
158-
Whether to return parameters for simulation runs that completed succesfully.
159-
Set to False if simulation always completes. Defaults to False.
160157
161-
Returns:
158+
Returns
162159
-------
163-
TensorLike | tuple[TensorLike, TensorLike]
160+
TensorLike
164161
Tensor of simulation results of shape (n_batch, self.out_dim).
165-
If `return_x` is True, also returns parameters corresponding to succesful
166-
simulation results.
162+
163+
Raises
164+
------
165+
RuntimeError
166+
If the number of simulations does not match the input.
167+
Use `forward_batch_skip_failures()` to handle failures.
168+
"""
169+
results, x_valid = self.forward_batch_skip_failures(x)
170+
171+
# Raise an error if the number of simulations does not match the input
172+
if x.shape[0] != x_valid.shape[0]:
173+
msg = (
174+
"Some simulations failed. Use forward_batch_skip_failures() to handle "
175+
"failures."
176+
)
177+
raise RuntimeError(msg)
178+
179+
return results
180+
181+
def forward_batch_skip_failures(
182+
self, x: TensorLike
183+
) -> tuple[TensorLike, TensorLike]:
184+
"""Run multiple simulations, skipping any that fail.
185+
186+
For simulators where for some inputs the simulation can fail.
187+
Failed simulations are skipped, and only successful results are returned
188+
along with their corresponding input parameters.
189+
190+
Parameters
191+
----------
192+
x: TensorLike
193+
Tensor of input parameters to make predictions for.
194+
195+
Returns
196+
-------
197+
tuple[TensorLike, TensorLike]
198+
Tuple of (simulation_results, valid_input_parameters).
199+
Only successful simulations are included.
167200
"""
168201
self.logger.info("Running batch simulation for %d samples", len(x))
169202

@@ -202,9 +235,8 @@ def forward_batch(
202235

203236
# stack results into a 2D array on first dim using torch
204237
results_tensor = torch.cat(results, dim=0)
205-
if return_x:
206-
return results_tensor, x[valid_idx]
207-
return results_tensor
238+
239+
return results_tensor, x[valid_idx]
208240

209241
def get_parameter_idx(self, name: str) -> int:
210242
"""

tests/experimental/test_experimental_base_simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _forward(self, x: TensorLike) -> TensorLike | None:
188188

189189
# This should process all samples without errors
190190
# We're just verifying it doesn't crash
191-
results, valid_x = simulator.forward_batch(batch, return_x=True)
191+
results, valid_x = simulator.forward_batch_skip_failures(batch)
192192
assert isinstance(results, TensorLike)
193193

194194
# Verify results shape

0 commit comments

Comments
 (0)