Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Large array support for randint #14242

Merged
merged 3 commits into from
Mar 2, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ struct UniformSampler {
template<typename xpu>
struct SampleRandIntKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lower, const IType *upper, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down
11 changes: 11 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def test_ndarray_random_uniform():
a = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
assert a[-1][0] != 0

def test_ndarray_random_randint():
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
a = nd.random.randint(100, 10000, shape=(LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
# check if randint can generate value greater than 2**32 (large)
low_large_value = 2**32
high_large_value = 2**34
a = nd.random.randint(low_large_value,high_large_value)
low = mx.nd.array([low_large_value],dtype='int64')
high = mx.nd.array([high_large_value],dtype='int64')
assert a.__gt__(low) & a.__lt__(high)

def test_ndarray_empty():
a = nd.empty((LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
Expand Down