Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[fix] Support nullop in transpose #15865

Merged
merged 11 commits into from
Sep 2, 2019
5 changes: 4 additions & 1 deletion src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,11 @@ void Transpose(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) {
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
if (param.axes.ndim() == 0) {
mxnet::TShape axes(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
Expand Down
5 changes: 4 additions & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,11 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (req[0] == kNullOp) {
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo";
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);

Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,22 @@ def test_children_same_name():
for c in b.get_children():
pass


def test_transpose_nullop():
for dim in range(1, 7):
a = mx.sym.Variable('a')
b = mx.sym.transpose(a, axes=tuple(np.random.permutation(dim)))
c = mx.sym.zeros_like(b)

shape = rand_shape_nd(dim)
nd_a = mx.nd.random.normal(shape=shape)
c_out = c.eval(ctx=mx.cpu(), a=nd_a)
b_out = b.eval(ctx=mx.cpu(), a=nd_a)

assert mx.test_utils.same(c_out[0].asnumpy(),
np.zeros_like(b_out[0].asnumpy()))


def test_gen_atomic_symbol_multiple_outputs():
data=mx.sym.Variable('data')
p = mx.sym.Variable('param')
Expand Down