Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

FullyConnected Bias performance improvement on GPU #16039

Merged
merged 8 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ extern __cuda_fake_struct blockIdx;
#include <cublas_v2.h>
#include <curand.h>

#include <vector>

#define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
Expand Down Expand Up @@ -353,16 +355,41 @@ int get_rows_per_block(size_t row_size, int num_threads_per_block);
} // namespace common
} // namespace mxnet

/*! \brief Maximum number of GPUs */
constexpr size_t kMaxNumGpus = 64;

// The implementations below assume that accesses of 32-bit ints are inherently atomic and
// can be read/written by multiple threads without locks. The values held should be < 2^31.

/*!
* \brief Return an attribute GPU `device_id`.
* \param device_id The device index of the cuda-capable gpu of interest.
* \param cached_values An array of attributes for already-looked-up GPUs.
* \param attr The attribute, by number.
* \param attr_name A string representation of the attribute, for error messages.
* \return the gpu's attribute value.
*/
inline int cudaAttributeLookup(int device_id, std::vector<int32_t> *cached_values,
cudaDeviceAttr attr, const char *attr_name) {
if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
} else if ((*cached_values)[device_id] < 0) {
int temp = -1;
CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
(*cached_values)[device_id] = static_cast<int32_t>(temp);
}
return (*cached_values)[device_id];
}

DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Determine major version number of the gpu's cuda compute architecture.
* \param device_id The device index of the cuda-capable gpu of interest.
* \return the major version number of the gpu's cuda compute architecture.
*/
inline int ComputeCapabilityMajor(int device_id) {
int major = 0;
CUDA_CALL(cudaDeviceGetAttribute(&major,
cudaDevAttrComputeCapabilityMajor, device_id));
return major;
static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
return cudaAttributeLookup(device_id, &capability_major,
cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
}

/*!
Expand All @@ -371,10 +398,9 @@ inline int ComputeCapabilityMajor(int device_id) {
* \return the minor version number of the gpu's cuda compute architecture.
*/
inline int ComputeCapabilityMinor(int device_id) {
int minor = 0;
CUDA_CALL(cudaDeviceGetAttribute(&minor,
cudaDevAttrComputeCapabilityMinor, device_id));
return minor;
static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
return cudaAttributeLookup(device_id, &capability_minor,
cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
}

/*!
Expand All @@ -388,6 +414,40 @@ inline int SMArch(int device_id) {
return 10 * major + minor;
}

/*!
* \brief Return the number of streaming multiprocessors of GPU `device_id`.
* \param device_id The device index of the cuda-capable gpu of interest.
* \return the gpu's count of streaming multiprocessors.
*/
inline int MultiprocessorCount(int device_id) {
static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
return cudaAttributeLookup(device_id, &sm_counts,
cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
}

/*!
* \brief Return the shared memory size in bytes of each of the GPU's streaming multiprocessors.
* \param device_id The device index of the cuda-capable gpu of interest.
* \return the shared memory size per streaming multiprocessor.
*/
inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor,
cudaDevAttrMaxSharedMemoryPerMultiprocessor,
"MaxSharedMemoryPerMultiprocessor");
}

/*!
* \brief Return whether the GPU `device_id` supports cooperative-group kernel launching.
* \param device_id The device index of the cuda-capable gpu of interest.
* \return the gpu's ability to run cooperative-group kernels.
*/
inline bool SupportsCooperativeLaunch(int device_id) {
static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
return cudaAttributeLookup(device_id, &coop_launch,
cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
}

/*!
* \brief Determine whether a cuda-capable gpu's architecture supports float16 math.
* Assume not if device_id is negative.
Expand Down
10 changes: 10 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

template <typename T>
struct AccType {
using type = T;
};

template <>
struct AccType<mshadow::half::half_t> {
using type = float;
};

DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
switch (type) { \
case mshadow::kFloat32: \
Expand Down
202 changes: 188 additions & 14 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../linalg.h"
Expand Down Expand Up @@ -59,6 +60,7 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
int num_hidden;
bool no_bias;
bool flatten;

DMLC_DECLARE_PARAMETER(FullyConnectedParam) {
// TODO(bing) add support for boolean
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
Expand All @@ -75,6 +77,53 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
}
};

template<typename DType>
void AddBias(Tensor<cpu, 1, DType> bias, Tensor<cpu, 2, DType> data,
Tensor<cpu, 2, DType> out, Stream<cpu>*) {
using namespace mshadow;
using namespace mshadow::expr;
out += repmat(bias, data.size(0));
}

#if defined(__CUDACC__)

template <typename DType, typename LType>
__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) {
const int NTHREADS = 256;
__shared__ LType scratch[NTHREADS * 2];
const index_t N = bias_length * sizeof(DType)/sizeof(LType);
const index_t base = blockIdx.x * N;
LType* const mat_aligned = reinterpret_cast<LType*>(mat) + base;
const LType* const bias_aligned = reinterpret_cast<LType*>(bias);
LType* const scratch_bias_load = scratch + threadIdx.x;
DType* const scratch_bias = reinterpret_cast<DType*>(scratch_bias_load);
LType* const scratch_mat_load = scratch_bias_load + NTHREADS;
DType* const scratch_mat = reinterpret_cast<DType*>(scratch_mat_load);
for (index_t i = threadIdx.x; i < N; i += blockDim.x) {
*scratch_bias_load = bias_aligned[i];
*scratch_mat_load = mat_aligned[i];
#pragma unroll
for (int j = 0; j < sizeof(LType)/sizeof(DType); ++j) {
scratch_mat[j] += scratch_bias[j];
}
mat_aligned[i] = *scratch_mat_load;
}
}

template<typename DType>
void AddBias(Tensor<gpu, 1, DType> bias, Tensor<gpu, 2, DType> data,
Tensor<gpu, 2, DType> out, Stream<gpu>* s) {
int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType));
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
add_bias_kernel<DType, LType><<<data.size(0), 256, 0, Stream<gpu>::GetStream(s)>>>(out.dptr_,
bias.dptr_,
data.size(0),
bias.shape_[0]);
});
}

#endif // __CUDACC__

template<typename xpu, typename DType>
void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -122,10 +171,147 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
<< "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
" This is not supported by FCForward. If bias is in row_sparse format, please"
" make sure all row ids are present.";
out += repmat(bias, data.size(0));
AddBias(bias, data, out, s);
}
}

#if defined (__CUDACC__)

template<typename LType, typename DType, typename AType>
__global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad,
const size_t lead_dim, const size_t other_dim) {
constexpr int num_warps = 16;
constexpr int threads_per_warp = 32;
const int values_per_read = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
const size_t stride = lead_dim / values_per_read;
__shared__ AType scratch[threads_per_warp * num_warps * values_per_read];
LType * my_scratch_load = &(reinterpret_cast<LType *>(scratch)[threadIdx.x]);
DType * my_values_load = reinterpret_cast<DType *>(my_scratch_load);
AType * my_values_acc = &(scratch[threadIdx.x * values_per_read]);
AType acc[values_per_read]; // NOLINT(*)
#pragma unroll
for (int i = 0; i < values_per_read; ++i) {
acc[i] = 0;
}
const size_t offset = blockIdx.x * threads_per_warp;
const int my_warp = threadIdx.x / threads_per_warp;
const int my_id = threadIdx.x % threads_per_warp;
const LType* aligned_grad = reinterpret_cast<const LType*>(grad);
const int rows_per_block = (other_dim + gridDim.y - 1) / gridDim.y;
const size_t start_row = my_warp + rows_per_block * blockIdx.y;
const size_t end_row = min(other_dim, static_cast<size_t>(rows_per_block * (blockIdx.y + 1)));
if (offset + my_id < stride) {
for (size_t i = start_row; i < end_row; i += num_warps) {
*my_scratch_load = aligned_grad[i * stride + offset + my_id];
#pragma unroll
for (int j = 0; j < values_per_read; ++j) {
acc[j] += static_cast<AType>(my_values_load[j]);
}
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < values_per_read; ++i) {
my_values_acc[i] = acc[i];
}

__syncthreads();

for (int i = 8; i > 0; i /= 2) {
if (my_warp < i) {
const int shared_offset = values_per_read * i * threads_per_warp;
#pragma unroll
for (int j = 0; j < values_per_read; ++j) {
my_values_acc[j] += my_values_acc[j + shared_offset];
}
}
__syncthreads();
}

if (threadIdx.x < min(threads_per_warp * values_per_read,
static_cast<int>(lead_dim - values_per_read * offset))) {
const size_t offset_out = values_per_read * offset +
blockIdx.y * lead_dim;
temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
}
}

template <typename DType, typename AType>
__global__ void AddBiasGradKernelPhase2(const AType * temp_space, DType * out,
int lead_dim, int n_blocks, OpReqType req) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < lead_dim) {
AType acc = 0;
for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
acc += temp_space[i];
}
KERNEL_ASSIGN(out[tid], req, static_cast<DType>(acc));
}
}

template<typename DType>
void AddBiasGrad(const TBlob& in_grad,
Tensor<gpu, 2, DType> grad,
OpReqType req,
int num_hidden,
const OpContext& ctx) {
if (req == kNullOp) return;
using AType = typename mxnet_op::AccType<DType>::type;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 1, DType> gbias = in_grad.get<gpu, 1, DType>(s);
TBlob grad_blob = TBlob(grad);
TBlob gbias_blob = TBlob(gbias);
mxnet::TShape x(1, 0);
mxnet::TShape small;
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
small = gbias_blob.shape_;
} else {
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
}
const int N = small.Size();
int ltype = mxnet::common::cuda::get_load_type(N * sizeof(DType));
const int M = grad_blob.shape_.Size() / N;
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
const unsigned int blocks_x = (N * sizeof(DType) + 32 * sizeof(LType) - 1) /
(32 * sizeof(LType));
const unsigned int preferred_number_of_blocks = 2 *
MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
const dim3 n_blocks = {blocks_x, blocks_y, 1};
auto scratch_space = ctx.requested[fullc::kTempSpace]
.get_space_typed<gpu, 1, AType>(mshadow::Shape1(N * blocks_y), s);
auto stream = mshadow::Stream<gpu>::GetStream(s);
AddBiasGradKernelPhase1<LType><<<n_blocks, 512, 0, stream>>>(scratch_space.dptr_,
grad.dptr_, N, M);
AddBiasGradKernelPhase2<<<(N + 127) / 128, 128, 0, stream>>>(scratch_space.dptr_,
gbias.dptr_, N,
blocks_y, req);
});
}
#endif

DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
template<typename DType>
void AddBiasGrad(const TBlob& in_grad,
Tensor<cpu, 2, DType> grad,
OpReqType req,
int num_hidden,
const OpContext& ctx) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
Tensor<cpu, 1, DType> gbias = in_grad.get<cpu, 1, DType>(s);
TBlob grad_blob = TBlob(grad);
TBlob gbias_blob = TBlob(gbias);
mxnet::TShape x(1, 0);
mxnet::TShape small;
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
small = gbias_blob.shape_;
} else {
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
}
ReduceAxesComputeImpl<cpu, mshadow::red::sum, false, false,
mshadow_op::identity>(ctx, {grad_blob}, {req},
{in_grad}, small);
}

template<typename xpu, typename DType>
void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data,
Expand Down Expand Up @@ -169,19 +355,7 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
// gradient of bias
if (!param.no_bias) {
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
TBlob grad_blob = TBlob(grad);
TBlob gbias_blob = TBlob(gbias);
mxnet::TShape x(1, 0);
mxnet::TShape small;
if (shape_assign(&gbias_blob.shape_, Shape2(param.num_hidden, 1))) {
small = gbias_blob.shape_;
} else {
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
}
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false,
mshadow_op::identity>(ctx, {grad_blob}, {req[fullc::kBias]},
{in_grad[fullc::kBias]}, small);
AddBiasGrad(in_grad[fullc::kBias], grad, req[fullc::kBias], param.num_hidden, ctx);
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
}
// gradient of data
// Legacy approach shown here for comparison:
Expand Down