Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support mixed-precision true_divide (#16711)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and reminisce committed Nov 3, 2019
1 parent 27a8fd5 commit e139442
Show file tree
Hide file tree
Showing 12 changed files with 549 additions and 75 deletions.
36 changes: 36 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
return mshadow::kFloat64;
}
if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
return mshadow::kFloat32;
}
return mshadow::kFloat16;
} else if (is_float(type1) || is_float(type2)) {
return is_float(type1) ? type1 : type2;
}
if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
return mshadow::kInt64;
}
if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
return mshadow::kInt32;
}
CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)))
<< "1 is UInt8 and 1 is Int8 should not get here";
if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
return mshadow::kUint8;
}
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
3 changes: 1 addition & 2 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ class LeakyReLUOp : public Operator {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::xelu>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::xelu>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape,
in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>(),
out_data[leakyrelu::kOut].dptr<DType>());
Expand Down
40 changes: 40 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ struct true_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(a) / static_cast<float>(b);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) / b;
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) / b;
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) / b;
}
#endif
};

struct rtrue_divide : public mxnet_op::tunable {
Expand All @@ -146,6 +166,26 @@ struct rtrue_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(b) / static_cast<float>(a);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return b / static_cast<mshadow::half::half_t>(a);
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return b / static_cast<float>(a);
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return b / static_cast<double>(a);
}
#endif
};

MXNET_BINARY_MATH_OP_NC(left, a);
Expand Down
113 changes: 113 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,69 @@ struct AccType<mshadow::half::half_t> {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
break; \
case mshadow::kUint8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not uint8"; \
} \
break; \
case mshadow::kInt8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not int8"; \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not bool"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
Expand Down Expand Up @@ -783,6 +846,56 @@ struct op_with_req {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

#ifndef _WIN32
/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i,
mshadow::half::half_t *out,
const DType *lhs,
const mshadow::half::half_t *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i,
mshadow::half::half_t *out,
const DType *lhs,
const mshadow::half::half_t value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}
#endif

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,7 @@ class DropoutOp {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::mul>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[dropout::kOut],
lstride, rstride, oshape,
in.dptr<DType>(),
Expand Down Expand Up @@ -463,8 +462,7 @@ class DropoutOp {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::mul>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>());
});
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
<< "Mask needs to be provided when using softmax with use_length=True.";
type = inputs[1].type_flag_;
}
MXNET_INT_TYPE_SWITCH(type, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
Expand Down Expand Up @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (softmax_use_length(attrs)) {
MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, {
if (req[1] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
Expand All @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MXNET_INT_TYPE_SWITCH(itype, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, {
IType * length_ptr = nullptr;
if (softmax_use_length(attrs)) {
length_ptr = inputs[2].dptr<IType>();
Expand Down
Loading

0 comments on commit e139442

Please sign in to comment.