Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 11, 2019
1 parent 0cf2c47 commit fd37040
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,27 @@ namespace op {
namespace mxnet_op {

struct softmax_fwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static AType Map(DType a, AType b) {
template<typename AType>
MSHADOW_XINLINE static AType Map(float a, AType b) {
return AType(expf(a)/b);
}

template<typename AType>
MSHADOW_XINLINE static AType Map(double a, AType b) {
return AType(exp(a)/b);
}
};


struct log_softmax_fwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static AType Map(DType a, AType b) {
return AType(a - logf(b));
template<typename DType>
MSHADOW_XINLINE static float Map(DType a, float b) {
return a - logf(b);
}

template<typename DType>
MSHADOW_XINLINE static double Map(DType a, double b) {
return a - log(b);
}
};

Expand Down Expand Up @@ -111,10 +121,15 @@ struct softmax_bwd {


struct log_softmax_bwd {
template<typename DType, typename AType>
MSHADOW_XINLINE static AType Map(DType ograd, DType out, AType sum) {
template<typename AType>
MSHADOW_XINLINE static AType Map(float ograd, float out, AType sum) {
return AType(ograd - expf(out)*sum);
}

template<typename AType>
MSHADOW_XINLINE static AType Map(double ograd, double out, AType sum) {
return AType(ograd - exp(out)*sum);
}
};


Expand Down

0 comments on commit fd37040

Please sign in to comment.