From e68225bb328a77956ed03a8d7672f89d0f3b2599 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 15 Apr 2019 15:37:08 +0800 Subject: [PATCH 1/2] try to use safe_acc --- src/operator/nn/layer_norm-inl.h | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index dc4914bf2457..5a5d4027591e 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -111,7 +111,7 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, // Calculate mean MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, mean_data, req[0], workspace, in_data); Tensor mean_data_tensor = mean_data.FlatTo1D(s); mean_data_tensor /= scalar(channel_size); @@ -125,7 +125,7 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, std_data, req[0], workspace, centered_out); Tensor std_data_tensor = std_data.FlatTo1D(s); std_data_tensor = F(std_data_tensor / scalar(channel_size) @@ -133,17 +133,17 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, }); }); // Calculate data = data / std - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], outputs[layernorm::kStd]}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], outputs[layernorm::kStd]}, + {kWriteTo}, {outputs[0]}); // Calculate data = data * gamma - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], gamma}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], gamma}, + {kWriteTo}, {outputs[0]}); // Calculate data = data + beta - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], beta}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], beta}, + {kWriteTo}, {outputs[0]}); } /* @@ -222,17 +222,17 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); // Compute normalized_data = (data - mean) / std - BinaryBroadcastCompute(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, ograd.reshape(red_exclude_src_shape)); }); @@ -244,7 +244,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, ograd_mult.reshape(red_exclude_src_shape)); }); @@ -263,7 +263,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, red_out.reshape(red_dst_shape), kWriteTo, workspace, ograd_mult.reshape(red_src_shape)); }); @@ -277,16 +277,16 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, red_out.reshape(red_dst_shape), kWriteTo, workspace, ograd_mult.reshape(red_src_shape)); }); Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(- channel_size); }); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); } } From 449021073622210d93aebe6f4a813d0a33d52c73 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Fri, 3 May 2019 15:11:35 +0800 Subject: [PATCH 2/2] fix layernorm bug --- src/operator/nn/layer_norm-inl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 5a5d4027591e..ccc2a163a7aa 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -203,14 +203,14 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_src_shape, - kAddTo, red_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape)); }); BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_src_shape, kAddTo, - red_exclude_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape)); }); }); workspace = ctx.requested[0].get_space_typed(