Skip to content

Commit bfa6bbf

Browse files
committed
Simulator.forward_batch can return params for which simulation succeeded
1 parent 73453f5 commit bfa6bbf

File tree

3 files changed

+28
-34
lines changed

3 files changed

+28
-34
lines changed

autoemulate/experimental/calibration/history_matching.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,19 +407,15 @@ def simulate(self, x: TensorLike) -> tuple[TensorLike, TensorLike]:
407407
tuple[TensorLike, TensorLike]
408408
Tensors of succesfully simulated input parameters and predictions.
409409
"""
410-
y, exclude_indices = self.simulator.forward_batch(x, return_failed_idx=True)
410+
# if simulation fails, returned y and x have fewer rows than input x
411+
y, x = self.simulator.forward_batch(x, return_x=True)
411412
y = y.to(self.device)
412-
exclude_indices = exclude_indices.to(self.device)
413+
x = x.to(self.device)
413414

414-
all_indices = torch.arange(x.size(0), device=self.device)
415-
mask = ~torch.isin(all_indices, exclude_indices)
416-
valid_x = x[mask]
417-
valid_y = y[mask]
415+
self.train_y = torch.cat([self.train_y, y], dim=0)
416+
self.train_x = torch.cat([self.train_x, x], dim=0)
418417

419-
self.train_y = torch.cat([self.train_y, valid_y], dim=0)
420-
self.train_x = torch.cat([self.train_x, valid_x], dim=0)
421-
422-
return valid_x, valid_y
418+
return x, y
423419

424420
def run(
425421
self,

autoemulate/experimental/simulations/base.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,64 +145,65 @@ def forward(self, x: TensorLike) -> TensorLike | None:
145145
return None
146146

147147
def forward_batch(
148-
self, samples: TensorLike, return_failed_idx: bool = False
148+
self, x: TensorLike, return_x: bool = False
149149
) -> TensorLike | tuple[TensorLike, TensorLike]:
150150
"""
151151
Run multiple simulations with different parameters.
152152
153153
Parameters
154154
----------
155-
samples: TensorLike
155+
x: TensorLike
156156
Tensor of input parameters to make predictions for.
157-
return_failed_idx: bool
158-
Whether to return indexes of failed simulations. Defaults to False.
157+
return_x: bool
158+
Whether to return parameters for simulation runs that completed succesfully.
159+
Set to False if simulation always completes. Defaults to False.
159160
160161
Returns:
161162
-------
162163
TensorLike | tuple[TensorLike, TensorLike]
163164
Tensor of simulation results of shape (n_batch, self.out_dim).
164-
If `return_failed_idx` is True, also returns tensor of failed
165-
simulation indexes.
165+
If `return_x` is True, also returns parameters corresponding to succesful
166+
simulation results.
166167
"""
167-
self.logger.info("Running batch simulation for %d samples", len(samples))
168+
self.logger.info("Running batch simulation for %d samples", len(x))
168169

169170
results = []
170171
successful = 0
171-
failed_idx = []
172+
valid_idx = []
172173

173174
# Process each sample with progress tracking
174175
for i in tqdm(
175-
range(len(samples)),
176+
range(len(x)),
176177
desc="Running simulations",
177178
disable=not self.progress_bar,
178-
total=len(samples),
179+
total=len(x),
179180
unit="sample",
180181
unit_scale=True,
181182
):
182-
logger.debug("Running simulation for sample %d/%d", i + 1, len(samples))
183-
result = self.forward(samples[i : i + 1])
183+
logger.debug("Running simulation for sample %d/%d", i + 1, len(x))
184+
result = self.forward(x[i : i + 1])
184185
if result is not None:
185186
results.append(result)
186187
successful += 1
187-
logger.debug("Simulation %d/%d successful", i + 1, len(samples))
188+
valid_idx.append(i)
189+
logger.debug("Simulation %d/%d successful", i + 1, len(x))
188190
else:
189-
failed_idx.append(i)
190191
logger.warning(
191-
"Simulation %d/%d failed. Result is None.", i + 1, len(samples)
192+
"Simulation %d/%d failed. Result is None.", i + 1, len(x)
192193
)
193194

194195
# Report results
195196
self.logger.info(
196197
"Successfully completed %d/%d simulations (%.1f%%)",
197198
successful,
198-
len(samples),
199-
(successful / len(samples) * 100 if len(samples) > 0 else 0.0),
199+
len(x),
200+
(successful / len(x) * 100 if len(x) > 0 else 0.0),
200201
)
201202

202203
# stack results into a 2D array on first dim using torch
203204
results_tensor = torch.cat(results, dim=0)
204-
if return_failed_idx:
205-
return results_tensor, torch.tensor(failed_idx)
205+
if return_x:
206+
return results_tensor, x[valid_idx]
206207
return results_tensor
207208

208209
def get_parameter_idx(self, name: str) -> int:

tests/experimental/test_experimental_base_simulator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,9 @@ def _forward(self, x: TensorLike) -> TensorLike | None:
188188

189189
# This should process all samples without errors
190190
# We're just verifying it doesn't crash
191-
results, invalid_indices = simulator.forward_batch(batch, return_failed_idx=True)
191+
results, valid_x = simulator.forward_batch(batch, return_x=True)
192192
assert isinstance(results, TensorLike)
193193

194194
# Verify results shape
195195
assert results.shape == (2, 1)
196-
197-
assert len(invalid_indices) == 2
198-
assert invalid_indices[0] == 0
199-
assert invalid_indices[1] == 2
196+
assert valid_x.shape == (2, 3)

0 commit comments

Comments
 (0)