Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8dfce3a
add unary ops which have spmd_rule but not add in yaml file.
Glencsa Apr 10, 2025
1d129c2
Merge branch 'spmd_test' into develop
Glencsa Apr 15, 2025
b9c9e6a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 15, 2025
f24c883
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
746356c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
efc91c5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 23, 2025
2109cf9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 8, 2025
773fda6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 22, 2025
7681189
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 23, 2025
f53affa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 26, 2025
44b3797
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Jun 19, 2025
443eab8
add spmd_rule for index_put.
Glencsa Jun 20, 2025
3057b25
remove annotations.
Glencsa Jun 20, 2025
d95de1e
fix bug.
Glencsa Jun 20, 2025
cbb93a4
fix bug.
Glencsa Jun 20, 2025
ae4dd7d
fix bug.
Glencsa Jun 20, 2025
1f7adcf
fix bug.
Glencsa Jun 20, 2025
0b802cf
fix CI bug.
Glencsa Jun 21, 2025
fc3ab3f
fix ci bug.
Glencsa Jun 21, 2025
b564e79
Adapt the changes in PR#73233
Glencsa Jun 21, 2025
26e6c68
fix ci bug.
Glencsa Jun 21, 2025
0f91fe4
fix ci bug.
Glencsa Jun 22, 2025
8fe21f2
apply review.
Glencsa Jun 23, 2025
2b9e9b3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Jun 23, 2025
6c8e49e
apply review.
Glencsa Jun 30, 2025
c144691
Merge branch 'index_put' of https://github.com/Glencsa/Paddle into in…
Glencsa Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions paddle/phi/infermeta/spmd_rules/index_put.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/index_put.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi::distributed {
SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x,
const std::vector<DistMetaTensor>& indices,
const DistMetaTensor& value,
const bool accumulate) {
// Step0: verify input args based on group_norm logic
auto x_shape = common::vectorize(x.dims());
int indices_size = indices.size();
auto indices_shape = common::vectorize(indices[0].dims());
auto value_shape = common::vectorize(value.dims());
int x_ndim = static_cast<int>(x_shape.size());
int indices_ndim = static_cast<int>(indices_shape.size());
int value_ndim = static_cast<int>(value_shape.size());

TensorDistAttr x_dist_attr_src = x.dist_attr();
std::vector<TensorDistAttr> indices_dist_attrs_src;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(indices_dist_attrs_src),
[](auto& meta) { return meta.dist_attr(); });
TensorDistAttr value_dist_attr_src = value.dist_attr();

std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();

PADDLE_ENFORCE_GE(x_ndim,
indices_size,
common::errors::InvalidArgument(
"The ndim of x in index_put should be "
"greater than or equal to the size of indices, "
"but got x_ndim:[%d],indices_size:[%d].",
x_ndim,
indices_size));

PADDLE_ENFORCE_LE(
value_ndim,
x_ndim - indices_size + 1,
common::errors::InvalidArgument("The ndim of value in index_put should "
"be less than or equal to [%d], "
"but got value_ndim:[%d].",
x_ndim - indices_size + 1,
value_ndim));
PADDLE_ENFORCE_EQ(
indices_ndim,
1,
common::errors::InvalidArgument(
"The ndim of indices in index_put should be equal to 1, "
"but got indices_ndim:[%d].",
indices_ndim));
for (int i = 0; i < indices_size; i++) {
PADDLE_ENFORCE_EQ(
indices[i].dims().size(),
1,
common::errors::InvalidArgument(
"The ndim of indices[%d] in index_put should be equal to 1, "
"but got indices[%d] ndim:[%d].",
i,
i,
indices[i].dims().size()));
}
std::string alphabet = "ijklmnopqrstuvwxyz";
std::string x_axes(x_ndim, '1');
for (int i = 0; i < x_ndim; ++i) {
x_axes[i] = alphabet[i];
}
std::string value_axes(value_ndim, '1');
int index = indices_size - 1;
for (int i = 0; i < value_ndim; ++i) {
value_axes[i] = x_axes[index++];
}

