Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Softmax optimization for GPU #15545

Merged
merged 16 commits into from
Aug 21, 2019
90 changes: 90 additions & 0 deletions src/common/cuda_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2017 by Contributors
* \file cuda_utils.cc
* \brief Common CUDA utilities.
*/

#include <mxnet/base.h>
#include <mshadow/base.h>

#include <algorithm>

#include "cuda_utils.h"

#if MXNET_USE_CUDA

namespace mxnet {
namespace common {
namespace cuda {

namespace {
bool IsPower2(size_t N) {
return ((N & (N - 1)) == 0) && N != 0;
}

size_t RoundToPower2(size_t N) {
size_t ret = 1;
size_t copyN = N;
while (N >= 2) {
ret *= 2;
N /= 2;
}
if (ret < copyN) {
ret *= 2;
}
return ret;
}
} // namespace

int get_load_type(size_t N) {
using namespace mshadow;
if (N % 8 == 0) {
return kFloat64;
} else if (N % 4 == 0) {
return kFloat32;
} else if (N % 2 == 0) {
return kFloat16;
} else {
return kUint8;
}
}

int get_rows_per_block(size_t row_size, int num_threads_per_block) {
const int warp_size = 32;
CHECK(IsPower2(num_threads_per_block))
<< "Number of threads in a block must be power of 2 to use get_rows_per_block function";
// How many read instructions should 1 thread at least do
const int read_instructions = 2;
const int desired_num_threads_per_row = (row_size + read_instructions - 1) / read_instructions;
int desired_num_warps_per_row = (desired_num_threads_per_row + warp_size - 1) / warp_size;
int actual_num_warps_per_row = std::min(desired_num_warps_per_row,
num_threads_per_block / warp_size);
// actual number of warps needs to be power of 2
actual_num_warps_per_row = RoundToPower2(desired_num_warps_per_row);
return num_threads_per_block / (warp_size * actual_num_warps_per_row);
}

} // namespace cuda
} // namespace common
} // namespace mxnet

#endif // MXNET_USE_CUDA
50 changes: 47 additions & 3 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
/*!
* Copyright (c) 2015 by Contributors
* \file cuda_utils.h
* \brief CUDA debugging utilities.
* \brief Common CUDA utilities.
*/
#ifndef MXNET_COMMON_CUDA_UTILS_H_
#define MXNET_COMMON_CUDA_UTILS_H_
Expand Down Expand Up @@ -327,6 +327,28 @@ class DeviceStore {
bool restore_;
};

/*!
* \brief Get the largest datatype suitable to read
* requested number of bytes.
*
* \input Number of bytes to be read
* \return mshadow representation of type that could
* be used for reading
*/
int get_load_type(size_t N);

/*!
* \brief Determine how many rows in a 2D matrix should a block
* of threads handle based on the row size and the number
* of threads in a block.
* \param row_size Size of the row expressed in the number of reads required to fully
* load it. For example, if the row has N elements, but each thread
* reads 2 elements with a single read, row_size should be N / 2.
* \param num_threads_per_block Number of threads in a block.
* \return the number of rows that should be handled by a single block.
*/
int get_rows_per_block(size_t row_size, int num_threads_per_block);

} // namespace cuda
} // namespace common
} // namespace mxnet
Expand Down Expand Up @@ -542,7 +564,7 @@ static inline __device__ void atomicAdd(double *address, double val) {
// Overload atomicAdd for half precision
// Taken from:
// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
#if defined(__CUDA_ARCH__)
#ifdef __CUDACC__
static inline __device__ void atomicAdd(mshadow::half::half_t *address,
mshadow::half::half_t val) {
unsigned int *address_as_ui =
Expand Down Expand Up @@ -607,6 +629,28 @@ __device__ inline DType ldg(const DType* address) {
return *address;
#endif
}
#endif

eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
template <typename OP, typename T>
__device__ inline T warp_reduce(T value, OP redfun) {
value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
return value;
}

template <typename OP>
__device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
float v = static_cast<float>(value);
v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
return mshadow::half::half_t(v);
}

#endif // __CUDACC__

#endif // MXNET_COMMON_CUDA_UTILS_H_
30 changes: 30 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,36 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_LOAD_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Invalid loading enum type " << type; \
}

/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
78 changes: 78 additions & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2017 by Contributors
* \file log_softmax.cc
* \brief CPU Implementation of log_softmax
*/
#include "./softmax-inl.h"
#include "../tensor/elemwise_unary_op.h"
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(log_softmax)
.add_alias("_npx_log_softmax")
.describe(R"code(Computes the log softmax of the input.
This is equivalent to computing softmax followed by log.

Examples::

>>> x = mx.nd.array([1, 2, .1])
>>> mx.nd.log_softmax(x).asnumpy()
array([-1.41702998, -0.41702995, -2.31702995], dtype=float32)

>>> x = mx.nd.array( [[1, 2, .1],[.1, 2, 1]] )
>>> mx.nd.log_softmax(x, axis=0).asnumpy()
array([[-0.34115392, -0.69314718, -1.24115396],
[-1.24115396, -0.69314718, -0.34115392]], dtype=float32)


)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_log_softmax"})
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "The input array.")
.add_arguments(SoftmaxParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_log_softmax)
.set_num_inputs(SoftmaxGradOpNumInputs)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);

} // namespace op
} // namespace mxnet
39 changes: 39 additions & 0 deletions src/operator/nn/log_softmax.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2017 by Contributors
* \file log_softmax.cu
* \brief GPU Implementation of log_softmax
*/
#include "./softmax-inl.h"
#include "../tensor/elemwise_unary_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(log_softmax)
.set_attr<FCompute>("FCompute<gpu>", SoftmaxCompute<gpu, mxnet_op::log_softmax_fwd>);

NNVM_REGISTER_OP(_backward_log_softmax)
.set_attr<FCompute>("FCompute<gpu>", SoftmaxGradCompute<gpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);

} // namespace op
} // namespace mxnet
Loading