Skip to content

Commit

Permalink
changes based on code reviews (#176)
Browse files Browse the repository at this point in the history
* remove scipy dependency

* move kvstore checks to backned

* add const to lambda
  • Loading branch information
eric-haibin-lin authored Aug 21, 2017
1 parent b208b01 commit 7241bee
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 73 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
22 changes: 2 additions & 20 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ def pull(self, key, out=None, priority=0):
[ 2. 2. 2.]]
"""
assert(out is not None)
if not isinstance(out, (list, tuple)):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.stype == 'default')
else:
for v in val:
assert(v.stype == 'default')
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
Expand Down Expand Up @@ -270,8 +262,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
other pull actions.
row_ids : NDArray or list of NDArray
The row_ids for which to pull for each value. The row_ids doesn't have to be unique
or sorted.
The row_ids for which to pull for each value. Each row_id is an 1D-NDArray \
whose values don't have to be unique nor sorted.
Examples
--------
Expand Down Expand Up @@ -299,16 +291,6 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
"""
assert(out is not None)
assert(row_ids is not None)
if isinstance(row_ids, NDArray):
row_ids = [row_ids]
if not isinstance(out, (list, tuple)):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.stype == 'row_sparse')
else:
for v in val:
assert(v.stype == 'row_sparse')
ckeys, cvals = _ctype_key_value(key, out)
_, crow_ids = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"
Expand Down
31 changes: 3 additions & 28 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,14 @@ def _create_kvstore(kvstore, num_device, arg_params):

return (kv, update_on_kvstore)

def _contains_non_default_storage(params):
if isinstance(params, (list, tuple)):
for param in params:
if param.stype != 'default':
return True
elif isinstance(params, NDArray):
return param.stype != 'default'
else:
return False

def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
name = param_names[idx]
kvstore.init(name, arg_params[name])