// Step1: set dims_mapping for input
for (int i = 0; i < indices_size; i++) {
x_dims_mapping[i] = -1;
}
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});
// Step2: set dims_mapping for output
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(x_dims_mapping);
// Step3: update input dims mapping
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
TensorDistAttr value_dist_attr_dst =
CopyTensorDistAttrForOutput(value.dist_attr());
value_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
std::vector<TensorDistAttr> indices_dist_attrs_dst = indices_dist_attrs_src;
for (auto& input_attr : indices_dist_attrs_dst) {
input_attr.set_dims_mapping(std::vector<int64_t>{-1});
}
// Step4: Log SpmdInfo
LOG_SPMD_INPUT(x);
// LOG_SPMD_INPUT(indices);
VLOG(4) << "name: indices";
VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " "
<< "indices_size: " << std::to_string(indices_size) << " "
<< "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string()
<< "] "
<< "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string()
<< "]";

LOG_SPMD_INPUT(value);
LOG_SPMD_OUTPUT(out_dist_attr);

return {{x_dist_attr_dst, indices_dist_attrs_dst, value_dist_attr_dst},
{out_dist_attr}};
}

SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x,
const std::vector<DistMetaTensor>& indices,
const DistMetaTensor& value,
const DistMetaTensor& out_grad,
const bool accumulate) {
// Step0: verify input args based on group_norm logic
auto x_shape = common::vectorize(x.dims());
int indices_size = indices.size();
auto indices_shape = common::vectorize(indices[0].dims());
auto value_shape = common::vectorize(value.dims());
auto out_grad_shape = common::vectorize(out_grad.dims());
int x_ndim = static_cast<int>(x_shape.size());
int indices_ndim = static_cast<int>(indices_shape.size());
int value_ndim = static_cast<int>(value_shape.size());
int out_grad_ndim = static_cast<int>(out_grad_shape.size());
TensorDistAttr x_dist_attr_src = x.dist_attr();
std::vector<TensorDistAttr> indices_dist_attrs_src;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(indices_dist_attrs_src),
[](auto& meta) { return meta.dist_attr(); });
TensorDistAttr value_dist_attr_src = value.dist_attr();
TensorDistAttr out_grad_dist_attr_src = out_grad.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
out_grad_ndim,
x_ndim,
common::errors::InvalidArgument(
"The ndim of out_grad in index_put_grad should be equal to the "
"ndim of x, but got out_grad_ndim:[%d],x_ndim:[%d].",
out_grad_ndim,
x_ndim));
PADDLE_ENFORCE_GE(x_ndim,
indices_size,
common::errors::InvalidArgument(
"The ndim of x in index_put should be "
"greater than or equal to the size of indices, "
"but got x_ndim:[%d],indices_size:[%d].",
x_ndim,
indices_size));

