Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8dfce3a
add unary ops which have spmd_rule but not add in yaml file.
Glencsa Apr 10, 2025
1d129c2
Merge branch 'spmd_test' into develop
Glencsa Apr 15, 2025
b9c9e6a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 15, 2025
f24c883
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
746356c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
efc91c5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 23, 2025
2109cf9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 8, 2025
773fda6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 22, 2025
044b45b
Add spmd_rule for batch_norm ops.
Glencsa May 23, 2025
bc53066
Add spmd_rule for batch_norm and batch_norm_grad.
Glencsa May 24, 2025
99b2c70
fix bug
Glencsa May 24, 2025
e8a47a6
fix bug.
Glencsa May 25, 2025
da1c541
fix bug.
Glencsa May 25, 2025
19683c0
add spmd_rule for sync_natch_norm
Glencsa May 31, 2025
7b8aa5f
add spmd_rule for sync_natch_norm
Glencsa May 31, 2025
134d4f8
fix bug.
Glencsa Jun 1, 2025
7fa09bd
fix bug.
Glencsa Jun 1, 2025
ccc360e
Add partial status.
Glencsa Jun 4, 2025
4d220eb
fix ci bug.
Glencsa Jun 4, 2025
874f5ae
fix CI bug.
Glencsa Jun 5, 2025
6ebf375
apply review.
Glencsa Jun 13, 2025
8cea62d
fix bug.
Glencsa Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
427 changes: 427 additions & 0 deletions paddle/phi/infermeta/spmd_rules/batch_norm.cc

Large diffs are not rendered by default.

57 changes: 57 additions & 0 deletions paddle/phi/infermeta/spmd_rules/batch_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {
SpmdInfo BatchNormInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const bool is_test = false,
const float momentum = 0.9,
const float epsilon = 1e-05,
const std::string& data_format = "NCHW",
const bool use_global_stats = false,
const bool trainable_statistics = false);
SpmdInfo BatchNormInferSpmdStatic(const DistMetaTensor& x,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor& scale,
const DistMetaTensor& bias);

SpmdInfo BatchNormGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const DistMetaTensor& mean_out,
const DistMetaTensor& variance_out,
const DistMetaTensor& saved_mean,
const DistMetaTensor& saved_variance,
const DistMetaTensor& reserve_space,
const DistMetaTensor& out_grad,
const float momentum = 0.9,
const float epsilon = 1e-05,
const std::string& data_format = "NCHW",
const bool is_test = false,
const bool use_global_stats = false,
const bool trainable_statistics = false);

} // namespace distributed
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ PD_REGISTER_SPMD_RULE(
fused_rotary_position_embedding,
PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd),
PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse));

