@@ -144,26 +144,59 @@ def forward(self, x: TensorLike) -> TensorLike | None:
144144 return y
145145 return None
146146
147- def forward_batch (
148- self , x : TensorLike , return_x : bool = False
149- ) -> TensorLike | tuple [ TensorLike , TensorLike ]:
150- """
151- Run multiple simulations with different parameters .
147+ def forward_batch (self , x : TensorLike ) -> TensorLike :
148+ """Run multiple simulations with different parameters.
149+
150+ For infallible simulators that always succeed.
151+ If your simulator might fail, use `forward_batch_skip_failures()` instead .
152152
153153 Parameters
154154 ----------
155155 x: TensorLike
156156 Tensor of input parameters to make predictions for.
157- return_x: bool
158- Whether to return parameters for simulation runs that completed succesfully.
159- Set to False if simulation always completes. Defaults to False.
160157
161- Returns:
158+ Returns
162159 -------
163- TensorLike | tuple[TensorLike, TensorLike]
160+ TensorLike
164161 Tensor of simulation results of shape (n_batch, self.out_dim).
165- If `return_x` is True, also returns parameters corresponding to succesful
166- simulation results.
162+
163+ Raises
164+ ------
165+ RuntimeError
166+ If the number of simulations does not match the input.
167+ Use `forward_batch_skip_failures()` to handle failures.
168+ """
169+ results , x_valid = self .forward_batch_skip_failures (x )
170+
171+ # Raise an error if the number of simulations does not match the input
172+ if x .shape [0 ] != x_valid .shape [0 ]:
173+ msg = (
174+ "Some simulations failed. Use forward_batch_skip_failures() to handle "
175+ "failures."
176+ )
177+ raise RuntimeError (msg )
178+
179+ return results
180+
181+ def forward_batch_skip_failures (
182+ self , x : TensorLike
183+ ) -> tuple [TensorLike , TensorLike ]:
184+ """Run multiple simulations, skipping any that fail.
185+
186+ For simulators where for some inputs the simulation can fail.
187+ Failed simulations are skipped, and only successful results are returned
188+ along with their corresponding input parameters.
189+
190+ Parameters
191+ ----------
192+ x: TensorLike
193+ Tensor of input parameters to make predictions for.
194+
195+ Returns
196+ -------
197+ tuple[TensorLike, TensorLike]
198+ Tuple of (simulation_results, valid_input_parameters).
199+ Only successful simulations are included.
167200 """
168201 self .logger .info ("Running batch simulation for %d samples" , len (x ))
169202
@@ -202,9 +235,8 @@ def forward_batch(
202235
203236 # stack results into a 2D array on first dim using torch
204237 results_tensor = torch .cat (results , dim = 0 )
205- if return_x :
206- return results_tensor , x [valid_idx ]
207- return results_tensor
238+
239+ return results_tensor , x [valid_idx ]
208240
209241 def get_parameter_idx (self , name : str ) -> int :
210242 """
0 commit comments