Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[fix] Support nullop in transpose #15865

Merged
merged 11 commits into from
Sep 2, 2019
3 changes: 3 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ void Transpose(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) {
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
if (param.axes.ndim() == 0) {
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (req[0] == kNullOp) {
return;
}
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
CHECK_EQ(inputs.size(), 1U);
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,22 @@ def test_children_same_name():
for c in b.get_children():
pass


def test_transpose_nullop():
for dim in range(1, 7):
a = mx.sym.Variable('a')
b = mx.sym.transpose(a, axes=tuple(np.random.permutation(dim)))
c = mx.sym.zeros_like(b)

shape = rand_shape_nd(dim)
nd_a = mx.nd.random.normal(shape=shape)
c_out = c.eval(ctx=mx.cpu(), a=nd_a)
b_out = b.eval(ctx=mx.cpu(), a=nd_a)

assert mx.test_utils.same(c_out[0].asnumpy(),
np.zeros_like(b_out[0].asnumpy()))


if __name__ == '__main__':
import nose
nose.runmodule()