Skip to content

Commit f728f15

Browse files
xuxinyi389zty-king
andauthored
[AutoParallel] PipelineStage (#72155)
* PipelineStage * poolish_code * 修改PipelineStage中的一些代码问题,同时提交相关的单测 * 优化map_structure_only,增加一些异常情况的测试 * 规范工具类函数的函数名 * 增加训练轮数,修正typo,增加标注 * support_auto_dp * codestyle * fix_multi_stages_inference_backward * fix_prepare_bwd * flatten_args增加条件判断,并在单测中测试 * 修复最后一个stage的bias的grad存在bug的问题,并添加layer有bias的单测,同时优化输入变量强制stop_grad=True * 单测匹配fix_prepare_bwd --------- Co-authored-by: zty-king <[email protected]>
1 parent 0de3f8e commit f728f15

File tree

7 files changed

+2141
-1
lines changed

7 files changed

+2141
-1
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__all__ = []
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import TYPE_CHECKING, Any
19+
20+
if TYPE_CHECKING:
21+
from collections.abc import Iterator
22+
23+
import paddle
24+
25+
from .utils import _map_debug_info
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
def stage_backward_input(
31+
stage_outputs_or_loss: list[paddle.Tensor],
32+
output_grads: list[paddle.Tensor] | None,
33+
input_values: list[paddle.Tensor],
34+
weights: Iterator[paddle.Tensor],
35+
) -> tuple[tuple[paddle.Tensor | None, ...], list[dict[str, Any]]]:
36+
raise NotImplementedError("stage_backward_input is not implemented yet")
37+
38+
39+
def stage_backward_weight(
40+
weights: Iterator[paddle.Tensor],
41+
param_groups: list[dict[str, Any]],
42+
retain_graph=False,
43+
) -> tuple[paddle.Tensor | None, ...]:
44+
raise NotImplementedError("stage_backward_weight is not implemented yet")
45+
46+
47+
def stage_backward(
48+
stage_output,
49+
output_grads,
50+
input_values,
51+
) -> tuple[paddle.Tensor | None, ...]:
52+
"""
53+
This is a helper function to:
54+
1. compute the gradients for the stage inputs, and
55+
2. accumulate gradients for the stage module's parameters.
56+
57+
Given the input value(s) and the corresponding gradient for the output
58+
value(s), compute and accumulate gradients for all parameter values (leaves
59+
in the autograd trace) as well as return a list of the gradients for the
60+
input values
61+
62+
"""
63+
64+
try:
65+
# stage_output may be a composite datatype like dict. Extract all individual
66+
# tensor values here
67+
stage_output_tensors: list[paddle.Tensor] = []
68+
output_grad_tensors: list[paddle.Tensor | None] = []
69+
70+
def extract_tensors_with_grads(
71+
output_val,
72+
grad_val,
73+
extract_tensors_with_grads,
74+
):
75+
if isinstance(output_val, paddle.Tensor):
76+
if output_val.stop_gradient and output_val.grad_fn is None:
77+
return
78+
assert isinstance(
79+
grad_val, (paddle.Tensor, type(None))
80+
), f"Expected Tensor or None gradient but got {type(grad_val)}"
81+
stage_output_tensors.append(output_val)
82+
output_grad_tensors.append(grad_val)
83+
elif isinstance(output_val, (tuple, list)):
84+
if grad_val is None:
85+
return
86+
assert isinstance(
87+
grad_val, (tuple, list)
88+
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
89+
assert len(output_val) == len(grad_val)
90+
for ov, gv in zip(output_val, grad_val):
91+
extract_tensors_with_grads(
92+
ov,
93+
gv,
94+
extract_tensors_with_grads,
95+
)
96+
elif isinstance(output_val, dict):
97+
if grad_val is None:
98+
return
99+
assert isinstance(grad_val, dict)
100+
assert set(output_val.keys()) == set(grad_val.keys())
101+
for k in output_val.keys():
102+
extract_tensors_with_grads(
103+
output_val[k], grad_val[k], extract_tensors_with_grads
104+
)
105+
else:
106+
# Output is a non-tensor type; just ignore it
107+
pass
108+
109+
# Note: ref cycle
110+
# break a ref cycle that would keep tensors alive until GC runs
111+
# 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward
112+
# and used in extract_tensors_with_grads
113+
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
114+
# and to itself (extract_tensors_with_grads) since it makes a recursive call
115+
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
116+
# fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore
117+
extract_tensors_with_grads(
118+
stage_output, output_grads, extract_tensors_with_grads
119+
)
120+
# Deactivate auto mixed precision context in the backward phase
121+
with paddle.amp.auto_cast(enable=False):
122+
paddle.autograd.backward(
123+
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
124+
)
125+
126+
# Extract gradients wrt the input values
127+
grad_inputs: list[paddle.Tensor | None] = []
128+
for val in input_values:
129+
if isinstance(val, paddle.Tensor):
130+
grad_inputs.append(val.grad)
131+
else:
132+
grad_inputs.append(None)
133+
134+
except Exception as e:
135+
exc_msg = f"""
136+
Failed to run stage backward:
137+
Stage output: {_map_debug_info(stage_output)}
138+
Output gradient: {_map_debug_info(output_grads)}
139+
Input: {_map_debug_info(input_values)}
140+
"""
141+
raise RuntimeError(exc_msg) from e
142+
143+
return tuple(grad_inputs)

0 commit comments

Comments
 (0)