From b3c91bfbbf20893bd6afe99e7995d53da3a69094 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Thu, 23 May 2019 02:56:23 +0800 Subject: [PATCH] Safe LayerNorm (#15002) * use float32 to store the reduction result of float16 enable safe accumulation fix bug fix * update test for safe_accumulate * fix --- src/operator/nn/layer_norm-inl.h | 102 ++++++---- src/operator/nn/layer_norm.cu | 253 +++++++++++++++---------- tests/python/unittest/test_operator.py | 40 ++-- 3 files changed, 245 insertions(+), 150 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 29224243dc40..456a5cb805ec 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -116,8 +116,13 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, // Calculate mean MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, mean_data, req[0], workspace, in_data); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, mean_data, req[0], workspace, in_data); + } else { + broadcast::Reduce( + s, mean_data, req[0], workspace, in_data); + } Tensor mean_data_tensor = mean_data.FlatTo1D(s); mean_data_tensor /= scalar(channel_size); }); @@ -130,25 +135,30 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, std_data, req[0], workspace, centered_out); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, std_data, req[0], workspace, centered_out); + } else { + broadcast::Reduce( + s, std_data, req[0], workspace, centered_out); + } Tensor std_data_tensor = std_data.FlatTo1D(s); std_data_tensor = F(std_data_tensor / scalar(channel_size) + scalar(param.eps)); }); }); // Calculate data = data / std - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], outputs[layernorm::kStd]}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], outputs[layernorm::kStd]}, + {kWriteTo}, {outputs[0]}); // Calculate data = data * gamma - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], gamma}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], gamma}, + {kWriteTo}, {outputs[0]}); // Calculate data = data + beta - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], beta}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], beta}, + {kWriteTo}, {outputs[0]}); } template @@ -233,19 +243,25 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); // Compute normalized_data = (data - mean) / std - BinaryBroadcastCompute(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, - ograd.reshape(red_exclude_src_shape)); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + } }); }); } @@ -255,9 +271,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, - ograd_mult.reshape(red_exclude_src_shape)); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } else { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + } }); }); } @@ -274,9 +296,15 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } }); Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(channel_size); @@ -288,16 +316,22 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( - s, red_out.reshape(red_dst_shape), kWriteTo, workspace, - ograd_mult.reshape(red_src_shape)); + if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } else { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + } }); Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(- channel_size); }); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); } } diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index b63046fc0026..db09969d6fcb 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -30,7 +30,7 @@ namespace mxnet { namespace op { template -__device__ __forceinline__ DType WARP_SHFL(DType value, int src_lane, +__device__ __forceinline__ DType warp_shfl(DType value, int src_lane, int width = 32, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_sync(mask, value, src_lane, width); @@ -40,7 +40,7 @@ __device__ __forceinline__ DType WARP_SHFL(DType value, int src_lane, } template -__device__ __forceinline__ DType WARP_SHFL_XOR(DType value, int laneMask, +__device__ __forceinline__ DType warp_shfl_xor(DType value, int laneMask, int width = 32, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); @@ -54,12 +54,12 @@ __device__ __forceinline__ DType WARP_SHFL_XOR(DType value, int laneMask, * The value 'curr' will be accumulated to the (mean, sigma2, count) triplet. * */ -template +template __device__ __forceinline__ void StepWelfordOnlineSum(const DType curr, DType& mean, //NOLINT DType& sigma2, //NOLINT - DType& count) { //NOLINT - count += DType(1); + IType& count) { //NOLINT + count += IType(1); DType delta = curr - mean; mean += delta / count; sigma2 += delta * (curr - mean); @@ -72,16 +72,16 @@ __device__ __forceinline__ void StepWelfordOnlineSum(const DType curr, * * TODO(sxjscience) Explore the possibility of int lhs_count and rhs_count */ -template +template __device__ __inline__ void ChanMergePartition(const DType lhs_mean, const DType lhs_sigma2, - const DType lhs_count, + const IType lhs_count, DType& rhs_mean, //NOLINT DType& rhs_sigma2, //NOLINT - DType& rhs_count) { //NOLINT + IType& rhs_count) { //NOLINT DType delta = rhs_mean - lhs_mean; - DType nA = lhs_count; - DType nB = rhs_count; + DType nA = static_cast(lhs_count); + DType nB = static_cast(rhs_count); rhs_count = nA + nB; if (rhs_count > DType(0)) { nA = nA / rhs_count; @@ -94,7 +94,10 @@ __device__ __inline__ void ChanMergePartition(const DType lhs_mean, } } - +/* Split the input column into multiple partitions and compute the mean/sigma of each partition. + * Each thread will keep a mean/sigma2. The mean/sigma2 can be further merged to get the mean and + * sigma2 of the column. + */ template __device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ col_vals, const int nchannel, @@ -110,11 +113,45 @@ __device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ for (; l + 3 < nchannel; l += 4 * nthread) { #pragma unroll for (int i = 0; i < 4; ++i) { - StepWelfordOnlineSum(col_vals[l + i], mean, sigma2, count); + StepWelfordOnlineSum(static_cast(col_vals[l + i]), mean, sigma2, count); } } for (; l < nchannel; ++l) { - StepWelfordOnlineSum(col_vals[l], mean, sigma2, count); + StepWelfordOnlineSum(static_cast(col_vals[l]), mean, sigma2, count); + } +} + +template<> +__device__ __forceinline__ +void BlockWelfordOnlineSum + (const mshadow::half::half_t* __restrict__ col_vals, + const int nchannel, + float& mean, //NOLINT + float& sigma2, //NOLINT + int& count) { //NOLINT + int tid = threadIdx.x + threadIdx.y * blockDim.x; + const int nthread = blockDim.x * blockDim.y; + // We cast the input half pointer to half2 to optimize the loading speed. + // Here, we need to notice that CUDA forces memory alignment, i.e., + // ASSERT static_cast(ptr) % sizeof(dtype) == 0. + // Thus, we need to shift the address of the half pointer to be aligned by half2. + int align_shift = (reinterpret_cast(col_vals) % 4) != 0; + int padding = (nchannel - align_shift) % 2; + int half2_size = (nchannel - align_shift) / 2; + const __half2* half2_col_vals = reinterpret_cast(col_vals + align_shift); + if (threadIdx.x == 0 && threadIdx.y == 0) { + if (align_shift) { + StepWelfordOnlineSum(__half2float(col_vals[0].cuhalf_), mean, sigma2, count); + } + if (padding) { + StepWelfordOnlineSum(__half2float(col_vals[nchannel - 1].cuhalf_), mean, sigma2, count); + } + } + + for (int l = tid; l < half2_size; l += nthread) { + float2 ele_val = __half22float2(half2_col_vals[l]); + StepWelfordOnlineSum(ele_val.x, mean, sigma2, count); + StepWelfordOnlineSum(ele_val.y, mean, sigma2, count); } } @@ -129,12 +166,12 @@ __device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ * var_data = (nbatch,) * It's always launched with (blockDim.x, blockDim.y) = (WARP_SIZE, blockDim.y) * Also, when blockDim.y > 1, it requires shared memory that has size: - * sizeof(DType) * blockDim.y + sizeof(DType) * blockDim.y / 2 + * sizeof(AType) * blockDim.y + sizeof(int) * blockDim.y / 2 */ -template +template __global__ void LayerNormFusedForwardKernelContig(const int nbatch, const int nchannel, - const DType eps, + const AType eps, const DType* __restrict__ in_data, const DType* __restrict__ gamma, const DType* __restrict__ beta, @@ -144,9 +181,9 @@ __global__ void LayerNormFusedForwardKernelContig(const int nbatch, int bid = blockIdx.x + blockIdx.y * gridDim.x; const int tid = threadIdx.y * blockDim.x + threadIdx.x; const int nthread = blockDim.x * blockDim.y; - DType count = 0; - DType mean = 0; - DType sigma2 = 0; + IType count = 0; + AType mean = 0; + AType sigma2 = 0; if (bid < nbatch) { extern __shared__ char buf[]; // Shared memory @@ -158,18 +195,19 @@ __global__ void LayerNormFusedForwardKernelContig(const int nbatch, // within a warp of threads. // After calling the function, threadIdx.x == 0 will store the result of // the aggregated (mean, sigma2, counts). - for (int mask = 16; mask > 0; mask >>= 1) { - DType meanB = WARP_SHFL_XOR(mean, mask); - DType sigma2B = WARP_SHFL_XOR(sigma2, mask); - DType countB = WARP_SHFL_XOR(count, mask); + for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) { + AType meanB = warp_shfl_xor(mean, mask); + AType sigma2B = warp_shfl_xor(sigma2, mask); + IType countB = warp_shfl_xor(count, mask); ChanMergePartition(meanB, sigma2B, countB, mean, sigma2, count); } if (blockDim.y > 1) { // Inter-warp reduction. Copy the upper-half of the warps to shared memory // and merge with the lower-half warp - DType* mean_buf = reinterpret_cast(buf); - DType* sigma2_buf = reinterpret_cast(buf + sizeof(DType) * blockDim.y / 2 * 32); - DType* count_buf = reinterpret_cast(buf + sizeof(DType) * blockDim.y * 32); + AType* mean_buf = reinterpret_cast(buf); + AType* sigma2_buf = + reinterpret_cast(buf + sizeof(AType) * blockDim.y / 2 * blockDim.x); + IType* count_buf = reinterpret_cast(buf + sizeof(AType) * blockDim.y * blockDim.x); for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; @@ -196,35 +234,40 @@ __global__ void LayerNormFusedForwardKernelContig(const int nbatch, sigma2 /= nchannel; } // Calculate the out_data: gamma * (x - mean) / sqrt(var + eps) + beta - DType std_eps = sqrt(sigma2 + eps); - DType invstd_eps = DType(1.0) / std_eps; + AType std_eps = sqrt(sigma2 + eps); + AType invstd_eps = DType(1.0) / std_eps; DType* out_col_val = out_data + bid * nchannel; if (gamma != NULL && beta != NULL) { for (int i = tid; i < nchannel; i += nthread) { - out_col_val[i] = gamma[i] * invstd_eps * (col_vals[i] - mean) + beta[i]; + out_col_val[i] = gamma[i] * static_cast(invstd_eps * + (static_cast(col_vals[i]) - mean)) + + beta[i]; } } else if (gamma == NULL && beta != NULL) { for (int i = tid; i < nchannel; i += nthread) { - out_col_val[i] = invstd_eps * (col_vals[i] - mean) + beta[i]; + out_col_val[i] = static_cast(invstd_eps * (static_cast(col_vals[i]) - mean)) + + beta[i]; } } else if (gamma != NULL && beta == NULL) { for (int i = tid; i < nchannel; i += nthread) { - out_col_val[i] = gamma[i] * invstd_eps * (col_vals[i] - mean); + out_col_val[i] = gamma[i] * static_cast(invstd_eps * + (static_cast(col_vals[i]) - mean)); } } else { for (int i = tid; i < nchannel; i += nthread) { - out_col_val[i] = invstd_eps * (col_vals[i] - mean); + out_col_val[i] = static_cast(invstd_eps * (static_cast(col_vals[i]) - mean)); } } // Write the out_data and var_data if (threadIdx.x == 0 && threadIdx.y == 0) { - mean_data[bid] = mean; - std_data[bid] = std_eps; + mean_data[bid] = static_cast(mean); + std_data[bid] = static_cast(std_eps); } } } +template void LayerNormGPUContig(const LayerNormParam param, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -268,12 +311,13 @@ void LayerNormGPUContig(const LayerNormParam param, } cudaStream_t stream = Stream::GetStream(ctx.get_stream()); const dim3 dimBlock(32, nthread_y); - MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { - int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(DType) - + (nthread_y / 2) * 32 * sizeof(DType) : 0; + MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, { + typedef typename std::conditional::type AType; + int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(AType) + + (nthread_y / 2) * 32 * sizeof(int) : 0; CheckLaunchParam(dimGrid, dimBlock); - LayerNormFusedForwardKernelContig<<>> - (nbatch, nchannel, static_cast(eps), + LayerNormFusedForwardKernelContig <<>> + (nbatch, nchannel, static_cast(eps), in_data.dptr(), gamma.dptr(), beta.dptr(), out_data.dptr(), mean_data.dptr(), std_data.dptr()); }); @@ -295,7 +339,12 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; if (axis == inputs[0].ndim() - 1) { // Try to use the accelerated CUDA kernels - return LayerNormGPUContig(param, ctx, inputs, req, outputs); + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (safe_acc) { + return LayerNormGPUContig(param, ctx, inputs, req, outputs); + } else { + return LayerNormGPUContig(param, ctx, inputs, req, outputs); + } } return LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); } @@ -327,17 +376,17 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, * This `LayerNormFusedBackwardKernel_PartGammaBeta` function implements the first step and * `LayerNormFusedBackwardKernel_GammaBeta` implements the second step. */ -template +template __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, const int nchannel, const DType* __restrict__ in_data, const DType* __restrict__ out_grad, const DType* __restrict__ mean_data, const DType* __restrict__ std_data, - DType* part_gamma_grad, - DType* part_beta_grad) { + AType* __restrict__ part_gamma_grad, + AType* __restrict__ part_beta_grad) { extern __shared__ char buf[]; - DType* d_buf = reinterpret_cast(buf); + AType* d_buf = reinterpret_cast(buf); const int npart = gridDim.y; const int block_row_num = (nbatch + npart - 1) / npart; // The rows are divided into `npart` parts. Each threadblock calculates the reduction result @@ -346,21 +395,22 @@ __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, const int c = blockIdx.x * blockDim.x + threadIdx.x; int r_begin = blockIdx.y * block_row_num; int r_end = min((blockIdx.y + 1) * block_row_num, nbatch); - DType* buf_gamma_grad = d_buf; - DType* buf_beta_grad = d_buf + blockDim.y * row_stride; - DType local_gamma_grad = 0; - DType local_beta_grad = 0; + AType* buf_gamma_grad = d_buf; + AType* buf_beta_grad = d_buf + blockDim.y * row_stride; + AType local_gamma_grad = 0; + AType local_beta_grad = 0; if (c < nchannel) { for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) { int r = r_b + threadIdx.y; if (r < r_end) { - DType local_mean = mean_data[r]; - DType local_std = std_data[r]; + AType local_mean = static_cast(mean_data[r]); + AType local_std = static_cast(std_data[r]); int read_idx = r * nchannel + c; - local_gamma_grad += (in_data[read_idx] - local_mean) / local_std - * out_grad[read_idx]; - local_beta_grad += out_grad[read_idx]; + AType local_in_data = static_cast(in_data[read_idx]); + AType local_out_grad = static_cast(out_grad[read_idx]); + local_gamma_grad += (local_in_data - local_mean) / local_std * local_out_grad; + local_beta_grad += local_out_grad; } } } @@ -384,20 +434,20 @@ __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, } } -template +template __global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch, const int nchannel, const int npart, - const DType* __restrict__ part_gamma_grad, - const DType* __restrict__ part_beta_grad, + const AType* __restrict__ part_gamma_grad, + const AType* __restrict__ part_beta_grad, DType* gamma_grad, DType* beta_grad) { const int c = blockIdx.x * blockDim.x + threadIdx.x; const int tid = threadIdx.y * blockDim.x + threadIdx.x; if (c < nchannel) { extern __shared__ char buf[]; - DType* buf_gamma_grad = reinterpret_cast(buf); - DType* buf_beta_grad = reinterpret_cast(buf) + blockDim.x * blockDim.y; + AType* buf_gamma_grad = reinterpret_cast(buf); + AType* buf_beta_grad = reinterpret_cast(buf) + blockDim.x * blockDim.y; buf_gamma_grad[tid] = 0; buf_beta_grad[tid] = 0; for (int r = threadIdx.y; r < npart; r += blockDim.y) { @@ -420,16 +470,16 @@ __global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch, if (threadIdx.y == 0) { if (gamma_grad) { if (gamma_addto) { - gamma_grad[c] += buf_gamma_grad[threadIdx.x]; + gamma_grad[c] += static_cast(buf_gamma_grad[threadIdx.x]); } else { - gamma_grad[c] = buf_gamma_grad[threadIdx.x]; + gamma_grad[c] = static_cast(buf_gamma_grad[threadIdx.x]); } } if (beta_grad) { if (beta_addto) { - beta_grad[c] += buf_beta_grad[threadIdx.x]; + beta_grad[c] += static_cast(buf_beta_grad[threadIdx.x]); } else { - beta_grad[c] = buf_beta_grad[threadIdx.x]; + beta_grad[c] = static_cast(buf_beta_grad[threadIdx.x]); } } } @@ -440,7 +490,7 @@ __global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch, * * */ -template +template __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch, const int nchannel, const DType* __restrict__ in_data, @@ -457,38 +507,38 @@ __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch, int tid = threadIdx.x + threadIdx.y * blockDim.x; // 1. Calculate: mean(out_grad * gamma / std, axis=-1) // mean(out_grad * gamma / std * (x - mean) / std, axis=-1) - DType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1) - DType sum_val1 = 0; // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1) - DType mean = mean_data[bid]; - DType invstd_eps = DType(1) / std_data[bid]; + AType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1) + AType sum_val1 = 0; // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1) + AType mean = static_cast(mean_data[bid]); + AType invstd_eps = AType(1) / static_cast(std_data[bid]); int l = LOAD_UNROLL * tid; for (; l + LOAD_UNROLL - 1 < nchannel; l += nthread * LOAD_UNROLL) { #pragma unroll for (int i = 0; i < LOAD_UNROLL; ++i) { - DType ele_og = out_grad[bid * nchannel + l + i]; - DType ele_x = in_data[bid * nchannel + l + i]; - DType ele_gamma = gamma[l + i]; + AType ele_og = static_cast(out_grad[bid * nchannel + l + i]); + AType ele_x = static_cast(in_data[bid * nchannel + l + i]); + AType ele_gamma = static_cast(gamma[l + i]); sum_val0 += ele_og * ele_gamma * invstd_eps; sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps; } } for (; l < nchannel; ++l) { - DType ele_og = out_grad[bid * nchannel + l]; - DType ele_x = in_data[bid * nchannel + l]; - DType ele_gamma = gamma[l]; + AType ele_og = static_cast(out_grad[bid * nchannel + l]); + AType ele_x = static_cast(in_data[bid * nchannel + l]); + AType ele_gamma = static_cast(gamma[l]); sum_val0 += ele_og * ele_gamma * invstd_eps; sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps; } // Intra-warp reduction (all-reduce) for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) { - sum_val0 += WARP_SHFL_XOR(sum_val0, mask); - sum_val1 += WARP_SHFL_XOR(sum_val1, mask); + sum_val0 += warp_shfl_xor(sum_val0, mask); + sum_val1 += warp_shfl_xor(sum_val1, mask); } // Inter-warp reduction (all-reduce) if (blockDim.y > 1) { - DType* sum_val0_buf = reinterpret_cast(buf); - DType* sum_val1_buf = - reinterpret_cast(buf + blockDim.y / 2 * blockDim.x * sizeof(DType)); + AType* sum_val0_buf = reinterpret_cast(buf); + AType* sum_val1_buf = + reinterpret_cast(buf + blockDim.y / 2 * blockDim.x * sizeof(AType)); for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; @@ -516,16 +566,17 @@ __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch, // 2. Calculate the gradient as // out_grad * gamma / std - sum_val0 - (x - mean) / std * sum_val1 for (int l = tid; l < nchannel; l += nthread) { - DType ele_out_grad = out_grad[bid * nchannel + l]; - DType ele_x = in_data[bid * nchannel + l]; + AType ele_out_grad = static_cast(out_grad[bid * nchannel + l]); + AType ele_x = static_cast(in_data[bid * nchannel + l]); + AType ele_gamma = static_cast(gamma[l]); if (data_addto) { data_grad[bid * nchannel + l] += - ele_out_grad * gamma[l] * invstd_eps - sum_val0 - - (ele_x - mean) * invstd_eps * sum_val1; + static_cast(ele_out_grad * ele_gamma * invstd_eps + - sum_val0 - (ele_x - mean) * invstd_eps * sum_val1); } else { data_grad[bid * nchannel + l] = - ele_out_grad * gamma[l] * invstd_eps - sum_val0 - - (ele_x - mean) * invstd_eps * sum_val1; + static_cast(ele_out_grad * ele_gamma * invstd_eps - sum_val0 + - (ele_x - mean) * invstd_eps * sum_val1); } } } @@ -544,6 +595,7 @@ void GetGammaBetaGradKernelParams(const int nbatch, const int nchannel, CheckLaunchParam(*gb_grid_dim, *gb_block_dim); } +template void LayerNormGradGPUContig(const LayerNormParam param, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -584,14 +636,15 @@ void LayerNormGradGPUContig(const LayerNormParam param, GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim, &gb_block_dim, &gb_grid_dim, &npart); if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(2 * npart * nchannel), s); - DType* part_gamma_grad_ptr = workspace.dptr_; - DType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel; + MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, { + typedef typename std::conditional::type AType; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * npart * nchannel), s); + AType* part_gamma_grad_ptr = workspace.dptr_; + AType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel; const int nshared_K1 = 2 * (part_grad_block_dim.x + 1) - * part_grad_block_dim.y * sizeof(DType); - const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(DType); + * part_grad_block_dim.y * sizeof(AType); + const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(AType); DType* gamma_grad_ptr = (gamma_grad_req != kNullOp) ? gamma_grad.dptr() : nullptr; DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr() : nullptr; LayerNormFusedBackwardKernel_PartGammaBeta @@ -642,16 +695,17 @@ void LayerNormGradGPUContig(const LayerNormParam param, const dim3 data_block_dim(32, nthread_y); const int LOAD_UNROLL = 4; if (data_grad_req != kNullOp) { - MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { - int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(DType) : 0; + MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, { + typedef typename std::conditional::type AType; + int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0; CheckLaunchParam(data_grid_dim, data_block_dim); if (data_grad_req == kAddTo) { - LayerNormFusedBackwardKernel_Data + LayerNormFusedBackwardKernel_Data <<>> (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), mean_data.dptr(), std_data.dptr(), gamma.dptr(), data_grad.dptr()); } else { - LayerNormFusedBackwardKernel_Data + LayerNormFusedBackwardKernel_Data <<>> (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), mean_data.dptr(), std_data.dptr(), gamma.dptr(), data_grad.dptr()); @@ -673,8 +727,13 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, } CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; if (axis == inputs[0].ndim() - 1) { - // Try to use the accelerated CUDA kernels - return LayerNormGradGPUContig(param, ctx, inputs, req, outputs); + // Use the accelerated CUDA kernels + bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); + if (safe_acc) { + return LayerNormGradGPUContig(param, ctx, inputs, req, outputs); + } else { + return LayerNormGradGPUContig(param, ctx, inputs, req, outputs); + } } return LayerNormGradComputeGeneral(attrs, ctx, inputs, req, outputs); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 324e8a3d7ed9..cb9b2f9fab41 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3541,25 +3541,27 @@ def l2norm(input_data, axis=0, keepdims=True): def test_layer_norm(): - for dtype, forward_check_eps, backward_check_eps in zip([np.float16, np.float32, np.float64], - [1E-2, 1E-3, 1E-4], - [1E-2, 1E-3, 1E-4]): - if dtype != np.float16: - in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), (128 * 32, 512)], [True, True, False] - else: - in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], [True, True] # large input + fp16 does not pass the forward check - for in_shape, finite_grad_check in zip(in_shape_l, finite_grad_check_l): - for axis in range(-len(in_shape), len(in_shape)): - for eps in [1E-2, 1E-3]: - if dtype == np.float16: - npy_grad_check = False - else: - npy_grad_check = True - check_layer_normalization(in_shape, axis, eps, dtype=dtype, - forward_check_eps=forward_check_eps, - backward_check_eps=backward_check_eps, - npy_grad_check=npy_grad_check, - finite_grad_check=finite_grad_check) + for enforce_safe_acc in ["1", "0"]: + os.environ["MXNET_SAFE_ACCUMULATION"] = enforce_safe_acc + for dtype, forward_check_eps, backward_check_eps in zip([np.float16, np.float32, np.float64], + [1E-2, 1E-3, 1E-4], + [1E-2, 1E-3, 1E-4]): + if dtype != np.float16: + in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), (128 * 32, 512)], [True, True, False] + else: + in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], [True, True] # large input + fp16 does not pass the forward check + for in_shape, finite_grad_check in zip(in_shape_l, finite_grad_check_l): + for axis in range(-len(in_shape), len(in_shape)): + for eps in [1E-2, 1E-3]: + if dtype == np.float16: + npy_grad_check = False + else: + npy_grad_check = True + check_layer_normalization(in_shape, axis, eps, dtype=dtype, + forward_check_eps=forward_check_eps, + backward_check_eps=backward_check_eps, + npy_grad_check=npy_grad_check, + finite_grad_check=finite_grad_check) # Numpy Implementation of Sequence Ops