Skip to content

Commit

Permalink
Softmax fp16 (#201)
Browse files Browse the repository at this point in the history
* softmax for fp16 with fp32 accumulator

* return AType in kernel

* add dtype

* kernel
  • Loading branch information
eric-haibin-lin authored Feb 12, 2019
1 parent f5ba735 commit 48b6c30
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 71 deletions.
42 changes: 42 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,48 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
typedef float AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*!
* \brief assign the val to out according
Expand Down
Loading

0 comments on commit 48b6c30

Please sign in to comment.