Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix SupportMKLDNN function for Convolution and Reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
agrabow committed Sep 2, 2021
1 parent b11cf8d commit dc3c2e9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ namespace op {
DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray& input) {
if ((params.kernel.ndim() != 1) && (params.kernel.ndim() != 2) && (params.kernel.ndim() != 3))
if (params.kernel.ndim() > 3 || params.kernel.ndim() == 0)
return false;
return IsMKLDNNType(input.dtype()) &&
((input.shape().ndim() == 3) || (input.shape().ndim() == 4) ||
(input.shape().ndim() == 5));
input.shape().ndim() >= 3 && input.shape().ndim() <= 5;
}

std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ namespace op {
bool SupportMKLDNNReshape(const NDArray& input, const NDArray& output) {
const int input_ndims = input.shape().ndim();
const int output_ndims = output.shape().ndim();
return input_ndims >= 1 && input_ndims <= 6 && output_ndims >= 1 && output_ndims <= 6 &&
IsMKLDNNType(input.dtype()) && input.shape().Size() > 0;
return input.shape().Size() > 0 && input_ndims >= 1 && input_ndims <= 6 && output_ndims >= 1 &&
output_ndims <= 6 && IsMKLDNNType(input.dtype());
}

MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType& req,
Expand Down

0 comments on commit dc3c2e9

Please sign in to comment.