Skip to content

Commit

Permalink
[MXNET-1410]Adding Large Tensor Support for tensor transpose (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
access2rohit authored and haohuw committed Jun 23, 2019
1 parent 15c9932 commit b82ff54
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1950,10 +1950,10 @@ struct ReverseParam : public dmlc::Parameter<ReverseParam> {
#define REVERSE_MAX_DIM 10U

struct reverse {
MSHADOW_XINLINE static int ReverseIndex(index_t idx,
index_t nreversedim,
const index_t * stride_,
const index_t * trailing_) {
MSHADOW_XINLINE static index_t ReverseIndex(index_t idx,
index_t nreversedim,
const index_t * stride_,
const index_t * trailing_) {
index_t outputIndex = idx;
for (index_t i = 0; i < nreversedim; ++i) {
const index_t low = outputIndex % trailing_[i];
Expand Down
27 changes: 27 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,33 @@ def test_unravel_index():
assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()


def create_2d_tensor(rows, columns):
a = np.arange(0, rows).reshape(rows, 1)
b = np.broadcast_to(a, shape=(a.shape[0], columns))
return nd.array(b, dtype=np.int64)


def test_transpose():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = b.T
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]


def test_swapaxes():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.swapaxes(b, dim1=0, dim2=1)
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]


def test_flip():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.flip(b, axis=0)
assert t.shape == (LARGE_X, SMALL_Y)
assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1]


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit b82ff54

Please sign in to comment.