Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support mixed-precision true_divide #16711

Merged
merged 1 commit into from
Nov 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
return mshadow::kFloat64;
}
if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
return mshadow::kFloat32;
}
return mshadow::kFloat16;
} else if (is_float(type1) || is_float(type2)) {
return is_float(type1) ? type1 : type2;
}
if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
return mshadow::kInt64;
}
if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
return mshadow::kInt32;
}
CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)))
<< "1 is UInt8 and 1 is Int8 should not get here";
if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
return mshadow::kUint8;
}
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
3 changes: 1 addition & 2 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ class LeakyReLUOp : public Operator {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::xelu>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::xelu>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[leakyrelu::kOut], lstride, rstride, oshape,
in_data[leakyrelu::kData].dptr<DType>(), in_data[leakyrelu::kGamma].dptr<DType>(),
out_data[leakyrelu::kOut].dptr<DType>());
Expand Down
40 changes: 40 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ struct true_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(a) / static_cast<float>(b);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(a) / b;
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return static_cast<float>(a) / b;
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return static_cast<double>(a) / b;
}
#endif
};

struct rtrue_divide : public mxnet_op::tunable {
Expand All @@ -146,6 +166,26 @@ struct rtrue_divide : public mxnet_op::tunable {
MSHADOW_XINLINE static float Map(DType a, DType b) {
return static_cast<float>(b) / static_cast<float>(a);
}

#ifndef _WIN32
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return b / static_cast<mshadow::half::half_t>(a);
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return b / static_cast<float>(a);
}

template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return b / static_cast<double>(a);
}
#endif
};

MXNET_BINARY_MATH_OP_NC(left, a);
Expand Down
113 changes: 113 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,69 @@ struct AccType<mshadow::half::half_t> {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
break; \
case mshadow::kUint8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not uint8"; \
} \
break; \
case mshadow::kInt8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not int8"; \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not bool"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
Expand Down Expand Up @@ -783,6 +846,56 @@ struct op_with_req {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

#ifndef _WIN32
/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i,
mshadow::half::half_t *out,
const DType *lhs,
const mshadow::half::half_t *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief inputs are two tensors with a half_t output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i,
mshadow::half::half_t *out,
const DType *lhs,
const mshadow::half::half_t value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}

/*! \brief inputs are two tensors with a double output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
}
#endif

/*! \brief inputs are two tensors with a float output tensor */
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,7 @@ class DropoutOp {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::mul>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[dropout::kOut],
lstride, rstride, oshape,
in.dptr<DType>(),
Expand Down Expand Up @@ -463,8 +462,7 @@ class DropoutOp {
mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, DType, DType,
mshadow_op::mul>, xpu>::
mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, mshadow_op::mul>, xpu>::
template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
grad.dptr<DType>(), mask.dptr<DType>(), gdata.dptr<DType>());
});
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
<< "Mask needs to be provided when using softmax with use_length=True.";
type = inputs[1].type_flag_;
}
MXNET_INT_TYPE_SWITCH(type, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
Expand Down Expand Up @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (softmax_use_length(attrs)) {
MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, {
if (req[1] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
Expand All @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MXNET_INT_TYPE_SWITCH(itype, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, {
IType * length_ptr = nullptr;
if (softmax_use_length(attrs)) {
length_ptr = inputs[2].dptr<IType>();
Expand Down
Loading