Skip to content

Commit 2ccc20e

Browse files
committed
PipelineStage
1 parent 018ffdd commit 2ccc20e

File tree

3 files changed

+1352
-0
lines changed

3 files changed

+1352
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Any, Iterator
19+
20+
from utils import map_debug_info
21+
22+
import paddle
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def stage_backward_input(
28+
stage_outputs_or_loss: list[paddle.Tensor],
29+
output_grads: list[paddle.Tensor] | None,
30+
input_values: list[paddle.Tensor],
31+
weights: Iterator[paddle.Tensor],
32+
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any]]]:
33+
raise NotImplementedError("stage_backward_input is not implemented yet")
34+
35+
36+
def stage_backward_weight(
37+
weights: Iterator[paddle.Tensor],
38+
param_groups: list[dict[str, Any]],
39+
retain_graph=False,
40+
) -> tuple[paddle.Tensor | None, ...]:
41+
raise NotImplementedError("stage_backward_weight is not implemented yet")
42+
43+
44+
def stage_backward(
45+
stage_output,
46+
output_grads,
47+
input_values,
48+
) -> tuple[paddle.Tensor | None, ...]:
49+
"""
50+
This is a helper function to:
51+
1. compute the gradients for the stage inputs, and
52+
2. accumulate gradients for the stage module's parameters.
53+
54+
Given the input value(s) and the corresponding gradient for the output
55+
value(s), compute and accumulate gradients for all parameter values (leaves
56+
in the autograd trace) as well as return a list of the gradients for the
57+
input values
58+
59+
"""
60+
61+
try:
62+
# stage_output may be a composite datatype like dict. Extract all individual
63+
# tensor values here
64+
stage_output_tensors: list[paddle.Tensor] = []
65+
output_grad_tensors: list[paddle.Tensor | None] = []
66+
67+
def extract_tensors_with_grads(
68+
output_val,
69+
grad_val,
70+
extract_tensors_with_grads,
71+
):
72+
if isinstance(output_val, paddle.Tensor):
73+
if output_val.stop_gradient and output_val.grad_fn is None:
74+
return
75+
assert isinstance(
76+
grad_val, (paddle.Tensor, type(None))
77+
), f"Expected Tensor or None gradient but got {type(grad_val)}"
78+
stage_output_tensors.append(output_val)
79+
output_grad_tensors.append(grad_val)
80+
elif isinstance(output_val, (tuple, list)):
81+
if grad_val is None:
82+
return
83+
assert isinstance(
84+
grad_val, (tuple, list)
85+
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
86+
assert len(output_val) == len(grad_val)
87+
for ov, gv in zip(output_val, grad_val):
88+
extract_tensors_with_grads(
89+
ov,
90+
gv,
91+
extract_tensors_with_grads,
92+
)
93+
elif isinstance(output_val, dict):
94+
if grad_val is None:
95+
return
96+
assert isinstance(grad_val, dict)
97+
assert set(output_val.keys()) == set(grad_val.keys())
98+
for k in output_val.keys():
99+
extract_tensors_with_grads(
100+
output_val[k], grad_val[k], extract_tensors_with_grads
101+
)
102+
else:
103+
# Output is a non-tensor type; just ignore it
104+
pass
105+
106+
# Note: ref cycle
107+
# break a ref cycle that would keep tensors alive until GC runs
108+
# 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward
109+
# and used in extract_tensors_with_grads
110+
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
111+
# and to itself (extract_tensors_with_grads) since it makes a recursive call
112+
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
113+
# fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore
114+
extract_tensors_with_grads(
115+
stage_output, output_grads, extract_tensors_with_grads
116+
)
117+
paddle.autograd.backward(
118+
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
119+
)
120+
121+
# Extract gradients wrt the input values
122+
grad_inputs: list[paddle.Tensor | None] = []
123+
for val in input_values:
124+
if isinstance(val, paddle.Tensor):
125+
grad_inputs.append(val.grad)
126+
else:
127+
grad_inputs.append(None)
128+
129+
except Exception as e:
130+
exc_msg = f"""
131+
Failed to run stage backward:
132+
Stage output: {map_debug_info(stage_output)}
133+
Output gradient: {map_debug_info(output_grads)}
134+
Input: {map_debug_info(input_values)}
135+
"""
136+
raise RuntimeError(exc_msg) from e
137+
138+
return tuple(grad_inputs)

0 commit comments

Comments
 (0)