Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix failing empty array (log_)softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Jun 22, 2020
1 parent 2fbec60 commit be0035c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ bool SupportMKLDNNLogSoftmax(const SoftmaxParam &param,
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();
const int axis = CheckAxis(param.axis, ndim);
const size_t array_size = data.shape().Size();
// MKLDNN does not support temperature argument in their log_softmax function
// now. Need update this once they start to support it.
// Currently, MKLDNN shows bad performance when log_softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
in_dtype != out_dtype ||
axis != (ndim - 1) ||
array_size == 0) {
return false;
}

Expand Down
6 changes: 4 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();
const int axis = CheckAxis(param.axis, ndim);
const size_t array_size = data.shape().Size();
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
// Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
in_dtype != out_dtype ||
axis != (ndim - 1) ||
array_size == 0) {
return false;
}

Expand Down

0 comments on commit be0035c

Please sign in to comment.