Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Faster GPU frozen BatchNorm (#17368)
Browse files Browse the repository at this point in the history
* Better frozen batchnorm

* Continue FreezeBN

* Optimizations

* Reduce number of mod operations

* Cleaning

* Fixing frozen bn with fix_gamma=False

* Fix lint in BN

* Backward frozen batchnorm

* More work on backward of Frozen BN

* Let it compile

* NCHW Frozen BN backward

* Frozen BN backward NHWC

* Cleaning

* Remove the change to Makefile

* Fix from rebase

* Temp space for BN backward

* Fix from review

* Fix lint

* Changes from review
  • Loading branch information
ptrendx committed Aug 18, 2020
1 parent 2610c10 commit e06ee4e
Show file tree
Hide file tree
Showing 4 changed files with 561 additions and 93 deletions.
83 changes: 71 additions & 12 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,27 +783,86 @@ __device__ inline DType ldg(const DType* address) {
#endif
}

template <typename OP, typename T>
namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
namespace cuda {

static constexpr const int warp_size = 32;

/*! \brief Reduction inside a warp.
* Template parameters:
* NVALUES - number of values to reduce (defaults to warp_size).
* \param value - values to be reduced.
* \param redfun - function used to perform reduction.
*/
template <int NVALUES = warp_size, typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
#pragma unroll
for (int i = warp_size / 2; i >= 1; i /= 2) {
if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
}
return value;
}

template <typename OP>
template <int NValues = warp_size, typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
#pragma unroll
for (int i = warp_size / 2; i >= 1; i /= 2) {
if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
}
return mshadow::half::half_t(v);
}

/*! \brief Reduction inside a block, requires all threads in a block to participate.
* It uses a 2 step approach:
* - all warps in a block perform intermediate reduction
* - first warp reduces the intermediate results.
* Template parameters:
* NTHREADS - number of threads in a block.
* all_reduce - whether all threads need the result of the reduction. If set to
* true, then all threads return with the same value. If set to
* false, then only thread 0 has the valid result. Defaults to true.
* \param value - value from each thread to be reduced
* \param redfun - function used to perform reduction
*/
template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
__device__ inline T reduce(const T& value, OP redfun) {
static_assert(NTHREADS <= warp_size * warp_size,
"Number of threads too large for reduction");
__shared__ T scratch[NTHREADS / warp_size];
const int thread_idx_in_warp = threadIdx.x % warp_size;
const int warp_id = threadIdx.x / warp_size;
const T my_val = warp_reduce<warp_size>(value, redfun);
if (thread_idx_in_warp == 0) {
scratch[warp_id] = my_val;
}
__syncthreads();
T ret = 0;
if (warp_id == 0) {
const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0;
const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun);
if (all_reduce) {
scratch[threadIdx.x] = my_val;
} else {
ret = my_val;
}
}
// Necessary to synchronize in order to use this function again
// as the shared memory scratch space is reused between calls
__syncthreads();
if (all_reduce) {
ret = scratch[0];
__syncthreads();
}
return ret;
}

} // namespace cuda
} // namespace common
} // namespace mxnet

#endif // __CUDACC__

#endif // MXNET_COMMON_CUDA_UTILS_H_
2 changes: 0 additions & 2 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,11 +662,9 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr_parser(ParamParser<BatchNormParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
Expand Down
Loading

0 comments on commit e06ee4e

Please sign in to comment.