PADDLE_ENFORCE_LE(
value_ndim,
x_ndim - indices_size + 1,
common::errors::InvalidArgument("The ndim of value in index_put should "
"be less than or equal to [%d], "
"but got value_ndim:[%d].",
x_ndim - indices_size + 1,
value_ndim));
PADDLE_ENFORCE_EQ(
indices_ndim,
1,
common::errors::InvalidArgument(
"The ndim of indices in index_put should be equal to 1, "
"but got indices_ndim:[%d].",
indices_ndim));
for (int i = 0; i < indices_size; i++) {
PADDLE_ENFORCE_EQ(
indices[i].dims().size(),
1,
common::errors::InvalidArgument(
"The ndim of indices[%d] in index_put should be equal to 1, "
"but got indices[%d] ndim:[%d].",
i,
i,
indices[i].dims().size()));
}
std::string alphabet = "ijklmnopqrstuvwxyz";
std::string x_axes(x_ndim, '1');
for (int i = 0; i < x_ndim; ++i) {
x_axes[i] = alphabet[i];
}
std::string value_axes(value_ndim, '1');
int index = indices_size - 1;
for (int i = 0; i < value_ndim; ++i) {
value_axes[i] = x_axes[index++];
}
// Step1: set x_dims_mapping
for (int i = 0; i < indices_size; i++) {
x_dims_mapping[i] = -1;
}
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});
// Step2: set dims_mapping for output
TensorDistAttr x_grad_dist_attr =
CopyTensorDistAttrForOutput(x_dist_attr_src);
x_grad_dist_attr.set_dims_mapping(x_dims_mapping);
TensorDistAttr value_grad_dist_attr =
CopyTensorDistAttrForOutput(value_dist_attr_src);
value_grad_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
// Step3: update input dims mapping
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
TensorDistAttr out_grad_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
out_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping);
TensorDistAttr value_dist_attr_dst =
CopyTensorDistAttrForOutput(value.dist_attr());
value_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(value_axes, axis_to_dim_map));
std::vector<TensorDistAttr> indices_dist_attrs_dst = indices_dist_attrs_src;
for (auto& input_attr : indices_dist_attrs_dst) {
input_attr.set_dims_mapping(std::vector<int64_t>{-1});
}
// Step4: Log SpmdInfo
LOG_SPMD_INPUT(x);
// LOG_SPMD_INPUT(indices);
VLOG(4) << "name: indices";
VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " "
<< "indices_size: " << std::to_string(indices_size) << " "
<< "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string()
<< "] "
<< "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string()
<< "]";

LOG_SPMD_INPUT(value);
LOG_SPMD_INPUT(out_grad);
LOG_SPMD_OUTPUT(x_grad_dist_attr);
LOG_SPMD_OUTPUT(value_grad_dist_attr);

return {{x_dist_attr_dst,
indices_dist_attrs_dst,
value_dist_attr_dst,
out_grad_dist_attr_dst},
{x_grad_dist_attr, value_grad_dist_attr}};
}

} // namespace phi::distributed
32 changes: 32 additions & 0 deletions paddle/phi/infermeta/spmd_rules/index_put.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {
SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x,
const std::vector<DistMetaTensor>& indices,
const DistMetaTensor& value,
const bool accumulate = false);
SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x,
const std::vector<DistMetaTensor>& indices,
const DistMetaTensor& value,
const DistMetaTensor& out_grad,
const bool accumulate = false);
} // namespace distributed
} // namespace phi
5 changes: 4 additions & 1 deletion paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,10 @@ PD_REGISTER_SPMD_RULE(
PD_REGISTER_SPMD_RULE(fused_rms_norm,
PD_INFER_SPMD(phi::distributed::RmsNormInferSpmd),
PD_INFER_SPMD(phi::distributed::RmsNormInferSpmdReverse));

// index_put
PD_REGISTER_SPMD_RULE(index_put,
PD_INFER_SPMD(phi::distributed::IndexPutInferSpmd),
PD_INFER_SPMD(phi::distributed::IndexPutGradInferSpmd));
PD_REGISTER_SPMD_RULE(
flash_attention,
PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic),
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/gather_nd.h"
#include "paddle/phi/infermeta/spmd_rules/gelu.h"
#include "paddle/phi/infermeta/spmd_rules/group_norm.h"
#include "paddle/phi/infermeta/spmd_rules/index_put.h"
#include "paddle/phi/infermeta/spmd_rules/index_select.h"
#include "paddle/phi/infermeta/spmd_rules/instance_norm.h"
#include "paddle/phi/infermeta/spmd_rules/label_smooth.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,7 @@
output : Tensor(x_grad), Tensor(value_grad)
infer_meta :
func : IndexPutGradInferMeta
spmd_rule : IndexPutGradInferSpmd
kernel :
func : index_put_grad
data_type : out_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,7 @@
output : Tensor(out)
infer_meta :
func : IndexPutInferMeta
spmd_rule : IndexPutInferSpmd
kernel :
func : index_put
data_type : x
Expand Down
Loading
Loading