Skip to content

Commit

Permalink
make TransposeShape infer shape form both sides (apache#15713)
Browse files Browse the repository at this point in the history
* make TransposeShape infer shape form both sides

* small fixes

* remove redundant lines

* unit tests
  • Loading branch information
dtracz authored and Ubuntu committed Aug 20, 2019
1 parent 950eae2 commit 34bd471
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,19 +344,34 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
mxnet::TShape ret(shp.ndim(), -1);
CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0";
CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0";
if (shp.ndim() == -1 && out_shp.ndim() == -1)
return false; // none of the shapes is known
if (out_shp.ndim() > 0 && shp.ndim() > 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1);
if (param.axes.ndim() == 0) {
for (int i = 0; i < shp.ndim(); ++i) {
ret[i] = shp[shp.ndim()-1-i];
}
for (int i = 0; i < out_shp.ndim(); ++i) {
get[shp.ndim()-1-i] = out_shp[i];
}
} else {
CHECK_EQ(shp.ndim(), param.axes.ndim());
CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(param.axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[param.axes[i]];
}
for (int i = 0; i < out_shp.ndim(); ++i) {
get[param.axes[i]] = out_shp[i];
}
}
SHAPE_ASSIGN_CHECK(*in_attrs, 0, get);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
return shape_is_known(ret);
}
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8984,6 +8984,26 @@ def test_get_operator_arguments():
ok_(operator_arguments.narg == 2)


def test_transpose_infer_shape_back():
o1 = mx.sym.ones(shape=[2,3])
o2 = mx.sym.ones(shape=[-1,-1])
t = mx.sym.transpose(o2)
b = o1 + t
x = b.bind(mx.cpu(), args={})
y = x.forward()
assert(y[0].shape == (2,3))


def test_transpose_infer_shape_mixed():
o1 = mx.sym.ones(shape=[2,-1])
o2 = mx.sym.ones(shape=[3,-1])
t = mx.sym.transpose(o2)
b = o1 + t
x = b.bind(mx.cpu(), args={})
y = x.forward()
assert(y[0].shape == (2,3))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 34bd471

Please sign in to comment.