Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
cut some integers in softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 1, 2019
1 parent 6b5b420 commit 727b84d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
57 changes: 57 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,63 @@ struct AccType<mshadow::half::half_t> {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_INT32_INT64_TYPE_SWITCH(type, DType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
break; \
case mshadow::kUint8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not uint8"; \
} \
break; \
case mshadow::kInt8: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not int8"; \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
LOG(FATAL) << "This operation only support " \
"integer types, not bool"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_LOAD_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
<< "Mask needs to be provided when using softmax with use_length=True.";
type = inputs[1].type_flag_;
}
MXNET_INT_TYPE_SWITCH(type, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
Expand Down Expand Up @@ -834,7 +834,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (softmax_use_length(attrs)) {
MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(inputs[2].type_flag_, IType, {
if (req[1] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
Expand All @@ -856,7 +856,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MXNET_INT_TYPE_SWITCH(itype, IType, {
MXNET_INT32_INT64_TYPE_SWITCH(itype, IType, {
IType * length_ptr = nullptr;
if (softmax_use_length(attrs)) {
length_ptr = inputs[2].dptr<IType>();
Expand Down

0 comments on commit 727b84d

Please sign in to comment.