From eeff444efc5bbe38cab31fef8d73fab9de36aa04 Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 15 Aug 2017 09:17:21 -0700 Subject: [PATCH] kvstore.row_sparse_pull for GPU and end-to-end benchmark: CPU vs. multi-GPUs (#150) * Add gpu support for BroadcastRowSparse * Fix bugs * Add benchmark script * Increase output dim size * Update weight on CPU using single GPU for sparse tensors * More fix * Optimize sparse_retain for special case * Change row sparse pull locations * Avoid sparse retain on cpu if possible * Use acc for metric * Fix misc --- benchmark/python/sparse_end2end.py | 226 ++++++++++++++++++ include/mxnet/ndarray.h | 6 + src/executor/graph_executor.cc | 4 +- src/kvstore/comm.h | 110 +++++++-- src/kvstore/kvstore_local.h | 2 +- src/operator/tensor/sparse_retain-inl.h | 69 +++++- tests/python/gpu/test_kvstore_gpu.py | 51 ++++ tests/python/unittest/test_sparse_operator.py | 10 +- 8 files changed, 445 insertions(+), 33 deletions(-) create mode 100644 benchmark/python/sparse_end2end.py create mode 100644 tests/python/gpu/test_kvstore_gpu.py diff --git a/benchmark/python/sparse_end2end.py b/benchmark/python/sparse_end2end.py new file mode 100644 index 000000000000..62a3b77b8482 --- /dev/null +++ b/benchmark/python/sparse_end2end.py @@ -0,0 +1,226 @@ +from mxnet.test_utils import * +import time +import argparse +import os + +parser = argparse.ArgumentParser(description="Run sparse linear regression " \ + "with distributed kvstore", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--profiler', type=int, default=0, + help='whether to use profiler') +parser.add_argument('--num-epoch', type=int, default=1, + help='number of epochs to train') +parser.add_argument('--batch-size', type=int, default=512, + help='number of examples per batch') +parser.add_argument('--num-batch', type=int, default=99999999, + help='number of batches per epoch') +parser.add_argument('--dummy-iter', type=int, default=0, + help='whether to use dummy iterator to exclude io cost') +parser.add_argument('--kvstore', type=str, default='local', + help='what kvstore to use [local, dist_sync, etc]') +parser.add_argument('--log-level', type=str, default='debug', + help='logging level [debug, info, error]') +parser.add_argument('--dataset', type=str, default='avazu', + help='what test dataset to use') +parser.add_argument('--num-gpu', type=int, default=0, + help='number of gpus to use. 0 means using cpu(0);' + 'otherwise, use gpu(0),...,gpu(num_gpu-1)') +parser.add_argument('--output-dim', type=int, default=4, + help='number of columns of the forward output') + + +def get_libsvm_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") + + +class DummyIter(mx.io.DataIter): + "A dummy iterator that always return the same batch, used for speed testing" + def __init__(self, real_iter): + super(DummyIter, self).__init__() + self.real_iter = real_iter + self.provide_data = real_iter.provide_data + self.provide_label = real_iter.provide_label + self.batch_size = real_iter.batch_size + + for batch in real_iter: + self.the_batch = batch + break + + def __iter__(self): + return self + + def next(self): + return self.the_batch + +# testing dataset sources +avazu = { + 'data_name': 'avazu-app.t', + 'data_origin_name': 'avazu-app.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", + 'feature_dim': 1000000, +} + +kdda = { + 'data_name': 'kdda.t', + 'data_origin_name': 'kdda.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", + 'feature_dim': 20216830, +} + +datasets = { 'kdda' : kdda, 'avazu' : avazu } + + +def get_sym(feature_dim): + x = mx.symbol.Variable("data", stype='csr') + norm_init = mx.initializer.Normal(sigma=0.01) + w = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), init=norm_init, stype='row_sparse') + embed = mx.symbol.dot(x, w) + y = mx.symbol.Variable("softmax_label") + model = mx.symbol.SoftmaxOutput(data=embed, label=y, name="out") + return model + + +def row_sparse_pull(kv, key, data, slices, weight_array, priority): + # if have kvstore, need to pull corresponding rows of + # the weights to each context + # column indices (NDArray type) of the csr data + # used as the row_idx of the weight row-sparse matrix + row_indices = data.indices + if len(slices) == 1: + kv.row_sparse_pull(key, weight_array, priority=priority, row_ids=row_indices) + else: # more than one slices, multi-GPU training. Need to retain weight rows according to data slices + # TODO(junwu): + # the following line blocks, may need to pre-compute + # and cache it outside the for loop + indptr = data.indptr.asnumpy() + row_idx_array = [] + for s in slices: + row_idx_array.append(row_indices[indptr[s.start]:indptr[s.stop]]) + kv.row_sparse_pull(key, weight_array, priority=priority, row_ids=row_idx_array) + + +if __name__ == '__main__': + + # arg parser + args = parser.parse_args() + num_epoch = args.num_epoch + num_batch = args.num_batch + kvstore = args.kvstore + profiler = args.profiler > 0 + batch_size = args.batch_size if args.num_gpu == 0 else args.num_gpu * args.batch_size + dummy_iter = args.dummy_iter + dataset = args.dataset + log_level = args.log_level + contexts = mx.context.cpu(0) if args.num_gpu < 1\ + else [mx.context.gpu(i) for i in range(args.num_gpu)] + + # create kvstore when there are gpus + kv = mx.kvstore.create(kvstore) if args.num_gpu >= 1 else None + rank = kv.rank if kv is not None else 0 + num_worker = kv.num_workers if kv is not None else 1 + + # only print log for rank 0 worker + import logging + if rank != 0: + log_level = logging.ERROR + elif log_level == 'DEBUG': + log_level = logging.DEBUG + else: + log_level = logging.INFO + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=log_level, format=head) + + # dataset + assert(dataset in datasets), "unknown dataset " + dataset + metadata = datasets[dataset] + feature_dim = metadata['feature_dim'] + if logging: + logging.debug('preparing data ... ') + data_dir = os.path.join(os.getcwd(), 'data') + path = os.path.join(data_dir, metadata['data_name']) + if not os.path.exists(path): + get_libsvm_data(data_dir, metadata['data_name'], metadata['url'], + metadata['data_origin_name']) + assert os.path.exists(path) + + # data iterator + train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,), + batch_size=batch_size, num_parts=num_worker, + part_index=rank) + if dummy_iter: + train_data = DummyIter(train_data) + + # model + model = get_sym(feature_dim) + + # module + mod = mx.mod.Module(symbol=model, data_names=['data'], + label_names=['softmax_label'], context=contexts) + mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0, + learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker) + mod.init_optimizer(optimizer=sgd, kvstore=kv) + # use accuracy as the metric + metric = mx.metric.create('acc') + + index = mod._exec_group.param_names.index('w') + # weight_array bound to executors of the contexts + weight_array = mod._exec_group.param_arrays[index] + + # start profiler + if profiler: + device = 'cpu' + if args.num_gpu > 0: + device = 'gpu' + str(args.num_gpu) + name = 'profile_' + args.dataset + '_' + device + '_nworker' + str(num_worker)\ + + '_batchsize' + str(args.batch_size) + '_outdim' + str(args.output_dim) + '.json' + mx.profiler.profiler_set_config(mode='all', filename=name) + mx.profiler.profiler_set_state('run') + + logging.debug('start training ...') + start = time.time() + data_iter = iter(train_data) + for epoch in range(num_epoch): + nbatch = 0 + end_of_batch = False + data_iter.reset() + metric.reset() + next_batch = next(data_iter) + if kv is not None: + row_sparse_pull(kv, 'w', next_batch.data[0], mod._exec_group.slices, weight_array, -index) + while not end_of_batch: + nbatch += 1 + batch = next_batch + + mod.forward_backward(batch) + # update parameters + mod.update() + + try: + # pre fetch next batch + next_batch = next(data_iter) + if nbatch == num_batch: + raise StopIteration + if kv is not None: + row_sparse_pull(kv, 'w', next_batch.data[0], mod._exec_group.slices, weight_array, -index) + except StopIteration: + end_of_batch = True + # accumulate prediction accuracy + mod.update_metric(metric, batch.label) + logging.info('epoch %d, %s' % (epoch, metric.get())) + if epoch == 0: + print "num_batches = ", nbatch + if profiler: + mx.profiler.profiler_set_state('stop') + end = time.time() + time_cost = end - start + logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost)) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 1bafd8b272bd..56e36dffbf27 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -226,6 +226,12 @@ class NDArray { return ptr_->aux_shapes; } + /*! returns the dtypes of all aux data */ + const std::vector& aux_types() const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_types; + } + /*! * \brief For a sparse operation on a csr matrix for example, * the size of the column index array diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index a22f030a28cb..9c4398343b1c 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1151,7 +1151,9 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { CHECK_LE(nword, std::numeric_limits::max()); // allocate float arrays TShape shape{static_cast(nword)}; - NDArray nd(shape, ctx); + // TODO(junwu): adding delay_alloc=true to create nd + // is a temporary solution. + NDArray nd(shape, ctx, true); data_pool_[i] = nd; // put the new allocated arrays to shared pool if (shared_pool != nullptr) { diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index a8f78425f9e4..cd0d3ab02825 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -112,9 +112,9 @@ class CommCPU : public Comm { // avoid extra copy for single device, but it may bring problems for // abnormal usage of kvstore if (src.size() == 1) { - if (src[0].storage_type() == buf.merged.storage_type()) { + if (src[0].storage_type() == kDefaultStorage) { return src[0]; - } else { + } else { // if sparse and only one GPU, always update weight on CPU CopyFromTo(src[0], &buf.merged, priority); return buf.merged; } @@ -188,39 +188,113 @@ class CommCPU : public Comm { } } - // TODO(haibin) support broadcast row_sparse on GPU void BroadcastRowSparse(int key, const NDArray& src, const std::vector>& dst, const bool use_copy, const int priority) override { using namespace mshadow; - auto size = dst.size(); - for (size_t i = 0; i < size; i++) { - auto out = dst[i].first; - auto row_id = dst[i].second; + CHECK_EQ(src.storage_type(), kRowSparseStorage) + << "BroadcastRowSparse expects row-sparse src NDArray"; + CHECK_EQ(src.ctx().dev_mask(), Context::kCPU) + << "BroadcastRowSparse with src on gpu context not supported"; + for (size_t i = 0; i < dst.size(); ++i) { + NDArray* out = dst[i].first; + NDArray row_id = dst[i].second; if (use_copy) { CopyFromTo(src, out, priority); } else { CHECK_EQ(out->storage_type(), kRowSparseStorage) << "BroadcastRowSparse expects row_sparse dst NDArray"; - CHECK_EQ(out->ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with dst on gpu context not supported"; CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with src on gpu context not supported"; + << "BroadcastRowSparse with row_indices on gpu context not supported"; // retain according to unique indices - Engine::Get()->PushSync([src, out, row_id](RunContext rctx) { - NDArray *output = out; - const auto indices = row_id.data(); - op::SparseRetainOpForwardRspImpl(rctx.get_stream(), - src, indices, kWriteTo, - output); - }, Context::CPU(), {src.var(), row_id.var()}, {out->var()}, - FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0]) + || (row_id.dtype() != out->aux_type(rowsparse::kIdx)) + || (out->ctx().dev_mask() != Context::kGPU); + if (use_sparse_retain) { // use sparse_retain op + const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU; + NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(), + src.ctx(), true, src.dtype(), src.aux_types()) : *out; + Engine::Get()->PushSync([=](RunContext rctx) { + const TBlob& indices = row_id.data(); + NDArray temp = out_cpu; // get rid of const qualifier + op::SparseRetainOpForwardRspImpl(rctx.get_stream(), + src, indices, kWriteTo, + &temp); + }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()}, + FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain")); + if (is_to_gpu) { + CopyFromTo(out_cpu, out, priority); + } + } else { // direct copy rows + Engine::Get()->PushSync([=](RunContext rctx) { + CopyRetainedRowsToGPU(rctx.get_stream(), rctx.get_stream(), + src, row_id, out); + }, out->ctx(), {src.var(), row_id.var()}, {out->var()}, + FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU")); + } } } } private: + /*! + * \brief When src is a rsp with full rows, + * simply copy retained rows directly from cpu to gpu + * without invoking sparse_retain op. + */ + void CopyRetainedRowsToGPU(mshadow::Stream* cpu_stream, + mshadow::Stream* gpu_stream, + const NDArray& src, + const NDArray& indices, + NDArray* dst) { +#if MXNET_USE_CUDA == 1 + CHECK_EQ(src.storage_type(), kRowSparseStorage) + << "CopyRetainedRowsToGPU expects row-sparse src NDArray"; + CHECK_EQ(src.ctx().dev_mask(), Context::kCPU) + << "CopyRetainedRowsToGPU with src on gpu context not supported"; + CHECK_EQ(src.storage_shape()[0], src.shape()[0]) + << "CopyRetainedRowsToGPU only supports src rsp with full rows"; + CHECK_EQ(indices.storage_type(), kDefaultStorage); + CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU); + CHECK_EQ(dst->storage_type(), kRowSparseStorage); + CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU); + CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx)) + << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)"; + if (!src.storage_initialized() || indices.data().Size() == 0U) { + op::FillZerosRspImpl(gpu_stream, dst); + return; + } + using namespace mshadow; + + const TBlob& src_data = src.data(); + const TBlob& idx_data = indices.data(); + const size_t row_length = src.shape().ProdShape(1, src.shape().ndim()); + const size_t num_rows_retained = idx_data.Size(); + dst->CheckAndAlloc({Shape1(num_rows_retained)}); + TBlob dst_data = dst->data(); + TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx); + MSHADOW_TYPE_SWITCH(src.dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, { + // copy idx array + Tensor dst_idx_tensor = dst_idx_data.FlatTo1D(gpu_stream); + const Tensor idx_tensor = idx_data.FlatTo1D(cpu_stream); + Copy(dst_idx_tensor, idx_tensor, gpu_stream); + // copy src data + const Tensor src_data_tensor = src_data.get_with_shape( + Shape2(src_data.shape_[0], row_length), cpu_stream); + Tensor dst_data_tensor = dst_data.get_with_shape( + Shape2(dst_data.shape_[0], row_length), gpu_stream); + for (size_t i = 0; i < num_rows_retained; ++i) { + Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream); + } + }) + }) +#else + LOG(FATAL) << "GPU not enabled"; +#endif + } + // reduce sum into val[0] inline void ReduceSumCPU(const std::vector &in_data) { MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, { diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 3e9f600fc243..d8c399edf017 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -126,7 +126,7 @@ class KVStoreLocal : public KVStore { void PullRowSparse(const std::vector& keys, const std::vector>& val_rowids, - int priority = 0) { + int priority = 0) override { std::vector uniq_keys; std::vector>> grouped_val_rowids; GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); diff --git a/src/operator/tensor/sparse_retain-inl.h b/src/operator/tensor/sparse_retain-inl.h index c88e3d373d0c..5add57c83b24 100644 --- a/src/operator/tensor/sparse_retain-inl.h +++ b/src/operator/tensor/sparse_retain-inl.h @@ -135,8 +135,8 @@ struct SparseRetainRspThreadKernel { }; /*! - * \brief This kernel is invoked when the input row-sparse - * is actually dense. + * \brief This kernel should be invoked when the row indices + * to be retained are all in the input rsp. * Each thread searches for a subarray of indices of * the user-input idx array for retain. The first index * in the subarray will be searched for using binary search. @@ -198,6 +198,36 @@ struct SparseRetainRspRowBlockKernel { } }; +/*! + * Copy input indices to output indices. + * Only used when input rsp is dense. + */ +struct SparseRetainCopyIndices { + template + MSHADOW_XINLINE static void Map(int i, RType* out_idx, IType* idx) { + out_idx[i] = idx[i]; + } +}; + +/*! + * Copy input retained rows to output rows. + * Only used when input rsp is dense. + * This kernel is only used when ctx is on GPU. + * So it's parallelized by out_rows' elements, + * instead of rows. + * For CPU ctx, we simply call mshadow::Copy. + */ +struct SparseRetainCopyRetainedRowsFromDns { + template + MSHADOW_XINLINE static void Map(int i, DType* out_rows, const DType* in_rows, + const RType* in_row_idx, const IType* idx, + const size_t row_length) { + const size_t irow = i / row_length; + const size_t icol = i % row_length; + out_rows[i] = in_rows[static_cast(idx[irow]) * row_length + icol]; + } +}; + template void SparseRetainOpForwardRspImpl(mshadow::Stream *s, const NDArray& input_nd, @@ -205,6 +235,7 @@ void SparseRetainOpForwardRspImpl(mshadow::Stream *s, const OpReqType req, NDArray* output_nd) { if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "SparseRetainOpForwardRspImpl only support req = kWriteTo now"; CHECK_EQ(input_nd.storage_type(), kRowSparseStorage) << "SparseRetainOpForwardRspImpl operator only takes row sparse NDArray as input"; CHECK_EQ(output_nd->storage_type(), kRowSparseStorage) @@ -231,13 +262,33 @@ void SparseRetainOpForwardRspImpl(mshadow::Stream *s, MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type if (input_idx.Size() == input_nd.shape()[0]) { // input rsp is dense - int num_threads = get_num_threads(idx_data.Size()); - size_t seg_len = (idx_data.Size() + num_threads - 1) / num_threads; - Kernel::Launch(s, num_threads, - output_data.dptr(), output_idx.dptr(), input_data.dptr(), - input_idx.dptr(), idx_data.dptr(), idx_data.Size(), - input_data.shape_[0], row_length, seg_len); - } else { + using namespace mshadow; + // copy indices + Tensor output_idx_tensor = output_idx.FlatTo1D(s); + const size_t num_rows_retained = output_idx.Size(); + if (output_idx.type_flag_ == idx_data.type_flag_) { // same type, use Copy + const Tensor idx_tensor = idx_data.FlatTo1D(s); + Copy(output_idx_tensor, idx_tensor, s); + } else { // different index types, use Kernel::Launch + Kernel::Launch(s, num_rows_retained, + output_idx.dptr(), idx_data.dptr()); + } + // copy data + if (std::is_same::value) { // For cpu, we can access output_idx_tensor[i] + const Tensor input_tensor = + input_data.get_with_shape(Shape2(input_data.shape_[0], row_length), s); + Tensor output_tensor = + output_data.get_with_shape(Shape2(output_data.shape_[0], row_length), + s); + for (size_t i = 0; i < num_rows_retained; ++i) { + Copy(output_tensor[i], input_tensor[output_idx_tensor[i]], s); + } + } else { // For gpu, have to kernel launch + Kernel::Launch(s, output_data.Size(), + output_data.dptr(), input_data.dptr(), input_idx.dptr(), + idx_data.dptr(), row_length); + } + } else { // input rsp is not dense Kernel::Launch(s, idx_data.Size(), output_data.dptr(), output_idx.dptr(), input_data.dptr(), input_idx.dptr(), idx_data.dptr(), input_data.shape_[0], row_length); diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py new file mode 100644 index 000000000000..6d3ba989a84f --- /dev/null +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -0,0 +1,51 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np +from mxnet.test_utils import assert_almost_equal, default_context + +shape = (4, 4) +keys = [5, 7, 11] +str_keys = ['b', 'c', 'd'] + + +def init_kv_with_str(stype='default'): + """init kv """ + kv = mx.kv.create() + # single + kv.init('a', mx.nd.zeros(shape, stype=stype)) + # list + kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) + return kv + + +def test_row_sparse_pull(): + kv = init_kv_with_str('row_sparse') + kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) + + def check_row_sparse_pull(kv, count, ctx=default_context()): + num_rows = shape[0] + vals = [] + row_ids = [] + all_row_ids = np.arange(num_rows) + for i in range(count): + vals.append(mx.nd.zeros(shape, ctx=ctx).tostype('row_sparse')) + row_id = np.random.randint(num_rows, size=num_rows) + row_ids.append(mx.nd.array(row_id, dtype='int64')) + row_ids_to_pull = row_ids[0] if len(row_ids) == 1 else row_ids + vals_to_pull = vals[0] if len(vals) == 1 else vals + + kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull) + for val, row_id in zip(vals, row_ids): + retained = val.asnumpy() + excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy()) + for row in range(num_rows): + expected_val = np.zeros_like(retained[row]) + expected_val += 0 if row in excluded_row_ids else 1 + assert_almost_equal(retained[row], expected_val) + + check_row_sparse_pull(kv, 1, mx.gpu(0)) + check_row_sparse_pull(kv, 4, mx.gpu(0)) + + +if __name__ == '__main__': + test_row_sparse_pull() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 58ab53d026f5..748a89990cbd 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -187,7 +187,7 @@ def check_csr_slice(shape, slice_input): def test_sparse_retain(): - def check_sparse_retain(shape, density): + def check_sparse_retain(shape, density, index_type=np.int64): num_rows = shape[0] rsp, _ = rand_sparse_ndarray(shape=shape, stype='row_sparse', density=density) length = np.random.randint(1, num_rows + 1) @@ -197,7 +197,7 @@ def check_sparse_retain(shape, density): tensor_retained_expected = np.zeros(shape) for i in idx: tensor_retained_expected[i][:] = dns[i] - indices = mx.nd.array(idx) + indices = mx.nd.array(idx, dtype=index_type) rsp_retained = mx.nd.sparse_retain(rsp, indices=indices) assert same(tensor_retained_expected, rsp_retained.asnumpy()) @@ -211,9 +211,11 @@ def check_sparse_retain(shape, density): shape = rand_shape_2d() shape_3d = rand_shape_3d() densities = [0.01, 0.1, 0.2, 0.5, 0.8, 1.0] + index_types = [np.float32, np.int32, np.int64] for density in densities: - check_sparse_retain(shape, density) - check_sparse_retain(shape_3d, density) + for itype in index_types: + check_sparse_retain(shape, density, itype) + check_sparse_retain(shape_3d, density, itype) def test_sparse_nd_zeros():