From d7da15c8ad720f8c10370051c3529b4ec6d42c2e Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 15 May 2019 23:11:11 +0000 Subject: [PATCH] add ctx for rand_ndarray and rand_sparse_ndarray --- python/mxnet/test_utils.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 7b46be487488..fb40474bc678 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -260,7 +260,7 @@ def assign_each2(input1, input2, function): def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, - shuffle_csr_indices=False): + shuffle_csr_indices=False, ctx=None): """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) Parameters @@ -301,6 +301,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non >>> assert(row4nnz == 2*row3nnz) """ + ctx = ctx if ctx else default_context() density = rnd.rand() if density is None else density dtype = default_dtype() if dtype is None else dtype distribution = "uniform" if distribution is None else distribution @@ -315,7 +316,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non idx_sample = rnd.rand(shape[0]) indices = np.argwhere(idx_sample < density).flatten() if indices.shape[0] == 0: - result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype) + result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, ctx=ctx) return result, (np.array([], dtype=dtype), np.array([])) # generate random values val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype) @@ -326,17 +327,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non if modifier_func is not None: val = assign_each(val, modifier_func) - arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype) + arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype, ctx=ctx) return arr, (val, indices) elif stype == 'csr': assert len(shape) == 2 if distribution == "uniform": csr = _get_uniform_dataset_csr(shape[0], shape[1], density, data_init=data_init, - shuffle_csr_indices=shuffle_csr_indices, dtype=dtype) + shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx) return csr, (csr.indptr, csr.indices, csr.data) elif distribution == "powerlaw": - csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype) + csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype).as_in_context(ctx) return csr, (csr.indptr, csr.indices, csr.data) else: assert(False), "Distribution not supported: %s" % (distribution) @@ -345,15 +346,17 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non assert(False), "unknown storage type" return False -def rand_ndarray(shape, stype='default', density=None, dtype=None, - modifier_func=None, shuffle_csr_indices=False, distribution=None): +def rand_ndarray(shape, stype='default', density=None, dtype=None, modifier_func=None, + shuffle_csr_indices=False, distribution=None, ctx=None): + """Generate a random sparse ndarray. Returns the generated ndarray.""" + ctx = ctx if ctx else default_context() if stype == 'default': - arr = mx.nd.array(random_arrays(shape), dtype=dtype) + arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx) else: arr, _ = rand_sparse_ndarray(shape, stype, density=density, modifier_func=modifier_func, dtype=dtype, shuffle_csr_indices=shuffle_csr_indices, - distribution=distribution) + distribution=distribution, ctx=ctx) return arr