Skip to content

Commit 2beead7

Browse files
H-Huangpytorchmergebot
authored andcommitted
[PP] move FSDP reduce scatters to end of step (pytorch#165106)
Move FSDP reduce scatters to the end of the PP step. The reduce scatter compute stream sync blocks the other stages from executing their backwards leading to bubbles. There should be a way to execute these RS earlier, but doing this for now as a quick fix. <img width="1056" height="463" alt="image" src="https://github.com/user-attachments/assets/b945dd55-8ab1-4acc-b862-c6e2e476b834" /> Pull Request resolved: pytorch#165106 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#164976
1 parent 3a110c9 commit 2beead7

File tree

3 files changed

+40
-38
lines changed

3 files changed

+40
-38
lines changed

test/distributed/test_composability.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _build_pp_schedule(
146146
total_layers,
147147
apply_dp,
148148
loss_fn,
149+
scale_grads=True,
149150
):
150151
if issubclass(ScheduleClass, PipelineScheduleSingle):
151152
pipeline_stage, offset = self._build_pp_stage(
@@ -163,6 +164,7 @@ def _build_pp_schedule(
163164
pipeline_stage,
164165
n_microbatches=num_microbatches,
165166
loss_fn=loss_fn,
167+
scale_grads=scale_grads,
166168
)
167169
else:
168170
n_virtual = 2
@@ -185,6 +187,7 @@ def _build_pp_schedule(
185187
stages,
186188
n_microbatches=num_microbatches,
187189
loss_fn=loss_fn,
190+
scale_grads=scale_grads,
188191
)
189192
return pipeline_schedule, partial_models, offsets
190193

@@ -523,8 +526,8 @@ def create_schedule(computation_types, microbatch_index=None):
523526
runtime.pipeline_order_with_comms = unshard_schedule
524527
runtime.step(dummy_input)
525528

526-
# Verify parameters are now unsharded
527-
check_fsdp_unsharded_state(stage.submod, expected_unsharded=True)
529+
# Verify parameters are still sharded
530+
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
528531

529532

530533
instantiate_parametrized_tests(ComposabilityTest)

torch/distributed/pipelining/schedules.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,10 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
625625
# Run microbatches
626626
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
627627

628+
# Stage post processing
629+
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
630+
self._stage._post_backward(grad_scale_factor)
631+
628632
# Return merged results per original format
629633
if self._stage.is_last:
630634
return self._merge_outputs(self._stage.output_chunks)
@@ -773,10 +777,6 @@ def _step_microbatches(
773777

774778
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
775779

776-
self._stage.scale_grads(
777-
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
778-
)
779-
780780
# Wait for all backward sends to finish
781781
for work in bwd_sends_to_wait:
782782
_wait_batch_p2p(work)
@@ -951,10 +951,6 @@ def _step_microbatches(
951951
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
952952
bwd_mb_index += 1
953953

954-
self._stage.scale_grads(
955-
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
956-
)
957-
958954
# Wait for the last backward send to finish
959955
_wait_batch_p2p(send_work)
960956

@@ -1555,6 +1551,12 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
15551551
# Run microbatches
15561552
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
15571553

1554+
# Stage post processing
1555+
# TODO: remove this section and include as part of the schedule IR?
1556+
for stage in self._stages:
1557+
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
1558+
stage._post_backward(grad_scale_factor)
1559+
15581560
# Return merged results per original format
15591561
for stage in self._stages:
15601562
if stage.is_last:
@@ -2086,15 +2088,12 @@ def _perform_action(action: _Action) -> None:
20862088
loss = self._maybe_get_loss(stage, mb_index)
20872089
backward_counter[stage_idx] += 1
20882090
last_backward = backward_counter[stage_idx] == self._n_microbatches
2089-
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
20902091
stage.backward_one_chunk(
20912092
mb_index,
20922093
loss=loss,
20932094
full_backward=True,
20942095
last_backward=last_backward,
20952096
)
2096-
if last_backward:
2097-
stage.scale_grads(grad_scale_factor)
20982097
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
20992098
# see [Note: V-schedule special case]
21002099
if is_prev_stage_on_this_rank:
@@ -2131,13 +2130,10 @@ def _perform_action(action: _Action) -> None:
21312130
_assert_unsharded(stage_idx)
21322131
backward_counter[stage_idx] += 1
21332132
last_backward = backward_counter[stage_idx] == self._n_microbatches
2134-
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
21352133
stage.backward_weight_one_chunk(
21362134
mb_index,
21372135
last_backward=last_backward,
21382136
)
2139-
if last_backward:
2140-
stage.scale_grads(grad_scale_factor)
21412137
else:
21422138
raise ValueError(f"{action=} is unknown or unsupported")
21432139

torch/distributed/pipelining/stage.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -651,28 +651,6 @@ def perform_backward(
651651
self.submod.set_reshard_after_backward(False)
652652
self.submod.set_requires_gradient_sync(False)
653653
result = perform_backward(backward_type)()
654-
if last_backward:
655-
# Manually call post backward for FSDP
656-
def run_post_backward(fsdp_module: FSDPModule) -> None:
657-
fsdp_module.set_is_last_backward(True)
658-
fsdp_module.set_reshard_after_backward(True)
659-
fsdp_module.set_requires_gradient_sync(True)
660-
661-
if isinstance(fsdp_module, ReplicateModule):
662-
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
663-
else:
664-
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
665-
666-
for state in distributed_state._state_ctx.all_states:
667-
if state._fsdp_param_group:
668-
state._fsdp_param_group.post_backward()
669-
670-
# it would be much better if pipelining backward invoked .backward so autograd hooks
671-
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
672-
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
673-
distributed_state._root_post_backward_final_callback()
674-
675-
run_post_backward(self.submod)
676654

677655
else:
678656
# Non-DP submodule, regular backward
@@ -998,6 +976,31 @@ def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]:
998976

999977
return ops
1000978

979+
def _post_backward(self, grad_scale_factor: int):
980+
# Manually call post backward for FSDP
981+
if isinstance(self.submod, FSDPModule):
982+
fsdp_module = self.submod
983+
fsdp_module.set_is_last_backward(True)
984+
fsdp_module.set_reshard_after_backward(True)
985+
fsdp_module.set_requires_gradient_sync(True)
986+
987+
if isinstance(fsdp_module, ReplicateModule):
988+
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
989+
else:
990+
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
991+
992+
for state in distributed_state._state_ctx.all_states:
993+
if state._fsdp_param_group:
994+
state._fsdp_param_group.post_backward()
995+
996+
# it would be much better if pipelining backward invoked .backward so autograd hooks
997+
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
998+
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
999+
distributed_state._root_post_backward_final_callback()
1000+
# Call gradient scaling at the end of the backward pass
1001+
# NOTE: this must happen after FSDP post_backward is FSDP is enabled
1002+
self.scale_grads(grad_scale_factor)
1003+
10011004

10021005
class _PipelineStage(_PipelineStageBase):
10031006
def __init__(

0 commit comments

Comments
 (0)