Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve sparse pull performance for gluon trainer (#11429)
Browse files Browse the repository at this point in the history
* clip sparse grad. fix _reduce for rowsparse param

* fix kvstore init for local kv

* trigger

* pull with ignore sparse

* rsp pull with priority

* add doc;

* fix bug in sparse kvstore

* +kvstore test

* add dist kvstore test

* enhance dist kv test

* fix lint

* fix lint

* CR comments
  • Loading branch information
eric-haibin-lin authored Jul 9, 2018
1 parent 459a891 commit 266de6b
Show file tree
Hide file tree
Showing 20 changed files with 350 additions and 143 deletions.
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=invalid
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon
}

Expand Down
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,38 @@ MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
const char** keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse);
/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse);
/*!
* \brief pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ enum class FnProperty {
/*! \brief Asynchronous function call */
kAsync,
/*! \brief Delete variable call */
kDeleteVar
kDeleteVar,
/*! \brief Prioritized sync operation on GPU */
kGPUPrioritized
}; // enum class FnProperty

/*!
Expand Down
6 changes: 4 additions & 2 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,21 @@ class KVStore {
* \param keys the list of keys
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
* \param ignore_sparse whether to ignore sparse arrays in the request
*/
virtual void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
int priority = 0, bool ignore_sparse = true) = 0;
/*!
* \brief pull a list of key-value pairs from the store
* \param keys the list of keys in string format
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
* \param ignore_sparse whether to ignore sparse arrays in the request
*/
virtual void Pull(const std::vector<std::string>& str_keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
int priority = 0, bool ignore_sparse = true) = 0;

/*!
* \brief pull a list of key-value pairs from the store.
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,10 +1026,12 @@ void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0);
* \param from the ndarray we want to copy data from
* \param to the target ndarray
* \param priority Priority of the action.
* \param is_opr whether it is invoked by an operator. For example, false if invoked from
KVStore, true if invoked from `_copyto` operator.
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0);
void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false);

/*!
* \brief Perform elementwise sum over each data from source, store result into out.
Expand Down
49 changes: 35 additions & 14 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
"got %s."%(type(params)))
self._params = []
# parameters to initialize on the kvstore
self._contains_sparse = False
self._contains_sparse_weight = False
self._contains_sparse_grad = False
self._param2idx = {}
for i, param in enumerate(params):
if not isinstance(param, Parameter):
Expand All @@ -80,7 +81,9 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
self._params.append(param)
param._set_trainer(self)
if param._stype != 'default':
self._contains_sparse = True
self._contains_sparse_weight = True
if param._grad_stype != 'default':
self._contains_sparse_grad = True
self._compression_params = compression_params
optimizer_params = optimizer_params if optimizer_params else {}
self._scale = float(optimizer_params.get('rescale_grad', 1.0))
Expand Down Expand Up @@ -153,13 +156,31 @@ def _reset_kvstore(self):
def _init_kvstore(self):
"""Create kvstore."""
config = self._kvstore_params
if self._contains_sparse:
# if weight is sparse, the weight must be updated on KVStore.
# training loop contains:
# - row_sparse_pull(sparse_weight)
# - forward()
# - backward()
# - push(sparse_grad), push(dense_grad)
# - pull(dense_weight)
if self._contains_sparse_weight:
kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
# update_on_kvstore is set to False by the user
# raise Error if update_on_kvstore is set to False by the user
if config['update_on_kvstore'] is False:
raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
"gradients and/or sparse weights are present for "
"Parameter '%s'."%param.name)
raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights "
"are present.")
# if weight is dense and grad is sparse, the weight better not be updated on KVStore.
# training loop contains:
# - forward()
# - backward()
# - push(grad)
# - pull(grad)
# - update(grad, weight)
elif self._contains_sparse_grad:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays)
update_on_kvstore = False
# normal case
else:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
Expand All @@ -169,9 +190,9 @@ def _init_kvstore(self):
if kvstore:
if self._compression_params:
kvstore.set_gradient_compression(self._compression_params)
# kv.pull(row_sparse_grad) is not supported
if 'dist' in kvstore.type and not self._contains_sparse:
update_on_kvstore = False
if 'dist' in kvstore.type:
# kv.pull(row_sparse_grad) is not supported for dist kvstore
update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad
if update_on_kvstore:
# optimizer preferably needs to be set before init for multiprecision
kvstore.set_optimizer(self._optimizer)
Expand Down Expand Up @@ -211,8 +232,8 @@ def _row_sparse_pull(self, parameter, out, row_id):
self._init_kvstore()
if self._params_to_init:
self._init_params()
self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \
out=out, row_ids=row_id)
idx = self._param2idx[parameter.name]
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)

def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
Expand Down Expand Up @@ -272,7 +293,7 @@ def _allreduce_grads(self):
self._kvstore.push(i, param.list_grad(), priority=-i)

if not self._update_on_kvstore:
self._kvstore.pull(i, param.list_grad(), priority=-i)
self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False)

def update(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update.
Expand Down Expand Up @@ -327,7 +348,7 @@ def _update(self, ignore_stale_grad=False):
if self._kvstore and self._update_on_kvstore:
if param._stype == 'default':
# 'row_sparse' parameters are not pulled immediately - they're pulled
# in `SparseBlock.sparse_forward`
# in `Block.forward`
self._kvstore.pull(i, param.list_data(), priority=-i)
continue

Expand Down
19 changes: 12 additions & 7 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def push(self, key, value, priority=0):
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))


def pull(self, key, out=None, priority=0):
def pull(self, key, out=None, priority=0, ignore_sparse=True):
""" Pulls a single value or a sequence of values from the store.
This function returns immediately after adding an operator to the engine.
Expand All @@ -247,8 +247,8 @@ def pull(self, key, out=None, priority=0):
The returned values are guaranteed to be the latest values in the store.
For `RowSparseNDArray` values, this call is ignored,
please use ``row_sparse_pull`` instead.
pull with `RowSparseNDArray` is not supported for dist kvstore.
Please use ``row_sparse_pull`` instead.
Parameters
----------
Expand All @@ -263,6 +263,9 @@ def pull(self, key, out=None, priority=0):
Higher priority pull operations are likely to be executed before
other pull actions.
ignore_sparse: bool, optional, default True
Whether to ignore sparse arrays in the request.
Examples
--------
>>> # pull a single key-value pair
Expand Down Expand Up @@ -298,11 +301,13 @@ def pull(self, key, out=None, priority=0):
assert(out is not None)
ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
if use_str_keys:
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys,
cvals, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))
else:
check_call(_LIB.MXKVStorePull(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys,
cvals, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))

def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
Expand Down
46 changes: 40 additions & 6 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -869,23 +869,57 @@ int MXKVStorePull(KVStoreHandle handle,
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
API_END();
}

int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, true);
API_END();
}

int MXKVStorePullWithSparse(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}

int MXKVStorePullWithSparseEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}

Expand Down
Loading

0 comments on commit 266de6b

Please sign in to comment.