Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add ctx argument for rand_ndarray and rand_sparse_ndarray test util funcs #14966

Merged
merged 1 commit into from
May 18, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def assign_each2(input1, input2, function):

def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False):
shuffle_csr_indices=False, ctx=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)

Parameters
Expand Down Expand Up @@ -301,6 +301,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
>>> assert(row4nnz == 2*row3nnz)

"""
ctx = ctx if ctx else default_context()
density = rnd.rand() if density is None else density
dtype = default_dtype() if dtype is None else dtype
distribution = "uniform" if distribution is None else distribution
Expand All @@ -315,7 +316,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype)
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, ctx=ctx)
return result, (np.array([], dtype=dtype), np.array([]))
# generate random values
val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
Expand All @@ -326,17 +327,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
if modifier_func is not None:
val = assign_each(val, modifier_func)

arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype)
arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype, ctx=ctx)
return arr, (val, indices)
elif stype == 'csr':
assert len(shape) == 2
if distribution == "uniform":
csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
data_init=data_init,
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype)
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
elif distribution == "powerlaw":
csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype)
csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
else:
assert(False), "Distribution not supported: %s" % (distribution)
Expand All @@ -345,15 +346,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
assert(False), "unknown storage type"
return False

def rand_ndarray(shape, stype='default', density=None, dtype=None,
modifier_func=None, shuffle_csr_indices=False, distribution=None):
def rand_ndarray(shape, stype='default', density=None, dtype=None, modifier_func=None,
shuffle_csr_indices=False, distribution=None, ctx=None):
"""Generate a random sparse ndarray. Returns the generated ndarray."""
ctx = ctx if ctx else default_context()
if stype == 'default':
arr = mx.nd.array(random_arrays(shape), dtype=dtype)
arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
else:
arr, _ = rand_sparse_ndarray(shape, stype, density=density,
modifier_func=modifier_func, dtype=dtype,
shuffle_csr_indices=shuffle_csr_indices,
distribution=distribution)
distribution=distribution, ctx=ctx)
return arr


Expand Down