diff --git a/benchmark/python/cast_storage.py b/benchmark/python/sparse/cast_storage.py similarity index 100% rename from benchmark/python/cast_storage.py rename to benchmark/python/sparse/cast_storage.py diff --git a/benchmark/python/dot.py b/benchmark/python/sparse/dot.py similarity index 100% rename from benchmark/python/dot.py rename to benchmark/python/sparse/dot.py diff --git a/benchmark/python/sparse_end2end.py b/benchmark/python/sparse/sparse_end2end.py similarity index 100% rename from benchmark/python/sparse_end2end.py rename to benchmark/python/sparse/sparse_end2end.py diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse/sparse_op.py similarity index 100% rename from benchmark/python/sparse_op.py rename to benchmark/python/sparse/sparse_op.py diff --git a/benchmark/python/util.py b/benchmark/python/sparse/util.py similarity index 100% rename from benchmark/python/util.py rename to benchmark/python/sparse/util.py diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 84759263007c..2af70e36e60a 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -234,14 +234,6 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - if not isinstance(out, (list, tuple)): - out = [out] - for val in out: - if not isinstance(val, (list, tuple)): - assert(val.stype == 'default') - else: - for v in val: - assert(v.stype == 'default') ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, @@ -270,8 +262,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): other pull actions. row_ids : NDArray or list of NDArray - The row_ids for which to pull for each value. The row_ids doesn't have to be unique - or sorted. + The row_ids for which to pull for each value. Each row_id is an 1D-NDArray \ + whose values don't have to be unique nor sorted. Examples -------- @@ -299,16 +291,6 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ assert(out is not None) assert(row_ids is not None) - if isinstance(row_ids, NDArray): - row_ids = [row_ids] - if not isinstance(out, (list, tuple)): - out = [out] - for val in out: - if not isinstance(val, (list, tuple)): - assert(val.stype == 'row_sparse') - else: - for v in val: - assert(v.stype == 'row_sparse') ckeys, cvals = _ctype_key_value(key, out) _, crow_ids = _ctype_key_value(key, row_ids) assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values" diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 38bb15484e7b..2444ca0dc59e 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -93,16 +93,6 @@ def _create_kvstore(kvstore, num_device, arg_params): return (kv, update_on_kvstore) -def _contains_non_default_storage(params): - if isinstance(params, (list, tuple)): - for param in params: - if param.stype != 'default': - return True - elif isinstance(params, NDArray): - return param.stype != 'default' - else: - return False - def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): @@ -110,12 +100,7 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_o kvstore.init(name, arg_params[name]) if update_on_kvstore: - if _contains_non_default_storage(param_on_devs): - # skip pulling row_sparse weights - warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ - 'sure to pull it with row_ids explicitly', RuntimeWarning) - else: - kvstore.pull(name, param_on_devs, priority=-idx) + kvstore.pull(name, param_on_devs, priority=-idx) def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" @@ -127,12 +112,7 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the weights - if _contains_non_default_storage(arg_list): - # skip pulling row_sparse weights - warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ - 'sure to pull it with row_ids', RuntimeWarning) - else: - kvstore.pull(name, arg_list, priority=-index) + kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): @@ -147,12 +127,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. - if _contains_non_default_storage(grad_list): - # skip pulling row_sparse weights - warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ - 'sure to pull it with row_ids', RuntimeWarning) - else: - kvstore.pull(name, grad_list, priority=-index) + kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 46f14b52b03b..c6f4bdabbcb8 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -29,7 +29,6 @@ import errno import logging from contextlib import contextmanager -import scipy.sparse as sp import numpy as np import numpy.testing as npt import numpy.random as rnd @@ -125,6 +124,7 @@ def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None): """ _validate_csr_generation_inputs(num_rows, num_cols, density, distribution="uniform") + from scipy import sparse as sp csr = sp.rand(num_rows, num_cols, density, dtype=dtype, format="csr") result = mx.nd.csr_matrix(csr.data, csr.indptr, csr.indices, (num_rows, num_cols), dtype=dtype) diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index b33c5081e7af..399754f5406d 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -111,7 +111,7 @@ class KVStoreDist : public KVStoreLocal { int priority) override { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -160,7 +160,7 @@ class KVStoreDist : public KVStoreLocal { const int priority = 0) { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -261,7 +261,7 @@ class KVStoreDist : public KVStoreLocal { // first aggregate the values over keys std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { // merge over devcies diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index d8c399edf017..11d4b644346e 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "./comm.h" @@ -85,7 +86,7 @@ class KVStoreLocal : public KVStore { int priority) override { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -114,7 +115,7 @@ class KVStoreLocal : public KVStore { int priority) override { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -129,7 +130,7 @@ class KVStoreLocal : public KVStore { int priority = 0) override { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& local = local_[key]; @@ -174,13 +175,75 @@ class KVStoreLocal : public KVStore { protected: /** - * \brief group values on keys + * \brief group values on keys for push */ - template + void GroupKVPairsPush(const std::vector& keys, + const std::vector& values, + std::vector *uniq_keys, + std::vector> *grouped_vals) { + // check if the storage type of a value is valid + auto validator = [this](const int key, const NDArray& nd) -> bool { + auto stype = nd.storage_type(); + // valid NDArray + if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; + // invalid NDArray, abort + LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; + return false; + }; + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + } + /** + * \brief group values on keys for pull + */ + void GroupKVPairsPull(const std::vector& keys, + const std::vector& values, + std::vector *uniq_keys, + std::vector> *grouped_vals) { + // check if the storage type of a value is valid + auto validator = [this](const int key, const NDArray* nd) -> bool { + // valid + if (nd->storage_type() == kDefaultStorage) return true; + // invalid, print warning messages once + if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { + LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " + << "Please make sure to use row_sparse_pull with row_ids instead."; + this->warnings_printed_.insert(key); + } + return false; + }; + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + } + /** + * \brief group values on keys for row_sparse_pull + */ + void GroupKVPairsPullRsp(const std::vector& keys, + const std::vector>& values, + std::vector *uniq_keys, + std::vector>> *grouped_vals) { + // check if the storage type of a value is valid + auto validator = [this](const int key, const std::pair& val_rowid) -> bool { + auto val_stype = val_rowid.first->storage_type(); + auto rowid_stype = val_rowid.second.storage_type(); + // check storage types + CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for " + << "row_sparse_pull values, but detected storage type " << val_stype; + CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for " + << "row_sparse_pull rowids, but detected storage type " << rowid_stype; + return true; + }; + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + } + + /** + * \brief group values on keys with validation. + * A value `v` is not included in the result if is_valid(v) returns false. + */ + template void GroupKVPairs(const std::vector& keys, const std::vector& values, std::vector* uniq_keys, - std::vector >* grouped_vals) { + std::vector >* grouped_vals, + const FValidate& is_valid) { CHECK_EQ(keys.size(), values.size()); // TODO(mli) check if already sorted as an optimization using Idx = std::pair; @@ -194,12 +257,14 @@ class KVStoreLocal : public KVStore { int pre_key = idx[0].first - 1; for (auto i : idx) { - if (i.first != pre_key) { - uniq_keys->push_back(i.first); - grouped_vals->push_back({values[i.second]}); - pre_key = i.first;; - } else { - grouped_vals->back().push_back(values[i.second]); + if (is_valid(i.first, values[i.second])) { + if (i.first != pre_key) { + uniq_keys->push_back(i.first); + grouped_vals->push_back({values[i.second]}); + pre_key = i.first; + } else { + grouped_vals->back().push_back(values[i.second]); + } } } } @@ -246,6 +311,8 @@ class KVStoreLocal : public KVStore { std::unordered_map str_key_dict_; /// the next available integer for string->int key mapping int next_str_key_ = 0; + /// whether printed warning due to mismatch stype in each key + std::unordered_set warnings_printed_; }; } // namespace kvstore } // namespace mxnet diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index c517da65de92..a43b98a635fb 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -52,7 +52,7 @@ def test_single_kv_pair(): def check_single_kv_pair(kv, key): kv.push(key, mx.nd.ones(shape)) val = mx.nd.empty(shape) - kv.pull(key, out = val) + kv.pull(key, out=val) check_diff_to_scalar(val, 1) check_single_kv_pair(init_kv(), 3) @@ -102,7 +102,7 @@ def test_list_kv_pair(): def check_list_kv_pair(kv, key): kv.push(key, [mx.nd.ones(shape)*4] * len(key)) val = [mx.nd.empty(shape)] * len(key) - kv.pull(key, out = val) + kv.pull(key, out=val) for v in val: check_diff_to_scalar(v, 4) @@ -122,7 +122,7 @@ def check_aggregator(kv, key, key_list): vals = [mx.nd.ones(shape, d) for d in devs] kv.push(key, vals) - kv.pull(key, out = vals) + kv.pull(key, out=vals) for v in vals: check_diff_to_scalar(v, num_devs) @@ -130,7 +130,7 @@ def check_aggregator(kv, key, key_list): # list vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list) kv.push(key_list, vals) - kv.pull(key_list, out = vals) + kv.pull(key_list, out=vals) for vv in vals: for v in vv: @@ -196,7 +196,7 @@ def check_updater(kv, key, key_list): vals = [mx.nd.ones(shape, d) for d in devs] kv.push(key, vals) - kv.pull(key, out = vals) + kv.pull(key, out=vals) for v in vals: check_diff_to_scalar(v, num_devs) @@ -208,7 +208,7 @@ def check_updater(kv, key, key_list): for i in range(num_push): kv.push(key_list, vals) - kv.pull(key_list, out = vals) + kv.pull(key_list, out=vals) for vv in vals: for v in vv: @@ -227,6 +227,43 @@ def test_get_type(): kv = mx.kv.create(kvtype) assert kv.type == kvtype +def test_invalid_pull(): + def check_invalid_single_kv_pair(kv, key): + dns_val = mx.nd.ones(shape) * 2 + rsp_val = dns_val.tostype('row_sparse') + kv.pull(key, out=rsp_val) + # pull should be ignored with no values updated + check_diff_to_scalar(rsp_val, 2) + try: + # row_sparse_pull should be aborted when vals.stype != row_sparse + kv.row_sparse_pull(key, out=dns_val, rowids=mx.nd.array([1])) + assert(False) + except: + pass + + def check_invalid_list_kv_pair(kv, key): + dns_val = [mx.nd.ones(shape) * 2] * len(key) + rsp_val = [val.tostype('row_sparse') for val in dns_val] + kv.pull(key, out=rsp_val) + for v in rsp_val: + # pull should be ignored with no values updated + check_diff_to_scalar(v, 2) + try: + # row_sparse_pull should be aborted when vals.stype != row_sparse + kv.row_sparse_pull(key, out=dns_val, rowids=[mx.nd.array([1])] * len(key)) + assert(False) + except: + pass + + int_kv = init_kv() + str_kv = init_kv_with_str() + + check_invalid_single_kv_pair(int_kv, 3) + check_invalid_single_kv_pair(str_kv, 'a') + + check_invalid_list_kv_pair(int_kv, keys) + check_invalid_list_kv_pair(str_kv, str_keys) + if __name__ == '__main__': test_init() test_get_type() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 9e8ace563e0d..003336d97ec4 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -519,9 +519,8 @@ def fm(factor_size, feature_dim, init): # initialize parameters by uniform random numbers mod.init_params(initializer=init) # use Sparse SGD with learning rate 0.1 to train - sgd = mx.optimizer.SGD(momentum=0.1, clip_gradient=5.0, learning_rate=0.01, - rescale_grad=1.0/batch_size) - mod.init_optimizer(optimizer=sgd) + adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001, rescale_grad=1.0/batch_size) + mod.init_optimizer(optimizer=adam) # use accuracy as the metric metric = mx.metric.create('MSE') # train 10 epoch