Skip to content

Commit

Permalink
[AutoParallel]Revise Infermeta of LayerNorm for Sequence-Data Hybrid …
Browse files Browse the repository at this point in the history
…Parallelism (#58776)

* modify infermate

* bugfix for kernel and spmd

* fix prim

* update unitest
  • Loading branch information
JZ-LIANG authored Nov 14, 2023
1 parent fdfd63b commit db105fd
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 171 deletions.
33 changes: 19 additions & 14 deletions paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,27 @@ SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x,
// x[0:begin_norm_axis], only the first axis of x can
// be sharded
std::string x_axes(x_ndim, '1');
x_axes[0] = alphabet[0];
std::string mean_axes(begin_norm_axis, '1');
std::string variance_axes(begin_norm_axis, '1');
// allow axis before begin_norm_axis be sharded
for (int i = 0; i < begin_norm_axis; ++i) {
x_axes[i] = alphabet[i];
mean_axes[i] = alphabet[i];
variance_axes[i] = alphabet[i];
}
// x_axes[0] = alphabet[0];
std::string scale_axes(1, x_axes[x_ndim - 1]);
std::string bias_axes(1, x_axes[x_ndim - 1]);

// get output notation
std::string out_axes = x_axes;
std::string mean_axes(1, '1'), variance_axes(1, '1');
if (begin_norm_axis > 0) {
mean_axes[0] = out_axes[0];
variance_axes[0] = out_axes[0];
}

// Step2: Sharding Propogation
// Step2.1: merge input sharding
// As the mean and variance in outputs are `flattened` from
// x[0:begin_norm_axis], only the first axis can be sharded,
// the axes 1 to begin_norm_axis-1 are set to be replicated.
std::fill(x_dims_mapping.begin() + 1, x_dims_mapping.end(), -1);
std::fill(x_dims_mapping.begin() + begin_norm_axis, x_dims_mapping.end(), -1);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});

Expand Down Expand Up @@ -199,16 +202,18 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x,
// the axes after norm_axis should be replicated,
// so set their notation to '1'.
std::string x_axes(x_ndim, '1');
x_axes[0] = alphabet[0];
std::string mean_axes(begin_norm_axis, '1');
std::string variance_axes(begin_norm_axis, '1');
// allow axis before begin_norm_axis be sharded
for (int i = 0; i < begin_norm_axis; ++i) {
x_axes[i] = alphabet[i];
mean_axes[i] = alphabet[i];
variance_axes[i] = alphabet[i];
}

std::string scale_axes(1, x_axes[x_ndim - 1]);
std::string bias_axes(1, x_axes[x_ndim - 1]);

std::string out_axes = x_axes;
std::string mean_axes(1, '1'), variance_axes(1, '1');
if (begin_norm_axis > 0) {
mean_axes[0] = out_axes[0];
variance_axes[0] = out_axes[0];
}

// Step2: Sharding Propogation
// For the axes after norm_axis in both input and output tensors,
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,10 @@ void LayerNormInferMeta(const MetaTensor& x,
x_dim.size()));

auto matrix_dim = phi::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);

