From 727b84d7b0c7d6b64ae01c1aee1fd79dccce23ce Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 1 Nov 2019 08:54:04 +0000 Subject: [PATCH] cut some integers in softmax --- src/operator/mxnet_op.h | 57 +++++++++++++++++++++++++++++++++++ src/operator/nn/softmax-inl.h | 6 ++-- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 468517d0fb4e..6d97bb65e9a5 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -481,6 +481,63 @@ struct AccType { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not uint8"; \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not int8"; \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kBool: \ + { \ + LOG(FATAL) << "This operation only support " \ + "integer types, not bool"; \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + #define MXNET_LOAD_TYPE_SWITCH(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: \ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 601a0526650c..89da570c133b 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, << "Mask needs to be provided when using softmax with use_length=True."; type = inputs[1].type_flag_; } - MXNET_INT_TYPE_SWITCH(type, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(type, IType, { IType* mask_ptr = nullptr; if (param.use_length.value()) { mask_ptr = inputs[1].dptr(); @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mxnet_op; if (softmax_use_length(attrs)) { - MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, { if (req[1] != kNullOp) { mxnet_op::Kernel::Launch( ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MXNET_INT_TYPE_SWITCH(itype, IType, { + MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, { IType * length_ptr = nullptr; if (softmax_use_length(attrs)) { length_ptr = inputs[2].dptr();