Skip to content

Commit 2fd8a7e

Browse files
authored
[AutoParallel] fix pp step return (#74913)
1 parent 348fa91 commit 2fd8a7e

File tree

1 file changed

+32
-10
lines changed
  • python/paddle/distributed/auto_parallel/pipelining

1 file changed

+32
-10
lines changed

python/paddle/distributed/auto_parallel/pipelining/schedules.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,14 @@ def _step_microbatches(
225225
raise NotImplementedError
226226

227227
@abstractmethod
228-
def step(self, *args, target=None, losses: list | None = None, **kwargs):
228+
def step(
229+
self,
230+
*args,
231+
target=None,
232+
losses: list | None = None,
233+
return_output: bool = False,
234+
**kwargs,
235+
):
229236
"""
230237
Run one iteration of the pipeline schedule with *whole-batch* input.
231238
Will chunk the input into microbatches automatically, and go through the
@@ -362,7 +369,14 @@ def _initialize_stage(self, args, kwargs, labels):
362369
self._stage._prepare_backward_infra(self._n_microbatches, loss)
363370
self._stage_initialized = True
364371

365-
def step(self, *args, target=None, losses: list | None = None, **kwargs):
372+
def step(
373+
self,
374+
*args,
375+
target=None,
376+
losses: list | None = None,
377+
return_output: bool = False,
378+
**kwargs,
379+
):
366380
"""
367381
Run one iteration of the pipeline schedule with *whole-batch* input.
368382
Will chunk the input into microbatches automatically, and go through the
@@ -390,10 +404,10 @@ def step(self, *args, target=None, losses: list | None = None, **kwargs):
390404
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
391405

392406
# Return merged results per original format
393-
if self._stage.is_last:
394-
return self._merge_outputs(self._stage.output_chunks)
395-
else:
396-
return None
407+
if return_output:
408+
if self._stage.is_last:
409+
return self._merge_outputs(self._stage.output_chunks)
410+
return None
397411

398412

399413
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None):
@@ -879,7 +893,14 @@ def _initialize_stages(self, args: tuple[Any, ...], kwargs, labels):
879893
)
880894
self._stages_initialized = True
881895

882-
def step(self, *args, target=None, losses: list | None = None, **kwargs):
896+
def step(
897+
self,
898+
*args,
899+
target=None,
900+
losses: list | None = None,
901+
return_output: bool = False,
902+
**kwargs,
903+
):
883904
"""
884905
Run one iteration of the pipeline schedule with *whole-batch* input.
885906
Will chunk the input into microbatches automatically, and go through the
@@ -906,9 +927,10 @@ def step(self, *args, target=None, losses: list | None = None, **kwargs):
906927
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
907928

908929
# Return merged results per original format
909-
for stage in self._stages:
910-
if stage.is_last:
911-
return self._merge_outputs(stage.output_chunks)
930+
if return_output:
931+
for stage in self._stages:
932+
if stage.is_last:
933+
return self._merge_outputs(stage.output_chunks)
912934
# Does not contain the last stage
913935
return None
914936

0 commit comments

Comments
 (0)