Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Faster GPU frozen BatchNorm #17368

Merged
merged 24 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 71 additions & 12 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,27 +764,86 @@ __device__ inline DType ldg(const DType* address) {
#endif
}

template <typename OP, typename T>
namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
namespace cuda {

static constexpr const int warp_size = 32;

/*! \brief Reduction inside a warp.
* Template parameters:
* NVALUES - number of values to reduce (defaults to warp_size).
* \param value - values to be reduced.
* \param redfun - function used to perform reduction.
*/
template <int NVALUES = warp_size, typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
#pragma unroll
for (int i = warp_size / 2; i >= 1; i /= 2) {
if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
}
return value;
}

template <typename OP>
template <int NValues = warp_size, typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
#pragma unroll
for (int i = warp_size / 2; i >= 1; i /= 2) {
if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
}
return mshadow::half::half_t(v);
}

/*! \brief Reduction inside a block, requires all threads in a block to participate.
* It uses a 2 step approach:
* - all warps in a block perform intermediate reduction
* - first warp reduces the intermediate results.
* Template parameters:
* NTHREADS - number of threads in a block.
* all_reduce - whether all threads need the result of the reduction. If set to
* true, then all threads return with the same value. If set to
* false, then only thread 0 has the valid result. Defaults to true.
* \param value - value from each thread to be reduced
* \param redfun - function used to perform reduction
*/
template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
__device__ inline T reduce(const T& value, OP redfun) {
static_assert(NTHREADS <= warp_size * warp_size,
"Number of threads too large for reduction");
__shared__ T scratch[NTHREADS / warp_size];
const int my_id = threadIdx.x % warp_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a more informative names other than my_*?

const int my_warp = threadIdx.x / warp_size;
const T my_val = warp_reduce<warp_size>(value, redfun);
if (my_id == 0) {
scratch[my_warp] = my_val;
}
__syncthreads();
T ret = 0;
if (my_warp == 0) {
const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0;
const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun);
if (all_reduce) {
scratch[threadIdx.x] = my_val;
} else {
ret = my_val;
}
}
// Necessary to synchronize in order to use this function again
// as the shared memory scratch space is reused between calls
__syncthreads();
if (all_reduce) {
ret = scratch[0];
__syncthreads();
}
return ret;
}

} // namespace cuda
} // namespace common
} // namespace mxnet

#endif // __CUDACC__

#endif // MXNET_COMMON_CUDA_UTILS_H_
Loading