Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Softmax optimization for GPU #15545

Merged
merged 16 commits into from
Aug 21, 2019
52 changes: 52 additions & 0 deletions src/common/cuda_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2017 by Contributors
* \file cuda_utils.cc
* \brief Common CUDA utilities.
*/

#include <mxnet/base.h>
#include <mshadow/base.h>
#include "cuda_utils.h"

#if MXNET_USE_CUDA

namespace mxnet {
namespace common {
namespace cuda {

int get_load_type(size_t N) {
using namespace mshadow;
if (N % 8 == 0) {
return kFloat64;
} else if (N % 4 == 0) {
return kFloat32;
} else if (N % 2 == 0) {
return kFloat16;
} else {
return kUint8;
}
}
} // namespace cuda
} // namespace common
} // namespace mxnet

#endif // MXNET_USE_CUDA
37 changes: 34 additions & 3 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
/*!
* Copyright (c) 2015 by Contributors
* \file cuda_utils.h
* \brief CUDA debugging utilities.
* \brief Common CUDA utilities.
*/
#ifndef MXNET_COMMON_CUDA_UTILS_H_
#define MXNET_COMMON_CUDA_UTILS_H_
Expand Down Expand Up @@ -326,6 +326,15 @@ class DeviceStore {
bool restore_;
};

/*! \brief Get the largest datatype suitable to read
* requested number of bytes.
*
* \input Number of bytes to be read
* \return mshadow representation of type that could
* be used for reading
*/
int get_load_type(size_t N);

} // namespace cuda
} // namespace common
} // namespace mxnet
Expand Down Expand Up @@ -550,7 +559,7 @@ static inline __device__ void atomicAdd(double *address, double val) {
// Overload atomicAdd for half precision
// Taken from:
// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
#if defined(__CUDA_ARCH__)
#ifdef __CUDACC__
static inline __device__ void atomicAdd(mshadow::half::half_t *address,
mshadow::half::half_t val) {
unsigned int *address_as_ui =
Expand Down Expand Up @@ -615,6 +624,28 @@ __device__ inline DType ldg(const DType* address) {
return *address;
#endif
}
#endif

eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
template <typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
return value;
}

template <typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
return mshadow::half::half_t(v);
}

#endif // __CUDACC__

#endif // MXNET_COMMON_CUDA_UTILS_H_
30 changes: 30 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,36 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_LOAD_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Invalid loading enum type " << type; \
}

/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
Loading