diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index dd98a0e6c966..2f7254040475 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -57,6 +57,8 @@ extern __cuda_fake_struct blockIdx; #include #include +#include + #define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \ static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \ QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \ @@ -353,16 +355,41 @@ int get_rows_per_block(size_t row_size, int num_threads_per_block); } // namespace common } // namespace mxnet +/*! \brief Maximum number of GPUs */ +constexpr size_t kMaxNumGpus = 64; + +// The implementations below assume that accesses of 32-bit ints are inherently atomic and +// can be read/written by multiple threads without locks. The values held should be < 2^31. + +/*! + * \brief Return an attribute GPU `device_id`. + * \param device_id The device index of the cuda-capable gpu of interest. + * \param cached_values An array of attributes for already-looked-up GPUs. + * \param attr The attribute, by number. + * \param attr_name A string representation of the attribute, for error messages. + * \return the gpu's attribute value. + */ +inline int cudaAttributeLookup(int device_id, std::vector *cached_values, + cudaDeviceAttr attr, const char *attr_name) { + if (device_id < 0 || device_id >= static_cast(cached_values->size())) { + LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id; + } else if ((*cached_values)[device_id] < 0) { + int temp = -1; + CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id)); + (*cached_values)[device_id] = static_cast(temp); + } + return (*cached_values)[device_id]; +} + /*! * \brief Determine major version number of the gpu's cuda compute architecture. * \param device_id The device index of the cuda-capable gpu of interest. * \return the major version number of the gpu's cuda compute architecture. */ inline int ComputeCapabilityMajor(int device_id) { - int major = 0; - CUDA_CALL(cudaDeviceGetAttribute(&major, - cudaDevAttrComputeCapabilityMajor, device_id)); - return major; + static std::vector capability_major(kMaxNumGpus, -1); + return cudaAttributeLookup(device_id, &capability_major, + cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor"); } /*! @@ -371,10 +398,9 @@ inline int ComputeCapabilityMajor(int device_id) { * \return the minor version number of the gpu's cuda compute architecture. */ inline int ComputeCapabilityMinor(int device_id) { - int minor = 0; - CUDA_CALL(cudaDeviceGetAttribute(&minor, - cudaDevAttrComputeCapabilityMinor, device_id)); - return minor; + static std::vector capability_minor(kMaxNumGpus, -1); + return cudaAttributeLookup(device_id, &capability_minor, + cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor"); } /*! @@ -388,6 +414,40 @@ inline int SMArch(int device_id) { return 10 * major + minor; } +/*! + * \brief Return the number of streaming multiprocessors of GPU `device_id`. + * \param device_id The device index of the cuda-capable gpu of interest. + * \return the gpu's count of streaming multiprocessors. + */ +inline int MultiprocessorCount(int device_id) { + static std::vector sm_counts(kMaxNumGpus, -1); + return cudaAttributeLookup(device_id, &sm_counts, + cudaDevAttrMultiProcessorCount, "MultiprocessorCount"); +} + +/*! + * \brief Return the shared memory size in bytes of each of the GPU's streaming multiprocessors. + * \param device_id The device index of the cuda-capable gpu of interest. + * \return the shared memory size per streaming multiprocessor. + */ +inline int MaxSharedMemoryPerMultiprocessor(int device_id) { + static std::vector max_smem_per_mutiprocessor(kMaxNumGpus, -1); + return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor, + cudaDevAttrMaxSharedMemoryPerMultiprocessor, + "MaxSharedMemoryPerMultiprocessor"); +} + +/*! + * \brief Return whether the GPU `device_id` supports cooperative-group kernel launching. + * \param device_id The device index of the cuda-capable gpu of interest. + * \return the gpu's ability to run cooperative-group kernels. + */ +inline bool SupportsCooperativeLaunch(int device_id) { + static std::vector coop_launch(kMaxNumGpus, -1); + return cudaAttributeLookup(device_id, &coop_launch, + cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch"); +} + /*! * \brief Determine whether a cuda-capable gpu's architecture supports float16 math. * Assume not if device_id is negative. diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index ae98952bb7bf..d90aa268195a 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -249,6 +249,16 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +template +struct AccType { + using type = T; +}; + +template <> +struct AccType { + using type = float; +}; + #define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\ switch (type) { \ case mshadow::kFloat32: \ diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 44af375486fb..28b6905148e4 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -32,6 +32,7 @@ #include #include #include +#include #include "../operator_common.h" #include "../elemwise_op_common.h" #include "../linalg.h" @@ -59,6 +60,7 @@ struct FullyConnectedParam : public dmlc::Parameter { int num_hidden; bool no_bias; bool flatten; + DMLC_DECLARE_PARAMETER(FullyConnectedParam) { // TODO(bing) add support for boolean DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1) @@ -75,6 +77,66 @@ struct FullyConnectedParam : public dmlc::Parameter { } }; +template +void AddBias(Tensor bias, Tensor data, + Tensor out, Stream*) { + using namespace mshadow; + using namespace mshadow::expr; + out += repmat(bias, data.size(0)); +} + +#if defined(__CUDACC__) + +namespace { + constexpr int nthreads_addbias = 256; + constexpr int nthreads_addbiasgrad_phase1 = 512; + constexpr int nthreads_addbiasgrad_phase2 = 128; + constexpr int threads_per_warp = 32; + + inline int ceil_div(int x, int y) { + return (x + y - 1) / y; + } +} // namespace + +template +__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) { + __shared__ LType scratch[nthreads_addbias * 2]; + const index_t N = bias_length * sizeof(DType)/sizeof(LType); + const index_t base = blockIdx.x * N; + LType* const mat_aligned = reinterpret_cast(mat) + base; + const LType* const bias_aligned = reinterpret_cast(bias); + LType* const scratch_bias_load = scratch + threadIdx.x; + DType* const scratch_bias = reinterpret_cast(scratch_bias_load); + LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias; + DType* const scratch_mat = reinterpret_cast(scratch_mat_load); + for (index_t i = threadIdx.x; i < N; i += blockDim.x) { + *scratch_bias_load = bias_aligned[i]; + *scratch_mat_load = mat_aligned[i]; +#pragma unroll + for (int j = 0; j < sizeof(LType)/sizeof(DType); ++j) { + scratch_mat[j] += scratch_bias[j]; + } + mat_aligned[i] = *scratch_mat_load; + } +} + +template +void AddBias(Tensor bias, Tensor data, + Tensor out, Stream* s) { + int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType)); + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + add_bias_kernel<<::GetStream(s)>>>(out.dptr_, + bias.dptr_, + data.size(0), + bias.shape_[0]); + }); +} + +#endif // __CUDACC__ + template void FCForward(const OpContext &ctx, const FullyConnectedParam ¶m, const std::vector &in_data, const std::vector &req, @@ -122,10 +184,153 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam ¶m, << "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]." " This is not supported by FCForward. If bias is in row_sparse format, please" " make sure all row ids are present."; - out += repmat(bias, data.size(0)); + AddBias(bias, data, out, s); } } +#if defined (__CUDACC__) + +template +__global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad, + const size_t lead_dim, const size_t other_dim) { + constexpr int num_warps = nthreads_addbiasgrad_phase1 / threads_per_warp; + const int values_per_read = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1; + const size_t stride = lead_dim / values_per_read; + __shared__ AType scratch[threads_per_warp * num_warps * values_per_read]; + LType * my_scratch_load = &(reinterpret_cast(scratch)[threadIdx.x]); + DType * my_values_load = reinterpret_cast(my_scratch_load); + AType * my_values_acc = &(scratch[threadIdx.x * values_per_read]); + AType acc[values_per_read]; // NOLINT(*) +#pragma unroll + for (int i = 0; i < values_per_read; ++i) { + acc[i] = 0; + } + const size_t offset = blockIdx.x * threads_per_warp; + const int my_warp = threadIdx.x / threads_per_warp; + const int my_id = threadIdx.x % threads_per_warp; + const LType* aligned_grad = reinterpret_cast(grad); + const int rows_per_block = (other_dim + gridDim.y - 1) / gridDim.y; + const size_t start_row = my_warp + rows_per_block * blockIdx.y; + const size_t end_row = min(other_dim, static_cast(rows_per_block * (blockIdx.y + 1))); + if (offset + my_id < stride) { + for (size_t i = start_row; i < end_row; i += num_warps) { + *my_scratch_load = aligned_grad[i * stride + offset + my_id]; +#pragma unroll + for (int j = 0; j < values_per_read; ++j) { + acc[j] += static_cast(my_values_load[j]); + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < values_per_read; ++i) { + my_values_acc[i] = acc[i]; + } + + __syncthreads(); + + for (int i = num_warps / 2; i > 0; i /= 2) { + if (my_warp < i) { + const int shared_offset = values_per_read * i * threads_per_warp; +#pragma unroll + for (int j = 0; j < values_per_read; ++j) { + my_values_acc[j] += my_values_acc[j + shared_offset]; + } + } + __syncthreads(); + } + + if (threadIdx.x < min(threads_per_warp * values_per_read, + static_cast(lead_dim - values_per_read * offset))) { + const size_t offset_out = values_per_read * offset + + blockIdx.y * lead_dim; + temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x]; + } +} + +template +__global__ void AddBiasGradKernelPhase2(const AType * temp_space, DType * out, + int lead_dim, int n_blocks, OpReqType req) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < lead_dim) { + AType acc = 0; + for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) { + acc += temp_space[i]; + } + KERNEL_ASSIGN(out[tid], req, static_cast(acc)); + } +} + +template +void AddBiasGrad(const TBlob& in_grad, + Tensor grad, + OpReqType req, + int num_hidden, + const OpContext& ctx) { + if (req == kNullOp) return; + using AType = typename mxnet_op::AccType::type; + mshadow::Stream *s = ctx.get_stream(); + Tensor gbias = in_grad.get(s); + TBlob grad_blob = TBlob(grad); + TBlob gbias_blob = TBlob(gbias); + mxnet::TShape x(1, 0); + mxnet::TShape small; + if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) { + small = gbias_blob.shape_; + } else { + small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional(x), true, false); + } + const int N = small.Size(); + int ltype = mxnet::common::cuda::get_load_type(N * sizeof(DType)); + const int M = grad_blob.shape_.Size() / N; + MXNET_LOAD_TYPE_SWITCH(ltype, LType, { + const unsigned int blocks_x = ceil_div(N * sizeof(DType), + threads_per_warp * sizeof(LType)); + const unsigned int preferred_number_of_blocks = 2 * + MultiprocessorCount(ctx.run_ctx.ctx.dev_id); + const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u); + const dim3 n_blocks = {blocks_x, blocks_y, 1}; + auto scratch_space = ctx.requested[fullc::kTempSpace] + .get_space_typed(mshadow::Shape1(N * blocks_y), s); + auto stream = mshadow::Stream::GetStream(s); + AddBiasGradKernelPhase1<<>>(scratch_space.dptr_, + grad.dptr_, N, M); + const int nblocks_phase2 = ceil_div(N, nthreads_addbiasgrad_phase2); + AddBiasGradKernelPhase2<<>>(scratch_space.dptr_, + gbias.dptr_, N, + blocks_y, req); + }); +} +#endif + +template +void AddBiasGrad(const TBlob& in_grad, + Tensor grad, + OpReqType req, + int num_hidden, + const OpContext& ctx) { + mshadow::Stream *s = ctx.get_stream(); + Tensor gbias = in_grad.get(s); + TBlob grad_blob = TBlob(grad); + TBlob gbias_blob = TBlob(gbias); + mxnet::TShape x(1, 0); + mxnet::TShape small; + if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) { + small = gbias_blob.shape_; + } else { + small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional(x), true, false); + } + ReduceAxesComputeImpl(ctx, {grad_blob}, {req}, + {in_grad}, small); +} + template void FCBackward(const OpContext &ctx, const FullyConnectedParam ¶m, const std::vector &out_grad, const std::vector &in_data, @@ -169,19 +374,7 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam ¶m, linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]); // gradient of bias if (!param.no_bias) { - Tensor gbias = in_grad[fullc::kBias].get(s); - TBlob grad_blob = TBlob(grad); - TBlob gbias_blob = TBlob(gbias); - mxnet::TShape x(1, 0); - mxnet::TShape small; - if (shape_assign(&gbias_blob.shape_, Shape2(param.num_hidden, 1))) { - small = gbias_blob.shape_; - } else { - small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional(x), true, false); - } - ReduceAxesComputeImpl(ctx, {grad_blob}, {req[fullc::kBias]}, - {in_grad[fullc::kBias]}, small); + AddBiasGrad(in_grad[fullc::kBias], grad, req[fullc::kBias], param.num_hidden, ctx); } // gradient of data // Legacy approach shown here for comparison: