Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
softmax_fallbach
Browse files Browse the repository at this point in the history
  • Loading branch information
luobao-intel committed Jul 6, 2018
1 parent e94146f commit 2e68c96
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "../tensor/elemwise_binary_op.h"
#include "mkldnn/mkldnn_base-inl.h"
#include "mkldnn/mkldnn_ops-inl.h"
#include "../../operator_common.h"

namespace mxnet {
namespace op {
Expand All @@ -49,7 +50,6 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
FallBackCompute(SoftmaxCompute<cpu, mxnet_op::softmax_fwd>, attrs, ctx,
inputs, req, outputs);
}
#endif

inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
Expand All @@ -67,9 +67,18 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
else
#endif
wanted_mode = DispatchMode::kFCompute;
return storage_type_assign(out_attrs, static_cast<NDArrayStorageType>((*in_attrs)[0]),
dispatch_mode, wanted_mode);

bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)){
dispatched = op::storage_type_assign(out_attrs, mxnet::kDefaultStorage, dispatch_mode, wanted_mode);
}
if (!dispatched){
dispatched = op::dispatch_fallback(out_attrs, dispatch_mode);
}

return dispatched;
}
#endif

MXNET_OPERATOR_REGISTER_UNARY(softmax)
.describe(R"code(Applies the softmax function.
Expand Down

0 comments on commit 2e68c96

Please sign in to comment.