diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 5885d73efe29..c513e65b0c0e 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -113,7 +113,14 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); if (req[0] == kNullOp) return; - if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) || +#if MXNET_USE_MKLDNN == 1 + if (IsMKLDNNData(inputs)) { + MKLDNNRun(MKLDNNSumForward, attrs, ctx, inputs, req, outputs); + } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { + FallBackCompute(ElementWiseSumCompute, attrs, ctx, inputs, req, outputs); + } +#endif + else if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) || // NOLINT(*) (inputs.size() == 3U && inputs[0].storage_type() == kDefaultStorage && inputs[1].storage_type() == kCSRStorage && inputs[2].storage_type() == kDefaultStorage) || (inputs.size() > 4U && common::ContainsStorageType(inputs, kDefaultStorage) && @@ -123,12 +130,6 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, ResourceRequest(ResourceRequest::kTempSpace)); NDArray out_nd = outputs[0]; mxnet::ndarray::ElementwiseSum(s, rsc, inputs, &out_nd); -#if MXNET_USE_MKLDNN == 1 - } else if (IsMKLDNNData(inputs)) { - MKLDNNRun(MKLDNNSumForward, attrs, ctx, inputs, req, outputs); - } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { - FallBackCompute(ElementWiseSumCompute, attrs, ctx, inputs, req, outputs); -#endif } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index b52bb03a80c8..44e7d3cf2be9 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -254,6 +254,28 @@ def test_flatten_slice_after_conv(): print(p[0]) +def test_mkldnn_sum_with_mkldnn_layout(): + + x_shape = (32, 3, 224, 224) + x_npy = np.ones(x_shape) + w_shape = (32, 3, 3, 3) + w_npy = np.ones(w_shape) + + x = mx.sym.Variable("x") + w = mx.sym.Variable("w") + z = mx.symbol.Convolution(data=x, weight=w, num_filter=32, kernel=(3, 3)) + num_inputs = [2, 3, 4, 5] + for i in num_inputs: + inputs = [] + for n in range(i): + inputs.append(z) + y = mx.sym.add_n(*inputs) # (only MKLDNN data input) + exe = y.simple_bind(ctx=mx.cpu(), x=x_shape, w=w_shape) + out = exe.forward(is_train=False, x=x_npy, w=np.ones(w_shape))[0] + #conv with kernel (3,3) on ones should give result=27 + single_cov = 27.0 + assert_almost_equal(out[0].asnumpy()[0, 0, 0], single_cov*i) + def test_mkldnn_sum_inplace_with_cpu_layout(): x_shape = (32, 3, 224, 224) @@ -263,7 +285,7 @@ def test_mkldnn_sum_inplace_with_cpu_layout(): x = mx.sym.Variable("x") y = mx.sym.Variable("y") z = mx.symbol.Convolution(data=x, num_filter=32, kernel=(3, 3)) - z = mx.sym.add_n(z, y) + z = mx.sym.add_n(z, y) # (MKLDNN data, cpu data) exe = z.simple_bind(ctx=mx.cpu(), x=x_shape, y=y_shape) out = exe.forward(is_train=False, x=x_npy, y=y_npy)[0] assert_almost_equal(out[0].asnumpy()[0, 0, 0], 1.0)