Skip to content

Commit

Permalink
Large array support for randint (apache#14242)
Browse files Browse the repository at this point in the history
* large array support for randint

* with seed for 2 random large array tests

* Trigger notification
  • Loading branch information
ChaiBapchya authored and vdantu committed Mar 31, 2019
1 parent 6d54e49 commit d40e677
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ struct UniformSampler {
template<typename xpu>
struct SampleRandIntKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lower, const IType *upper, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down
14 changes: 14 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mxnet as mx
import numpy as np
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

# dimension constants
MEDIUM_X = 10000
Expand Down Expand Up @@ -45,10 +46,23 @@ def test_ndarray_ones():
assert a[-1][0] == 1
assert nd.sum(a).asnumpy() == LARGE_SIZE

@with_seed()
def test_ndarray_random_uniform():
a = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
assert a[-1][0] != 0

@with_seed()
def test_ndarray_random_randint():
a = nd.random.randint(100, 10000, shape=(LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
# check if randint can generate value greater than 2**32 (large)
low_large_value = 2**32
high_large_value = 2**34
a = nd.random.randint(low_large_value,high_large_value)
low = mx.nd.array([low_large_value],dtype='int64')
high = mx.nd.array([high_large_value],dtype='int64')
assert a.__gt__(low) & a.__lt__(high)

def test_ndarray_empty():
a = nd.empty((LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
Expand Down

0 comments on commit d40e677

Please sign in to comment.