Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enabling Large Index support for slice and softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 23, 2019
1 parent 588a0e9 commit 21c13f2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 54 deletions.
13 changes: 13 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ def assign_each2(input1, input2, function):

return output

# For testing Large Tensors having total size > 2^32 elements
def create_2d_tensor(rows, columns, dtype=np.int64):
a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
b = nd.broadcast_to(a, shape=(a.shape[0], columns))
return nd.array(b, dtype=dtype)

# For testing Large Vectors having total size > 2^32 elements
def create_vector(size, dtype=np.int64):
a = nd.arange(0, size, dtype=dtype)
# Implicitly calling nd.waitall()
assert a[0] == 0
return a

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
Expand Down
18 changes: 9 additions & 9 deletions src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ class SoftmaxOutputOp : public Operator {
CHECK_EQ(out_data.size(), 1U) << "SoftmaxOutput Output: [output]";
Stream<xpu> *s = ctx.get_stream<xpu>();
if (param_.multi_output) {
int n = in_data[softmaxout_enum::kData].size(0);
int k = in_data[softmaxout_enum::kData].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<int>(in_data[softmaxout_enum::kData].Size()/n/k));
index_t n = in_data[softmaxout_enum::kData].size(0);
index_t k = in_data[softmaxout_enum::kData].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<index_t>(in_data[softmaxout_enum::kData].Size()/n/k));
Tensor<xpu, 3, DType> data =
in_data[softmaxout_enum::kData].get_with_shape<xpu, 3, DType>(s3, s);
Tensor<xpu, 3, DType> out =
Expand All @@ -131,8 +131,8 @@ class SoftmaxOutputOp : public Operator {
Tensor<xpu, 2, DType> out = out_data[softmaxout_enum::kOut].FlatTo2D<xpu, DType>(s);
Softmax(out, data);
} else {
int n = in_data[softmaxout_enum::kData].size(0);
int k = in_data[softmaxout_enum::kData].Size()/n;
index_t n = in_data[softmaxout_enum::kData].size(0);
index_t k = in_data[softmaxout_enum::kData].Size()/n;
Shape<2> s2 = Shape2(n, k);
Tensor<xpu, 2, DType> data =
in_data[softmaxout_enum::kData].get_with_shape<xpu, 2, DType>(s2, s);
Expand Down Expand Up @@ -171,9 +171,9 @@ class SoftmaxOutputOp : public Operator {
grad = (out - label) * scalar<DType>(param_.grad_scale);
}
} else if (param_.multi_output) {
int n = out_data[softmaxout_enum::kOut].size(0);
int k = out_data[softmaxout_enum::kOut].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<int>(out_data[softmaxout_enum::kOut].Size()/n/k));
index_t n = out_data[softmaxout_enum::kOut].size(0);
index_t k = out_data[softmaxout_enum::kOut].size(1);
Shape<3> s3 = Shape3(n, k, static_cast<index_t>(out_data[softmaxout_enum::kOut].Size()/n/k));
Shape<2> s2 = Shape2(s3[0], s3[2]);
Tensor<xpu, 2, DType> label =
in_data[softmaxout_enum::kLabel].get_with_shape<xpu, 2, DType>(s2, s);
Expand Down Expand Up @@ -224,7 +224,7 @@ class SoftmaxOutputOp : public Operator {
// Tensor<xpu, 2, DType> out = out_data[softmaxout_enum::kOut].FlatTo2D<xpu, DType>(s);
// Tensor<xpu, 2, DType> grad = in_grad[softmaxout_enum::kData].FlatTo2D<xpu, DType>(s);
} else {
int n = out_data[softmaxout_enum::kOut].size(0);
index_t n = out_data[softmaxout_enum::kOut].size(0);
data_shape = Shape2(n, out_data[softmaxout_enum::kOut].Size()/n);
}
Tensor<xpu, 1, DType> label = in_data[softmaxout_enum::kLabel].get_with_shape<xpu, 1, DType>(
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,8 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}

inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape,
const index_t i, const int b,
const int e, const int s,
const index_t i, const index_t b,
const index_t e, const index_t s,
mxnet::TShape* oshape) {
if (!mxnet::dim_size_is_known(dshape, i)) {
(*oshape)[i] = -1;
Expand Down Expand Up @@ -765,7 +765,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
common::StaticArray<index_t, ndim> begin, end, step;
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
for (int i = 0; i < param.begin.ndim(); ++i) {
const int b = begin[i], e = end[i], s = step[i];
const index_t b = begin[i], e = end[i], s = step[i];
SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape);
}
})
Expand Down
8 changes: 1 addition & 7 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import mxnet as mx

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

Expand All @@ -31,12 +31,6 @@
LARGE_SIZE = LARGE_X * SMALL_Y


def create_2d_tensor(rows, columns, dtype=np.int64):
a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
b = nd.broadcast_to(a, shape=(a.shape[0], columns))
return nd.array(b, dtype=dtype)


def test_gluon_embedding():
m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
m.initialize()
Expand Down
50 changes: 15 additions & 35 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import mxnet as mx

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, create_vector
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

Expand Down Expand Up @@ -85,25 +85,18 @@ def test_elementwise():
a = nd.ones(shape=LARGE_X)
b = nd.ones(shape=LARGE_X)
res = a + b
assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
assert res[-1].asnumpy() == 2
res = a + 1
assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
res = nd.sqrt(a + 3)
assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
assert res[-1].asnumpy() == 2
res = nd.sqrt(a + 8)
assert res[-1].asnumpy() == 3


def test_reduce():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1]


def test_FullyConnected():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.ones(shape=(SMALL_Y, SMALL_Y))
res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True)
assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1]


def test_broadcast():
a = nd.ones(shape=(LARGE_X, SMALL_Y*2))
b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
Expand All @@ -116,7 +109,7 @@ def test_broadcast():
def test_clip():
a = nd.arange(0, LARGE_X)
res = nd.clip(a, a_min=100, a_max=1000)
assert np.sum(res[-1].asnumpy() == 1000) == 101
assert np.sum(res[-1].asnumpy() == 1000) == 1


def test_argmin():
Expand All @@ -139,12 +132,6 @@ def test_take():
assert np.sum(res.asnumpy() == 1) == res.shape[0]


def test_slice():
a = nd.ones(shape=(2, LARGE_X))
res = nd.slice(a, begin=(1, LARGE_X-1000000000), end=(2, LARGE_X))
assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]


def test_slice_assign():
a = nd.ones(shape=LARGE_X)
a[LARGE_X-1:LARGE_X] = 1000
Expand Down Expand Up @@ -262,13 +249,6 @@ def test_unravel_index():
assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()


def create_large_vector(size, dtype=np.int64):
a = nd.arange(0, size, dtype=dtype)
# Implicitly calling nd.waitall()
assert a[0] == 0
return a


def test_transpose():
b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X)
t = b.T
Expand All @@ -285,35 +265,35 @@ def test_swapaxes():

def test_flip():
b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X)
t = nd.flip(b, axis=0)
assert t.shape == (LARGE_X, 1)
assert t[-1, :].asnumpy() == 0
t = nd.flip(b, axis=1)
assert t.shape == (1, LARGE_X)
assert t[-1, -1].asnumpy() == 0


def test_softmax():
input_data = mx.nd.ones(2, LARGE_X)
true_output = np.full(LARGE_X, 0.5)
input_data = nd.ones((2, LARGE_X))
output = nd.softmax(input_data, axis=0)
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
assert output[0][0] == 0.5
assert output[-1][-1] == 0.5


def test_argsort():
b = create_large_vector(size=LARGE_X)
b = create_vector(size=LARGE_X)
s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64)
mx.nd.waitall()
assert (s[0].asnumpy() == (LARGE_X - 1)).all()


def test_sort():
b = create_large_vector(size=LARGE_X)
b = create_vector(size=LARGE_X)
s = nd.sort(b, axis=0, is_ascend=False)
assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all()
s = nd.sort(b, is_ascend=True)
assert np.sum(s[0].asnumpy() == 0).all()


def test_topk():
b = create_large_vector(size=LARGE_X)
b = create_vector(size=LARGE_X)
k = nd.topk(b, k=10, axis=0, dtype=np.int64)
assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False)
Expand Down

0 comments on commit 21c13f2

Please sign in to comment.