Skip to content

Commit 31656c9

Browse files
authored
[Auto Parallel] Add spmd rule No.15 for index_put and index_put_grad ops. (#73486)
* add unary ops which have spmd_rule but not add in yaml file. * add spmd_rule for index_put. * remove annotations. * fix bug. * fix bug. * fix bug. * fix bug. * fix CI bug. * fix ci bug. * Adapt the changes in PR#73233 * fix ci bug. * fix ci bug. * apply review. * apply review.
1 parent 02208f4 commit 31656c9

File tree

7 files changed

+377
-1
lines changed

7 files changed

+377
-1
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/index_put.h"
16+
17+
#include "glog/logging.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
19+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
20+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
21+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
22+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
23+
24+
namespace phi::distributed {
25+
SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x,
26+
const std::vector<DistMetaTensor>& indices,
27+
const DistMetaTensor& value,
28+
const bool accumulate) {
29+
// Step0: verify input args based on group_norm logic
30+
auto x_shape = common::vectorize(x.dims());
31+
int indices_size = indices.size();
32+
auto indices_shape = common::vectorize(indices[0].dims());
33+
auto value_shape = common::vectorize(value.dims());
34+
int x_ndim = static_cast<int>(x_shape.size());
35+
int indices_ndim = static_cast<int>(indices_shape.size());
36+
int value_ndim = static_cast<int>(value_shape.size());
37+
38+
TensorDistAttr x_dist_attr_src = x.dist_attr();
39+
std::vector<TensorDistAttr> indices_dist_attrs_src;
40+
std::transform(indices.begin(),
41+
indices.end(),
42+
std::back_inserter(indices_dist_attrs_src),
43+
[](auto& meta) { return meta.dist_attr(); });
44+
TensorDistAttr value_dist_attr_src = value.dist_attr();
45+
46+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
47+
48+
PADDLE_ENFORCE_GE(x_ndim,
49+
indices_size,
50+
common::errors::InvalidArgument(
51+
"The ndim of x in index_put should be "
52+
"greater than or equal to the size of indices, "
53+
"but got x_ndim:[%d],indices_size:[%d].",
54+
x_ndim,
55+
indices_size));
56+
57+
PADDLE_ENFORCE_LE(
58+
value_ndim,
59+
x_ndim - indices_size + 1,
60+
common::errors::InvalidArgument("The ndim of value in index_put should "
61+
"be less than or equal to [%d], "
62+
"but got value_ndim:[%d].",
63+
x_ndim - indices_size + 1,
64+
value_ndim));
65+
PADDLE_ENFORCE_EQ(
66+
indices_ndim,
67+
1,
68+
common::errors::InvalidArgument(
69+
"The ndim of indices in index_put should be equal to 1, "
70+
"but got indices_ndim:[%d].",
71+
indices_ndim));
72+
for (int i = 0; i < indices_size; i++) {
73+
PADDLE_ENFORCE_EQ(
74+
indices[i].dims().size(),
75+
1,
76+
common::errors::InvalidArgument(
77+
"The ndim of indices[%d] in index_put should be equal to 1, "
78+
"but got indices[%d] ndim:[%d].",
79+
i,
80+
i,
81+
indices[i].dims().size()));
82+
}
83+
std::string alphabet = "ijklmnopqrstuvwxyz";
84+
std::string x_axes(x_ndim, '1');
85+
for (int i = 0; i < x_ndim; ++i) {
86+
x_axes[i] = alphabet[i];
87+
}
88+
std::string value_axes(value_ndim, '1');
89+
int index = indices_size - 1;
90+
for (int i = 0; i < value_ndim; ++i) {
91+
value_axes[i] = x_axes[index++];
92+
}
93+
94+
// Step1: set dims_mapping for input
95+
for (int i = 0; i < indices_size; i++) {
96+
x_dims_mapping[i] = -1;
97+
}
98+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
99+
ShardingMergeForTensors({{x_axes, x_dims_mapping}});
100+
// Step2: set dims_mapping for output
101+
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
102+
out_dist_attr.set_dims_mapping(x_dims_mapping);
103+
// Step3: update input dims mapping
104+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
105+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
106+
TensorDistAttr value_dist_attr_dst =
107+
CopyTensorDistAttrForOutput(value.dist_attr());
108+
value_dist_attr_dst.set_dims_mapping(
109+
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
110+
std::vector<TensorDistAttr> indices_dist_attrs_dst = indices_dist_attrs_src;
111+
for (auto& input_attr : indices_dist_attrs_dst) {
112+
input_attr.set_dims_mapping(std::vector<int64_t>{-1});
113+
}
114+
// Step4: Log SpmdInfo
115+
LOG_SPMD_INPUT(x);
116+
// LOG_SPMD_INPUT(indices);
117+
VLOG(4) << "name: indices";
118+
VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " "
119+
<< "indices_size: " << std::to_string(indices_size) << " "
120+
<< "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string()
121+
<< "] "
122+
<< "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string()
123+
<< "]";
124+
125+
LOG_SPMD_INPUT(value);
126+
LOG_SPMD_OUTPUT(out_dist_attr);
127+
128+
return {{x_dist_attr_dst, indices_dist_attrs_dst, value_dist_attr_dst},
129+
{out_dist_attr}};
130+
}
131+
132+
SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x,
133+
const std::vector<DistMetaTensor>& indices,
134+
const DistMetaTensor& value,
135+
const DistMetaTensor& out_grad,
136+
const bool accumulate) {
137+
// Step0: verify input args based on group_norm logic
138+
auto x_shape = common::vectorize(x.dims());
139+
int indices_size = indices.size();
140+
auto indices_shape = common::vectorize(indices[0].dims());
141+
auto value_shape = common::vectorize(value.dims());
142+
auto out_grad_shape = common::vectorize(out_grad.dims());
143+
int x_ndim = static_cast<int>(x_shape.size());
144+
int indices_ndim = static_cast<int>(indices_shape.size());
145+
int value_ndim = static_cast<int>(value_shape.size());
146+
int out_grad_ndim = static_cast<int>(out_grad_shape.size());
147+
TensorDistAttr x_dist_attr_src = x.dist_attr();
148+
std::vector<TensorDistAttr> indices_dist_attrs_src;
149+
std::transform(indices.begin(),
150+
indices.end(),
151+
std::back_inserter(indices_dist_attrs_src),
152+
[](auto& meta) { return meta.dist_attr(); });
153+
TensorDistAttr value_dist_attr_src = value.dist_attr();
154+
TensorDistAttr out_grad_dist_attr_src = out_grad.dist_attr();
155+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
156+
PADDLE_ENFORCE_EQ(
157+
out_grad_ndim,
158+
x_ndim,
159+
common::errors::InvalidArgument(
160+
"The ndim of out_grad in index_put_grad should be equal to the "
161+
"ndim of x, but got out_grad_ndim:[%d],x_ndim:[%d].",
162+
out_grad_ndim,
163+
x_ndim));
164+
PADDLE_ENFORCE_GE(x_ndim,
165+
indices_size,
166+
common::errors::InvalidArgument(
167+
"The ndim of x in index_put should be "
168+
"greater than or equal to the size of indices, "
169+
"but got x_ndim:[%d],indices_size:[%d].",
170+
x_ndim,
171+
indices_size));
172+
173+
PADDLE_ENFORCE_LE(
174+
value_ndim,
175+
x_ndim - indices_size + 1,
176+
common::errors::InvalidArgument("The ndim of value in index_put should "
177+
"be less than or equal to [%d], "
178+
"but got value_ndim:[%d].",
179+
x_ndim - indices_size + 1,
180+
value_ndim));
181+
PADDLE_ENFORCE_EQ(
182+
indices_ndim,
183+
1,
184+
common::errors::InvalidArgument(
185+
"The ndim of indices in index_put should be equal to 1, "
186+
"but got indices_ndim:[%d].",
187+
indices_ndim));
188+
for (int i = 0; i < indices_size; i++) {
189+
PADDLE_ENFORCE_EQ(
190+
indices[i].dims().size(),
191+
1,
192+
common::errors::InvalidArgument(
193+
"The ndim of indices[%d] in index_put should be equal to 1, "
194+
"but got indices[%d] ndim:[%d].",
195+
i,
196+
i,
197+
indices[i].dims().size()));
198+
}
199+
std::string alphabet = "ijklmnopqrstuvwxyz";
200+
std::string x_axes(x_ndim, '1');
201+
for (int i = 0; i < x_ndim; ++i) {
202+
x_axes[i] = alphabet[i];
203+
}
204+
std::string value_axes(value_ndim, '1');
205+
int index = indices_size - 1;
206+
for (int i = 0; i < value_ndim; ++i) {
207+
value_axes[i] = x_axes[index++];
208+
}
209+
// Step1: set x_dims_mapping
210+
for (int i = 0; i < indices_size; i++) {
211+
x_dims_mapping[i] = -1;
212+
}
213+
std::unordered_map<std::string, int64_t> axis_to_dim_map =
214+
ShardingMergeForTensors({{x_axes, x_dims_mapping}});
215+
// Step2: set dims_mapping for output
216+
TensorDistAttr x_grad_dist_attr =
217+
CopyTensorDistAttrForOutput(x_dist_attr_src);
218+
x_grad_dist_attr.set_dims_mapping(x_dims_mapping);
219+
TensorDistAttr value_grad_dist_attr =
220+
CopyTensorDistAttrForOutput(value_dist_attr_src);
221+
value_grad_dist_attr.set_dims_mapping(
222+
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
223+
// Step3: update input dims mapping
224+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
225+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
226+
TensorDistAttr out_grad_dist_attr_dst =
227+
CopyTensorDistAttrForOutput(x_dist_attr_src);
228+
out_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping);
229+
TensorDistAttr value_dist_attr_dst =
230+
CopyTensorDistAttrForOutput(value.dist_attr());
231+
value_dist_attr_dst.set_dims_mapping(
232+
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
233+
std::vector<TensorDistAttr> indices_dist_attrs_dst = indices_dist_attrs_src;
234+
for (auto& input_attr : indices_dist_attrs_dst) {
235+
input_attr.set_dims_mapping(std::vector<int64_t>{-1});
236+
}
237+
// Step4: Log SpmdInfo
238+
LOG_SPMD_INPUT(x);
239+
// LOG_SPMD_INPUT(indices);
240+
VLOG(4) << "name: indices";
241+
VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " "
242+
<< "indices_size: " << std::to_string(indices_size) << " "
243+
<< "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string()
244+
<< "] "
245+
<< "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string()
246+
<< "]";
247+
248+
LOG_SPMD_INPUT(value);
249+
LOG_SPMD_INPUT(out_grad);
250+
LOG_SPMD_OUTPUT(x_grad_dist_attr);
251+
LOG_SPMD_OUTPUT(value_grad_dist_attr);
252+
253+
return {{x_dist_attr_dst,
254+
indices_dist_attrs_dst,
255+
value_dist_attr_dst,
256+
out_grad_dist_attr_dst},
257+
{x_grad_dist_attr, value_grad_dist_attr}};
258+
}
259+
260+
} // namespace phi::distributed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x,
23+
const std::vector<DistMetaTensor>& indices,
24+
const DistMetaTensor& value,
25+
const bool accumulate = false);
26+
SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x,
27+
const std::vector<DistMetaTensor>& indices,
28+
const DistMetaTensor& value,
29+
const DistMetaTensor& out_grad,
30+
const bool accumulate = false);
31+
} // namespace distributed
32+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,10 @@ PD_REGISTER_SPMD_RULE(
544544
PD_REGISTER_SPMD_RULE(fused_rms_norm,
545545
PD_INFER_SPMD(phi::distributed::RmsNormInferSpmd),
546546
PD_INFER_SPMD(phi::distributed::RmsNormInferSpmdReverse));
547-
547+
// index_put
548+
PD_REGISTER_SPMD_RULE(index_put,
549+
PD_INFER_SPMD(phi::distributed::IndexPutInferSpmd),
550+
PD_INFER_SPMD(phi::distributed::IndexPutGradInferSpmd));
548551
PD_REGISTER_SPMD_RULE(
549552
flash_attention,
550553
PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic),

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ limitations under the License. */
4949
#include "paddle/phi/infermeta/spmd_rules/gather_nd.h"
5050
#include "paddle/phi/infermeta/spmd_rules/gelu.h"
5151
#include "paddle/phi/infermeta/spmd_rules/group_norm.h"
52+
#include "paddle/phi/infermeta/spmd_rules/index_put.h"
5253
#include "paddle/phi/infermeta/spmd_rules/index_select.h"
5354
#include "paddle/phi/infermeta/spmd_rules/instance_norm.h"
5455
#include "paddle/phi/infermeta/spmd_rules/label_smooth.h"

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,7 @@
17251725
output : Tensor(x_grad), Tensor(value_grad)
17261726
infer_meta :
17271727
func : IndexPutGradInferMeta
1728+
spmd_rule : IndexPutGradInferSpmd
17281729
kernel :
17291730
func : index_put_grad
17301731
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,7 @@
28252825
output : Tensor(out)
28262826
infer_meta :
28272827
func : IndexPutInferMeta
2828+
spmd_rule : IndexPutInferSpmd
28282829
kernel :
28292830
func : index_put
28302831
data_type : x

0 commit comments

Comments
 (0)