// keep the axis size before normalization for shape of variance and mean
auto before_norm_dims = slice_ddim(x_dim, 0, begin_norm_axis);
// int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
if (scale) {
PADDLE_ENFORCE_EQ(scale.dims().size(),
Expand Down Expand Up @@ -644,11 +647,11 @@ void LayerNormInferMeta(const MetaTensor& x,
? phi::DataType::FLOAT32
: x_dtype;
if (mean) {
mean->set_dims({left});
mean->set_dims({before_norm_dims});
mean->set_dtype(param_type);
}
if (variance) {
variance->set_dims({left});
variance->set_dims({before_norm_dims});
variance->set_dtype(param_type);
}
}
Expand Down
13 changes: 10 additions & 3 deletions paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,15 @@ void LayerNormGradKernel(const Context& dev_ctx,
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
DDim var_shape({left});

d_y.Resize(matrix_shape);
// resize mean and var to match the shape of resized d_y for broadcast (Resize
// will not modify the underline data)
auto mean_tmp = mean;
mean_tmp.Resize(var_shape);
auto variance_tmp = variance;
variance_tmp.Resize(var_shape);

funcs::ColwiseSum2D<phi::CPUContext, T> colwise_sum(left, right, dev_ctx);
DenseTensor x_tmp = x;
Expand All @@ -69,11 +76,11 @@ void LayerNormGradKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(&temp_norm);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, mean, funcs::SubtractFunctor<T>(), &temp_norm, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor<T>(), &temp_norm, 0);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
temp_norm,
variance,
variance_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&temp_norm,
0);
Expand Down Expand Up @@ -137,7 +144,7 @@ void LayerNormGradKernel(const Context& dev_ctx,
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
*d_x,
variance,
variance_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
d_x,
0);
Expand Down
30 changes: 19 additions & 11 deletions paddle/phi/kernels/cpu/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,44 @@ void LayerNormKernel(const Context& dev_ctx,
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
DDim normalized_shape({left});

auto x_tmp = x;
x_tmp.Resize(matrix_shape);
DenseTensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
// resize mean and var to match the shape of resized x_tmp for broadcast
DenseTensor mean_tmp;
mean_tmp.ShareDataWith(*mean);
mean_tmp.Resize(normalized_shape);
DenseTensor var_tmp;
var_tmp.ShareDataWith(*var);
var_tmp.Resize(normalized_shape);

#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
defined(__OSX__)

funcs::RowwiseMean2D<phi::CPUContext, T> row_mean(left, right, dev_ctx);

// get mean
row_mean(dev_ctx, x_tmp, mean);
row_mean(dev_ctx, x_tmp, &mean_tmp);

// get variance

phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubAndSquareFunctor<T>(), &out, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubAndSquareFunctor<T>(), &out, 0);

row_mean(dev_ctx, out, var);
row_mean(dev_ctx, out, &var_tmp);

// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x_tmp, *mean, funcs::SubtractFunctor<T>(), &out, 0);
dev_ctx, x_tmp, mean_tmp, funcs::SubtractFunctor<T>(), &out, 0);

phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T>(
dev_ctx,
out,
*var,
var_tmp,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&out,
0);
Expand All @@ -93,17 +101,17 @@ void LayerNormKernel(const Context& dev_ctx,
dev_ctx, out, *bias, funcs::AddFunctor<T>(), &out, 1);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(),
PADDLE_ENFORCE_EQ(mean_tmp.numel(),
left,
phi::errors::InvalidArgument(
"mean's length (%d) is not equal with expected (%d).",
mean->numel(),
mean_tmp.numel(),
left));
PADDLE_ENFORCE_EQ(var->numel(),
PADDLE_ENFORCE_EQ(var_tmp.numel(),
left,
phi::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(),
var_tmp.numel(),
left));
if (scale) {
PADDLE_ENFORCE_EQ(
Expand All @@ -128,8 +136,8 @@ void LayerNormKernel(const Context& dev_ctx,
.At(right);
ker(x_tmp.data<T>(),
out.data<T>(),
mean->data<T>(),
var->data<T>(),
mean_tmp.data<T>(),
var_tmp.data<T>(),
scale ? scale->data<T>() : nullptr,
bias ? bias->data<T>() : nullptr,
static_cast<int>(left),
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
bias = reshape(bias, x.shape[begin_norm_axis:])
out = out + bias

mean_ = reshape(mean_, [-1])
variance = reshape(variance, [-1])
# keep the mean and variance shape as input x before begin_norm_axis
mean_ = reshape(mean_, x.shape[:begin_norm_axis])
variance = reshape(variance, x.shape[:begin_norm_axis])
if is_amp:
out = cast(out, dtype)
return out, mean_, variance
Expand Down
Loading

0 comments on commit db105fd

Please sign in to comment.