From 4cd7049fd4d58562077de8315bf380f908d7f670 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 4 Dec 2019 16:47:37 -0800 Subject: [PATCH 01/19] Better frozen batchnorm --- src/operator/nn/batch_norm.cu | 172 ++++++++++++++++++++++++++++------ 1 file changed, 143 insertions(+), 29 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index be9309c8bfb1..6e3ab0381857 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -44,9 +44,30 @@ using namespace mxnet; +namespace { + /*! \brief inverse standard deviation <-> variance */ -#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0/sqrt((__var$) + DType(__eps$))) -#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$)) +template +MSHADOW_XINLINE AccReal variance_to_invstd(DType var, AccReal eps) { + return rsqrtf(static_cast(var) + eps); +} + +template <> +MSHADOW_XINLINE double variance_to_invstd(double var, double eps) { + return rsqrt(var + eps); +} + +template +MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) { + return 1.0f / (invstd * invstd) - eps; +} + +template <> +MSHADOW_XINLINE double invstd_to_variance(double invstd, double eps) { + return 1.0 / (invstd * invstd) - eps; +} + +} // namespace namespace mxnet { namespace op { @@ -204,32 +225,94 @@ static __device__ T reduce(Op op, DeviceTensor tensor, int plane) { return shared[0]; } -template +namespace { + constexpr int inference_forward_threads = 512; +} // namespace + +template +__global__ void BatchNormalizationUpdateOutputInferenceKernel2( + const DType* input, + DType* output, + const index_t size, + const index_t outer_size, + const index_t num_channels, + const index_t inner_size, + const AType* runningMean, + const AType* runningVar, + AType* saveMean, + AType* saveInvStd, + AType* weight, + AType* bias, + const AType epsilon, + const uint32_t flags) { + constexpr int nvec = sizeof(LType) / sizeof(DType); + __shared__ union { + LType aligned[inference_forward_threads]; + DType separate[nvec * inference_forward_threads]; + } scratch; + + const index_t tid = threadIdx.x + blockIdx.x * blockDim.x; + const index_t stride = blockDim.x * gridDim.x; + const LType* input_aligned = reinterpret_cast(input); + LType* output_aligned = reinterpret_cast(output); + DType* my_scratch = scratch.separate + nvec * threadIdx.x; + for (index_t i = tid; i < size / nvec; i += stride) { + scratch.aligned[threadIdx.x] = input_aligned[i]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + const index_t my_channel = ((nvec * i + j) / inner_size) % num_channels; + AType current_input = static_cast(my_scratch[j]); + + AType invstd = variance_to_invstd(runningVar[my_channel], epsilon); + AType mean = static_cast(runningMean[my_channel]); + AType gamma = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? static_cast(weight[my_channel]) + : 1; + AType beta = (bias != nullptr) ? static_cast(bias[my_channel]) + : 0; + current_input = gamma * (current_input - mean) * invstd + beta; + my_scratch[j] = current_input; + } + + output_aligned[i] = scratch.aligned[threadIdx.x]; + } + + if (tid < num_channels) { + saveMean[tid] = runningMean[tid]; + saveInvStd[tid] = variance_to_invstd(runningVar[tid], epsilon); + if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 + && weight != nullptr) { + weight[tid] = 1; + } + } +} + +template __global__ void BatchNormalizationUpdateOutputInferenceKernel( DeviceTensor input, DeviceTensor output, - DeviceTensor1 runningMean, - DeviceTensor1 runningVar, - DeviceTensor1 saveMean, - DeviceTensor1 saveInvStd, - DeviceTensor1 weight, - DeviceTensor1 bias, - const DType epsilon, + AccReal* runningMean, + AccReal* runningVar, + AccReal* saveMean, + AccReal* saveInvStd, + AccReal* weight, + AccReal* bias, + const AccReal epsilon, const uint32_t flags) { int plane = blockIdx.x; - AccReal invstd = VARIANCE_TO_INVSTD(runningVar[plane], epsilon); - AccReal mean = ScalarConvert::to(runningMean[plane]); - AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) - ? ScalarConvert::to(weight[plane]) - : ScalarConvert::to(1); - AccReal beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) - : ScalarConvert::to(0); + AccReal invstd = variance_to_invstd(runningVar[plane], epsilon); + AccReal mean = static_cast(runningMean[plane]); + AccReal gamma = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? static_cast(weight[plane]) + : 1; + AccReal beta = (bias != nullptr) ? static_cast(bias[plane]) + : 0; if (threadIdx.x == 0) { saveMean[plane] = runningMean[plane]; - saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon); + saveInvStd[plane] = variance_to_invstd(runningVar[plane], epsilon); if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 - && weight.numElements() > 0) { + && weight != nullptr) { weight[plane] = AccReal(1); } } @@ -318,7 +401,7 @@ static __global__ void BatchNormalizationBackwardKernel( CUDATensors tensors, const uint32_t flags, const AccReal momentum, - const double eps) { + const AccReal eps) { int plane = blockIdx.x; int N = gradOutput.OuterSize() * gradOutput.InnerSize(); @@ -331,7 +414,7 @@ static __global__ void BatchNormalizationBackwardKernel( invstd = tensors.saveInvStd[plane]; } else { mean = ScalarConvert::to(tensors.runningMean[plane]); - invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps); + invstd = variance_to_invstd(tensors.runningVar[plane], eps); } const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? @@ -352,7 +435,7 @@ static __global__ void BatchNormalizationBackwardKernel( const AccReal gradScale = invstd * weightVal; if (threadIdx.x == 0 && is_train_and_not_global_stats) { - const AccReal localVariance = INVSTD_TO_VARIANCE(tensors.saveInvStd[plane], eps); + const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps); const AccReal localMean = tensors.saveMean[plane]; // update running averages @@ -508,13 +591,44 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, DCHECK_GT(weight.numElements(), 0); if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) { - dim3 blocks(input.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); - BatchNormalizationUpdateOutputInferenceKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, output, runningMean, runningVar, saveMean, - saveInvStd, weight, bias, eps, flags); + AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr; + AccReal* gamma_ptr = ((flags & FIX_GAMMA_FLAG) == 0) && weight.numElements() == 0 ? + weight.dptr_ : + nullptr; + if (dmlc::GetEnv("BN_DEBUG", 0)) { + int nvec = sizeof(double) / sizeof(DType); + index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); + index_t blocks = (size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads); + BatchNormalizationUpdateOutputInferenceKernel2 + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); + } else { + if (param.axis == -1 || + param.axis == input.shape_.ndim() - 1) { + std::cout << "NHWC BN" << std::endl; + dim3 blocks(input.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); + BatchNormalizationUpdateOutputInferenceKernel> + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input, output, runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, + saveInvStd.dptr_, gamma_ptr, bias_ptr, eps, flags); + } else { + dim3 blocks(input.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); + BatchNormalizationUpdateOutputInferenceKernel> + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input, output, runningMean, runningVar, saveMean, + saveInvStd, gamma_ptr, bias_ptr, eps, flags); + } + } } else { dim3 blocks(input.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); From ece922677648391217b7c298f656c500c7b30ab3 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Dec 2019 11:08:14 -0800 Subject: [PATCH 02/19] Continue FreezeBN --- src/operator/nn/batch_norm.cu | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 6e3ab0381857..e80dfe8976e6 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -246,9 +246,12 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel2( const AType epsilon, const uint32_t flags) { constexpr int nvec = sizeof(LType) / sizeof(DType); - __shared__ union { + __shared__ union scratch { LType aligned[inference_forward_threads]; DType separate[nvec * inference_forward_threads]; + + scratch() {} + ~scratch() {} } scratch; const index_t tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -598,11 +601,12 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, if (dmlc::GetEnv("BN_DEBUG", 0)) { int nvec = sizeof(double) / sizeof(DType); index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); + index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; index_t blocks = (size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads); BatchNormalizationUpdateOutputInferenceKernel2 <<::GetStream(s)>>>( input.dptr_, output.dptr_, - size, input.OuterSize(), + aligned_size, input.OuterSize(), input.ChannelCount(), input.InnerSize(), runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, saveInvStd.dptr_, @@ -610,7 +614,7 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, eps, flags); } else { if (param.axis == -1 || - param.axis == input.shape_.ndim() - 1) { + param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { std::cout << "NHWC BN" << std::endl; dim3 blocks(input.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); @@ -618,15 +622,17 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, batchnorm::BNTensor3> <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( input, output, runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, - saveInvStd.dptr_, gamma_ptr, bias_ptr, eps, flags); + saveInvStd.dptr_, gamma_ptr, bias_ptr, + static_cast(eps), flags); } else { dim3 blocks(input.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); - BatchNormalizationUpdateOutputInferenceKernel> <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, output, runningMean, runningVar, saveMean, - saveInvStd, gamma_ptr, bias_ptr, eps, flags); + input, output, runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, + saveInvStd.dptr_, gamma_ptr, bias_ptr, + static_cast(eps), flags); } } } else { From 78d5d1f86576e7cc104b29f197be50b184ec54d7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Dec 2019 13:43:01 -0800 Subject: [PATCH 03/19] Optimizations --- Makefile | 2 + src/operator/nn/batch_norm.cu | 91 ++++++++++++++++++++++++----------- 2 files changed, 64 insertions(+), 29 deletions(-) diff --git a/Makefile b/Makefile index 16d7d2393736..6e49f4d01f3d 100644 --- a/Makefile +++ b/Makefile @@ -135,6 +135,8 @@ else NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) endif +NVCCFLAGS += --ptxas-options=-v + # CFLAGS for segfault logger ifeq ($(USE_SIGNAL_HANDLER), 1) CFLAGS += -DMXNET_USE_SIGNAL_HANDLER=1 diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index e80dfe8976e6..4b48f5eae079 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -227,9 +227,11 @@ static __device__ T reduce(Op op, DeviceTensor tensor, int plane) { namespace { constexpr int inference_forward_threads = 512; + constexpr int shmem_elements = 1536; } // namespace -template +template +__launch_bounds__(inference_forward_threads) __global__ void BatchNormalizationUpdateOutputInferenceKernel2( const DType* input, DType* output, @@ -246,38 +248,57 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel2( const AType epsilon, const uint32_t flags) { constexpr int nvec = sizeof(LType) / sizeof(DType); - __shared__ union scratch { - LType aligned[inference_forward_threads]; - DType separate[nvec * inference_forward_threads]; - - scratch() {} - ~scratch() {} + __shared__ AType saved_invstd[shmem_elements]; + __shared__ AType saved_mean[shmem_elements]; + __shared__ AType saved_weight[shmem_elements]; + __shared__ AType saved_bias[shmem_elements]; + union scratch { + LType aligned; + DType separate[nvec]; + + __device__ inline scratch() {} + __device__ inline ~scratch() {} } scratch; + if (small_num_channels) { + for (int i = threadIdx.x; i < num_channels; i += blockDim.x) { + saved_invstd[i] = variance_to_invstd(runningVar[i], epsilon); + saved_mean[i] = runningMean[i]; + saved_weight[i] = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? weight[i] + : 1; + saved_bias[i] = (bias != nullptr) ? bias[i] : 0; + } + __syncthreads(); + } + const index_t tid = threadIdx.x + blockIdx.x * blockDim.x; const index_t stride = blockDim.x * gridDim.x; const LType* input_aligned = reinterpret_cast(input); LType* output_aligned = reinterpret_cast(output); - DType* my_scratch = scratch.separate + nvec * threadIdx.x; for (index_t i = tid; i < size / nvec; i += stride) { - scratch.aligned[threadIdx.x] = input_aligned[i]; + scratch.aligned = input_aligned[i]; #pragma unroll for (int j = 0; j < nvec; ++j) { const index_t my_channel = ((nvec * i + j) / inner_size) % num_channels; - AType current_input = static_cast(my_scratch[j]); - - AType invstd = variance_to_invstd(runningVar[my_channel], epsilon); - AType mean = static_cast(runningMean[my_channel]); - AType gamma = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) - ? static_cast(weight[my_channel]) - : 1; - AType beta = (bias != nullptr) ? static_cast(bias[my_channel]) - : 0; + AType current_input = static_cast(scratch.separate[j]); + + AType invstd = small_num_channels ? saved_invstd[my_channel] + : variance_to_invstd(runningVar[my_channel], epsilon); + AType mean = small_num_channels ? saved_mean[my_channel] + : runningMean[my_channel]; + AType gamma = small_num_channels ? saved_weight[my_channel] + : ((weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) + ? weight[my_channel] + : 1); + AType beta = small_num_channels ? saved_bias[my_channel] + : ((bias != nullptr) ? bias[my_channel] + : 0); current_input = gamma * (current_input - mean) * invstd + beta; - my_scratch[j] = current_input; + scratch.separate[j] = current_input; } - output_aligned[i] = scratch.aligned[threadIdx.x]; + output_aligned[i] = scratch.aligned; } if (tid < num_channels) { @@ -603,15 +624,27 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; index_t blocks = (size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads); - BatchNormalizationUpdateOutputInferenceKernel2 - <<::GetStream(s)>>>( - input.dptr_, output.dptr_, - aligned_size, input.OuterSize(), - input.ChannelCount(), input.InnerSize(), - runningMean.dptr_, runningVar.dptr_, - saveMean.dptr_, saveInvStd.dptr_, - gamma_ptr, bias_ptr, - eps, flags); + if (input.ChannelCount() < shmem_elements) { + BatchNormalizationUpdateOutputInferenceKernel2 + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); + } else { + BatchNormalizationUpdateOutputInferenceKernel2 + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); + } } else { if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { From 36edffde1ad445efc397a5e28a1f8c84cae34062 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Dec 2019 14:47:49 -0800 Subject: [PATCH 04/19] Reduce number of mod operations --- src/operator/nn/batch_norm.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 4b48f5eae079..95ce95ffd3b8 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -278,9 +278,11 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel2( LType* output_aligned = reinterpret_cast(output); for (index_t i = tid; i < size / nvec; i += stride) { scratch.aligned = input_aligned[i]; + const index_t my_channel_base = (nvec * i) % (inner_size * num_channels); #pragma unroll for (int j = 0; j < nvec; ++j) { - const index_t my_channel = ((nvec * i + j) / inner_size) % num_channels; + index_t my_channel = (my_channel_base + j) / inner_size; + if (my_channel >= num_channels) my_channel = my_channel % num_channels; AType current_input = static_cast(scratch.separate[j]); AType invstd = small_num_channels ? saved_invstd[my_channel] @@ -616,9 +618,9 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) { AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr; - AccReal* gamma_ptr = ((flags & FIX_GAMMA_FLAG) == 0) && weight.numElements() == 0 ? - weight.dptr_ : - nullptr; + AccReal* gamma_ptr = weight.numElements() == 0 ? + weight.dptr_ : + nullptr; if (dmlc::GetEnv("BN_DEBUG", 0)) { int nvec = sizeof(double) / sizeof(DType); index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); From 3240ff30255a87f04041eee87f8d8c329dea9899 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Dec 2019 15:31:52 -0800 Subject: [PATCH 05/19] Cleaning --- src/operator/nn/batch_norm.cu | 125 +++++++++------------------------- 1 file changed, 31 insertions(+), 94 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 95ce95ffd3b8..872aabe96c22 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -232,7 +232,7 @@ namespace { template __launch_bounds__(inference_forward_threads) -__global__ void BatchNormalizationUpdateOutputInferenceKernel2( +__global__ void BatchNormalizationUpdateOutputInferenceKernel( const DType* input, DType* output, const index_t size, @@ -301,53 +301,14 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel2( } output_aligned[i] = scratch.aligned; - } - - if (tid < num_channels) { - saveMean[tid] = runningMean[tid]; - saveInvStd[tid] = variance_to_invstd(runningVar[tid], epsilon); - if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 - && weight != nullptr) { - weight[tid] = 1; - } - } -} - -template -__global__ void BatchNormalizationUpdateOutputInferenceKernel( - DeviceTensor input, - DeviceTensor output, - AccReal* runningMean, - AccReal* runningVar, - AccReal* saveMean, - AccReal* saveInvStd, - AccReal* weight, - AccReal* bias, - const AccReal epsilon, - const uint32_t flags) { - int plane = blockIdx.x; - AccReal invstd = variance_to_invstd(runningVar[plane], epsilon); - AccReal mean = static_cast(runningMean[plane]); - AccReal gamma = (weight != nullptr && (flags & FIX_GAMMA_FLAG) == 0) - ? static_cast(weight[plane]) - : 1; - AccReal beta = (bias != nullptr) ? static_cast(bias[plane]) - : 0; - if (threadIdx.x == 0) { - saveMean[plane] = runningMean[plane]; - saveInvStd[plane] = variance_to_invstd(runningVar[plane], epsilon); - if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 - && weight != nullptr) { - weight[plane] = AccReal(1); - } - } - // Write normalized and update the output - for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) { - const DType inp = input.get_ref(batch, plane, x); - output.get_ref(batch, plane, x) = - ScalarConvert::to(gamma * (inp - mean) * invstd + beta); + if (i < num_channels) { + saveMean[i] = runningMean[i]; + saveInvStd[i] = variance_to_invstd(runningVar[i], epsilon); + if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0 + && weight != nullptr) { + weight[i] = 1; + } } } } @@ -621,54 +582,30 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, AccReal* gamma_ptr = weight.numElements() == 0 ? weight.dptr_ : nullptr; - if (dmlc::GetEnv("BN_DEBUG", 0)) { - int nvec = sizeof(double) / sizeof(DType); - index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); - index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; - index_t blocks = (size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads); - if (input.ChannelCount() < shmem_elements) { - BatchNormalizationUpdateOutputInferenceKernel2 - <<::GetStream(s)>>>( - input.dptr_, output.dptr_, - aligned_size, input.OuterSize(), - input.ChannelCount(), input.InnerSize(), - runningMean.dptr_, runningVar.dptr_, - saveMean.dptr_, saveInvStd.dptr_, - gamma_ptr, bias_ptr, - eps, flags); - } else { - BatchNormalizationUpdateOutputInferenceKernel2 - <<::GetStream(s)>>>( - input.dptr_, output.dptr_, - aligned_size, input.OuterSize(), - input.ChannelCount(), input.InnerSize(), - runningMean.dptr_, runningVar.dptr_, - saveMean.dptr_, saveInvStd.dptr_, - gamma_ptr, bias_ptr, - eps, flags); - } - } else { - if (param.axis == -1 || - param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { - std::cout << "NHWC BN" << std::endl; - dim3 blocks(input.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); - BatchNormalizationUpdateOutputInferenceKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, output, runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, - saveInvStd.dptr_, gamma_ptr, bias_ptr, - static_cast(eps), flags); + int nvec = sizeof(double) / sizeof(DType); + index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); + index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; + index_t blocks = std::min((size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads), static_cast(512)); + if (input.ChannelCount() < shmem_elements) { + BatchNormalizationUpdateOutputInferenceKernel + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); } else { - dim3 blocks(input.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(input.InnerSize(), false)); - BatchNormalizationUpdateOutputInferenceKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, output, runningMean.dptr_, runningVar.dptr_, saveMean.dptr_, - saveInvStd.dptr_, gamma_ptr, bias_ptr, - static_cast(eps), flags); - } + BatchNormalizationUpdateOutputInferenceKernel + <<::GetStream(s)>>>( + input.dptr_, output.dptr_, + aligned_size, input.OuterSize(), + input.ChannelCount(), input.InnerSize(), + runningMean.dptr_, runningVar.dptr_, + saveMean.dptr_, saveInvStd.dptr_, + gamma_ptr, bias_ptr, + eps, flags); } } else { dim3 blocks(input.ChannelCount()); From ed9d4acf4e366837c35759553e9d70149cccf26c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 10 Dec 2019 15:49:56 -0800 Subject: [PATCH 06/19] Fixing frozen bn with fix_gamma=False --- src/operator/nn/batch_norm.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 872aabe96c22..7ddc488fc0a7 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -579,9 +579,7 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, if ((flags & IS_TRAINING_FLAG) == 0 || (flags & USE_GLOBAL_STATS_FLAG) != 0) { AccReal* bias_ptr = bias.numElements() > 0 ? bias.dptr_ : nullptr; - AccReal* gamma_ptr = weight.numElements() == 0 ? - weight.dptr_ : - nullptr; + AccReal* gamma_ptr = weight.numElements() > 0 ? weight.dptr_ : nullptr; int nvec = sizeof(double) / sizeof(DType); index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; From 96c21ef4c1b4165da2a50271de1e716580beba3b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 10 Dec 2019 16:09:49 -0800 Subject: [PATCH 07/19] Fix lint in BN --- src/operator/nn/batch_norm.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 7ddc488fc0a7..398bfb91bb00 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -254,7 +254,7 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel( __shared__ AType saved_bias[shmem_elements]; union scratch { LType aligned; - DType separate[nvec]; + DType separate[nvec]; // NOLINT(*) __device__ inline scratch() {} __device__ inline ~scratch() {} @@ -583,7 +583,9 @@ static void BatchNormalizationUpdateOutput(mshadow::Stream *s, int nvec = sizeof(double) / sizeof(DType); index_t size = input.InnerSize() * input.OuterSize() * input.ChannelCount(); index_t aligned_size = ((size + nvec - 1) / nvec) * nvec; - index_t blocks = std::min((size + nvec * inference_forward_threads - 1) / (nvec * inference_forward_threads), static_cast(512)); + index_t blocks = std::min((size + nvec * inference_forward_threads - 1) / + (nvec * inference_forward_threads), + static_cast(512)); if (input.ChannelCount() < shmem_elements) { BatchNormalizationUpdateOutputInferenceKernel <<::GetStream(s)>>>( From c1748947a5ddf20a30f788cea14046b5fe1fb468 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Dec 2019 12:40:40 -0800 Subject: [PATCH 08/19] Backward frozen batchnorm --- src/operator/nn/batch_norm.cu | 127 ++++++++++++++++++++++++++-------- 1 file changed, 100 insertions(+), 27 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 398bfb91bb00..4fd68a5953d8 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -380,6 +380,75 @@ struct CUDATensors { DeviceTensor1 saveInvStd; }; +template +static __global__ void FrozenBatchNormalizationBackwardKernel( + const DType* input, + const DType* gradOutput, + DType* gradInput, + CUDATensors tensors, + const uint32_t flags, + const AccReal momentum, + const AccReal eps, + const int splitk, + const index_t num_channels, + const int channels_per_block, + const index_t outer_dim, + const index_t inner_dim, + const index_t NHW) { + + const index_t start_channel = (blockIdx.x / splitk) * channels_per_block; + + + + + + int plane = blockIdx.x; + int N = gradOutput.OuterSize() * gradOutput.InnerSize(); + + AccReal mean, invstd; + mean = ScalarConvert::to(tensors.runningMean[plane]); + invstd = variance_to_invstd(tensors.runningVar[plane], eps); + + const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? + ScalarConvert::to(tensors.weight[plane]) : AccReal(1); + const AccReal norm = AccReal(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(gradOutput) + // 2. DotProduct(input - mean, gradOutput) + GradOp g(mean, input, gradOutput); + Float2< DType, AccReal > res = reduce < Float2 < DType, AccReal >, + GradOp< DType, AccReal, DeviceTensor >, DeviceTensor > (g, gradOutput, plane); + const AccReal gradOutputSum = res.v1; + const AccReal dotP = res.v2; + + const AccReal gradMean = gradOutputSum * norm; + const AccReal projScale = dotP * norm * invstd * invstd; + const AccReal gradScale = invstd * weightVal; + + if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); + } + } + } + + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + } else { + tensors.gradWeight[plane] = DType(0); + } + } + + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + } +} + template static __global__ void BatchNormalizationBackwardKernel( const DeviceTensor input, @@ -392,17 +461,9 @@ static __global__ void BatchNormalizationBackwardKernel( int plane = blockIdx.x; int N = gradOutput.OuterSize() * gradOutput.InnerSize(); - const bool is_train_and_not_global_stats = - (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0; - AccReal mean, invstd; - if (is_train_and_not_global_stats) { - mean = ScalarConvert::to(tensors.saveMean[plane]); - invstd = tensors.saveInvStd[plane]; - } else { - mean = ScalarConvert::to(tensors.runningMean[plane]); - invstd = variance_to_invstd(tensors.runningVar[plane], eps); - } + mean = ScalarConvert::to(tensors.saveMean[plane]); + invstd = tensors.saveInvStd[plane]; const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? ScalarConvert::to(tensors.weight[plane]) : AccReal(1); @@ -421,7 +482,7 @@ static __global__ void BatchNormalizationBackwardKernel( const AccReal projScale = dotP * norm * invstd * invstd; const AccReal gradScale = invstd * weightVal; - if (threadIdx.x == 0 && is_train_and_not_global_stats) { + if (threadIdx.x == 0) { const AccReal localVariance = invstd_to_variance(tensors.saveInvStd[plane], eps); const AccReal localMean = tensors.saveMean[plane]; @@ -436,15 +497,10 @@ static __global__ void BatchNormalizationBackwardKernel( for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); - } + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); } } } @@ -651,16 +707,33 @@ static void BatchNormalizationBackward(mshadow::Stream *s, tensors.saveInvStd = devicetensor(out_data[batchnorm::kVar]); DCHECK_GT(tensors.weight.numElements(), 0); + const bool is_train_and_not_global_stats = + (flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0; + + if (is_train_and_not_global_stats) { +#ifdef NDEBUG + constexpr bool SMALLER_THREADS = false; +#else + constexpr bool SMALLER_THREADS = true; +#endif + dim3 blocks(gradOutput.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); + BatchNormalizationBackwardKernel> + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input, gradOutput, gradInput, tensors, flags, momentum, eps); + } else { #ifdef NDEBUG - constexpr bool SMALLER_THREADS = false; + constexpr bool SMALLER_THREADS = false; #else - constexpr bool SMALLER_THREADS = true; + constexpr bool SMALLER_THREADS = true; #endif - dim3 blocks(gradOutput.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); - BatchNormalizationBackwardKernel> - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, gradOutput, gradInput, tensors, flags, momentum, eps); + dim3 blocks(gradOutput.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); + FrozenBatchNormalizationBackwardKernel> + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input, gradOutput, gradInput, tensors, flags, momentum, eps); + } MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward); } From 94adcf3acbde5cd7b14ab2a58997f53387a09a40 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 18 Dec 2019 09:21:24 -0800 Subject: [PATCH 09/19] More work on backward of Frozen BN --- src/operator/nn/batch_norm.cu | 54 +++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 4fd68a5953d8..8ddd31a90a07 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -385,18 +385,66 @@ static __global__ void FrozenBatchNormalizationBackwardKernel( const DType* input, const DType* gradOutput, DType* gradInput, - CUDATensors tensors, + AType* gradWeight, + AType* gradBias, + const AType* weight, + const AType* runningMean, + const AType* runningVar, const uint32_t flags, const AccReal momentum, const AccReal eps, const int splitk, const index_t num_channels, - const int channels_per_block, + const int loads_per_block, const index_t outer_dim, const index_t inner_dim, const index_t NHW) { - const index_t start_channel = (blockIdx.x / splitk) * channels_per_block; + constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const int num_channels_per_thread = inner_dim < nvec ? nvec / inner_dim : 1; + const int num_NHW_per_thread = nvec / num_channels_per_thread; + + const int channels_per_block = loads_per_block * num_channels_per_thread; + const index_t start_channel = (blockIdx.x / splitk) * channels_per_block + (threadIdx.x % loads_per_block) * num_channels_per_thread; + + typedef union { + LType aligned; + DType separate[nvec]; // NOLINT(*) + __device__ inline scratch() {} + __device__ inline ~scratch() {} + } scratch; + + scratch temp_input, temp_grad; + + AType mean[nvec]; + AType grad_sum[nvec]; + AType dotP[nvec]; + +#pragma unroll + for (int v = 0; v < nvec; ++v) { + mean[v] = runningMean[start_channel + v % num_channels_per_thread]; + } + + const LType* aligned_input = reinterpret_cast(input); + const LType* aligned_gradOutput = reinterpret_cast(gradOutput); + + for (int i = threadIdx.x / loads_per_block; i < NHW / splitk; i += blockDim.x / loads_per_block) { + const index_t idx = (i / inner_dim) * (num_channels * inner_dim) + + inner_dim * start_channel + + (i % inner_dim); + const index_t aligned_idx = idx / nvec; + temp_input.aligned = aligned_input[aligned_idx]; + temp_grad.aligned = aligned_gradOutput[aligned_idx]; +#pragma unroll + for (int v = 0; v < nvec; ++v) { + const AType g = static_cast(temp_grad.separate[v]); + const AType inp = static_cast(temp_input[v]); + grad_sum[v] += g; + dotP[v] += (inp - mean[v]) * g; + } + } + + From 0305c12c4ed3617b1386fb005fcefc79a210d56d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 7 Jan 2020 16:25:09 -0800 Subject: [PATCH 10/19] Let it compile --- src/operator/nn/batch_norm.cu | 123 +++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 46 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 8ddd31a90a07..89cdb5ea9675 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -381,10 +381,10 @@ struct CUDATensors { }; template -static __global__ void FrozenBatchNormalizationBackwardKernel( +static __global__ void FrozenBatchNormalizationBackwardKernelCLast( const DType* input, const DType* gradOutput, - DType* gradInput, + DType* gradInput/*, AType* gradWeight, AType* gradBias, const AType* weight, @@ -398,58 +398,76 @@ static __global__ void FrozenBatchNormalizationBackwardKernel( const int loads_per_block, const index_t outer_dim, const index_t inner_dim, - const index_t NHW) { - - constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; - const int num_channels_per_thread = inner_dim < nvec ? nvec / inner_dim : 1; - const int num_NHW_per_thread = nvec / num_channels_per_thread; + const index_t NHW*/) { +#if 0 + int plane = blockIdx.x; + int N = gradOutput.OuterSize() * gradOutput.InnerSize(); - const int channels_per_block = loads_per_block * num_channels_per_thread; - const index_t start_channel = (blockIdx.x / splitk) * channels_per_block + (threadIdx.x % loads_per_block) * num_channels_per_thread; + AccReal mean, invstd; + mean = ScalarConvert::to(tensors.runningMean[plane]); + invstd = variance_to_invstd(tensors.runningVar[plane], eps); - typedef union { - LType aligned; - DType separate[nvec]; // NOLINT(*) - __device__ inline scratch() {} - __device__ inline ~scratch() {} - } scratch; + const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? + ScalarConvert::to(tensors.weight[plane]) : AccReal(1); + const AccReal norm = AccReal(1) / N; - scratch temp_input, temp_grad; + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(gradOutput) + // 2. DotProduct(input - mean, gradOutput) + GradOp g(mean, input, gradOutput); + Float2< DType, AccReal > res = reduce < Float2 < DType, AccReal >, + GradOp< DType, AccReal, DeviceTensor >, DeviceTensor > (g, gradOutput, plane); + const AccReal gradOutputSum = res.v1; + const AccReal dotP = res.v2; - AType mean[nvec]; - AType grad_sum[nvec]; - AType dotP[nvec]; + const AccReal gradMean = gradOutputSum * norm; + const AccReal projScale = dotP * norm * invstd * invstd; + const AccReal gradScale = invstd * weightVal; -#pragma unroll - for (int v = 0; v < nvec; ++v) { - mean[v] = runningMean[start_channel + v % num_channels_per_thread]; + if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); + } + } } - const LType* aligned_input = reinterpret_cast(input); - const LType* aligned_gradOutput = reinterpret_cast(gradOutput); - - for (int i = threadIdx.x / loads_per_block; i < NHW / splitk; i += blockDim.x / loads_per_block) { - const index_t idx = (i / inner_dim) * (num_channels * inner_dim) + - inner_dim * start_channel + - (i % inner_dim); - const index_t aligned_idx = idx / nvec; - temp_input.aligned = aligned_input[aligned_idx]; - temp_grad.aligned = aligned_gradOutput[aligned_idx]; -#pragma unroll - for (int v = 0; v < nvec; ++v) { - const AType g = static_cast(temp_grad.separate[v]); - const AType inp = static_cast(temp_input[v]); - grad_sum[v] += g; - dotP[v] += (inp - mean[v]) * g; + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + } else { + tensors.gradWeight[plane] = DType(0); } } + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + } +#endif +} - - - - - +template +static __global__ void FrozenBatchNormalizationBackwardKernel( + const DType* input, + const DType* gradOutput, + DType* gradInput/*, + AType* gradWeight, + AType* gradBias, + const AType* weight, + const AType* runningMean, + const AType* runningVar, + const uint32_t flags, + const AccReal momentum, + const AccReal eps, + const int splitk, + const index_t num_channels, + const int loads_per_block, + const index_t outer_dim, + const index_t inner_dim, + const index_t NHW*/) { +#if 0 int plane = blockIdx.x; int N = gradOutput.OuterSize() * gradOutput.InnerSize(); @@ -495,6 +513,7 @@ static __global__ void FrozenBatchNormalizationBackwardKernel( if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); } +#endif } template @@ -770,6 +789,7 @@ static void BatchNormalizationBackward(mshadow::Stream *s, <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( input, gradOutput, gradInput, tensors, flags, momentum, eps); } else { + if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { #ifdef NDEBUG constexpr bool SMALLER_THREADS = false; #else @@ -777,10 +797,21 @@ static void BatchNormalizationBackward(mshadow::Stream *s, #endif dim3 blocks(gradOutput.ChannelCount()); dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); - FrozenBatchNormalizationBackwardKernel> + FrozenBatchNormalizationBackwardKernelCLast <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input, gradOutput, gradInput, tensors, flags, momentum, eps); + input.dptr_, gradOutput.dptr_, gradInput.dptr_); + } else { +#ifdef NDEBUG + constexpr bool SMALLER_THREADS = false; +#else + constexpr bool SMALLER_THREADS = true; +#endif + dim3 blocks(gradOutput.ChannelCount()); + dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); + FrozenBatchNormalizationBackwardKernel + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input.dptr_, gradOutput.dptr_, gradInput.dptr_); + } } MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward); } From 3ad5514e0e1207b991ce3b44a2f8105fc1db5721 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 8 Jan 2020 12:58:20 -0800 Subject: [PATCH 11/19] NCHW Frozen BN backward --- src/common/cuda_utils.h | 83 +++++++++++++++--- src/operator/nn/batch_norm.cu | 154 +++++++++++++++++++++------------- src/operator/nn/softmax-inl.h | 6 +- 3 files changed, 172 insertions(+), 71 deletions(-) diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index ccf0931f2480..1c14471f9d2e 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -764,27 +764,86 @@ __device__ inline DType ldg(const DType* address) { #endif } -template +namespace mxnet { +namespace common { +/*! \brief common utils for cuda */ +namespace cuda { + +static constexpr const int warp_size = 32; + +/*! \brief Reduction inside a warp. + * Template parameters: + * NVALUES - number of values to reduce (defaults to warp_size). + * \param value - values to be reduced. + * \param redfun - function used to perform reduction. + */ +template __device__ inline T warp_reduce(T value, OP redfun) { - value = redfun(value, __shfl_down_sync(0xffffffff, value, 16)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 8)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 4)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 2)); - value = redfun(value, __shfl_down_sync(0xffffffff, value, 1)); +#pragma unroll + for (int i = warp_size / 2; i >= 1; i /= 2) { + if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + } return value; } -template +template __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) { float v = static_cast(value); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 16)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 8)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 4)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 2)); - v = redfun(v, __shfl_down_sync(0xffffffff, v, 1)); +#pragma unroll + for (int i = warp_size / 2; i >= 1; i /= 2) { + if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i)); + } return mshadow::half::half_t(v); } +/*! \brief Reduction inside a block, requires all threads in a block to participate. + * It uses a 2 step approach: + * - all warps in a block perform intermediate reduction + * - first warp reduces the intermediate results. + * Template parameters: + * NTHREADS - number of threads in a block. + * all_reduce - whether all threads need the result of the reduction. If set to + * true, then all threads return with the same value. If set to + * false, then only thread 0 has the valid result. Defaults to true. + * \param value - value from each thread to be reduced + * \param redfun - function used to perform reduction + */ +template +__device__ inline T reduce(const T& value, OP redfun) { + static_assert(NTHREADS <= warp_size * warp_size, + "Number of threads too large for reduction"); + __shared__ T scratch[NTHREADS / warp_size]; + const int my_id = threadIdx.x % warp_size; + const int my_warp = threadIdx.x / warp_size; + const T my_val = warp_reduce(value, redfun); + if (my_id == 0) { + scratch[my_warp] = my_val; + } + __syncthreads(); + T ret = 0; + if (my_warp == 0) { + const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0; + const T my_val = warp_reduce(prev_val, redfun); + if (all_reduce) { + scratch[threadIdx.x] = my_val; + } else { + ret = my_val; + } + } + // Necessary to synchronize in order to use this function again + // as the shared memory scratch space is reused between calls + __syncthreads(); + if (all_reduce) { + ret = scratch[0]; + __syncthreads(); + } + return ret; +} + +} // namespace cuda +} // namespace common +} // namespace mxnet + #endif // __CUDACC__ #endif // MXNET_COMMON_CUDA_UTILS_H_ diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 89cdb5ea9675..ac387238ecef 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -27,6 +27,8 @@ #include #include #include "batch_norm-inl.h" +#include "../../common/cuda_utils.h" + #define WRITE_DATA_FLAG 1 #define WRITE_GAMMA_FLAG 2 @@ -448,72 +450,90 @@ static __global__ void FrozenBatchNormalizationBackwardKernelCLast( #endif } -template -static __global__ void FrozenBatchNormalizationBackwardKernel( +template +__global__ void FrozenBatchNormalizationBackwardKernel( const DType* input, const DType* gradOutput, - DType* gradInput/*, + DType* gradInput, AType* gradWeight, AType* gradBias, const AType* weight, const AType* runningMean, const AType* runningVar, - const uint32_t flags, - const AccReal momentum, - const AccReal eps, - const int splitk, + const index_t outer, + const index_t inner, const index_t num_channels, - const int loads_per_block, - const index_t outer_dim, - const index_t inner_dim, - const index_t NHW*/) { -#if 0 - int plane = blockIdx.x; - int N = gradOutput.OuterSize() * gradOutput.InnerSize(); + const index_t NHW_div_nvec, + const AType eps, + const uint32_t flags) { + const index_t my_channel = blockIdx.x; + const AType invstd = variance_to_invstd(runningVar[my_channel], eps); + const AType mean = runningMean[my_channel]; + const AType gamma = weight != nullptr ? weight[my_channel] : 1; + constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) + : 1; + union scratch { + LType aligned; + DType separate[nvec]; // NOLINT(*) - AccReal mean, invstd; - mean = ScalarConvert::to(tensors.runningMean[plane]); - invstd = variance_to_invstd(tensors.runningVar[plane], eps); + __device__ inline scratch() {} + __device__ inline ~scratch() {} + }; - const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? - ScalarConvert::to(tensors.weight[plane]) : AccReal(1); - const AccReal norm = AccReal(1) / N; + scratch scratch_input, scratch_grad; - // Compute two values across (batch, x/y/z) in one pass: - // 1. Sum(gradOutput) - // 2. DotProduct(input - mean, gradOutput) - GradOp g(mean, input, gradOutput); - Float2< DType, AccReal > res = reduce < Float2 < DType, AccReal >, - GradOp< DType, AccReal, DeviceTensor >, DeviceTensor > (g, gradOutput, plane); - const AccReal gradOutputSum = res.v1; - const AccReal dotP = res.v2; + const LType* input_aligned = reinterpret_cast(input); + const LType* gradOutput_aligned = reinterpret_cast(gradOutput); + LType* gradInput_aligned = reinterpret_cast(gradInput); - const AccReal gradMean = gradOutputSum * norm; - const AccReal projScale = dotP * norm * invstd * invstd; - const AccReal gradScale = invstd * weightVal; + const index_t inner_div_nvec = inner / nvec; + + AType sum_gamma = 0; + AType sum_beta = 0; + + + for (index_t i = threadIdx.x; i < NHW_div_nvec; i += blockDim.x) { + const index_t inner_idx = i % inner_div_nvec; + const index_t outer_idx = i / inner_div_nvec; + const index_t idx = inner_idx + + (my_channel + outer_idx * num_channels) * inner_div_nvec; + scratch_grad.aligned = gradOutput_aligned[idx]; + scratch_input.aligned = input_aligned[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + sum_beta += static_cast(scratch_grad.separate[j]); + sum_gamma += static_cast(scratch_grad.separate[j]) * + (static_cast(scratch_input.separate[j]) - mean); - if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); - } } - } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { - if ((flags & FIX_GAMMA_FLAG) == 0) { - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); - } else { - tensors.gradWeight[plane] = DType(0); + if (flags & WRITE_DATA_FLAG) { + // Gradient to input +#pragma unroll + for (int j = 0; j < nvec; ++j) { + scratch_grad.separate[j] *= invstd * gamma; + } + gradInput_aligned[idx] = scratch_grad.aligned; } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + sum_gamma = common::cuda::reduce(sum_gamma, + [](AType a, AType b) { return a + b; }); + sum_beta = common::cuda::reduce(sum_beta, + [](AType a, AType b) { return a + b; }); + + if (threadIdx.x == 0) { + if (flags & WRITE_GAMMA_FLAG) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + gradWeight[my_channel] = sum_gamma * invstd; + } else { + gradWeight[my_channel] = 0; + } + } + if (flags & WRITE_BETA_FLAG) { + gradBias[my_channel] = sum_beta; + } } -#endif } template @@ -789,7 +809,23 @@ static void BatchNormalizationBackward(mshadow::Stream *s, <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( input, gradOutput, gradInput, tensors, flags, momentum, eps); } else { + std::cout <<"Frozen" < 0) + ? tensors.weight.dptr_ + : nullptr; + if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { + std::cout << "NHWC " << std::endl; #ifdef NDEBUG constexpr bool SMALLER_THREADS = false; #else @@ -801,16 +837,22 @@ static void BatchNormalizationBackward(mshadow::Stream *s, <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( input.dptr_, gradOutput.dptr_, gradInput.dptr_); } else { -#ifdef NDEBUG - constexpr bool SMALLER_THREADS = false; -#else - constexpr bool SMALLER_THREADS = true; -#endif + std::cout << "NCHW " << std::endl; dim3 blocks(gradOutput.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); - FrozenBatchNormalizationBackwardKernel - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input.dptr_, gradOutput.dptr_, gradInput.dptr_); + int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const index_t NHW_div_nvec = gradOutput.OuterSize() * gradOutput.InnerSize() / nvec; + constexpr int threads = 512; + FrozenBatchNormalizationBackwardKernel + <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( + input.dptr_, gradOutput.dptr_, gradInput.dptr_, + tensors.gradWeight.dptr_, tensors.gradBias.dptr_, + gamma, tensors.runningMean.dptr_, + tensors.runningVar.dptr_, + gradOutput.OuterSize(), gradOutput.InnerSize(), + gradOutput.ChannelCount(), NHW_div_nvec, eps, flags_copy); + }); } } MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormalizationBackward); diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 2dbdbe170537..4c5ec2d9caaf 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -348,7 +348,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return ::max(x, y); }); scratch[threadIdx.x] = my_value; } @@ -372,7 +372,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y;}); scratch[threadIdx.x] = my_value; } @@ -485,7 +485,7 @@ __global__ void softmax_stride1_grad_kernel(const OType *out, const OType *ograd __syncthreads(); } if (my_id < warp_size) { - AType my_value = warp_reduce(scratch[threadIdx.x], + AType my_value = common::cuda::warp_reduce(scratch[threadIdx.x], [](AType x, AType y) { return x + y; }); scratch[threadIdx.x] = my_value; } From 45b2efa6020ce9024caa61aff220742e3683a7c8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 15 Jan 2020 15:06:45 -0800 Subject: [PATCH 12/19] Frozen BN backward NHWC --- src/operator/nn/batch_norm.cu | 197 +++++++++++++++++++++++++++++++--- 1 file changed, 183 insertions(+), 14 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index ac387238ecef..2f3b466ab988 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -382,6 +382,151 @@ struct CUDATensors { DeviceTensor1 saveInvStd; }; +namespace { + inline int ceil_div(int x, int y) { + return (x + y - 1) / y; + } +} // namespace + +template +__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1( + const DType* input, const DType* gradOutput, AType* temp_space, + DType* gradInput, const AType* weight, const AType* runningMean, + const AType* runningVar, const index_t outer, const index_t num_channels, + const AType eps, const uint32_t flags) { + using mxnet::common::cuda::warp_size; + constexpr int num_warps = NTHREADS / warp_size; + constexpr int nvec = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const size_t stride = num_channels / nvec; + + union vectorized_loader { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} + }; + + vectorized_loader vec_input, vec_gradOutput; + + __shared__ AType scratch[NTHREADS * 2 * nvec]; + AType * my_values_gamma = &(scratch[threadIdx.x * nvec]); + AType * my_values_beta = &(scratch[(NTHREADS + threadIdx.x) * nvec]); + + AType sum_gamma[nvec]; // NOLINT(*) + AType sum_beta[nvec]; // NOLINT(*) +#pragma unroll + for (int i = 0; i < nvec; ++i) { + sum_gamma[i] = 0; + sum_beta[i] = 0; + } + + const size_t offset = blockIdx.x * warp_size; + const int my_warp = threadIdx.x / warp_size; + const int my_id = threadIdx.x % warp_size; + + AType invstd[nvec]; + AType mean[nvec]; + AType gamma[nvec]; + size_t channel_offset = (offset + my_id) * nvec; + + if (channel_offset < num_channels) { +#pragma unroll + for (int i = 0; i < nvec; ++i) { + invstd[i] = variance_to_invstd(runningVar[channel_offset + i], eps); + mean[i] = runningMean[channel_offset + i]; + gamma[i] = weight != nullptr ? weight[channel_offset + i] : 1; + } + } + + const LType* aligned_gradOutput = reinterpret_cast(gradOutput); + const LType* aligned_input = reinterpret_cast(input); + LType* gradInput_aligned = reinterpret_cast(gradInput); + + const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y; + const size_t start_row = my_warp + rows_per_block * blockIdx.y; + const size_t end_row = min(outer, static_cast(rows_per_block * (blockIdx.y + 1))); + if (offset + my_id < stride) { + for (size_t i = start_row; i < end_row; i += num_warps) { + const index_t idx = i * stride + offset + my_id; + vec_gradOutput.aligned = aligned_gradOutput[idx]; + vec_input.aligned = aligned_input[idx]; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + sum_beta[j] += static_cast(vec_gradOutput.separate[j]); + sum_gamma[j] += static_cast(vec_gradOutput.separate[j]) * + (static_cast(vec_input.separate[j]) - mean[j]); + } + if (flags & WRITE_DATA_FLAG) { + // Gradient to input +#pragma unroll + for (int j = 0; j < nvec; ++j) { + vec_gradOutput.separate[j] *= invstd[j] * gamma[j]; + } + gradInput_aligned[idx] = vec_gradOutput.aligned; + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + my_values_gamma[i] = sum_gamma[i]; + my_values_beta[i] = sum_beta[i]; + } + + __syncthreads(); + + for (int i = num_warps / 2; i > 0; i /= 2) { + if (my_warp < i) { + const int shared_offset = nvec * i * warp_size; +#pragma unroll + for (int j = 0; j < nvec; ++j) { + my_values_gamma[j] += my_values_gamma[j + shared_offset]; + my_values_beta[j] += my_values_beta[j + shared_offset]; + } + } + __syncthreads(); + } + + if (threadIdx.x < min(warp_size * nvec, + static_cast(num_channels - nvec * offset))) { + const size_t offset_out = nvec * offset + + blockIdx.y * num_channels; + const size_t offset_beta = gridDim.y * num_channels; + temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x]; + temp_space[offset_beta + offset_out + threadIdx.x] = scratch[NTHREADS * nvec + threadIdx.x]; + } +} + +template +__global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType * temp_space, + const AType * runningVar, + AType * out_gamma, + AType * out_beta, + int lead_dim, int n_blocks, + AType epsilon, uint32_t flags) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < lead_dim) { + AType sum_gamma = 0; + AType sum_beta = 0; + for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) { + sum_gamma += temp_space[i]; + sum_beta += temp_space[i + lead_dim * n_blocks]; + } + if (flags & WRITE_GAMMA_FLAG) { + if ((flags & FIX_GAMMA_FLAG) == 0) { + const AType invstd = variance_to_invstd(runningVar[tid], epsilon); + out_gamma[tid] = sum_gamma * invstd; + } else { + out_gamma[tid] = 0; + } + } + if (flags & WRITE_BETA_FLAG) { + out_beta[tid] = sum_beta; + } + } +} + template static __global__ void FrozenBatchNormalizationBackwardKernelCLast( const DType* input, @@ -472,15 +617,15 @@ __global__ void FrozenBatchNormalizationBackwardKernel( const AType gamma = weight != nullptr ? weight[my_channel] : 1; constexpr int nvec = sizeof(LType) > sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; - union scratch { + union vectorized_loader { LType aligned; DType separate[nvec]; // NOLINT(*) - __device__ inline scratch() {} - __device__ inline ~scratch() {} + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} }; - scratch scratch_input, scratch_grad; + vectorized_loader scratch_input, scratch_grad; const LType* input_aligned = reinterpret_cast(input); const LType* gradOutput_aligned = reinterpret_cast(gradOutput); @@ -810,6 +955,9 @@ static void BatchNormalizationBackward(mshadow::Stream *s, input, gradOutput, gradInput, tensors, flags, momentum, eps); } else { std::cout <<"Frozen" < *s, if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { std::cout << "NHWC " << std::endl; -#ifdef NDEBUG - constexpr bool SMALLER_THREADS = false; -#else - constexpr bool SMALLER_THREADS = true; -#endif - dim3 blocks(gradOutput.ChannelCount()); - dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize(), SMALLER_THREADS)); - FrozenBatchNormalizationBackwardKernelCLast - <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( - input.dptr_, gradOutput.dptr_, gradInput.dptr_); + const int C = gradOutput.ChannelCount(); + int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType)); + const int M = gradOutput.OuterSize(); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + const unsigned int blocks_x = ceil_div(C * sizeof(DType), + mxnet::common::cuda::warp_size * sizeof(LType)); + const unsigned int preferred_number_of_blocks = 2 * + MultiprocessorCount(ctx.run_ctx.ctx.dev_id); + const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u); + const dim3 n_blocks = {blocks_x, blocks_y, 1}; + auto scratch_space = ctx.requested[batchnorm::kTempSpace] + .get_space_typed(mshadow::Shape1(C * blocks_y * 2), s); + auto stream = mshadow::Stream::GetStream(s); + constexpr int nthreads_phase1 = 512; + constexpr int nthreads_phase2 = 128; + FrozenBatchNormalizationBackwardKernelCLastPhase1 + <<>>(input.dptr_, gradOutput.dptr_, + scratch_space.dptr_, + gradInput.dptr_, + gamma, + tensors.runningMean.dptr_, + tensors.runningVar.dptr_, + M, C, eps, flags_copy); + const int nblocks_phase2 = ceil_div(C, nthreads_phase2); + FrozenBatchNormalizationBackwardKernelCLastPhase2 + <<>>(scratch_space.dptr_, + tensors.runningVar.dptr_, + tensors.gradWeight.dptr_, + tensors.gradBias.dptr_, C, + blocks_y, eps, flags_copy); + }); } else { std::cout << "NCHW " << std::endl; dim3 blocks(gradOutput.ChannelCount()); From bd1697172c5b3c8e4512f64e4af6b642796ad183 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 17 Jan 2020 16:30:44 -0800 Subject: [PATCH 13/19] Cleaning --- src/operator/nn/batch_norm.cu | 85 +++-------------------------------- 1 file changed, 6 insertions(+), 79 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 2f3b466ab988..60d84595e7e6 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -254,7 +254,7 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel( __shared__ AType saved_mean[shmem_elements]; __shared__ AType saved_weight[shmem_elements]; __shared__ AType saved_bias[shmem_elements]; - union scratch { + union vectorized_loader { LType aligned; DType separate[nvec]; // NOLINT(*) @@ -425,9 +425,9 @@ __global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1( const int my_warp = threadIdx.x / warp_size; const int my_id = threadIdx.x % warp_size; - AType invstd[nvec]; - AType mean[nvec]; - AType gamma[nvec]; + AType invstd[nvec]; // NOLINT(*) + AType mean[nvec]; // NOLINT(*) + AType gamma[nvec]; // NOLINT(*) size_t channel_offset = (offset + my_id) * nvec; if (channel_offset < num_channels) { @@ -527,74 +527,6 @@ __global__ void FrozenBatchNormalizationBackwardKernelCLastPhase2(const AType * } } -template -static __global__ void FrozenBatchNormalizationBackwardKernelCLast( - const DType* input, - const DType* gradOutput, - DType* gradInput/*, - AType* gradWeight, - AType* gradBias, - const AType* weight, - const AType* runningMean, - const AType* runningVar, - const uint32_t flags, - const AccReal momentum, - const AccReal eps, - const int splitk, - const index_t num_channels, - const int loads_per_block, - const index_t outer_dim, - const index_t inner_dim, - const index_t NHW*/) { -#if 0 - int plane = blockIdx.x; - int N = gradOutput.OuterSize() * gradOutput.InnerSize(); - - AccReal mean, invstd; - mean = ScalarConvert::to(tensors.runningMean[plane]); - invstd = variance_to_invstd(tensors.runningVar[plane], eps); - - const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? - ScalarConvert::to(tensors.weight[plane]) : AccReal(1); - const AccReal norm = AccReal(1) / N; - - // Compute two values across (batch, x/y/z) in one pass: - // 1. Sum(gradOutput) - // 2. DotProduct(input - mean, gradOutput) - GradOp g(mean, input, gradOutput); - Float2< DType, AccReal > res = reduce < Float2 < DType, AccReal >, - GradOp< DType, AccReal, DeviceTensor >, DeviceTensor > (g, gradOutput, plane); - const AccReal gradOutputSum = res.v1; - const AccReal dotP = res.v2; - - const AccReal gradMean = gradOutputSum * norm; - const AccReal projScale = dotP * norm * invstd * invstd; - const AccReal gradScale = invstd * weightVal; - - if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); - } - } - } - - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { - if ((flags & FIX_GAMMA_FLAG) == 0) { - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); - } else { - tensors.gradWeight[plane] = DType(0); - } - } - - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); - } -#endif -} - template __global__ void FrozenBatchNormalizationBackwardKernel( const DType* input, @@ -954,10 +886,6 @@ static void BatchNormalizationBackward(mshadow::Stream *s, <<< blocks, threads, 0, mshadow::Stream::GetStream(s) >>> ( input, gradOutput, gradInput, tensors, flags, momentum, eps); } else { - std::cout <<"Frozen" < *s, : nullptr; if (param.axis == -1 || param.axis == in_data[batchnorm::kData].shape_.ndim() - 1) { - std::cout << "NHWC " << std::endl; const int C = gradOutput.ChannelCount(); int ltype = mxnet::common::cuda::get_load_type(C * sizeof(DType)); const int M = gradOutput.OuterSize(); @@ -985,7 +912,8 @@ static void BatchNormalizationBackward(mshadow::Stream *s, const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u); const dim3 n_blocks = {blocks_x, blocks_y, 1}; auto scratch_space = ctx.requested[batchnorm::kTempSpace] - .get_space_typed(mshadow::Shape1(C * blocks_y * 2), s); + .get_space_typed(mshadow::Shape1(C * blocks_y * 2), + s); auto stream = mshadow::Stream::GetStream(s); constexpr int nthreads_phase1 = 512; constexpr int nthreads_phase2 = 128; @@ -1006,7 +934,6 @@ static void BatchNormalizationBackward(mshadow::Stream *s, blocks_y, eps, flags_copy); }); } else { - std::cout << "NCHW " << std::endl; dim3 blocks(gradOutput.ChannelCount()); int ltype = mxnet::common::cuda::get_load_type(gradOutput.InnerSize() * sizeof(DType)); MXNET_LOAD_TYPE_SWITCH(ltype, LType, { From 45642bcb4329d5b42e686d6d03d12099794c3d7c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 17 Jan 2020 16:39:03 -0800 Subject: [PATCH 14/19] Remove the change to Makefile --- Makefile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Makefile b/Makefile index 6e49f4d01f3d..16d7d2393736 100644 --- a/Makefile +++ b/Makefile @@ -135,8 +135,6 @@ else NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) endif -NVCCFLAGS += --ptxas-options=-v - # CFLAGS for segfault logger ifeq ($(USE_SIGNAL_HANDLER), 1) CFLAGS += -DMXNET_USE_SIGNAL_HANDLER=1 From abc88c2cfe606f86955d46fa1f2ed2f0b3384ed4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 17 Jan 2020 17:46:27 -0800 Subject: [PATCH 15/19] Fix from rebase --- src/operator/nn/batch_norm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 60d84595e7e6..274593e13ee9 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -258,8 +258,8 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel( LType aligned; DType separate[nvec]; // NOLINT(*) - __device__ inline scratch() {} - __device__ inline ~scratch() {} + __device__ inline vectorized_loader() {} + __device__ inline ~vectorized_loader() {} } scratch; if (small_num_channels) { From 2fe6c68aa0eb7f8573fdef89bc70c959f3ee904e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 21 Jan 2020 09:59:05 -0800 Subject: [PATCH 16/19] Temp space for BN backward --- src/operator/nn/batch_norm.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index ea1c76965a9b..4d1e3859ba6d 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -597,11 +597,9 @@ NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) -#if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) From a11f5fe20d23fa0b503d7878359b8d7356f5cd55 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 21 Jan 2020 10:00:41 -0800 Subject: [PATCH 17/19] Fix from review --- src/operator/nn/batch_norm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 274593e13ee9..0df4fb135d68 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -61,7 +61,7 @@ MSHADOW_XINLINE double variance_to_invstd(double var, double eps) { template MSHADOW_XINLINE AccReal invstd_to_variance(AccReal invstd, AccReal eps) { - return 1.0f / (invstd * invstd) - eps; + return static_cast(1.0) / (invstd * invstd) - eps; } template <> From 983b7484f4213a4735412aa8a3be615bf0f02d3f Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 21 Jan 2020 10:01:28 -0800 Subject: [PATCH 18/19] Fix lint --- src/operator/nn/batch_norm.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 0df4fb135d68..97e08cfa03b9 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -581,7 +581,6 @@ __global__ void FrozenBatchNormalizationBackwardKernel( sum_beta += static_cast(scratch_grad.separate[j]); sum_gamma += static_cast(scratch_grad.separate[j]) * (static_cast(scratch_input.separate[j]) - mean); - } if (flags & WRITE_DATA_FLAG) { From e101b7ca47d74b44217d8e6db1efc7e48f6cb6be Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 3 Apr 2020 13:36:22 -0700 Subject: [PATCH 19/19] Changes from review --- src/common/cuda_utils.h | 10 +++++----- src/operator/nn/batch_norm.cu | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index c5a960ecf09c..22ac42c6c67b 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -832,15 +832,15 @@ __device__ inline T reduce(const T& value, OP redfun) { static_assert(NTHREADS <= warp_size * warp_size, "Number of threads too large for reduction"); __shared__ T scratch[NTHREADS / warp_size]; - const int my_id = threadIdx.x % warp_size; - const int my_warp = threadIdx.x / warp_size; + const int thread_idx_in_warp = threadIdx.x % warp_size; + const int warp_id = threadIdx.x / warp_size; const T my_val = warp_reduce(value, redfun); - if (my_id == 0) { - scratch[my_warp] = my_val; + if (thread_idx_in_warp == 0) { + scratch[warp_id] = my_val; } __syncthreads(); T ret = 0; - if (my_warp == 0) { + if (warp_id == 0) { const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0; const T my_val = warp_reduce(prev_val, redfun); if (all_reduce) { diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 97e08cfa03b9..7f222a1e3f9c 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -423,12 +423,12 @@ __global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1( const size_t offset = blockIdx.x * warp_size; const int my_warp = threadIdx.x / warp_size; - const int my_id = threadIdx.x % warp_size; + const int thread_idx_in_warp = threadIdx.x % warp_size; AType invstd[nvec]; // NOLINT(*) AType mean[nvec]; // NOLINT(*) AType gamma[nvec]; // NOLINT(*) - size_t channel_offset = (offset + my_id) * nvec; + size_t channel_offset = (offset + thread_idx_in_warp) * nvec; if (channel_offset < num_channels) { #pragma unroll @@ -446,9 +446,9 @@ __global__ void FrozenBatchNormalizationBackwardKernelCLastPhase1( const int rows_per_block = (outer + gridDim.y - 1) / gridDim.y; const size_t start_row = my_warp + rows_per_block * blockIdx.y; const size_t end_row = min(outer, static_cast(rows_per_block * (blockIdx.y + 1))); - if (offset + my_id < stride) { + if (offset + thread_idx_in_warp < stride) { for (size_t i = start_row; i < end_row; i += num_warps) { - const index_t idx = i * stride + offset + my_id; + const index_t idx = i * stride + offset + thread_idx_in_warp; vec_gradOutput.aligned = aligned_gradOutput[idx]; vec_input.aligned = aligned_input[idx]; #pragma unroll