Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Jul 23, 2019
1 parent 40c0d00 commit 1609a2b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 39 deletions.
33 changes: 33 additions & 0 deletions src/common/cuda_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*!
* Copyright (c) 2017 by Contributors
* \file cuda_utils.cc
* \brief CUDA debugging utilities.
*/

#include <mxnet/base.h>
#include <mshadow/base.h>
#include "cuda_utils.h"

#if MXNET_USE_CUDA

namespace mxnet {
namespace common {
namespace cuda {

int get_load_type(size_t N) {
using namespace mshadow;
if (N % 8 == 0) {
return kFloat64;
} else if (N % 4 == 0) {
return kFloat32;
} else if (N % 2 == 0) {
return kFloat16;
} else {
return kInt8;
}
}
} // namespace cuda
} // namespace common
} // namespace mxnet

#endif // MXNET_USE_CUDA
35 changes: 33 additions & 2 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ class DeviceStore {
bool restore_;
};

/*! \brief Get the largest datatype suitable to read
* requested number of bytes.
*
* \input Number of bytes to be read
* \return mshadow representation of type that could
* be used for reading
*/
int get_load_type(size_t N);

} // namespace cuda
} // namespace common
} // namespace mxnet
Expand Down Expand Up @@ -550,7 +559,7 @@ static inline __device__ void atomicAdd(double *address, double val) {
// Overload atomicAdd for half precision
// Taken from:
// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
#if defined(__CUDA_ARCH__)
#ifdef __CUDACC__
static inline __device__ void atomicAdd(mshadow::half::half_t *address,
mshadow::half::half_t val) {
unsigned int *address_as_ui =
Expand Down Expand Up @@ -615,6 +624,28 @@ __device__ inline DType ldg(const DType* address) {
return *address;
#endif
}
#endif

template <typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
return value;
}

template <typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
return mshadow::half::half_t(v);
}

#endif // __CUDACC__

#endif // MXNET_COMMON_CUDA_UTILS_H_
42 changes: 5 additions & 37 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../tensor/broadcast_reduce_op.h"
#include "../../common/cuda_utils.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -312,27 +313,6 @@ __global__ void softmax_compute_kernel(DType *in, OType *out, IType *length,

const int softmax_threads_per_block = 512;

template <typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
return value;
}

template <typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
return mshadow::half::half_t(v);
}

template<typename OP, bool negate, typename AType, typename LType,
typename DType, typename OType, typename IType>
__global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, IType *length,
Expand All @@ -356,7 +336,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
// the division by zero warning generated for such invalid cases.
const int row_length = entries_per_load > 0 ? M / entries_per_load : 0;

const LType * in_aligned = reinterpret_cast<const LType *>(in);
const LType* in_aligned = reinterpret_cast<const LType*>(in);
size_t base = my_row * row_length;

for (index_t i = my_id; i < row_length; i += threads_per_row) {
Expand Down Expand Up @@ -420,7 +400,7 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp
}
__syncthreads();

LType * out_aligned = reinterpret_cast<LType *>(out);
LType* out_aligned = reinterpret_cast<LType*>(out);

for (index_t i = my_id; i < row_length; i += threads_per_row) {
out_aligned[base + i] = persistent_storage[my_local_row * row_length + i];
Expand All @@ -429,18 +409,6 @@ __global__ void softmax_stride1_compute_kernel(const DType *in, OType *out, ITyp

namespace {

int get_load_type(size_t N) {
if (N % 8 == 0) {
return kFloat64;
} else if (N % 4 == 0) {
return kFloat32;
} else if (N % 2 == 0) {
return kFloat16;
} else {
return kInt8;
}
}

int get_rows_per_block(size_t N) {
const int warp_size = 32;
// How many read instructions should 1 thread at least do
Expand Down Expand Up @@ -479,9 +447,9 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
// Using 20 kB of shared memory for persistent storage in the optimized case
const size_t max_opt_M = 20 * 1024 / DSize;
if (stride[axis] == 1 &&
M <= max_opt_M &&
static_cast<size_t>(M) <= max_opt_M &&
std::is_same<DType, OType>::value) {
int ltype = get_load_type(M * sizeof(DType));
int ltype = mxnet::common::cuda::get_load_type(M * sizeof(DType));
MSHADOW_TYPE_SWITCH(ltype, LType, {
int rows_per_block = get_rows_per_block(M * sizeof(DType) / sizeof(LType));
int nblocks = (N + rows_per_block - 1) / rows_per_block;
Expand Down

0 comments on commit 1609a2b

Please sign in to comment.