Skip to content

Commit

Permalink
fp16 safe norm operator (apache#14616)
Browse files Browse the repository at this point in the history
* use safe aggregation for norm

* safe norm with DataType, AccuType and OutType

* new test for backward

* change back to MSHADOW_TYPE_SWITCH

* remove dead debug outputs

* Allow integer types
  • Loading branch information
haojin2 authored and haohuw committed Jun 23, 2019
1 parent 9ae2674 commit 47be295
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 148 deletions.
68 changes: 64 additions & 4 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,13 @@ struct nanprod {
/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*)
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
Expand Down Expand Up @@ -1012,6 +1012,66 @@ struct nrm2 {
}
};

/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
dst += src;
}
/*! \brief do stable reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - dst_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};

struct nanprod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down
83 changes: 75 additions & 8 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,20 +273,87 @@ inline int get_num_threads<cpu>(const int N) {
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
{ \
typedef uint8_t DType; \
typedef uint8_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
typedef int8_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
typedef int32_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
typedef int64_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_ACC_TYPE_SWITCH(type, DType, AType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
typedef float AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
typedef uint32_t AType; \
} \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
{ \
typedef int8_t DType; \
typedef int32_t AType; \
} \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
{ \
typedef int32_t DType; \
typedef int64_t AType; \
} \
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
{ \
typedef int64_t DType; \
typedef int64_t AType; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
Expand Down
61 changes: 36 additions & 25 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ void BinaryBroadcastComputeImpl(Stream<gpu> *s, const OpReqType req,
}

const int nthread_reduce = kMaxThreadsPerBlock;
template<typename Reducer, int ndim, typename DType, typename OP, int unroll>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll>
__launch_bounds__(nthread_reduce)
__global__ void reduce_kernel(const int N, const int M, const bool addto,
const DType* __restrict big, DType *small,
const DType* __restrict big, OType *small,
const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
const Shape<ndim> big_shape, const Shape<ndim> big_stride,
const int Mnext, const bool do_transpose) {
extern __shared__ char shTileChar[];
DType* shTile = (DType*)(shTileChar);
AType* shTile = (AType*)(shTileChar);
const int tid = threadIdx.x + threadIdx.y*blockDim.x;
const int bx = (do_transpose) ? blockDim.y : blockDim.x;
const int by = (do_transpose) ? blockDim.x : blockDim.y;
Expand All @@ -95,7 +95,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
Shape<ndim> coord = unravel(idx, small_shape);
int idx_big0 = ravel(coord, big_shape0);

DType val, residual;
AType val, residual;
Reducer::SetInitValue(val, residual);
if (idx < N) {
for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
Expand All @@ -113,7 +113,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
}
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
if (k + u*by < Mend) Reducer::Reduce(val, AType(tmp[u]), residual);
}
}
}
Expand All @@ -127,7 +127,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, tmp_residual;
AType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
Expand All @@ -139,12 +139,12 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
}
if (idx < N && tidy == 0) {
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
assign(&small[idx + m0*N], addto, OType(val));
}
}
}
Expand Down Expand Up @@ -261,18 +261,18 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
}
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
__global__ void reduce_kernel_M1(const int N, const bool addto,
const DType* __restrict big, DType *small, const Shape<ndim> bshape,
const DType* __restrict big, OType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
DType val, residual;
AType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP::Map(big[j]), residual);
Reducer::Reduce(val, AType(OP::Map(big[j])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
assign(&small[idx], addto, OType(val));
}
}

Expand Down Expand Up @@ -491,7 +491,7 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const mxnet::TShape& small, const mxn

if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
config.workspace_size += config.N*config.Mnext*sizeof(DType);
config.workspace_size += config.N*config.Mnext*sizeof(double);
// Set gridDim.y to Mnext
config.kernel_1.gridDim.y = std::min(kBaseGridNum, config.Mnext);
}
Expand All @@ -516,23 +516,22 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const mxnet::TShape& small, const mxn
{__VA_ARGS__} \
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig<ndim>& config) {
if (config.M == 1) {
reduce_kernel_M1<Reducer, ndim, DType, OP>
reduce_kernel_M1<Reducer, ndim, AType, DType, OType, OP>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<OType>(), big.shape_.get<ndim>(),
small.shape_.get<ndim>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
} else {

DType* small_dptr = small.dptr<DType>();
OType* small_dptr = small.dptr<OType>();
bool addto = (req == kAddTo);
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
addto = false;
// Check that the workspace is contigiuous
CHECK_EQ(workspace.CheckContiguous(), true);
Expand All @@ -544,7 +543,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce );
KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig<ndim>::unroll_reduce, UNROLL, {
reduce_kernel<Reducer, ndim, DType, OP, UNROLL>
reduce_kernel<Reducer, ndim, AType, DType, OType, OP, UNROLL>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
small.shape_.get<ndim>(), config.rshape, config.rstride, config.Mnext,
Expand All @@ -553,9 +552,9 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);

if (config.Mnext > 1) {
reduce_lines_kernel<Reducer, DType>
reduce_lines_kernel<Reducer, OType>
<<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<OType>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
}
}
Expand Down Expand Up @@ -610,14 +609,26 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const

#undef KERNEL_UNROLL_SWITCH

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, DType, OP>(stream, small, req, big, workspace, config);
if (safe_acc) {
MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
stream, small, req, big, workspace, config);
});
});
} else {
ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
}
}

template <typename Reducer, int ndim, typename DType, typename OP>
Expand Down
Loading

0 comments on commit 47be295

Please sign in to comment.