Skip to content

Commit

Permalink
Add warning for fp16 inputs with MXNET_SAFE_ACCUMULATION=0 (apache#15046
Browse files Browse the repository at this point in the history
)
  • Loading branch information
eric-haibin-lin authored and haohuw committed Jun 23, 2019
1 parent 3c3cb55 commit 53bb668
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
21 changes: 15 additions & 6 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,18 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);

bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}

// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, mean_data, req[0], workspace, in_data);
} else {
Expand All @@ -136,7 +144,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, false>(
s, std_data, req[0], workspace, centered_out);
} else {
Expand Down Expand Up @@ -251,10 +259,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{normalized_data, std},
{kWriteTo}, {normalized_data});
// Calculate grad_beta
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
Expand All @@ -272,7 +281,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
Expand All @@ -297,7 +306,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
Expand All @@ -317,7 +326,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ void LayerNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (axis == inputs[0].ndim() - 1) {
// Try to use the accelerated CUDA kernels
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LayerNorm with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}
if (safe_acc) {
return LayerNormGPUContig<true>(param, ctx, inputs, req, outputs);
} else {
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
param.temperature.value() : 1.0;
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for softmax with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
Expand Down
11 changes: 8 additions & 3 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1183,17 +1183,22 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
}

bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}
if (param.ord == 1) {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
}
} else if (param.ord == 2) {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, true, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
} else {
Expand Down

0 comments on commit 53bb668

Please sign in to comment.