diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 611dd7287206..4df32d6a6331 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -330,8 +330,11 @@ void Transpose(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (req[0] == kNullOp) { + return; + } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; + CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; if (param.axes.ndim() == 0) { mxnet::TShape axes(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index f02a38ac07c4..9137737095ab 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -322,8 +322,11 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (req[0] == kNullOp) { + return; + } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; + CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 28f302f9ec15..ae770ed0644f 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -389,6 +389,22 @@ def test_children_same_name(): for c in b.get_children(): pass + +def test_transpose_nullop(): + for dim in range(1, 7): + a = mx.sym.Variable('a') + b = mx.sym.transpose(a, axes=tuple(np.random.permutation(dim))) + c = mx.sym.zeros_like(b) + + shape = rand_shape_nd(dim) + nd_a = mx.nd.random.normal(shape=shape) + c_out = c.eval(ctx=mx.cpu(), a=nd_a) + b_out = b.eval(ctx=mx.cpu(), a=nd_a) + + assert mx.test_utils.same(c_out[0].asnumpy(), + np.zeros_like(b_out[0].asnumpy())) + + def test_gen_atomic_symbol_multiple_outputs(): data=mx.sym.Variable('data') p = mx.sym.Variable('param')