From b525e2e0d394c3636b311d5ccbe40e11419f84c5 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 16 Oct 2019 23:59:29 +0000 Subject: [PATCH 1/3] bug fix for the input of same axes --- src/operator/swapaxis-inl.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/operator/swapaxis-inl.h b/src/operator/swapaxis-inl.h index fd9872db6ec8..319fd508c380 100644 --- a/src/operator/swapaxis-inl.h +++ b/src/operator/swapaxis-inl.h @@ -35,6 +35,7 @@ #include #include #include "./operator_common.h" +#include "./mshadow_op.h" namespace mxnet { namespace op { @@ -63,7 +64,6 @@ template class SwapAxisOp : public Operator { public: explicit SwapAxisOp(SwapAxisParam p) { - CHECK_NE(p.dim1, p.dim2) << "dim1 can not be equal dim2."; this->param_ = p; } @@ -131,6 +131,16 @@ class SwapAxisOp : public Operator { if (shape_in.Size() == 0U) return; + if (axis1 == axis2) { + if (out_req == kAddTo) { + mxnet_op::Kernel, xpu>::Launch( + s, data_out.Size(), data_out.dptr(), data_in.dptr()); + } else { + mxnet_op::copy(s, data_out, data_in); + } + return; + } + Shape<5> inter_shape; Reshape2Five(&inter_shape, shape_in, axis1, axis2); From a457bef20ed502077122e3ebf341707bd8f8b5e5 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Thu, 17 Oct 2019 00:46:08 +0000 Subject: [PATCH 2/3] tests added --- tests/python/unittest/test_operator.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 66bd9ec6b489..a16dc6c693ab 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -716,6 +716,23 @@ def test_swapaxes(): assert_almost_equal(out, swap_) + config = [((1, 1, 2), 0, 1), + ((1, 1, 2), -1, -2), + ((4, 5, 6, 7), 1, 1), + ((4, 5, 6, 7), 2, 3), + ((4, 5, 6, 7), -2, 2), + ((4, 5, 6, 7), -2, -3)] + + for shape, axis1, axis2 in config: + data_np = np.random.uniform(size=shape) + data_mx = mx.nd.array(data_np, dtype=data_np.dtype) + ret_np = np.swapaxes(data_np, axis1=axis1, axis2=axis2) + ret_mx = mx.symbol.SwapAxis(data, dim1=axis1, dim2=axis2) + exe_c = ret_mx.bind(default_context(), args=[data_mx]) + exe_c.forward(is_train=True) + out = exe_c.outputs[0] + assert_almost_equal(out, ret_np) + @with_seed() def test_scalarop(): From e6126d609dd897f772f9fee0d43271a9af56e7f1 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Thu, 17 Oct 2019 08:26:19 +0000 Subject: [PATCH 3/3] numpy swapaxes tests added --- tests/python/unittest/test_numpy_op.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 28edc2e05306..d1ff6c61793a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1967,9 +1967,11 @@ def check_comp_op(op_name, x1, x2): @with_seed() @use_np def test_np_swapaxes(): - config = [((0, 1, 2), 0, 1), - ((0, 1, 2), -1, -2), - ((4, 5, 6, 7), 2, 3), + config = [((0, 1, 2), 0, 0), + ((0, 1, 2), 1, 2), + ((0, 1, 2), 1, -2), + ((4, 5, 6, 7), 1, 1), + ((4, 5, 6, 7), 2, -2), ((4, 5, 6, 7), -2, -3)] class TestSwapaxes(HybridBlock):