From e06ee4e1a725aed62a66b525036676500010e7f0 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 17 Aug 2020 23:07:37 -0700 Subject: [PATCH] Faster GPU frozen BatchNorm (#17368) * Better frozen batchnorm * Continue FreezeBN * Optimizations * Reduce number of mod operations * Cleaning * Fixing frozen bn with fix_gamma=False * Fix lint in BN * Backward frozen batchnorm * More work on backward of Frozen BN * Let it compile * NCHW Frozen BN backward * Frozen BN backward NHWC * Cleaning * Remove the change to Makefile * Fix from rebase * Temp space for BN backward * Fix from review * Fix lint * Changes from review --- src/common/cuda_utils.h | 83 ++++- src/operator/nn/batch_norm.cc | 2 - src/operator/nn/batch_norm.cu | 563 +++++++++++++++++++++++++++++----- src/operator/nn/softmax-inl.h | 6 +- 4 files changed, 561 insertions(+), 93 deletions(-) diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 0971cfd22361..22ac42c6c67b 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -783,27 +783,86 @@ __device__ inline DType ldg(const DType* address) { #endif } -template +namespace mxnet { +namespace common { +/*! \brief common utils for cuda */ +namespace cuda { + +static constexpr const int warp_size = 32; + +/*! \brief Reduction inside a warp. + * Template parameters: + * NVALUES - number of values to reduce (defaults to warp_size). + * \param value - values to be reduced. + * \param redfun - function used to perform reduction. + */ +template __device__ inline T warp_reduce(T value, OP redfun) { - value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); +#pragma unroll + for (int i = warp_size / 2; i >= 1; i /= 2) { + if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } return value; } -template +template __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { float v = static_cast(value); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); +#pragma unroll + for (int i = warp_size / 2; i >= 1; i /= 2) { + if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i)); + } return mshadow::half::half_t(v); } +/*! \brief Reduction inside a block, requires all threads in a block to participate. + * It uses a 2 step approach: + * - all warps in a block perform intermediate reduction + * - first warp reduces the intermediate results. + * Template parameters: + * NTHREADS - number of threads in a block. + * all_reduce - whether all threads need the result of the reduction. If set to + * true, then all threads return with the same value. If set to + * false, then only thread 0 has the valid result. Defaults to true. + * \param value - value from each thread to be reduced + * \param redfun - function used to perform reduction + */ +template +__device__ inline T reduce(const T& value, OP redfun) { + static_assert(NTHREADS <= warp_size * warp_size, + "Number of threads too large for reduction"); + __shared__ T scratch[NTHREADS / warp_size]; + const int thread_idx_in_warp = threadIdx.x % warp_size; + const int warp_id = threadIdx.x / warp_size; + const T my_val = warp_reduce(value, redfun); + if (thread_idx_in_warp == 0) { + scratch[warp_id] = my_val; + } + __syncthreads(); + T ret = 0; + if (warp_id == 0) { + const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0; + const T my_val = warp_reduce(prev_val, redfun); + if (all_reduce) { + scratch[threadIdx.x] = my_val; + } else { + ret = my_val; + } + } + // Necessary to synchronize in order to use this function again + // as the shared memory scratch space is reused between calls + __syncthreads(); + if (all_reduce) { + ret = scratch[0]; + __syncthreads(); + } + return ret; +} + +} // namespace cuda +} // namespace common +} // namespace mxnet + #endif // __CUDACC__ #endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 2fdd31ec3cab..2a91a3706794 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -662,11 +662,9 @@ NNVM_REGISTER_OP(_backward_BatchNorm) }) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) -#if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index c7e991f98d18..72e4a76a26d4 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -27,6 +27,8 @@ #include #include #include "batch_norm-inl.h" +#include "../../common/cuda_utils.h" + #define WRITE_DATA_FLAG 1 #define WRITE_GAMMA_FLAG 2 @@ -47,9 +49,30 @@ using namespace mxnet; +namespace { + /*! \brief inverse standard deviation <-> variance */ -#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$))) -#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$)) +template +MSHADOW_XINLINE AccReal variance_to_invstd(DType var, AccReal eps) { + return rsqrtf(static_cast(var) + eps); +} + +template <> +MSHADOW_XINLINE double variance_to_invstd(double var, double eps) { + return rsqrt(var + eps); +} + +template +MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) { + return static_cast(1.0) / (invstd * invstd) - eps; +} + +template <> +MSHADOW_XINLINE double invstd_to_variance(double invstd, double eps) { + return 1.0 / (invstd * invstd) - eps; +} + +} // namespace namespace mxnet { namespace op { @@ -206,41 +229,90 @@ static __device__ T reduce(Op op, DeviceTensor tensor, int plane) { return shared[0]; } -template +namespace { + constexpr int inference_forward_threads = 512; + constexpr int shmem_elements = 1536; +} // namespace + +template +__launch_bounds__(inference_forward_threads) __global__ void BatchNormalizationUpdateOutputInferenceKernel( - DeviceTensor input, - DeviceTensor output, - DeviceTensor1 runningMean, - DeviceTensor1 runningVar, - DeviceTensor1 saveMean, - DeviceTensor1 saveInvStd, - DeviceTensor1 weight, - DeviceTensor1 bias, - const DType epsilon, + const DType* input, + DType* output, + const index_t size, + const index_t outer_size, + const index_t num_channels, + const index_t inner_size, + const AType* runningMean, + const AType* runningVar, + AType* saveMean, + AType* saveInvStd, + AType* weight, + AType* bias, + const AType epsilon, const uint32_t flags) { - int plane = blockIdx.x; - - AccReal invstd = VARIANCE_TO_INVSTD(runningVar[plane], epsilon); - AccReal mean = ScalarConvert::to(runningMean[plane]); - AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) - ? ScalarConvert::to(weight[plane]) - : ScalarConvert::to(1); - AccReal beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) - : ScalarConvert::to(0); - if (threadIdx.x == 0) { - saveMean[plane] = runningMean[plane]; - saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon); - if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 - && weight.numElements() > 0) { - weight[plane] = AccReal(1); + constexpr int nvec = sizeof(LType) / sizeof(DType); + __shared__ AType saved_invstd[shmem_elements]; + __shared__ AType saved_mean[shmem_elements]; + __shared__ AType saved_weight[shmem_elements]; + __shared__ AType saved_bias[shmem_elements]; + union vectorized_loader { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} + } scratch; + + if (small_num_channels) { + for (int i = threadIdx.x; i < num_channels; i += blockDim.x) { + saved_invstd[i] = variance_to_invstd(runningVar[i], epsilon); + saved_mean[i] = runningMean[i]; + saved_weight[i] = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? weight[i] + : 1; + saved_bias[i] = (bias != nullptr) ? bias[i] : 0; } + __syncthreads(); } - // Write normalized and update the output - for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) { - const DType inp = input.get_ref(batch, plane, x); - output.get_ref(batch, plane, x) = - ScalarConvert::to(gamma * (inp - mean) * invstd + beta); + + const index_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const index_t stride = blockDim.x * gridDim.x; + const LType* input_aligned = reinterpret_cast(input); + LType* output_aligned = reinterpret_cast(output); + for (index_t i = tid; i < size / nvec; i += stride) { + scratch.aligned = input_aligned[i]; + const index_t my_channel_base = (nvec * i) % (inner_size * num_channels); +#pragma unroll + for (int j = 0; j < nvec; ++j) { + index_t my_channel = (my_channel_base + j) / inner_size; + if (my_channel >= num_channels) my_channel = my_channel % num_channels; + AType current_input = static_cast(scratch.separate[j]); + + AType invstd = small_num_channels ? saved_invstd[my_channel] + : variance_to_invstd(runningVar[my_channel], epsilon); + AType mean = small_num_channels ? saved_mean[my_channel] + : runningMean[my_channel]; + AType gamma = small_num_channels ? saved_weight[my_channel] + : ((weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? weight[my_channel] + : 1); + AType beta = small_num_channels ? saved_bias[my_channel] + : ((bias != nullptr) ? bias[my_channel] + : 0); + current_input = gamma * (current_input - mean) * invstd + beta; + scratch.separate[j] = current_input; + } + + output_aligned[i] = scratch.aligned; + + if (i < num_channels) { + saveMean[i] = runningMean[i]; + saveInvStd[i] = variance_to_invstd(runningVar[i], epsilon); + if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 + && weight != nullptr) { + weight[i] = 1; + } } } } @@ -312,6 +384,266 @@ struct CUDATensors { DeviceTensor1 saveInvStd; }; +namespace { + inline int ceil_div(int x, int y) { + return (x + y - 1) / y; + } +} // namespace + +template +__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1( + const DType* input, const DType* gradOutput, AType* temp_space, + DType* gradInput, const AType* weight, const AType* runningMean, + const AType* runningVar, const index_t outer, const index_t num_channels, + const AType eps, const uint32_t flags) { + using mxnet::common::cuda::warp_size; + constexpr int num_warps = NTHREADS / warp_size; + constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const size_t stride = num_channels / nvec; + + union vectorized_loader { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} + }; + + vectorized_loader vec_input, vec_gradOutput; + + __shared__ AType scratch[NTHREADS * 2 * nvec]; + AType * my_values_gamma = &(scratch[threadIdx.x * nvec]); + AType * my_values_beta = &(scratch[(NTHREADS + threadIdx.x) * nvec]); + + AType sum_gamma[nvec]; // NOLINT(*) + AType sum_beta[nvec]; // NOLINT(*) +#pragma unroll + for (int i = 0; i < nvec; ++i) { + sum_gamma[i] = 0; + sum_beta[i] = 0; + } + + const size_t offset = blockIdx.x * warp_size; + const int my_warp = threadIdx.x / warp_size; + const int thread_idx_in_warp = threadIdx.x % warp_size; + + AType invstd[nvec]; // NOLINT(*) + AType mean[nvec]; // NOLINT(*) + AType gamma[nvec]; // NOLINT(*) + size_t channel_offset = (offset + thread_idx_in_warp) * nvec; + + if (channel_offset < num_channels) { +#pragma unroll + for (int i = 0; i < nvec; ++i) { + invstd[i] = variance_to_invstd(runningVar[channel_offset + i], eps); + mean[i] = runningMean[channel_offset + i]; + gamma[i] = weight != nullptr ? weight[channel_offset + i] : 1; + } + } + + const LType* aligned_gradOutput = reinterpret_cast(gradOutput); + const LType* aligned_input = reinterpret_cast(input); + LType* gradInput_aligned = reinterpret_cast(gradInput); + + const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y; + const size_t start_row = my_warp + rows_per_block * blockIdx.y; + const size_t end_row = min(outer, static_cast(rows_per_block * (blockIdx.y + 1))); + if (offset + thread_idx_in_warp < stride) { + for (size_t i = start_row; i < end_row; i += num_warps) { + const index_t idx = i * stride + offset + thread_idx_in_warp; + vec_gradOutput.aligned = aligned_gradOutput[idx]; + vec_input.aligned = aligned_input[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + sum_beta[j] += static_cast(vec_gradOutput.separate[j]); + sum_gamma[j] += static_cast(vec_gradOutput.separate[j]) * + (static_cast(vec_input.separate[j]) - mean[j]); + } + if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) { + // Gradient to input +#pragma unroll + for (int j = 0; j < nvec; ++j) { + vec_gradOutput.separate[j] *= invstd[j] * gamma[j]; + } + if (flags & ADDTO_DATA_FLAG) { + vec_input.aligned = gradInput_aligned[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + vec_gradOutput.separate[j] += vec_input.separate[j]; + } + } + gradInput_aligned[idx] = vec_gradOutput.aligned; + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + my_values_gamma[i] = sum_gamma[i]; + my_values_beta[i] = sum_beta[i]; + } + + __syncthreads(); + + for (int i = num_warps / 2; i > 0; i /= 2) { + if (my_warp < i) { + const int shared_offset = nvec * i * warp_size; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + my_values_gamma[j] += my_values_gamma[j + shared_offset]; + my_values_beta[j] += my_values_beta[j + shared_offset]; + } + } + __syncthreads(); + } + + if (threadIdx.x < min(warp_size * nvec, + static_cast(num_channels - nvec * offset))) { + const size_t offset_out = nvec * offset + + blockIdx.y * num_channels; + const size_t offset_beta = gridDim.y * num_channels; + temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x]; + temp_space[offset_beta + offset_out + threadIdx.x] = scratch[NTHREADS * nvec + threadIdx.x]; + } +} + +template +__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType * temp_space, + const AType * runningVar, + AType * out_gamma, + AType * out_beta, + int lead_dim, int n_blocks, + AType epsilon, uint32_t flags) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < lead_dim) { + AType sum_gamma = 0; + AType sum_beta = 0; + for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) { + sum_gamma += temp_space[i]; + sum_beta += temp_space[i + lead_dim * n_blocks]; + } + if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + const AType invstd = variance_to_invstd(runningVar[tid], epsilon); + if (flags & WRITE_GAMMA_FLAG) { + out_gamma[tid] = sum_gamma * invstd; + } else { + out_gamma[tid] += sum_gamma * invstd; + } + } else { + if (flags & WRITE_GAMMA_FLAG) { + out_gamma[tid] = 0; + } + } + } + if (flags & WRITE_BETA_FLAG) { + out_beta[tid] = sum_beta; + } else if (flags & ADDTO_BETA_FLAG) { + out_beta[tid] += sum_beta; + } + } +} + +template +__global__ void FrozenBatchNormalizationBackwardKernel( + const DType* input, + const DType* gradOutput, + DType* gradInput, + AType* gradWeight, + AType* gradBias, + const AType* weight, + const AType* runningMean, + const AType* runningVar, + const index_t outer, + const index_t inner, + const index_t num_channels, + const index_t NHW_div_nvec, + const AType eps, + const uint32_t flags) { + const index_t my_channel = blockIdx.x; + const AType invstd = variance_to_invstd(runningVar[my_channel], eps); + const AType mean = runningMean[my_channel]; + const AType gamma = weight != nullptr ? weight[my_channel] : 1; + constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) + : 1; + union vectorized_loader { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} + }; + + vectorized_loader vec_input, vec_gradOutput; + + const LType* input_aligned = reinterpret_cast(input); + const LType* gradOutput_aligned = reinterpret_cast(gradOutput); + LType* gradInput_aligned = reinterpret_cast(gradInput); + + const index_t inner_div_nvec = inner / nvec; + + AType sum_gamma = 0; + AType sum_beta = 0; + + + for (index_t i = threadIdx.x; i < NHW_div_nvec; i += blockDim.x) { + const index_t inner_idx = i % inner_div_nvec; + const index_t outer_idx = i / inner_div_nvec; + const index_t idx = inner_idx + + (my_channel + outer_idx * num_channels) * inner_div_nvec; + vec_gradOutput.aligned = gradOutput_aligned[idx]; + vec_input.aligned = input_aligned[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + sum_beta += static_cast(vec_gradOutput.separate[j]); + sum_gamma += static_cast(vec_gradOutput.separate[j]) * + (static_cast(vec_input.separate[j]) - mean); + } + + if (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) { + // Gradient to input +#pragma unroll + for (int j = 0; j < nvec; ++j) { + vec_gradOutput.separate[j] *= invstd * gamma; + } + if (flags & ADDTO_DATA_FLAG) { + vec_input.aligned = gradInput_aligned[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + vec_gradOutput.separate[j] += vec_input.separate[j]; + } + } + gradInput_aligned[idx] = vec_gradOutput.aligned; + } + } + + sum_gamma = common::cuda::reduce(sum_gamma, + [](AType a, AType b) { return a + b; }); + sum_beta = common::cuda::reduce(sum_beta, + [](AType a, AType b) { return a + b; }); + + if (threadIdx.x == 0) { + if (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + if (flags & WRITE_GAMMA_FLAG) { + gradWeight[my_channel] = sum_gamma * invstd; + } else { + gradWeight[my_channel] += sum_gamma * invstd; + } + } else { + if (flags & WRITE_GAMMA_FLAG) { + gradWeight[my_channel] = 0; + } + } + } + if (flags & WRITE_BETA_FLAG) { + gradBias[my_channel] = sum_beta; + } else if (flags & ADDTO_BETA_FLAG) { + gradBias[my_channel] += sum_beta; + } + } +} + template static __global__ void BatchNormalizationBackwardKernel( const DeviceTensor input, @@ -320,21 +652,13 @@ static __global__ void BatchNormalizationBackwardKernel( CUDATensors tensors, const uint32_t flags, const AccReal momentum, - const double eps) { + const AccReal eps) { int plane = blockIdx.x; int N = gradOutput.OuterSize() * gradOutput.InnerSize(); - const bool is_train_and_not_global_stats = - (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0; - AccReal mean, invstd; - if (is_train_and_not_global_stats) { - mean = ScalarConvert::to(tensors.saveMean[plane]); - invstd = tensors.saveInvStd[plane]; - } else { - mean = ScalarConvert::to(tensors.runningMean[plane]); - invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps); - } + mean = ScalarConvert::to(tensors.saveMean[plane]); + invstd = tensors.saveInvStd[plane]; const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? ScalarConvert::to(tensors.weight[plane]) : AccReal(1); @@ -353,8 +677,8 @@ static __global__ void BatchNormalizationBackwardKernel( const AccReal projScale = dotP * norm * invstd * invstd; const AccReal gradScale = invstd * weightVal; - if (threadIdx.x == 0 && is_train_and_not_global_stats) { - const AccReal localVariance = INVSTD_TO_VARIANCE(tensors.saveInvStd[plane], eps); + if (threadIdx.x == 0) { + const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps); const AccReal localMean = tensors.saveMean[plane]; // update running averages @@ -370,15 +694,10 @@ static __global__ void BatchNormalizationBackwardKernel( for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); - } + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); } } } else { @@ -386,15 +705,10 @@ static __global__ void BatchNormalizationBackwardKernel( for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) += - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) += ScalarConvert::to( - gradOut * gradScale); - } + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) += + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); } } } @@ -537,13 +851,35 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, DCHECK_GT(weight.numElements(), 0); if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) { - dim3 blocks(input.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize())); - BatchNormalizationUpdateOutputInferenceKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, output, runningMean, runningVar, saveMean, - saveInvStd, weight, bias, eps, flags); + AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr; + AccReal* gamma_ptr = weight.numElements() > 0 ? weight.dptr_ : nullptr; + int nvec = sizeof(double) / sizeof(DType); + index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); + index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; + index_t blocks = std::min((size + nvec * inference_forward_threads - 1) / + (nvec * inference_forward_threads), + static_cast(512)); + if (input.ChannelCount() < shmem_elements) { + BatchNormalizationUpdateOutputInferenceKernel + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); + } else { + BatchNormalizationUpdateOutputInferenceKernel + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); + } } else { dim3 blocks(input.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize())); @@ -588,11 +924,86 @@ static void BatchNormalizationBackward(mshadow::Stream *s, tensors.saveInvStd = devicetensor(out_data[batchnorm::kVar]); DCHECK_GT(tensors.weight.numElements(), 0); - dim3 blocks(gradOutput.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize())); - BatchNormalizationBackwardKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, gradOutput, gradInput, tensors, flags, momentum, eps); + const bool is_train_and_not_global_stats = + (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0; + + if (is_train_and_not_global_stats) { +#ifdef NDEBUG + constexpr bool SMALLER_THREADS = false; +#else + constexpr bool SMALLER_THREADS = true; +#endif + dim3 blocks(gradOutput.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize())); + BatchNormalizationBackwardKernel> + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input, gradOutput, gradInput, tensors, flags, momentum, eps); + } else { + uint32_t flags_copy = flags; + if (gradInput.Size() <= 0) { + flags_copy = (flags_copy & ~WRITE_DATA_FLAG); + } + if (tensors.gradWeight.numElements() <= 0) { + flags_copy = (flags_copy & ~WRITE_GAMMA_FLAG); + } + if (tensors.gradBias.numElements() <= 0) { + flags_copy = (flags_copy & ~WRITE_BETA_FLAG); + } + AccReal* gamma = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) + ? tensors.weight.dptr_ + : nullptr; + + if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { + const int C = gradOutput.ChannelCount(); + int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType)); + const int M = gradOutput.OuterSize(); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + const unsigned int blocks_x = ceil_div(C * sizeof(DType), + mxnet::common::cuda::warp_size * sizeof(LType)); + const unsigned int preferred_number_of_blocks = 2 * + MultiprocessorCount(ctx.run_ctx.ctx.dev_id); + const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u); + const dim3 n_blocks = {blocks_x, blocks_y, 1}; + auto scratch_space = ctx.requested[batchnorm::kTempSpace] + .get_space_typed(mshadow::Shape1(C * blocks_y * 2), + s); + auto stream = mshadow::Stream::GetStream(s); + constexpr int nthreads_phase1 = 512; + constexpr int nthreads_phase2 = 128; + FrozenBatchNormalizationBackwardKernelCLastPhase1 + <<>>(input.dptr_, gradOutput.dptr_, + scratch_space.dptr_, + gradInput.dptr_, + gamma, + tensors.runningMean.dptr_, + tensors.runningVar.dptr_, + M, C, eps, flags_copy); + const int nblocks_phase2 = ceil_div(C, nthreads_phase2); + FrozenBatchNormalizationBackwardKernelCLastPhase2 + <<>>(scratch_space.dptr_, + tensors.runningVar.dptr_, + tensors.gradWeight.dptr_, + tensors.gradBias.dptr_, C, + blocks_y, eps, flags_copy); + }); + } else { + dim3 blocks(gradOutput.ChannelCount()); + int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const index_t NHW_div_nvec = gradOutput.OuterSize() * gradOutput.InnerSize() / nvec; + constexpr int threads = 512; + FrozenBatchNormalizationBackwardKernel + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input.dptr_, gradOutput.dptr_, gradInput.dptr_, + tensors.gradWeight.dptr_, tensors.gradBias.dptr_, + gamma, tensors.runningMean.dptr_, + tensors.runningVar.dptr_, + gradOutput.OuterSize(), gradOutput.InnerSize(), + gradOutput.ChannelCount(), NHW_div_nvec, eps, flags_copy); + }); + } + } MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward); } diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 9a67e82b8c06..ee27006c9876 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -350,7 +350,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return ::max(x, y); }); scratch[threadIdx.x] = my_value; } @@ -374,7 +374,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y;}); scratch[threadIdx.x] = my_value; } @@ -488,7 +488,7 @@ __global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y; }); scratch[threadIdx.x] = my_value; }