Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid Parallelism #58776

Merged
merged 7 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,27 @@ SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x,
// x[0:begin_norm_axis], only the first axis of x can
// be sharded
std::string x_axes(x_ndim, '1');
x_axes[0] = alphabet[0];
std::string mean_axes(begin_norm_axis, '1');
std::string variance_axes(begin_norm_axis, '1');
// allow axis before begin_norm_axis be sharded
for (int i = 0; i < begin_norm_axis; ++i) {
x_axes[i] = alphabet[i];
mean_axes[i] = alphabet[i];
variance_axes[i] = alphabet[i];
}
// x_axes[0] = alphabet[0];
std::string scale_axes(1, x_axes[x_ndim - 1]);
std::string bias_axes(1, x_axes[x_ndim - 1]);

// get output notation
std::string out_axes = x_axes;
std::string mean_axes(1, '1'), variance_axes(1, '1');
if (begin_norm_axis > 0) {
mean_axes[0] = out_axes[0];
variance_axes[0] = out_axes[0];
}

// Step2: Sharding Propogation
// Step2.1: merge input sharding
// As the mean and variance in outputs are `flattened` from
// x[0:begin_norm_axis], only the first axis can be sharded,
// the axes 1 to begin_norm_axis-1 are set to be replicated.
std::fill(x_dims_mapping.begin() + 1, x_dims_mapping.end(), -1);
std::fill(x_dims_mapping.begin() + begin_norm_axis, x_dims_mapping.end(), -1);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});

Expand Down Expand Up @@ -199,16 +202,18 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x,
// the axes after norm_axis should be replicated,
// so set their notation to '1'.
std::string x_axes(x_ndim, '1');
x_axes[0] = alphabet[0];
std::string mean_axes(begin_norm_axis, '1');
std::string variance_axes(begin_norm_axis, '1');
// allow axis before begin_norm_axis be sharded
for (int i = 0; i < begin_norm_axis; ++i) {
x_axes[i] = alphabet[i];
mean_axes[i] = alphabet[i];
variance_axes[i] = alphabet[i];
}

std::string scale_axes(1, x_axes[x_ndim - 1]);
std::string bias_axes(1, x_axes[x_ndim - 1]);

std::string out_axes = x_axes;
std::string mean_axes(1, '1'), variance_axes(1, '1');
if (begin_norm_axis > 0) {
mean_axes[0] = out_axes[0];
variance_axes[0] = out_axes[0];
}

// Step2: Sharding Propogation
// For the axes after norm_axis in both input and output tensors,
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,10 @@ void LayerNormInferMeta(const MetaTensor& x,
x_dim.size()));

auto matrix_dim = phi::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);

// keep the axis size before normalization for shape of variance and mean
auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis);
// int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
if (scale) {
PADDLE_ENFORCE_EQ(scale.dims().size(),
Expand Down Expand Up @@ -644,11 +647,11 @@ void LayerNormInferMeta(const MetaTensor& x,
? phi::DataType::FLOAT32
: x_dtype;
if (mean) {
mean->set_dims({left});
mean->set_dims({before_norm_dims});
mean->set_dtype(param_type);
}
if (variance) {
variance->set_dims({left});
variance->set_dims({before_norm_dims});
variance->set_dtype(param_type);
}
}
Expand Down
13 changes: 10 additions & 3 deletions paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,15 @@ void LayerNormGradKernel(const Context& dev_ctx,
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
DDim var_shape({left});

d_y.Resize(matrix_shape);
// resize mean and var to match the shape of resized d_y for broadcast (Resize
// will not modify the underline data)
auto mean_tmp = mean;
mean_tmp.Resize(var_shape);
auto variance_tmp = variance;
variance_tmp.Resize(var_shape);

funcs::ColwiseSum2D<phi::CPUContext, T> colwise_sum(left, right, dev_ctx);
DenseTensor x_tmp = x;
Expand All @@ -69,11 +76,11 @@ void LayerNormGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(&temp_norm);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, mean, funcs::SubtractFunctor<T>(), &temp_norm, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor<T>(), &temp_norm, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
temp_norm,
variance,
variance_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&temp_norm,
0);
Expand Down Expand Up @@ -137,7 +144,7 @@ void LayerNormGradKernel(const Context& dev_ctx,
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
*d_x,
variance,
variance_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
d_x,
0);
Expand Down
30 changes: 19 additions & 11 deletions paddle/phi/kernels/cpu/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,44 @@ void LayerNormKernel(const Context& dev_ctx,
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
DDim normalized_shape({left});

auto x_tmp = x;
x_tmp.Resize(matrix_shape);
DenseTensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
// resize mean and var to match the shape of resized x_tmp for broadcast
DenseTensor mean_tmp;
mean_tmp.ShareDataWith(*mean);
mean_tmp.Resize(normalized_shape);
DenseTensor var_tmp;
var_tmp.ShareDataWith(*var);
var_tmp.Resize(normalized_shape);

#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
defined(__OSX__)

funcs::RowwiseMean2D<phi::CPUContext, T> row_mean(left, right, dev_ctx);

// get mean
row_mean(dev_ctx, x_tmp, mean);
row_mean(dev_ctx, x_tmp, &mean_tmp);

// get variance

phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor<T>(), &out, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubAndSquareFunctor<T>(), &out, 0);

row_mean(dev_ctx, out, var);
row_mean(dev_ctx, out, &var_tmp);

// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubtractFunctor<T>(), &out, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor<T>(), &out, 0);

phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
out,
*var,
var_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&out,
0);
Expand All @@ -93,17 +101,17 @@ void LayerNormKernel(const Context& dev_ctx,
dev_ctx, out, *bias, funcs::AddFunctor<T>(), &out, 1);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(),
PADDLE_ENFORCE_EQ(mean_tmp.numel(),
left,
phi::errors::InvalidArgument(
"mean's length (%d) is not equal with expected (%d).",
mean->numel(),
mean_tmp.numel(),
left));
PADDLE_ENFORCE_EQ(var->numel(),
PADDLE_ENFORCE_EQ(var_tmp.numel(),
left,
phi::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(),
var_tmp.numel(),
left));
if (scale) {
PADDLE_ENFORCE_EQ(
Expand All @@ -128,8 +136,8 @@ void LayerNormKernel(const Context& dev_ctx,
.At(right);
ker(x_tmp.data<T>(),
out.data<T>(),
mean->data<T>(),
var->data<T>(),
mean_tmp.data<T>(),
var_tmp.data<T>(),
scale ? scale->data<T>() : nullptr,
bias ? bias->data<T>() : nullptr,
static_cast<int>(left),
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias

mean_ = reshape(mean_, [-1])
variance = reshape(variance, [-1])
# keep the mean and variance shape as input x before begin_norm_axis
mean_ = reshape(mean_, x.shape[:begin_norm_axis])
variance = reshape(variance, x.shape[:begin_norm_axis])
if is_amp:
out = cast(out, dtype)
return out, mean_, variance
Expand Down
Loading