diff --git a/paddle/phi/infermeta/spmd_rules/batch_norm.cc b/paddle/phi/infermeta/spmd_rules/batch_norm.cc new file mode 100644 index 00000000000000..ec3e57b364100a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/batch_norm.cc @@ -0,0 +1,427 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/infermeta/spmd_rules/batch_norm.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi::distributed { + +SpmdInfo BatchNormInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const bool is_test, + const float momentum, + const float epsilon, + const std::string& data_format, + const bool use_global_stats, + const bool trainable_statistics) { + // Step0: verify input args based on batch_norm logic + auto x_shape = common::vectorize(x.dims()); + auto mean_shape = common::vectorize(mean.dims()); + auto variance_shape = common::vectorize(variance.dims()); + auto scale_shape = common::vectorize(scale.dims()); + auto bias_shape = common::vectorize(bias.dims()); + int x_ndim = static_cast(x_shape.size()); + int mean_ndim = static_cast(mean_shape.size()); + int variance_ndim = static_cast(variance_shape.size()); + int scale_ndim = static_cast(scale_shape.size()); + int bias_ndim = static_cast(bias_shape.size()); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + TensorDistAttr mean_dist_attr_src = mean.dist_attr(); + TensorDistAttr variance_dist_attr_src = variance.dist_attr(); + TensorDistAttr scale_dist_attr_src = scale.dist_attr(); + TensorDistAttr bias_dist_attr_src = bias.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + std::vector mean_dims_mapping = mean.dist_attr().dims_mapping(); + std::vector variance_dims_mapping = + variance.dist_attr().dims_mapping(); + std::vector scale_dims_mapping = scale.dist_attr().dims_mapping(); + std::vector bias_dims_mapping = bias.dist_attr().dims_mapping(); + + PADDLE_ENFORCE_GE( + x_ndim, + 2, + common::errors::InvalidArgument( + "The ndim of x in batch_norm should be greater than 1, but got [%d].", + x_ndim)); + PADDLE_ENFORCE_LE( + x_ndim, + 5, + common::errors::InvalidArgument( + "The ndim of x in batch_norm should be less than 6, but got [%d].", + x_ndim)); + + PADDLE_ENFORCE_EQ( + mean_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of mean in batch_norm should be 1, but got [%d].", + mean_ndim)); + + PADDLE_ENFORCE_EQ( + variance_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of variance in batch_norm should be 1, but got [%d].", + variance_ndim)); + + PADDLE_ENFORCE_EQ( + scale_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of scale in batch_norm should be 1, but got [%d].", + scale_ndim)); + + PADDLE_ENFORCE_EQ( + bias_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of bias in batch_norm should be 1, but got [%d].", + bias_ndim)); + + // Step1: Build Einsum Notation + + std::string alphabet = "ijklmnopqrstuvwxyz"; + // get input notation + // The mean and variance was flatten at C axis + std::string x_axes(x_ndim, '1'); + for (int i = 0; i < x_ndim; ++i) { + x_axes[i] = alphabet[i]; + } + int c_index = data_format[1] == 'C' ? 1 : x_ndim - 1; + std::string mean_axes(1, x_axes[c_index]); + std::string variance_axes(1, x_axes[c_index]); + std::string scale_axes(1, x_axes[c_index]); + std::string bias_axes(1, x_axes[c_index]); + + // get output notation + std::string out_axes = x_axes; + + // Step2: Sharding Propagation + // Step2.1: merge input sharding + // Only C axis can be shard. + auto c_dim = + x_dims_mapping[c_index]; // type: "NC"、"NCL"、"NLC"、"NCHW"、"NHWC"" and + // "NCDHW" + + for (int i = 0; i < x_ndim; ++i) { + x_dims_mapping[i] = -1; + } + x_dims_mapping[c_index] = c_dim; + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); + + // Step2.2: infer output dims mapping + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr mean_dist_attr = CopyTensorDistAttrForOutput(mean.dist_attr()); + TensorDistAttr variance_dist_attr = + CopyTensorDistAttrForOutput(variance.dist_attr()); + TensorDistAttr saved_mean_dist_attr = + CopyTensorDistAttrForOutput(mean.dist_attr()); + TensorDistAttr saved_variance_dist_attr = + CopyTensorDistAttrForOutput(variance.dist_attr()); + TensorDistAttr reserve_space_dist_attr = + CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(out_axes, axis_to_dim_map)); + mean_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(mean_axes, axis_to_dim_map)); + variance_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(variance_axes, axis_to_dim_map)); + saved_mean_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(mean_axes, axis_to_dim_map)); + saved_variance_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(variance_axes, axis_to_dim_map)); + reserve_space_dist_attr.set_dims_mapping({-1}); + + // Step2.3: update input dims mapping + // mean, variance, mean_out, variance_out and + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr scale_dist_attr_dst = + CopyTensorDistAttrForOutput(scale.dist_attr()); + TensorDistAttr bias_dist_attr_dst = + CopyTensorDistAttrForOutput(bias.dist_attr()); + TensorDistAttr mean_dist_attr_dst = + CopyTensorDistAttrForOutput(mean.dist_attr()); + TensorDistAttr variance_dist_attr_dst = + CopyTensorDistAttrForOutput(variance.dist_attr()); + scale_dist_attr_dst.set_dims_mapping({-1}); + bias_dist_attr_dst.set_dims_mapping({-1}); + variance_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(variance_axes, axis_to_dim_map)); + mean_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(mean_axes, axis_to_dim_map)); + + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(mean); + LOG_SPMD_INPUT(variance); + LOG_SPMD_INPUT(scale); + LOG_SPMD_INPUT(bias); + LOG_SPMD_OUTPUT(out_dist_attr); + LOG_SPMD_OUTPUT(mean_dist_attr); + LOG_SPMD_OUTPUT(variance_dist_attr); + LOG_SPMD_OUTPUT(saved_mean_dist_attr); + LOG_SPMD_OUTPUT(saved_variance_dist_attr); + LOG_SPMD_OUTPUT(reserve_space_dist_attr); + return {{x_dist_attr_dst, + mean_dist_attr_dst, + variance_dist_attr_dst, + scale_dist_attr_dst, + bias_dist_attr_dst}, + {out_dist_attr, + mean_dist_attr, + variance_dist_attr, + saved_mean_dist_attr, + saved_variance_dist_attr, + reserve_space_dist_attr}}; +} +SpmdInfo BatchNormInferSpmdStatic(const DistMetaTensor& x, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor& scale, + const DistMetaTensor& bias) { + return BatchNormInferSpmd(x, mean, variance, scale, bias); +} +SpmdInfo BatchNormGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& mean_out, + const DistMetaTensor& variance_out, + const DistMetaTensor& saved_mean, + const DistMetaTensor& saved_variance, + const DistMetaTensor& reserve_space, + const DistMetaTensor& out_grad, + const float momentum, + const float epsilon, + const std::string& data_format, + const bool is_test, + const bool use_global_stats, + const bool trainable_statistics) { + auto x_shape = common::vectorize(x.dims()); + auto scale_shape = common::vectorize(scale.dims()); + auto bias_shape = common::vectorize(bias.dims()); + auto mean_out_shape = common::vectorize(mean_out.dims()); + auto variance_out_shape = common::vectorize(variance_out.dims()); + auto saved_mean_shape = common::vectorize(saved_mean.dims()); + auto saved_variance_shape = common::vectorize(saved_variance.dims()); + auto reserve_space_shape = common::vectorize(reserve_space.dims()); + auto out_grad_shape = common::vectorize(out_grad.dims()); + int x_ndim = static_cast(x_shape.size()); + int scale_ndim = static_cast(scale_shape.size()); + int bias_ndim = static_cast(bias_shape.size()); + int mean_out_ndim = static_cast(mean_out_shape.size()); + int variance_out_ndim = static_cast(variance_out_shape.size()); + int saved_mean_ndim = static_cast(saved_mean_shape.size()); + int saved_variance_ndim = static_cast(saved_variance_shape.size()); + int reserve_space_ndim = static_cast(reserve_space_shape.size()); + int out_grad_ndim = static_cast(out_grad_shape.size()); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + TensorDistAttr scale_dist_attr_src = scale.dist_attr(); + TensorDistAttr bias_dist_attr_src = bias.dist_attr(); + TensorDistAttr mean_out_dist_attr_src = mean_out.dist_attr(); + TensorDistAttr variance_out_dist_attr_src = variance_out.dist_attr(); + TensorDistAttr saved_mean_dist_attr_src = saved_mean.dist_attr(); + TensorDistAttr saved_variance_dist_attr_src = saved_variance.dist_attr(); + TensorDistAttr reserve_space_dist_attr_src = reserve_space.dist_attr(); + TensorDistAttr out_grad_dist_attr_src = out_grad.dist_attr(); + PADDLE_ENFORCE_GE( + x_ndim, + 2, + common::errors::InvalidArgument( + "The ndim of x in batch_norm should be greater than 1, but got [%d].", + x_ndim)); + PADDLE_ENFORCE_LE( + x_ndim, + 5, + common::errors::InvalidArgument( + "The ndim of x in batch_norm should be less than 6, but got [%d].", + x_ndim)); + PADDLE_ENFORCE_EQ(out_grad_ndim, + x_ndim, + common::errors::InvalidArgument( + "The ndim of out_grad in batch_norm should be equal " + "with x, but got out_grad:[%d] and x:[%d].", + out_grad_ndim, + x_ndim)); + PADDLE_ENFORCE_EQ( + mean_out_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of mean_out in batch_norm should be 1, but got [%d].", + mean_out_ndim)); + + PADDLE_ENFORCE_EQ( + variance_out_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of variance_out in batch_norm should be 1, but got [%d].", + variance_out_ndim)); + + PADDLE_ENFORCE_EQ( + scale_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of scale in batch_norm should be 1, but got [%d].", + scale_ndim)); + + PADDLE_ENFORCE_EQ( + bias_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of bias in batch_norm should be 1, but got [%d].", + bias_ndim)); + PADDLE_ENFORCE_EQ( + saved_mean_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of saved_mean in batch_norm should be 1, but got [%d].", + saved_mean_ndim)); + + PADDLE_ENFORCE_EQ( + saved_variance_ndim, + 1, + common::errors::InvalidArgument( + "The ndim of saved_variance in batch_norm should be 1, but got [%d].", + saved_variance_ndim)); + + PADDLE_ENFORCE_EQ( + reserve_space_ndim, + 1, + common::errors::InvalidArgument("The ndim of reserve_space_ndim in " + "batch_norm should be 1, but got [%d].", + reserve_space_ndim)); + + std::string alphabet = "ijklmnopqrstuvwxyz"; + // get input notation + // The mean and variance was flatten at C axis + std::string x_axes(x_ndim, '1'); + std::string out_grad_axes(out_grad_ndim, '1'); + + for (int i = 0; i < x_ndim; ++i) { + x_axes[i] = alphabet[i]; + out_grad_axes[i] = alphabet[i]; + } + int c_index = data_format[1] == 'C' ? 1 : x_ndim - 1; + std::string mean_out_axes(1, x_axes[c_index]); + std::string variance_out_axes(1, x_axes[c_index]); + std::string scale_axes(1, x_axes[c_index]); + std::string bias_axes(1, x_axes[c_index]); + std::string saved_mean_axes(1, x_axes[c_index]); + std::string saved_variance_axes(1, x_axes[c_index]); + std::string reserve_space_axes(1, x_axes[c_index]); + + auto c_dim = + x_dims_mapping[c_index]; // Only C axis can be sharded. ndim Type: + // type: "NC"、"NCL"、"NLC"、"NCHW"、"NHWC"" and + // "NCDHW" + + for (int i = 0; i < x_ndim; ++i) { + x_dims_mapping[i] = -1; + } + x_dims_mapping[c_index] = c_dim; + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); + // infer output spmdinfo + TensorDistAttr x_grad_dist_attr = + CopyTensorDistAttrForOutput(x_dist_attr_src); + x_grad_dist_attr.set_dims_mapping(x_dims_mapping); + TensorDistAttr scale_grad_dist_attr = + CopyTensorDistAttrForOutput(scale.dist_attr()); + scale_grad_dist_attr.set_dims_mapping({-1}); + TensorDistAttr bias_grad_dist_attr = + CopyTensorDistAttrForOutput(bias.dist_attr()); + bias_grad_dist_attr.set_dims_mapping({-1}); + // infer input spmdinfo + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr mean_out_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + mean_out_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(mean_out_axes, axis_to_dim_map)); + TensorDistAttr variance_out_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + variance_out_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(variance_out_axes, axis_to_dim_map)); + TensorDistAttr scale_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + scale_dist_attr_dst.set_dims_mapping({-1}); + TensorDistAttr bias_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + bias_dist_attr_dst.set_dims_mapping({-1}); + TensorDistAttr saved_mean_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + saved_mean_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(saved_mean_axes, axis_to_dim_map)); + TensorDistAttr saved_variance_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + saved_variance_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(saved_variance_axes, axis_to_dim_map)); + TensorDistAttr reserve_space_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + reserve_space_dist_attr_dst.set_dims_mapping({-1}); + TensorDistAttr out_grad_dist_attr_dst = + CopyTensorDistAttrForOutput(x_dist_attr_src); + out_grad_dist_attr_dst.set_dims_mapping( + GetDimsMappingForAxes(out_grad_axes, axis_to_dim_map)); + + // partial grad dim + std::vector partial_on_dims; + for (int i = 0; i < x_ndim; ++i) { + auto mapping = x_dims_mapping[i]; + if (mapping != -1) { + partial_on_dims.push_back(mapping); + } + } + scale_grad_dist_attr.set_partial_status(partial_on_dims); + bias_grad_dist_attr.set_partial_status(partial_on_dims); + + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(scale); + LOG_SPMD_INPUT(bias); + LOG_SPMD_INPUT(mean_out); + LOG_SPMD_INPUT(variance_out); + LOG_SPMD_INPUT(saved_mean); + LOG_SPMD_INPUT(saved_variance); + LOG_SPMD_INPUT(reserve_space); + LOG_SPMD_INPUT(out_grad); + LOG_SPMD_OUTPUT(x_grad_dist_attr); + LOG_SPMD_OUTPUT(scale_grad_dist_attr); + LOG_SPMD_OUTPUT(bias_grad_dist_attr); + + return {{x_dist_attr_dst, + scale_dist_attr_dst, + bias_dist_attr_dst, + mean_out_dist_attr_dst, + variance_out_dist_attr_dst, + saved_mean_dist_attr_dst, + saved_variance_dist_attr_dst, + reserve_space_dist_attr_dst, + out_grad_dist_attr_dst}, + {x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}}; +} + +} // namespace phi::distributed diff --git a/paddle/phi/infermeta/spmd_rules/batch_norm.h b/paddle/phi/infermeta/spmd_rules/batch_norm.h new file mode 100644 index 00000000000000..c8d6aaae33f47d --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/batch_norm.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo BatchNormInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const bool is_test = false, + const float momentum = 0.9, + const float epsilon = 1e-05, + const std::string& data_format = "NCHW", + const bool use_global_stats = false, + const bool trainable_statistics = false); +SpmdInfo BatchNormInferSpmdStatic(const DistMetaTensor& x, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor& scale, + const DistMetaTensor& bias); + +SpmdInfo BatchNormGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& mean_out, + const DistMetaTensor& variance_out, + const DistMetaTensor& saved_mean, + const DistMetaTensor& saved_variance, + const DistMetaTensor& reserve_space, + const DistMetaTensor& out_grad, + const float momentum = 0.9, + const float epsilon = 1e-05, + const std::string& data_format = "NCHW", + const bool is_test = false, + const bool use_global_stats = false, + const bool trainable_statistics = false); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc index d2b3277bb9113f..26fd240c0e83cb 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.cc +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -66,7 +66,6 @@ PD_REGISTER_SPMD_RULE( fused_rotary_position_embedding, PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd), PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse)); - // replicated rule /* for unittest */ PD_REGISTER_SPMD_RULE( replicated, @@ -525,6 +524,9 @@ PD_REGISTER_SPMD_RULE( PD_REGISTER_SPMD_RULE(mean_all, PD_INFER_SPMD(phi::distributed::MeanAllInferSpmd), PD_INFER_SPMD(phi::distributed::MeanAllGradInferSpmd)); +// batch_norm +PD_REGISTER_SPMD_RULE( + batch_norm, PD_INFER_SPMD(phi::distributed::BatchNormInferSpmdStatic)); // layer_norm PD_REGISTER_SPMD_RULE( diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 24f37017967f0d..df44b04316dbea 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/argmax.h" #include "paddle/phi/infermeta/spmd_rules/argmin.h" #include "paddle/phi/infermeta/spmd_rules/argsort.h" +#include "paddle/phi/infermeta/spmd_rules/batch_norm.h" #include "paddle/phi/infermeta/spmd_rules/c_embedding.h" #include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h" #include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h" diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml index eace256a421022..dfd70a04e15774 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml @@ -76,6 +76,7 @@ infer_meta : func : GeneralTernaryGradInferMeta param : [x, scale, bias] + spmd_rule : BatchNormGradInferSpmd kernel : func : batch_norm_grad data_type : out_grad diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml index d69045cec55a60..2d91090affd14b 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml @@ -50,6 +50,7 @@ output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta: func : BatchNormInferMeta + spmd_rule : BatchNormInferSpmd kernel : func : batch_norm data_type : x diff --git a/paddle/phi/ops/yaml/inconsistent/static_backward.yaml b/paddle/phi/ops/yaml/inconsistent/static_backward.yaml index ddd587c7c5a0e7..79aa3b3fd1a8f8 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_backward.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_backward.yaml @@ -88,6 +88,7 @@ infer_meta : func : GeneralTernaryGradInferMeta param : [x, scale, bias] + spmd_rule : BatchNormGradInferSpmd kernel : func : batch_norm_grad data_type : out_grad diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 708b55ebf3cd40..7ffbc471f52139 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -71,6 +71,7 @@ output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) infer_meta: func : BatchNormInferMeta + spmd_rule : BatchNormInferSpmd kernel : func : batch_norm data_type : x diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 54e2f2a7c98cb7..ef9baf4cba8cad 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -2605,6 +2605,113 @@ TEST(MeanAll, Ctor) { check_dim_mapping(backward_info.first[1], {}); check_dim_mapping(backward_info.second[0], {-1, -1}); } +TEST(BatchNorm, Ctor) { + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + // test forward + // data_format = NCHW + // [0, 1, -1, -1],[-1],[-1],[-1],[-1] ->[-1 , 1, -1, -1],[1],[1],[1],[1],[-1] + auto x_dist_attr = TensorDistAttr(); + x_dist_attr.set_process_mesh(process_mesh); + x_dist_attr.set_dims_mapping({0, 1, -1, -1}); + x_dist_attr.set_dynamic_dims({false, false, false, false}); + auto one_dim_dist_attr = TensorDistAttr(); + one_dim_dist_attr.set_process_mesh(process_mesh); + one_dim_dist_attr.set_dims_mapping({-1}); + one_dim_dist_attr.set_dynamic_dims({false}); + + phi::distributed::DistMetaTensor x = phi::distributed::DistMetaTensor( + common::make_ddim({16, 16, 16, 16}), x_dist_attr); + phi::distributed::DistMetaTensor mean = phi::distributed::DistMetaTensor( + common::make_ddim({16}), one_dim_dist_attr); + phi::distributed::DistMetaTensor variance = phi::distributed::DistMetaTensor( + common::make_ddim({16}), one_dim_dist_attr); + phi::distributed::DistMetaTensor scale = phi::distributed::DistMetaTensor( + common::make_ddim({16}), one_dim_dist_attr); + phi::distributed::DistMetaTensor bias = phi::distributed::DistMetaTensor( + common::make_ddim({16}), one_dim_dist_attr); + phi::distributed::SpmdInfo forward_info = + phi::distributed::BatchNormInferSpmdStatic( + x, mean, variance, scale, bias); + + EXPECT_EQ(forward_info.first.size(), 5UL); + EXPECT_EQ(forward_info.second.size(), 6UL); + check_dim_mapping(forward_info.first[0], {-1, 1, -1, -1}); + check_dim_mapping(forward_info.first[1], {1}); + check_dim_mapping(forward_info.first[2], {1}); + check_dim_mapping(forward_info.first[3], {-1}); + check_dim_mapping(forward_info.first[4], {-1}); + check_dim_mapping(forward_info.second[0], {-1, 1, -1, -1}); + check_dim_mapping(forward_info.second[1], {1}); + check_dim_mapping(forward_info.second[2], {1}); + check_dim_mapping(forward_info.second[3], {1}); + check_dim_mapping(forward_info.second[4], {1}); + check_dim_mapping(forward_info.second[5], {-1}); + + // test backward + // data_format = NCHW + // [0, 1, -1, -1],[-1],[-1],[-1],[-1],[-1],[-1],[-1],[0, 1, -1, -1] + // ->[-1,1,-1,-1],[-1],[-1] + // dst_input: [-1, 1, -1, -1],[-1],[-1],[1],[1],[1],[1],[-1],[-1, 1, -1, -1] + + x = phi::distributed::DistMetaTensor(common::make_ddim({16, 16, 16, 16}), + x_dist_attr); + phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor( + common::make_ddim({16, 16, 16, 16}), x_dist_attr); + phi::distributed::DistMetaTensor mean_out = phi::distributed::DistMetaTensor( + common::make_ddim({16}), one_dim_dist_attr); + phi::distributed::DistMetaTensor variance_out = + phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + scale = phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + bias = phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + phi::distributed::DistMetaTensor saved_mean = + phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + phi::distributed::DistMetaTensor saved_variance = + phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + phi::distributed::DistMetaTensor reserve_space = + phi::distributed::DistMetaTensor(common::make_ddim({16}), + one_dim_dist_attr); + phi::distributed::SpmdInfo backward_info = + phi::distributed::BatchNormGradInferSpmd(x, + scale, + bias, + mean_out, + variance_out, + saved_mean, + saved_variance, + reserve_space, + out_grad, + 0.9, + 0.1, + "NCHW", + false, + false, + false); + + EXPECT_EQ(backward_info.first.size(), 9UL); + EXPECT_EQ(backward_info.second.size(), 3UL); + check_dim_mapping(backward_info.first[0], {-1, 1, -1, -1}); + check_dim_mapping(backward_info.first[1], {-1}); + check_dim_mapping(backward_info.first[2], {-1}); + check_dim_mapping(backward_info.first[3], {1}); + check_dim_mapping(backward_info.first[4], {1}); + check_dim_mapping(backward_info.first[5], {1}); + check_dim_mapping(backward_info.first[6], {1}); + check_dim_mapping(backward_info.first[7], {-1}); + check_dim_mapping(backward_info.first[8], {-1, 1, -1, -1}); + + check_dim_mapping(backward_info.second[0], {-1, 1, -1, -1}); + check_dim_mapping(backward_info.second[1], {-1}); + check_dim_mapping(backward_info.second[2], {-1}); +} TEST(Topk, Ctor) { std::vector mesh_shape = {2, 2}; std::vector process_ids = {0, 1, 2, 3};