Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Softmax with length (#15169)
Browse files Browse the repository at this point in the history
* softmax with length forward

* softmax with length backward

* new macro to reduce compile-time heap usage and limit length to integers only

* address comments
  • Loading branch information
haojin2 committed Jul 19, 2019
1 parent 6bdcef2 commit 076b2f3
Show file tree
Hide file tree
Showing 4 changed files with 487 additions and 64 deletions.
51 changes: 51 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,57 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
Loading

0 comments on commit 076b2f3

Please sign in to comment.