// replicated rule /* for unittest */
PD_REGISTER_SPMD_RULE(
replicated,
Expand Down Expand Up @@ -525,6 +524,9 @@ PD_REGISTER_SPMD_RULE(
PD_REGISTER_SPMD_RULE(mean_all,
PD_INFER_SPMD(phi::distributed::MeanAllInferSpmd),
PD_INFER_SPMD(phi::distributed::MeanAllGradInferSpmd));
// batch_norm
PD_REGISTER_SPMD_RULE(
batch_norm, PD_INFER_SPMD(phi::distributed::BatchNormInferSpmdStatic));

// layer_norm
PD_REGISTER_SPMD_RULE(
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/spmd_rules/argmax.h"
#include "paddle/phi/infermeta/spmd_rules/argmin.h"
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
#include "paddle/phi/infermeta/spmd_rules/batch_norm.h"
#include "paddle/phi/infermeta/spmd_rules/c_embedding.h"
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h"
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
spmd_rule : BatchNormGradInferSpmd
kernel :
func : batch_norm_grad
data_type : out_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta:
func : BatchNormInferMeta
spmd_rule : BatchNormInferSpmd
kernel :
func : batch_norm
data_type : x
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/static_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
spmd_rule : BatchNormGradInferSpmd
kernel :
func : batch_norm_grad
data_type : out_grad
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta:
func : BatchNormInferMeta
spmd_rule : BatchNormInferSpmd
kernel :
func : batch_norm
data_type : x
Expand Down
107 changes: 107 additions & 0 deletions test/cpp/auto_parallel/spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2605,6 +2605,113 @@ TEST(MeanAll, Ctor) {
check_dim_mapping(backward_info.first[1], {});
check_dim_mapping(backward_info.second[0], {-1, -1});
}
TEST(BatchNorm, Ctor) {
std::vector<int64_t> mesh_shape = {2, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);

// test forward
// data_format = NCHW
// [0, 1, -1, -1],[-1],[-1],[-1],[-1] ->[-1 , 1, -1, -1],[1],[1],[1],[1],[-1]
auto x_dist_attr = TensorDistAttr();
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping({0, 1, -1, -1});
x_dist_attr.set_dynamic_dims({false, false, false, false});
auto one_dim_dist_attr = TensorDistAttr();
one_dim_dist_attr.set_process_mesh(process_mesh);
one_dim_dist_attr.set_dims_mapping({-1});
one_dim_dist_attr.set_dynamic_dims({false});

phi::distributed::DistMetaTensor x = phi::distributed::DistMetaTensor(
common::make_ddim({16, 16, 16, 16}), x_dist_attr);
phi::distributed::DistMetaTensor mean = phi::distributed::DistMetaTensor(
common::make_ddim({16}), one_dim_dist_attr);
phi::distributed::DistMetaTensor variance = phi::distributed::DistMetaTensor(
common::make_ddim({16}), one_dim_dist_attr);
phi::distributed::DistMetaTensor scale = phi::distributed::DistMetaTensor(
common::make_ddim({16}), one_dim_dist_attr);
phi::distributed::DistMetaTensor bias = phi::distributed::DistMetaTensor(
common::make_ddim({16}), one_dim_dist_attr);
phi::distributed::SpmdInfo forward_info =
phi::distributed::BatchNormInferSpmdStatic(
x, mean, variance, scale, bias);

EXPECT_EQ(forward_info.first.size(), 5UL);
EXPECT_EQ(forward_info.second.size(), 6UL);
check_dim_mapping(forward_info.first[0], {-1, 1, -1, -1});
check_dim_mapping(forward_info.first[1], {1});
check_dim_mapping(forward_info.first[2], {1});
check_dim_mapping(forward_info.first[3], {-1});
check_dim_mapping(forward_info.first[4], {-1});
check_dim_mapping(forward_info.second[0], {-1, 1, -1, -1});
check_dim_mapping(forward_info.second[1], {1});
check_dim_mapping(forward_info.second[2], {1});
check_dim_mapping(forward_info.second[3], {1});
check_dim_mapping(forward_info.second[4], {1});
check_dim_mapping(forward_info.second[5], {-1});

// test backward
// data_format = NCHW
// [0, 1, -1, -1],[-1],[-1],[-1],[-1],[-1],[-1],[-1],[0, 1, -1, -1]
// ->[-1,1,-1,-1],[-1],[-1]
// dst_input: [-1, 1, -1, -1],[-1],[-1],[1],[1],[1],[1],[-1],[-1, 1, -1, -1]

x = phi::distributed::DistMetaTensor(common::make_ddim({16, 16, 16, 16}),
x_dist_attr);
phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor(
common::make_ddim({16, 16, 16, 16}), x_dist_attr);
phi::distributed::DistMetaTensor mean_out = phi::distributed::DistMetaTensor(
common::make_ddim({16}), one_dim_dist_attr);
phi::distributed::DistMetaTensor variance_out =
phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
scale = phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
bias = phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
phi::distributed::DistMetaTensor saved_mean =
phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
phi::distributed::DistMetaTensor saved_variance =
phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
phi::distributed::DistMetaTensor reserve_space =
phi::distributed::DistMetaTensor(common::make_ddim({16}),
one_dim_dist_attr);
phi::distributed::SpmdInfo backward_info =
phi::distributed::BatchNormGradInferSpmd(x,
scale,
bias,
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space,
out_grad,
0.9,
0.1,
"NCHW",
false,
false,
false);

EXPECT_EQ(backward_info.first.size(), 9UL);
EXPECT_EQ(backward_info.second.size(), 3UL);
check_dim_mapping(backward_info.first[0], {-1, 1, -1, -1});
check_dim_mapping(backward_info.first[1], {-1});
check_dim_mapping(backward_info.first[2], {-1});
check_dim_mapping(backward_info.first[3], {1});
check_dim_mapping(backward_info.first[4], {1});
check_dim_mapping(backward_info.first[5], {1});
check_dim_mapping(backward_info.first[6], {1});
check_dim_mapping(backward_info.first[7], {-1});
check_dim_mapping(backward_info.first[8], {-1, 1, -1, -1});

check_dim_mapping(backward_info.second[0], {-1, 1, -1, -1});
check_dim_mapping(backward_info.second[1], {-1});
check_dim_mapping(backward_info.second[2], {-1});
}
TEST(Topk, Ctor) {
std::vector<int64_t> mesh_shape = {2, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
Expand Down
Loading