Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
return AType in kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 8, 2019
1 parent 1f17c28 commit eea52d3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ namespace mxnet_op {

struct softmax_fwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType a, AType b) {
return DType(expf(a)/b);
MSHADOW_XINLINE static AType Map(DType a, AType b) {
return AType(expf(a)/b);
}
};


struct log_softmax_fwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType a, AType b) {
return DType(a - logf(b));
MSHADOW_XINLINE static AType Map(DType a, AType b) {
return AType(a - logf(b));
}
};

Expand Down Expand Up @@ -105,16 +105,16 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,

struct softmax_bwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, AType sum) {
return DType(out * (ograd - sum));
MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
return AType(out * (ograd - sum));
}
};


struct log_softmax_bwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, AType sum) {
return DType(ograd - expf(out)*sum);
MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
return AType(ograd - expf(out)*sum);
}
};

Expand Down

0 comments on commit eea52d3

Please sign in to comment.