@@ -145,64 +145,65 @@ def forward(self, x: TensorLike) -> TensorLike | None:
145145 return None
146146
147147 def forward_batch (
148- self , samples : TensorLike , return_failed_idx : bool = False
148+ self , x : TensorLike , return_x : bool = False
149149 ) -> TensorLike | tuple [TensorLike , TensorLike ]:
150150 """
151151 Run multiple simulations with different parameters.
152152
153153 Parameters
154154 ----------
155- samples : TensorLike
155+ x : TensorLike
156156 Tensor of input parameters to make predictions for.
157- return_failed_idx: bool
158- Whether to return indexes of failed simulations. Defaults to False.
157+ return_x: bool
158+ Whether to return parameters for simulation runs that completed succesfully.
159+ Set to False if simulation always completes. Defaults to False.
159160
160161 Returns:
161162 -------
162163 TensorLike | tuple[TensorLike, TensorLike]
163164 Tensor of simulation results of shape (n_batch, self.out_dim).
164- If `return_failed_idx ` is True, also returns tensor of failed
165- simulation indexes .
165+ If `return_x ` is True, also returns parameters corresponding to succesful
166+ simulation results .
166167 """
167- self .logger .info ("Running batch simulation for %d samples" , len (samples ))
168+ self .logger .info ("Running batch simulation for %d samples" , len (x ))
168169
169170 results = []
170171 successful = 0
171- failed_idx = []
172+ valid_idx = []
172173
173174 # Process each sample with progress tracking
174175 for i in tqdm (
175- range (len (samples )),
176+ range (len (x )),
176177 desc = "Running simulations" ,
177178 disable = not self .progress_bar ,
178- total = len (samples ),
179+ total = len (x ),
179180 unit = "sample" ,
180181 unit_scale = True ,
181182 ):
182- logger .debug ("Running simulation for sample %d/%d" , i + 1 , len (samples ))
183- result = self .forward (samples [i : i + 1 ])
183+ logger .debug ("Running simulation for sample %d/%d" , i + 1 , len (x ))
184+ result = self .forward (x [i : i + 1 ])
184185 if result is not None :
185186 results .append (result )
186187 successful += 1
187- logger .debug ("Simulation %d/%d successful" , i + 1 , len (samples ))
188+ valid_idx .append (i )
189+ logger .debug ("Simulation %d/%d successful" , i + 1 , len (x ))
188190 else :
189- failed_idx .append (i )
190191 logger .warning (
191- "Simulation %d/%d failed. Result is None." , i + 1 , len (samples )
192+ "Simulation %d/%d failed. Result is None." , i + 1 , len (x )
192193 )
193194
194195 # Report results
195196 self .logger .info (
196197 "Successfully completed %d/%d simulations (%.1f%%)" ,
197198 successful ,
198- len (samples ),
199- (successful / len (samples ) * 100 if len (samples ) > 0 else 0.0 ),
199+ len (x ),
200+ (successful / len (x ) * 100 if len (x ) > 0 else 0.0 ),
200201 )
201202
202203 # stack results into a 2D array on first dim using torch
203204 results_tensor = torch .cat (results , dim = 0 )
204- if return_failed_idx :
205- return results_tensor , torch . tensor ( failed_idx )
205+ if return_x :
206+ return results_tensor , x [ valid_idx ]
206207 return results_tensor
207208
208209 def get_parameter_idx (self , name : str ) -> int :
0 commit comments