if update_on_kvstore:
if _contains_non_default_storage(param_on_devs):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids explicitly', RuntimeWarning)
else:
kvstore.pull(name, param_on_devs, priority=-idx)
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
Expand All @@ -127,12 +112,7 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
if _contains_non_default_storage(arg_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
kvstore.pull(name, arg_list, priority=-index)
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None):
Expand All @@ -147,12 +127,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
if _contains_non_default_storage(grad_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
kvstore.pull(name, grad_list, priority=-index)
kvstore.pull(name, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import errno
import logging
from contextlib import contextmanager
import scipy.sparse as sp
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
Expand Down Expand Up @@ -125,6 +124,7 @@ def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
"""
_validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform")
from scipy import sparse as sp
csr = sp.rand(num_rows, num_cols, density, dtype=dtype, format="csr")
result = mx.nd.sparse.csr_matrix(csr.data, csr.indptr, csr.indices,
(num_rows, num_cols), dtype=dtype)
Expand Down
6 changes: 3 additions & 3 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class KVStoreDist : public KVStoreLocal {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -160,7 +160,7 @@ class KVStoreDist : public KVStoreLocal {
const int priority = 0) {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -261,7 +261,7 @@ class KVStoreDist : public KVStoreLocal {
// first aggregate the values over keys
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
// merge over devcies
Expand Down
91 changes: 79 additions & 12 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <vector>
#include <string>
#include <utility>
#include <functional>
#include <algorithm>
#include "./comm.h"

Expand Down Expand Up @@ -85,7 +86,7 @@ class KVStoreLocal : public KVStore {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -114,7 +115,7 @@ class KVStoreLocal : public KVStore {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand All @@ -129,7 +130,7 @@ class KVStoreLocal : public KVStore {
int priority = 0) override {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const NDArray& local = local_[key];
Expand Down Expand Up @@ -174,13 +175,75 @@ class KVStoreLocal : public KVStore {

protected:
/**
* \brief group values on keys
* \brief group values on keys for push
*/
template <typename V>
void GroupKVPairsPush(const std::vector<int>& keys,
const std::vector<NDArray>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray& nd) -> bool {
auto stype = nd.storage_type();
// valid NDArray
if (stype == kDefaultStorage || stype == kRowSparseStorage) return true;
// invalid NDArray, abort
LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype;
return false;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}
/**
* \brief group values on keys for pull
*/
void GroupKVPairsPull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray*>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray* nd) -> bool {
// valid
if (nd->storage_type() == kDefaultStorage) return true;
// invalid, print warning messages once
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
<< "Please make sure to use row_sparse_pull with row_ids instead.";
this->warnings_printed_.insert(key);
}
return false;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}
/**
* \brief group values on keys for row_sparse_pull
*/
void GroupKVPairsPullRsp(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<std::pair<NDArray*, NDArray>>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const std::pair<NDArray*, NDArray>& val_rowid) -> bool {
auto val_stype = val_rowid.first->storage_type();
auto rowid_stype = val_rowid.second.storage_type();
// check storage types
CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for "
<< "row_sparse_pull values, but detected storage type " << val_stype;
CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for "
<< "row_sparse_pull rowids, but detected storage type " << rowid_stype;
return true;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}

/**
* \brief group values on keys with validation.
* A value `v` is not included in the result if is_valid(v) returns false.
*/
template <typename V, typename FValidate>
void GroupKVPairs(const std::vector<int>& keys,
const std::vector<V>& values,
std::vector<int>* uniq_keys,
std::vector<std::vector<V> >* grouped_vals) {
std::vector<std::vector<V> >* grouped_vals,
const FValidate& is_valid) {
CHECK_EQ(keys.size(), values.size());
// TODO(mli) check if already sorted as an optimization
using Idx = std::pair<int, int>;
Expand All @@ -194,12 +257,14 @@ class KVStoreLocal : public KVStore {

int pre_key = idx[0].first - 1;
for (auto i : idx) {
if (i.first != pre_key) {
uniq_keys->push_back(i.first);
grouped_vals->push_back({values[i.second]});
pre_key = i.first;;
} else {
grouped_vals->back().push_back(values[i.second]);
if (is_valid(i.first, values[i.second])) {
if (i.first != pre_key) {
uniq_keys->push_back(i.first);
grouped_vals->push_back({values[i.second]});
pre_key = i.first;
} else {
grouped_vals->back().push_back(values[i.second]);
}
}
}
}
Expand Down Expand Up @@ -246,6 +311,8 @@ class KVStoreLocal : public KVStore {
std::unordered_map<std::string, int> str_key_dict_;
/// the next available integer for string->int key mapping
int next_str_key_ = 0;
/// whether printed warning due to mismatch stype in each key
std::unordered_set<int> warnings_printed_;
};
} // namespace kvstore
} // namespace mxnet
Expand Down
49 changes: 43 additions & 6 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_single_kv_pair():
def check_single_kv_pair(kv, key):
kv.push(key, mx.nd.ones(shape))
val = mx.nd.empty(shape)
kv.pull(key, out = val)
kv.pull(key, out=val)
check_diff_to_scalar(val, 1)

check_single_kv_pair(init_kv(), 3)
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_list_kv_pair():
def check_list_kv_pair(kv, key):
kv.push(key, [mx.nd.ones(shape)*4] * len(key))
val = [mx.nd.empty(shape)] * len(key)
kv.pull(key, out = val)
kv.pull(key, out=val)
for v in val:
check_diff_to_scalar(v, 4)

Expand All @@ -122,15 +122,15 @@ def check_aggregator(kv, key, key_list):
vals = [mx.nd.ones(shape, d) for d in devs]

kv.push(key, vals)
kv.pull(key, out = vals)
kv.pull(key, out=vals)

for v in vals:
check_diff_to_scalar(v, num_devs)

# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list)
kv.push(key_list, vals)
kv.pull(key_list, out = vals)
kv.pull(key_list, out=vals)

for vv in vals:
for v in vv:
Expand Down Expand Up @@ -196,7 +196,7 @@ def check_updater(kv, key, key_list):
vals = [mx.nd.ones(shape, d) for d in devs]

kv.push(key, vals)
kv.pull(key, out = vals)
kv.pull(key, out=vals)

for v in vals:
check_diff_to_scalar(v, num_devs)
Expand All @@ -208,7 +208,7 @@ def check_updater(kv, key, key_list):
for i in range(num_push):
kv.push(key_list, vals)

kv.pull(key_list, out = vals)
kv.pull(key_list, out=vals)

for vv in vals:
for v in vv:
Expand All @@ -227,6 +227,43 @@ def test_get_type():
kv = mx.kv.create(kvtype)
assert kv.type == kvtype

def test_invalid_pull():
def check_invalid_single_kv_pair(kv, key):
dns_val = mx.nd.ones(shape) * 2
rsp_val = dns_val.tostype('row_sparse')
kv.pull(key, out=rsp_val)
# pull should be ignored with no values updated
check_diff_to_scalar(rsp_val, 2)
try:
# row_sparse_pull should be aborted when vals.stype != row_sparse
kv.row_sparse_pull(key, out=dns_val, rowids=mx.nd.array([1]))
assert(False)
except:
pass

def check_invalid_list_kv_pair(kv, key):
dns_val = [mx.nd.ones(shape) * 2] * len(key)
rsp_val = [val.tostype('row_sparse') for val in dns_val]
kv.pull(key, out=rsp_val)
for v in rsp_val:
# pull should be ignored with no values updated
check_diff_to_scalar(v, 2)
try:
# row_sparse_pull should be aborted when vals.stype != row_sparse
kv.row_sparse_pull(key, out=dns_val, rowids=[mx.nd.array([1])] * len(key))
assert(False)
except:
pass

int_kv = init_kv()
str_kv = init_kv_with_str()

check_invalid_single_kv_pair(int_kv, 3)
check_invalid_single_kv_pair(str_kv, 'a')

check_invalid_list_kv_pair(int_kv, keys)
check_invalid_list_kv_pair(str_kv, str_keys)

if __name__ == '__main__':
test_init()
test_get_type()
Expand Down
5 changes: 2 additions & 3 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,8 @@ def fm(factor_size, feature_dim, init):
# initialize parameters by uniform random numbers
mod.init_params(initializer=init)
# use Sparse SGD with learning rate 0.1 to train
sgd = mx.optimizer.SGD(momentum=0.1, clip_gradient=5.0, learning_rate=0.01,
rescale_grad=1.0/batch_size)
mod.init_optimizer(optimizer=sgd)
adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001, rescale_grad=1.0/batch_size)
mod.init_optimizer(optimizer=adam)
# use accuracy as the metric
metric = mx.metric.create('MSE')
# train 10 epoch
Expand Down

0 comments on commit 7241bee

Please sign in to comment.