Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

FullyConnected Bias performance improvement on GPU #16039

Merged
merged 8 commits into from
Sep 24, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,26 @@ void AddBias(Tensor<cpu, 1, DType> bias, Tensor<cpu, 2, DType> data,

#if defined(__CUDACC__)

namespace {
constexpr int nthreads_addbias = 256;
constexpr int nthreads_addbiasgrad_phase1 = 512;
constexpr int nthreads_addbiasgrad_phase2 = 128;

inline int ceil_div(int x, int y) {
return (x + y - 1) / y;
}
} // namespace

template <typename DType, typename LType>
__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) {
const int NTHREADS = 256;
__shared__ LType scratch[NTHREADS * 2];
__shared__ LType scratch[nthreads_addbias * 2];
const index_t N = bias_length * sizeof(DType)/sizeof(LType);
const index_t base = blockIdx.x * N;
LType* const mat_aligned = reinterpret_cast<LType*>(mat) + base;
const LType* const bias_aligned = reinterpret_cast<LType*>(bias);
LType* const scratch_bias_load = scratch + threadIdx.x;
DType* const scratch_bias = reinterpret_cast<DType*>(scratch_bias_load);
LType* const scratch_mat_load = scratch_bias_load + NTHREADS;
LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias;
DType* const scratch_mat = reinterpret_cast<DType*>(scratch_mat_load);
for (index_t i = threadIdx.x; i < N; i += blockDim.x) {
*scratch_bias_load = bias_aligned[i];
Expand All @@ -115,10 +124,13 @@ void AddBias(Tensor<gpu, 1, DType> bias, Tensor<gpu, 2, DType> data,
Tensor<gpu, 2, DType> out, Stream<gpu>* s) {
int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType));
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
add_bias_kernel<DType, LType><<<data.size(0), 256, 0, Stream<gpu>::GetStream(s)>>>(out.dptr_,
bias.dptr_,
data.size(0),
bias.shape_[0]);
add_bias_kernel<DType, LType><<<data.size(0),
nthreads_addbias,
0,
Stream<gpu>::GetStream(s)>>>(out.dptr_,
bias.dptr_,
data.size(0),
bias.shape_[0]);
});
}

Expand Down Expand Up @@ -180,8 +192,8 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
template<typename LType, typename DType, typename AType>
__global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad,
const size_t lead_dim, const size_t other_dim) {
constexpr int num_warps = 16;
constexpr int threads_per_warp = 32;
constexpr int num_warps = nthreads_addbiasgrad_phase1 / threads_per_warp;
const int values_per_read = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
const size_t stride = lead_dim / values_per_read;
__shared__ AType scratch[threads_per_warp * num_warps * values_per_read];
Expand Down Expand Up @@ -217,7 +229,7 @@ __global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad,

__syncthreads();

for (int i = 8; i > 0; i /= 2) {
for (int i = num_warps / 2; i > 0; i /= 2) {
if (my_warp < i) {
const int shared_offset = values_per_read * i * threads_per_warp;
#pragma unroll
Expand Down Expand Up @@ -272,20 +284,26 @@ void AddBiasGrad(const TBlob& in_grad,
int ltype = mxnet::common::cuda::get_load_type(N * sizeof(DType));
const int M = grad_blob.shape_.Size() / N;
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
const unsigned int blocks_x = (N * sizeof(DType) + 32 * sizeof(LType) - 1) /
(32 * sizeof(LType));
const unsigned int blocks_x = ceil_div(N * sizeof(DType), 32 * sizeof(LType));
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
const unsigned int preferred_number_of_blocks = 2 *
MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
const dim3 n_blocks = {blocks_x, blocks_y, 1};
auto scratch_space = ctx.requested[fullc::kTempSpace]
.get_space_typed<gpu, 1, AType>(mshadow::Shape1(N * blocks_y), s);
auto stream = mshadow::Stream<gpu>::GetStream(s);
AddBiasGradKernelPhase1<LType><<<n_blocks, 512, 0, stream>>>(scratch_space.dptr_,
grad.dptr_, N, M);
AddBiasGradKernelPhase2<<<(N + 127) / 128, 128, 0, stream>>>(scratch_space.dptr_,
gbias.dptr_, N,
blocks_y, req);
AddBiasGradKernelPhase1<LType><<<n_blocks,
nthreads_addbiasgrad_phase1,
0,
stream>>>(scratch_space.dptr_,
grad.dptr_, N, M);
const int nblocks_phase2 = ceil_div(N, nthreads_addbiasgrad_phase2);
AddBiasGradKernelPhase2<<<nblocks_phase2,
nthreads_addbiasgrad_phase2,
0,
stream>>>(scratch_space.dptr_,
gbias.dptr_, N,
blocks_y, req);
});
}
#endif
Expand Down