Skip to content

Commit

Permalink
Safe LayerNorm (apache#15002)
Browse files Browse the repository at this point in the history
* use float32 to store the reduction result of float16

enable safe accumulation

fix bug

fix

* update test for safe_accumulate

* fix
  • Loading branch information
sxjscience authored and Rohit Kumar Srivastava committed May 22, 2019
1 parent 89a3852 commit 5dbf828
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 150 deletions.
102 changes: 68 additions & 34 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,13 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, mean_data, req[0], workspace, in_data);
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, mean_data, req[0], workspace, in_data);
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, mean_data, req[0], workspace, in_data);
}
Tensor<xpu, 1, DType> mean_data_tensor = mean_data.FlatTo1D<xpu, DType>(s);
mean_data_tensor /= scalar<DType>(channel_size);
});
Expand All @@ -130,25 +135,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::square>(
s, std_data, req[0], workspace, centered_out);
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, false>(
s, std_data, req[0], workspace, centered_out);
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_data, req[0], workspace, centered_out);
}
Tensor<xpu, 1, DType> std_data_tensor = std_data.FlatTo1D<xpu, DType>(s);
std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / scalar<DType>(channel_size)
+ scalar<DType>(param.eps));
});
});
// Calculate data = data / std
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{outputs[0], outputs[layernorm::kStd]},
{kWriteTo}, {outputs[0]});
// Calculate data = data * gamma
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{outputs[0], gamma},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
{outputs[0], gamma},
{kWriteTo}, {outputs[0]});
// Calculate data = data + beta
BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::plus>(attrs, ctx,
{outputs[0], beta},
{kWriteTo}, {outputs[0]});
}

template<typename xpu>
Expand Down Expand Up @@ -233,19 +243,25 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2,
mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id());
// Compute normalized_data = (data - mean) / std
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data, mean},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{normalized_data, std},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, mshadow_op::minus>(attrs, ctx,
{data, mean},
{kWriteTo}, {normalized_data});
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{normalized_data, std},
{kWriteTo}, {normalized_data});
// Calculate grad_beta
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
}
});
});
}
Expand All @@ -255,9 +271,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
}
});
});
}
Expand All @@ -274,9 +296,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
}
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(channel_size);
Expand All @@ -288,16 +316,22 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
} else {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, true>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
}
});
Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
red_out_tensor /= scalar<DType>(- channel_size);
});
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {outputs[0]});
BinaryBroadcastCompute<xpu, mshadow_op::mul>(attrs, ctx,
{normalized_data, red_out},
{kAddTo}, {outputs[0]});
}
}

Expand Down
Loading

0 comments on commit 5dbf828

Please sign in to comment.