@@ -225,7 +225,14 @@ def _step_microbatches(
225225 raise NotImplementedError
226226
227227 @abstractmethod
228- def step (self , * args , target = None , losses : list | None = None , ** kwargs ):
228+ def step (
229+ self ,
230+ * args ,
231+ target = None ,
232+ losses : list | None = None ,
233+ return_output : bool = False ,
234+ ** kwargs ,
235+ ):
229236 """
230237 Run one iteration of the pipeline schedule with *whole-batch* input.
231238 Will chunk the input into microbatches automatically, and go through the
@@ -362,7 +369,14 @@ def _initialize_stage(self, args, kwargs, labels):
362369 self ._stage ._prepare_backward_infra (self ._n_microbatches , loss )
363370 self ._stage_initialized = True
364371
365- def step (self , * args , target = None , losses : list | None = None , ** kwargs ):
372+ def step (
373+ self ,
374+ * args ,
375+ target = None ,
376+ losses : list | None = None ,
377+ return_output : bool = False ,
378+ ** kwargs ,
379+ ):
366380 """
367381 Run one iteration of the pipeline schedule with *whole-batch* input.
368382 Will chunk the input into microbatches automatically, and go through the
@@ -390,10 +404,10 @@ def step(self, *args, target=None, losses: list | None = None, **kwargs):
390404 self ._step_microbatches (args_split , kwargs_split , targets_split , losses )
391405
392406 # Return merged results per original format
393- if self . _stage . is_last :
394- return self ._merge_outputs ( self . _stage .output_chunks )
395- else :
396- return None
407+ if return_output :
408+ if self ._stage .is_last :
409+ return self . _merge_outputs ( self . _stage . output_chunks )
410+ return None
397411
398412
399413def _batch_p2p (p2p_ops : list [dist .P2POp ], desc : str | None = None ):
@@ -879,7 +893,14 @@ def _initialize_stages(self, args: tuple[Any, ...], kwargs, labels):
879893 )
880894 self ._stages_initialized = True
881895
882- def step (self , * args , target = None , losses : list | None = None , ** kwargs ):
896+ def step (
897+ self ,
898+ * args ,
899+ target = None ,
900+ losses : list | None = None ,
901+ return_output : bool = False ,
902+ ** kwargs ,
903+ ):
883904 """
884905 Run one iteration of the pipeline schedule with *whole-batch* input.
885906 Will chunk the input into microbatches automatically, and go through the
@@ -906,9 +927,10 @@ def step(self, *args, target=None, losses: list | None = None, **kwargs):
906927 self ._step_microbatches (args_split , kwargs_split , targets_split , losses )
907928
908929 # Return merged results per original format
909- for stage in self ._stages :
910- if stage .is_last :
911- return self ._merge_outputs (stage .output_chunks )
930+ if return_output :
931+ for stage in self ._stages :
932+ if stage .is_last :
933+ return self ._merge_outputs (stage .output_chunks )
912934 # Does not contain the last stage
913935 return None
914936
0 commit comments