From 8dfce3a5a7b03f109d4e657ff33d18370b462a95 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Thu, 10 Apr 2025 19:31:23 +0800 Subject: [PATCH 01/14] add unary ops which have spmd_rule but not add in yaml file. --- paddle/phi/ops/yaml/ops.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index f59b7d59c2d9ed..901d1265a7c4b8 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -48,6 +48,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : acos inplace: (x -> out) @@ -60,6 +61,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : acosh inplace: (x -> out) @@ -370,6 +372,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : asin inplace: (x -> out) @@ -382,6 +385,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : asinh inplace: (x -> out) @@ -433,6 +437,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : atan inplace: (x -> out) @@ -455,6 +460,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : atanh inplace: (x -> out) @@ -563,6 +569,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : bernoulli interfaces : paddle::dialect::InferSymbolicShapeInterface @@ -860,6 +867,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : ceil inplace : (x -> out) @@ -1674,6 +1682,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : erf inplace : (x -> out) @@ -1686,6 +1695,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : erfinv inplace : (x -> out) @@ -1736,6 +1746,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd param : [x] kernel : func : expm1 @@ -2049,6 +2060,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : floor inplace : (x -> out) @@ -2927,6 +2939,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule: ElementwiseUnaryInferSpmd kernel : func : lgamma inplace: (x -> out) @@ -2989,6 +3002,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule: ElementwiseUnaryInferSpmd kernel : func : log inplace: (x -> out) @@ -3001,6 +3015,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : log10 inplace: (x -> out) @@ -3013,6 +3028,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : log1p inplace: (x -> out) @@ -3025,6 +3041,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : log2 inplace: (x -> out) @@ -3137,6 +3154,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : logsigmoid backward : logsigmoid_grad @@ -3831,6 +3849,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule: ElementwiseUnaryInferSpmd kernel : func : poisson backward : poisson_grad @@ -4633,6 +4652,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : sign backward : sign_grad @@ -4669,6 +4689,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : sinh inplace: (x -> out) @@ -4726,6 +4747,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd param : [x] kernel : func : softsign @@ -4796,6 +4818,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : sqrt {dense -> dense}, sqrt_sr {selected_rows -> selected_rows} @@ -4946,6 +4969,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd param : [x] kernel : func : swish @@ -4994,6 +5018,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : tan inplace : (x -> out) @@ -5006,6 +5031,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : tanh inplace : (x -> out) @@ -5018,6 +5044,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : tanh_shrink backward : tanh_shrink_grad @@ -5207,6 +5234,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : trunc inplace: (input -> out) From 443eab81b85bcbc8057df796310c93656a2fd2c5 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 11:12:25 +0800 Subject: [PATCH 02/14] add spmd_rule for index_put. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 238 +++++++++++++++++++ paddle/phi/infermeta/spmd_rules/index_put.h | 32 +++ paddle/phi/infermeta/spmd_rules/rules.cc | 5 +- paddle/phi/infermeta/spmd_rules/rules.h | 1 + paddle/phi/ops/yaml/backward.yaml | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + test/cpp/auto_parallel/spmd_rule_test.cc | 69 ++++++ 7 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/infermeta/spmd_rules/index_put.cc create mode 100644 paddle/phi/infermeta/spmd_rules/index_put.h diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc new file mode 100644 index 00000000000000..32b20024bf3b7f --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -0,0 +1,238 @@ +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/index_put.h" + +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +// 当x_ndim>indices_size时,被广播的维度是可以被切的,indices内所有的tensor的dims_mapping只能是-1,value的dims_mapping也只能是-1 + +namespace phi::distributed { +SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, + const std::vector& indices, + const DistMetaTensor& value) { + // Step0: verify input args based on group_norm logic + auto x_shape = common::vectorize(x.dims()); + int indices_size = indices.size(); + auto indices_shape = common::vectorize(indices[0].dims()); + auto value_shape = common::vectorize(value.dims()); + int x_ndim = static_cast(x_shape.size()); + int indices_ndim = static_cast(indices_shape.size()); + int value_ndim = static_cast(value_shape.size()); + + TensorDistAttr x_dist_attr_src = x.dist_attr(); + std::vector indices_dist_attrs_src; + std::transform(indices.begin(), + indices.end(), + std::back_inserter(indices_dist_attrs_src), + [](auto& meta) { return meta.dist_attr(); }); + TensorDistAttr value_dist_attr_src = value.dist_attr(); + + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + + PADDLE_ENFORCE_GE(x_ndim, + indices_size, + common::errors::InvalidArgument( + "The ndim of x in index_put should be " + "greater than or equal to the size of indices, " + "but got x_ndim:[%d],indices_size:[%d].", + x_ndim, + indices_size)); + + PADDLE_ENFORCE_EQ(value_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of value in index_put should be equal to 1, " + "but got value_ndim:[%d].", + value_ndim)); + PADDLE_ENFORCE_EQ( + indices_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of indices in index_put should be equal to 1, " + "but got indices_ndim:[%d].", + indices_ndim)); + for (int i = 0; i < indices_size; i++) { + PADDLE_ENFORCE_EQ( + indices[i].dims().size(), + 1, + common::errors::InvalidArgument( + "The ndim of indices[%d] in index_put should be equal to 1, " + "but got indices[%d] ndim:[%d].", + i, + i, + indices[i].dims().size())); + } + // Step1: set dims_mapping for input + if (x_ndim > indices_size) { + for (int i = 0; i < indices_size; i++) { + x_dims_mapping[i] == -1; + } + } else { // indices_size = x_ndim + for (int i = 0; i < x_ndim; i++) { + x_dims_mapping[i] = -1; + } + } + // Step2: set dims_mapping for output + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(x_dims_mapping); + // Step3: update input dims mapping + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr value_dist_attr_dst = + CopyTensorDistAttrForOutput(value.dist_attr()); + value_dist_attr_dst.set_dims_mapping({-1}); + std::vector indices_dist_attrs_dst = indices_dist_attrs_src; + for (auto& input_attr : indices_dist_attrs_dst) { + input_attr.set_dims_mapping({-1}); + } + // Step4: Log SpmdInfo + LOG_SPMD_INPUT(x); + // LOG_SPMD_INPUT(indices); + VLOG(4) << "name:indices"; + VLOG(4) << "shape: [" << indices_shape << "] " + << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() + << "] " + << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() + << "]"; + + LOG_SPMD_INPUT(value); + LOG_SPMD_OUTPUT(out_dist_attr); + + return {{x_dist_attr_dst, indices_dist_attrs_dst, value_dist_attr_dst}, + {out_dist_attr}}; +} + +SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, + const std::vector& indices, + const DistMetaTensor& value, + const DistMetaTensor& out_grad) { + // Step0: verify input args based on group_norm logic + auto x_shape = common::vectorize(x.dims()); + int indices_size = indices.size(); + auto indices_shape = common::vectorize(indices[0].dims()); + auto value_shape = common::vectorize(value.dims()); + auto out_grad_shape = common::vectorize(out_grad.dims()); + int x_ndim = static_cast(x_shape.size()); + int indices_ndim = static_cast(indices_shape.size()); + int value_ndim = static_cast(value_shape.size()); + int out_grad_ndim = static_cast(out_grad_shape.size()); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + std::vector indices_dist_attrs_src; + std::transform(indices.begin(), + indices.end(), + std::back_inserter(indices_dist_attrs_src), + [](auto& meta) { return meta.dist_attr(); }); + TensorDistAttr value_dist_attr_src = value.dist_attr(); + TensorDistAttr out_grad_dist_attr_src = out_grad.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_grad_ndim, + x_ndim, + common::errors::InvalidArgument( + "The ndim of out_grad in index_put_grad should be equal to the " + "ndim of x, but got out_grad_ndim:[%d],x_ndim:[%d].", + out_grad_ndim, + x_ndim)); + PADDLE_ENFORCE_GE(x_ndim, + indices_size, + common::errors::InvalidArgument( + "The ndim of x in index_put should be " + "greater than or equal to the size of indices, " + "but got x_ndim:[%d],indices_size:[%d].", + x_ndim, + indices_size)); + + PADDLE_ENFORCE_EQ(value_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of value in index_put should be equal to 1, " + "but got value_ndim:[%d].", + value_ndim)); + PADDLE_ENFORCE_EQ( + indices_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of indices in index_put should be equal to 1, " + "but got indices_ndim:[%d].", + indices_ndim)); + for (int i = 0; i < indices_size; i++) { + PADDLE_ENFORCE_EQ( + indices[i].dims().size(), + 1, + common::errors::InvalidArgument( + "The ndim of indices[%d] in index_put should be equal to 1, " + "but got indices[%d] ndim:[%d].", + i, + i, + indices[i].dims().size())); + } + // Step1: set x_dims_mapping + if (x_ndim > indices_size) { + for (int i = 0; i < indices_size; i++) { + x_dims_mapping[i] == -1; + } + } else { // indices_size = x_ndim + for (int i = 0; i < x_ndim; i++) { + x_dims_mapping[i] = -1; + } + } + // Step2: set dims_mapping for output + TensorDistAttr x_grad_dist_attr = + CopyTensorDistAttrForOutput(x_dist_attr_src); + x_grad_dist_attr.set_dims_mapping(x_dims_mapping); + TensorDistAttr value_grad_dist_attr = + CopyTensorDistAttrForOutput(value_dist_attr_src); + value_grad_dist_attr.set_dims_mapping({-1}); + // Step3: update input dims mapping + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr out_grad_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + out_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr value_dist_attr_dst = + CopyTensorDistAttrForOutput(value.dist_attr()); + value_dist_attr_dst.set_dims_mapping({-1}); + std::vector indices_dist_attrs_dst = indices_dist_attrs_src; + for (auto& input_attr : indices_dist_attrs_dst) { + input_attr.set_dims_mapping({-1}); + } + // Step4: Log SpmdInfo + LOG_SPMD_INPUT(x); + // LOG_SPMD_INPUT(indices); + VLOG(4) << "name:indices"; + VLOG(4) << "shape: [" << indices_shape << "] " + << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() + << "] " + << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() + << "]"; + + LOG_SPMD_INPUT(value); + LOG_SPMD_INPUT(out_grad); + LOG_SPMD_OUTPUT(x_grad_dist_attr); + LOG_SPMD_OUTPUT(value_grad_dist_attr); + + return {{x_dist_attr_dst, + indices_dist_attrs_dst, + value_dist_attr_dst, + out_grad_dist_attr_dst}, + {x_grad_dist_attr, value_grad_dist_attr}}; +} + +} // namespace phi::distributed diff --git a/paddle/phi/infermeta/spmd_rules/index_put.h b/paddle/phi/infermeta/spmd_rules/index_put.h new file mode 100644 index 00000000000000..07257030915f1b --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/index_put.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, + const std::vector& indices, + const DistMetaTensor& value, + bool accumulate = false); +SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, + const std::vector& indices, + const DistMetaTensor& value, + const DistMetaTensor& out_grad, + bool accumulate = false); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc index b57c7f58e0c6a5..491631d4fe947a 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.cc +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -544,7 +544,10 @@ PD_REGISTER_SPMD_RULE( PD_REGISTER_SPMD_RULE(fused_rms_norm, PD_INFER_SPMD(phi::distributed::RmsNormInferSpmd), PD_INFER_SPMD(phi::distributed::RmsNormInferSpmdReverse)); - +// index_put +PD_REGISTER_SPMD_RULE(index_put, + PD_INFER_SPMD(phi::distributed::IndexPutInferSpmd), + PD_INFER_SPMD(phi::distributed::IndexPutGradInferSpmd)); PD_REGISTER_SPMD_RULE( flash_attention, PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic), diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 2c3fe0a7575a6c..a34a425c3b0fa0 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -49,6 +49,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/gather_nd.h" #include "paddle/phi/infermeta/spmd_rules/gelu.h" #include "paddle/phi/infermeta/spmd_rules/group_norm.h" +#include "paddle/phi/infermeta/spmd_rules/index_put.h" #include "paddle/phi/infermeta/spmd_rules/index_select.h" #include "paddle/phi/infermeta/spmd_rules/instance_norm.h" #include "paddle/phi/infermeta/spmd_rules/label_smooth.h" diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index abc2149bdfed38..c728f44b3367f9 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1724,6 +1724,7 @@ output : Tensor(x_grad), Tensor(value_grad) infer_meta : func : IndexPutGradInferMeta + spmd_rule : IndexPutGradInferSpmd kernel : func : index_put_grad data_type : out_grad diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 000ab6a39f57c7..082c37f0fdab92 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -2825,6 +2825,7 @@ output : Tensor(out) infer_meta : func : IndexPutInferMeta + spmd_rule : IndexPutInferSpmd kernel : func : index_put data_type : x diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index ede5654040ec0b..5d9c9ae19cab47 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -230,6 +230,75 @@ TEST(MatmulSPMDRule, Ctor) { check_partial_dims(inferred_dist_attrs.second[0], {0}); VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } + +TEST(IndexPut, Ctor) { + // build input data class + std::vector x_shape = {64, 64, 64}; + std::vector indice_shape = {32}; + std::vector value_shape = {32}; + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + TensorDistAttr x_dist_attr = TensorDistAttr(); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping(std::vector({-1, 0, 1})); + x_dist_attr.set_dynamic_dims(std::vector({false, false, false, false})); + + TensorDistAttr value_dist_attr = TensorDistAttr(); + value_dist_attr.set_process_mesh(process_mesh); + value_dist_attr.set_dims_mapping(std::vector({-1})); + value_dist_attr.set_dynamic_dims(std::vector({false})); + TensorDistAttr indice_dist_attr = TensorDistAttr(); + indice_dist_attr.set_process_mesh(process_mesh); + indice_dist_attr.set_dims_mapping(std::vector({-1})); + indice_dist_attr.set_dynamic_dims(std::vector({false})); + + // Test forward. + // [-1,0, 1], [[-1],[-1]], [-1] --> [-1,-1, 1] + + phi::distributed::DistMetaTensor x(common::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor value(common::make_ddim(value_shape), + value_dist_attr); + std::vector indices; + for (int i = 0; i < 2; ++i) { + phi::distributed::DistMetaTensor indice(common::make_ddim(indice_shape), + indice_dist_attr); + indices.push_back(indice); + } + phi::distributed::SpmdInfo forward_info = + phi::distributed::IndexPutInferSpmd(x, indices, value); + size_t input_size = 3; + size_t output_size = 1; + EXPECT_EQ(forward_info.first.size(), input_size); + EXPECT_EQ(forward_info.second.size(), output_size); + check_dim_mapping(forward_info.first[0], {-1, -1, 1}); + check_dim_mapping(forward_info.first[1], {-1}); + check_dim_mapping(forward_info.first[2], {-1}); + check_dim_mapping(forward_info.second[0], {-1, -1, 1}); + VLOG(4) << "test forward done."; + + // Test backward. + // [-1,0, 1], [[-1],[-1]], [-1],[-1,0, 1] --> [-1,-1, 1], [-1] + + phi::distributed::DistMetaTensor out_grad(common::make_ddim(x_shape), + x_dist_attr); + + phi::distributed::SpmdInfo forward_info = + phi::distributed::IndexPutGradInferSpmd(x, indices, value, out_grad); + size_t input_size = 4; + size_t output_size = 2; + EXPECT_EQ(forward_info.first.size(), input_size); + EXPECT_EQ(forward_info.second.size(), output_size); + check_dim_mapping(forward_info.first[0], {-1, -1, 1}); + check_dim_mapping(forward_info.first[1], {-1}); + check_dim_mapping(forward_info.first[2], {-1}); + check_dim_mapping(forward_info.first[3], {-1, -1, 1}); + check_dim_mapping(forward_info.second[0], {-1, -1, 1}); + check_dim_mapping(forward_info.second[1], {-1}); + VLOG(4) << "test backward done."; +} TEST(InstanceNorm, Ctor) { // build input data class std::vector x_shape = {64, 64, 64, 64}; // N,C,H,W From 3057b253828c08a60aa9a2e436ce4dac6794fd31 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 11:23:02 +0800 Subject: [PATCH 03/14] remove annotations. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 32b20024bf3b7f..e8089897cf732e 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -21,8 +21,6 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" #include "paddle/phi/infermeta/spmd_rules/utils.h" -// 当x_ndim>indices_size时,被广播的维度是可以被切的,indices内所有的tensor的dims_mapping只能是-1,value的dims_mapping也只能是-1 - namespace phi::distributed { SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, const std::vector& indices, From d95de1e5a250a849632a4127b99dac3bfac8e4d5 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 12:48:12 +0800 Subject: [PATCH 04/14] fix bug. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index e8089897cf732e..663d96823eb436 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -104,7 +104,7 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, LOG_SPMD_INPUT(x); // LOG_SPMD_INPUT(indices); VLOG(4) << "name:indices"; - VLOG(4) << "shape: [" << indices_shape << "] " + VLOG(4) << "shape: [" << std::to_string(indices_shape) << "] " << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() << "] " << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() @@ -215,7 +215,7 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, LOG_SPMD_INPUT(x); // LOG_SPMD_INPUT(indices); VLOG(4) << "name:indices"; - VLOG(4) << "shape: [" << indices_shape << "] " + VLOG(4) << "shape: [" << std::to_string(indices_shape) << "] " << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() << "] " << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() From cbb93a40872bba75e2960bf0813e25e3b39e70bb Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 13:35:23 +0800 Subject: [PATCH 05/14] fix bug. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 663d96823eb436..d82ac13407f612 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -103,8 +103,9 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, // Step4: Log SpmdInfo LOG_SPMD_INPUT(x); // LOG_SPMD_INPUT(indices); - VLOG(4) << "name:indices"; - VLOG(4) << "shape: [" << std::to_string(indices_shape) << "] " + VLOG(4) << "name: indices"; + VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " " + << "indices_size: " << std::to_string(indices_size) << " " << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() << "] " << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() @@ -214,8 +215,9 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, // Step4: Log SpmdInfo LOG_SPMD_INPUT(x); // LOG_SPMD_INPUT(indices); - VLOG(4) << "name:indices"; - VLOG(4) << "shape: [" << std::to_string(indices_shape) << "] " + VLOG(4) << "name: indices"; + VLOG(4) << "ndim: " << std::to_string(indices_ndim) << " " + << "indices_size: " << std::to_string(indices_size) << " " << "indices_dist_attr_src: [" << indices_dist_attrs_src[0].to_string() << "] " << "indices_dist_attr_dst: [" << indices_dist_attrs_dst[0].to_string() From ae4dd7decee7192790f74006933b8fb4fe64a161 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 20:42:27 +0800 Subject: [PATCH 06/14] fix bug. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 6 ++++-- paddle/phi/infermeta/spmd_rules/index_put.h | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index d82ac13407f612..4049b212461556 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -24,7 +24,8 @@ limitations under the License. */ namespace phi::distributed { SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, const std::vector& indices, - const DistMetaTensor& value) { + const DistMetaTensor& value, + const bool accumulate) { // Step0: verify input args based on group_norm logic auto x_shape = common::vectorize(x.dims()); int indices_size = indices.size(); @@ -121,7 +122,8 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, const std::vector& indices, const DistMetaTensor& value, - const DistMetaTensor& out_grad) { + const DistMetaTensor& out_grad, + const bool accumulate) { // Step0: verify input args based on group_norm logic auto x_shape = common::vectorize(x.dims()); int indices_size = indices.size(); diff --git a/paddle/phi/infermeta/spmd_rules/index_put.h b/paddle/phi/infermeta/spmd_rules/index_put.h index 07257030915f1b..c7daa1eb0bc162 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.h +++ b/paddle/phi/infermeta/spmd_rules/index_put.h @@ -22,11 +22,11 @@ namespace distributed { SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, const std::vector& indices, const DistMetaTensor& value, - bool accumulate = false); + const bool accumulate = false); SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, const std::vector& indices, const DistMetaTensor& value, const DistMetaTensor& out_grad, - bool accumulate = false); + const bool accumulate = false); } // namespace distributed } // namespace phi From 1f7adcf0d43b96f6cfd32300b3d457c4dc9a38b4 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Fri, 20 Jun 2025 22:11:13 +0800 Subject: [PATCH 07/14] fix bug. --- test/cpp/auto_parallel/spmd_rule_test.cc | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 5d9c9ae19cab47..a2dc9b1d0b2b91 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -285,18 +285,18 @@ TEST(IndexPut, Ctor) { phi::distributed::DistMetaTensor out_grad(common::make_ddim(x_shape), x_dist_attr); - phi::distributed::SpmdInfo forward_info = + phi::distributed::SpmdInfo backward_info = phi::distributed::IndexPutGradInferSpmd(x, indices, value, out_grad); - size_t input_size = 4; - size_t output_size = 2; - EXPECT_EQ(forward_info.first.size(), input_size); - EXPECT_EQ(forward_info.second.size(), output_size); - check_dim_mapping(forward_info.first[0], {-1, -1, 1}); - check_dim_mapping(forward_info.first[1], {-1}); - check_dim_mapping(forward_info.first[2], {-1}); - check_dim_mapping(forward_info.first[3], {-1, -1, 1}); - check_dim_mapping(forward_info.second[0], {-1, -1, 1}); - check_dim_mapping(forward_info.second[1], {-1}); + input_size = 4; + output_size = 2; + EXPECT_EQ(backward_info.first.size(), input_size); + EXPECT_EQ(backward_info.second.size(), output_size); + check_dim_mapping(backward_info.first[0], {-1, -1, 1}); + check_dim_mapping(backward_info.first[1], {-1}); + check_dim_mapping(backward_info.first[2], {-1}); + check_dim_mapping(backward_info.first[3], {-1, -1, 1}); + check_dim_mapping(backward_info.second[0], {-1, -1, 1}); + check_dim_mapping(backward_info.second[1], {-1}); VLOG(4) << "test backward done."; } TEST(InstanceNorm, Ctor) { From 0b802cfca049e83c9a2e027b44b278ecff6d2d2d Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Sat, 21 Jun 2025 10:30:25 +0800 Subject: [PATCH 08/14] fix CI bug. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 2 +- test/cpp/auto_parallel/spmd_rule_test.cc | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 4049b212461556..342801d57a0934 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -81,7 +81,7 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, // Step1: set dims_mapping for input if (x_ndim > indices_size) { for (int i = 0; i < indices_size; i++) { - x_dims_mapping[i] == -1; + x_dims_mapping[i] = -1; } } else { // indices_size = x_ndim for (int i = 0; i < x_ndim; i++) { diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index a2dc9b1d0b2b91..b1b213fa465eb9 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -274,7 +274,11 @@ TEST(IndexPut, Ctor) { EXPECT_EQ(forward_info.first.size(), input_size); EXPECT_EQ(forward_info.second.size(), output_size); check_dim_mapping(forward_info.first[0], {-1, -1, 1}); - check_dim_mapping(forward_info.first[1], {-1}); + auto indices_dist_attr = forward_info.first[1]; + for (auto& attr : indices_dist_attr) { + check_dim_mapping(attr, {-1}); + } + check_dim_mapping(forward_info.first[2], {-1}); check_dim_mapping(forward_info.second[0], {-1, -1, 1}); VLOG(4) << "test forward done."; @@ -292,7 +296,11 @@ TEST(IndexPut, Ctor) { EXPECT_EQ(backward_info.first.size(), input_size); EXPECT_EQ(backward_info.second.size(), output_size); check_dim_mapping(backward_info.first[0], {-1, -1, 1}); - check_dim_mapping(backward_info.first[1], {-1}); + indices_dist_attr = backward_info.first[1]; + for (auto& attr : indices_dist_attr) { + check_dim_mapping(attr, {-1}); + } + check_dim_mapping(backward_info.first[2], {-1}); check_dim_mapping(backward_info.first[3], {-1, -1, 1}); check_dim_mapping(backward_info.second[0], {-1, -1, 1}); From fc3ab3f12c196b2741827212e602094dd6211231 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Sat, 21 Jun 2025 14:27:59 +0800 Subject: [PATCH 09/14] fix ci bug. --- test/cpp/auto_parallel/spmd_rule_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index b1b213fa465eb9..1cd705c17d03a3 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -274,7 +274,7 @@ TEST(IndexPut, Ctor) { EXPECT_EQ(forward_info.first.size(), input_size); EXPECT_EQ(forward_info.second.size(), output_size); check_dim_mapping(forward_info.first[0], {-1, -1, 1}); - auto indices_dist_attr = forward_info.first[1]; + std::vector indices_dist_attr = forward_info.first[1]; for (auto& attr : indices_dist_attr) { check_dim_mapping(attr, {-1}); } From b564e798edf5e5bd0f34e938e46d4d03860efe8b Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Sat, 21 Jun 2025 15:10:04 +0800 Subject: [PATCH 10/14] Adapt the changes in PR#73233 --- paddle/phi/infermeta/spmd_rules/index_put.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 342801d57a0934..c445ce681b6966 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -96,10 +96,10 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, x_dist_attr_dst.set_dims_mapping(x_dims_mapping); TensorDistAttr value_dist_attr_dst = CopyTensorDistAttrForOutput(value.dist_attr()); - value_dist_attr_dst.set_dims_mapping({-1}); + value_dist_attr_dst.set_dims_mapping(std::vector{-1}); std::vector indices_dist_attrs_dst = indices_dist_attrs_src; for (auto& input_attr : indices_dist_attrs_dst) { - input_attr.set_dims_mapping({-1}); + input_attr.set_dims_mapping(std::vector{-1}); } // Step4: Log SpmdInfo LOG_SPMD_INPUT(x); @@ -200,7 +200,7 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, x_grad_dist_attr.set_dims_mapping(x_dims_mapping); TensorDistAttr value_grad_dist_attr = CopyTensorDistAttrForOutput(value_dist_attr_src); - value_grad_dist_attr.set_dims_mapping({-1}); + value_grad_dist_attr.set_dims_mapping(std::vector{-1}); // Step3: update input dims mapping TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); x_dist_attr_dst.set_dims_mapping(x_dims_mapping); @@ -209,10 +209,10 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, out_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping); TensorDistAttr value_dist_attr_dst = CopyTensorDistAttrForOutput(value.dist_attr()); - value_dist_attr_dst.set_dims_mapping({-1}); + value_dist_attr_dst.set_dims_mapping(std::vector{-1}); std::vector indices_dist_attrs_dst = indices_dist_attrs_src; for (auto& input_attr : indices_dist_attrs_dst) { - input_attr.set_dims_mapping({-1}); + input_attr.set_dims_mapping(std::vector{-1}); } // Step4: Log SpmdInfo LOG_SPMD_INPUT(x); From 26e6c68ef50645c5710c10610b6b5990618a38db Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Sat, 21 Jun 2025 22:34:20 +0800 Subject: [PATCH 11/14] fix ci bug. --- test/cpp/auto_parallel/spmd_rule_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 1cd705c17d03a3..b2b9c7d28c2cde 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -274,7 +274,8 @@ TEST(IndexPut, Ctor) { EXPECT_EQ(forward_info.first.size(), input_size); EXPECT_EQ(forward_info.second.size(), output_size); check_dim_mapping(forward_info.first[0], {-1, -1, 1}); - std::vector indices_dist_attr = forward_info.first[1]; + std::vector indices_dist_attr = + paddle::get<1>(forward_info.first[1]); for (auto& attr : indices_dist_attr) { check_dim_mapping(attr, {-1}); } @@ -296,7 +297,7 @@ TEST(IndexPut, Ctor) { EXPECT_EQ(backward_info.first.size(), input_size); EXPECT_EQ(backward_info.second.size(), output_size); check_dim_mapping(backward_info.first[0], {-1, -1, 1}); - indices_dist_attr = backward_info.first[1]; + indices_dist_attr = paddle::get<1>(backward_info.first[1]); for (auto& attr : indices_dist_attr) { check_dim_mapping(attr, {-1}); } From 0f91fe4d4ea167182765796b852cad378517e137 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Sun, 22 Jun 2025 10:26:36 +0800 Subject: [PATCH 12/14] fix ci bug. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 2 +- test/cpp/auto_parallel/spmd_rule_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index c445ce681b6966..8308c531850060 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -187,7 +187,7 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, // Step1: set x_dims_mapping if (x_ndim > indices_size) { for (int i = 0; i < indices_size; i++) { - x_dims_mapping[i] == -1; + x_dims_mapping[i] = -1; } } else { // indices_size = x_ndim for (int i = 0; i < x_ndim; i++) { diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index b2b9c7d28c2cde..1aeec0820e90fe 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -244,7 +244,7 @@ TEST(IndexPut, Ctor) { TensorDistAttr x_dist_attr = TensorDistAttr(); x_dist_attr.set_process_mesh(process_mesh); x_dist_attr.set_dims_mapping(std::vector({-1, 0, 1})); - x_dist_attr.set_dynamic_dims(std::vector({false, false, false, false})); + x_dist_attr.set_dynamic_dims(std::vector({false, false, false})); TensorDistAttr value_dist_attr = TensorDistAttr(); value_dist_attr.set_process_mesh(process_mesh); From 8fe21f2be06d76965707c4afec0d47b3ca1287f2 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Mon, 23 Jun 2025 16:45:44 +0800 Subject: [PATCH 13/14] apply review. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 62 +++++++++++++++----- test/cpp/auto_parallel/spmd_rule_test.cc | 20 +++---- 2 files changed, 57 insertions(+), 25 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 8308c531850060..04139f4e911dd4 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -54,12 +54,14 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, x_ndim, indices_size)); - PADDLE_ENFORCE_EQ(value_ndim, - 1, - common::errors::InvalidArgument( - "The ndim of value in index_put should be equal to 1, " - "but got value_ndim:[%d].", - value_ndim)); + PADDLE_ENFORCE_LE( + value_ndim, + x_ndim - indices_size + 1, + common::errors::InvalidArgument("The ndim of value in index_put should " + "be less than or equal to [%d], " + "but got value_ndim:[%d].", + x_ndim - indices_size + 1, + value_ndim)); PADDLE_ENFORCE_EQ( indices_ndim, 1, @@ -78,6 +80,17 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, i, indices[i].dims().size())); } + std::string alphabet = "ijklmnopqrstuvwxyz"; + std::string x_axes(x_ndim, '1'); + for (int i = 0; i < x_ndim; ++i) { + x_axes[i] = alphabet[i]; + } + std::string value_axes(value_ndim, '1'); + int index = indices_size - 1; + for (int i = 0; i < value_ndim; ++i) { + value_axes[i] = x_axes[index++]; + } + // Step1: set dims_mapping for input if (x_ndim > indices_size) { for (int i = 0; i < indices_size; i++) { @@ -88,6 +101,8 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, x_dims_mapping[i] = -1; } } + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); // Step2: set dims_mapping for output TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); out_dist_attr.set_dims_mapping(x_dims_mapping); @@ -96,7 +111,8 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, x_dist_attr_dst.set_dims_mapping(x_dims_mapping); TensorDistAttr value_dist_attr_dst = CopyTensorDistAttrForOutput(value.dist_attr()); - value_dist_attr_dst.set_dims_mapping(std::vector{-1}); + value_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(value_axes, axis_to_dim_map)); std::vector indices_dist_attrs_dst = indices_dist_attrs_src; for (auto& input_attr : indices_dist_attrs_dst) { input_attr.set_dims_mapping(std::vector{-1}); @@ -160,12 +176,14 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, x_ndim, indices_size)); - PADDLE_ENFORCE_EQ(value_ndim, - 1, - common::errors::InvalidArgument( - "The ndim of value in index_put should be equal to 1, " - "but got value_ndim:[%d].", - value_ndim)); + PADDLE_ENFORCE_LE( + value_ndim, + x_ndim - indices_size + 1, + common::errors::InvalidArgument("The ndim of value in index_put should " + "be less than or equal to [%d], " + "but got value_ndim:[%d].", + x_ndim - indices_size + 1, + value_ndim)); PADDLE_ENFORCE_EQ( indices_ndim, 1, @@ -184,6 +202,16 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, i, indices[i].dims().size())); } + std::string alphabet = "ijklmnopqrstuvwxyz"; + std::string x_axes(x_ndim, '1'); + for (int i = 0; i < x_ndim; ++i) { + x_axes[i] = alphabet[i]; + } + std::string value_axes(value_ndim, '1'); + int index = indices_size - 1; + for (int i = 0; i < value_ndim; ++i) { + value_axes[i] = x_axes[index++]; + } // Step1: set x_dims_mapping if (x_ndim > indices_size) { for (int i = 0; i < indices_size; i++) { @@ -194,13 +222,16 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, x_dims_mapping[i] = -1; } } + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); // Step2: set dims_mapping for output TensorDistAttr x_grad_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); x_grad_dist_attr.set_dims_mapping(x_dims_mapping); TensorDistAttr value_grad_dist_attr = CopyTensorDistAttrForOutput(value_dist_attr_src); - value_grad_dist_attr.set_dims_mapping(std::vector{-1}); + value_grad_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(value_axes, axis_to_dim_map)); // Step3: update input dims mapping TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); x_dist_attr_dst.set_dims_mapping(x_dims_mapping); @@ -209,7 +240,8 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, out_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping); TensorDistAttr value_dist_attr_dst = CopyTensorDistAttrForOutput(value.dist_attr()); - value_dist_attr_dst.set_dims_mapping(std::vector{-1}); + value_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(value_axes, axis_to_dim_map)); std::vector indices_dist_attrs_dst = indices_dist_attrs_src; for (auto& input_attr : indices_dist_attrs_dst) { input_attr.set_dims_mapping(std::vector{-1}); diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 1aeec0820e90fe..48cfe64f3ab4e9 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -235,7 +235,7 @@ TEST(IndexPut, Ctor) { // build input data class std::vector x_shape = {64, 64, 64}; std::vector indice_shape = {32}; - std::vector value_shape = {32}; + std::vector value_shape = {32, 64}; std::vector mesh_shape = {2, 3}; std::vector process_ids = {0, 1, 2, 3, 4, 5}; std::vector dim_names = {"x", "y"}; @@ -248,16 +248,16 @@ TEST(IndexPut, Ctor) { TensorDistAttr value_dist_attr = TensorDistAttr(); value_dist_attr.set_process_mesh(process_mesh); - value_dist_attr.set_dims_mapping(std::vector({-1})); - value_dist_attr.set_dynamic_dims(std::vector({false})); + value_dist_attr.set_dims_mapping(std::vector({-1, -1})); + value_dist_attr.set_dynamic_dims(std::vector({false, false})); TensorDistAttr indice_dist_attr = TensorDistAttr(); indice_dist_attr.set_process_mesh(process_mesh); indice_dist_attr.set_dims_mapping(std::vector({-1})); indice_dist_attr.set_dynamic_dims(std::vector({false})); // Test forward. - // [-1,0, 1], [[-1],[-1]], [-1] --> [-1,-1, 1] - + // [-1,0, 1], [[-1],[-1]], [-1,-1] --> [-1,-1, 1] + // infer input:[-1,-1, 1], [[-1],[-1]], [-1,1] phi::distributed::DistMetaTensor x(common::make_ddim(x_shape), x_dist_attr); phi::distributed::DistMetaTensor value(common::make_ddim(value_shape), value_dist_attr); @@ -280,13 +280,13 @@ TEST(IndexPut, Ctor) { check_dim_mapping(attr, {-1}); } - check_dim_mapping(forward_info.first[2], {-1}); + check_dim_mapping(forward_info.first[2], {-1, 1}); check_dim_mapping(forward_info.second[0], {-1, -1, 1}); VLOG(4) << "test forward done."; // Test backward. - // [-1,0, 1], [[-1],[-1]], [-1],[-1,0, 1] --> [-1,-1, 1], [-1] - + // [-1,0, 1], [[-1],[-1]], [-1,-1],[-1,0, 1] --> [-1,-1, 1], [-1,1] + // infer input:[-1,-1, 1], [[-1],[-1]], [-1,1],[-1,-1, 1] phi::distributed::DistMetaTensor out_grad(common::make_ddim(x_shape), x_dist_attr); @@ -302,10 +302,10 @@ TEST(IndexPut, Ctor) { check_dim_mapping(attr, {-1}); } - check_dim_mapping(backward_info.first[2], {-1}); + check_dim_mapping(backward_info.first[2], {-1, 1}); check_dim_mapping(backward_info.first[3], {-1, -1, 1}); check_dim_mapping(backward_info.second[0], {-1, -1, 1}); - check_dim_mapping(backward_info.second[1], {-1}); + check_dim_mapping(backward_info.second[1], {-1, 1}); VLOG(4) << "test backward done."; } TEST(InstanceNorm, Ctor) { From 6c8e49ef595963492a8ffd254f9826f274d429a3 Mon Sep 17 00:00:00 2001 From: Glencsa <3501406249@qq.com> Date: Mon, 30 Jun 2025 20:05:51 +0800 Subject: [PATCH 14/14] apply review. --- paddle/phi/infermeta/spmd_rules/index_put.cc | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/index_put.cc b/paddle/phi/infermeta/spmd_rules/index_put.cc index 8308c531850060..db914291c7064a 100644 --- a/paddle/phi/infermeta/spmd_rules/index_put.cc +++ b/paddle/phi/infermeta/spmd_rules/index_put.cc @@ -79,14 +79,8 @@ SpmdInfo IndexPutInferSpmd(const DistMetaTensor& x, indices[i].dims().size())); } // Step1: set dims_mapping for input - if (x_ndim > indices_size) { - for (int i = 0; i < indices_size; i++) { - x_dims_mapping[i] = -1; - } - } else { // indices_size = x_ndim - for (int i = 0; i < x_ndim; i++) { - x_dims_mapping[i] = -1; - } + for (int i = 0; i < indices_size; i++) { + x_dims_mapping[i] = -1; } // Step2: set dims_mapping for output TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); @@ -185,14 +179,8 @@ SpmdInfo IndexPutGradInferSpmd(const DistMetaTensor& x, indices[i].dims().size())); } // Step1: set x_dims_mapping - if (x_ndim > indices_size) { - for (int i = 0; i < indices_size; i++) { - x_dims_mapping[i] = -1; - } - } else { // indices_size = x_ndim - for (int i = 0; i < x_ndim; i++) { - x_dims_mapping[i] = -1; - } + for (int i = 0; i < indices_size; i++) { + x_dims_mapping[i] = -1; } // Step2: set dims_mapping for output TensorDistAttr x_grad_dist_attr =