Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1294] Add KVSTORE PushPull API (#15559)
Browse files Browse the repository at this point in the history
* Switch to latest ps-lite

* Add PushPull API

* Add PushPull test cases
  • Loading branch information
anandj91 authored and eric-haibin-lin committed Sep 8, 2019
1 parent 5806200 commit 1c67928
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 17 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/ps-lite
42 changes: 42 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,48 @@ MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
const NDArrayHandle* row_ids,
int priority);

/*!
* \brief push and pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
* \param vnum the number of key-value pairs corresponding to vkeys
* \param vkeys the list of keys for the values to be pushed
* \param onum the number of key-value pairs corresponding to okeys
* \param okeys the list of keys for the values to be pulled
* \param vals the list of values
* \param outs the list of outputs
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle,
mx_uint vnum,
const int* vkeys,
mx_uint onum,
const int* okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority);
/*!
* \brief push and pull a list of (key, value) pairs from the kvstore,
* where each key is a string
* \param handle handle to the kvstore
* \param vnum the number of key-value pairs corresponding to vkeys
* \param vkeys the list of keys for the values to be pushed
* \param onum the number of key-value pairs corresponding to okeys
* \param okeys the list of keys for the values to be pulled
* \param vals the list of values
* \param outs the list of outputs
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle,
mx_uint vnum,
const char** vkeys,
mx_uint onum,
const char** okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority);

/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
27 changes: 27 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,33 @@ class KVStore {
const std::vector<NDArray*>& values,
int priority = 0, bool ignore_sparse = true) = 0;

/*!
* \brief push and pull a list of key-value pairs from the store
* \param vkeys the list of keys to be pushed
* \param okeys the list of keys to be pulled. Should be the same set of keys in vkeys.
* \param values the list of values to be pushed
* \param outs the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
*/
virtual void PushPull(const std::vector<int>& vkeys,
const std::vector<int>& okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outs,
int priority = 0) = 0;

/*!
* \brief push and pull a list of key-value pairs from the store
* \param vkeys the list of keys to be pushed in string format
* \param okeys the list of keys to be pulled in string format. Should be the same set of keys in vkeys.
* \param values the list of values to be pushed
* \param outs the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
*/
virtual void PushPull(const std::vector<std::string>& str_vkeys,
const std::vector<std::string>& str_okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outs,
int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store.
* The NDArray pulled back will be in row_sparse storage with only the
Expand Down
81 changes: 81 additions & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,87 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True):
cvals, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))

def pushpull(self, key, value, out=None, priority=0):
""" Performs push and pull a single value or a sequence of values from the store.
This function is coalesced form of push and pull operations. This function returns
immediately after adding an operator to the engine. Subsequent attempts to read
from the `out` variable will be blocked until the pull operation completes.
`value` is pushed to the kvstore server for the specified keys and the updated
values are pulled from the server to `out`. If `out` is not specified the pulled
values are written to `value`. The returned values are guaranteed to be the latest
values in the store.
pushpull with `RowSparseNDArray` is not supported for dist kvstore.
Parameters
----------
key : str, int, or sequence of str or int
Keys.
value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray,
or list of list of NDArray or RowSparseNDArray
Values corresponding to the keys.
out: NDArray or list of NDArray or list of list of NDArray
Values corresponding to the keys.
priority : int, optional
The priority of the pull operation.
Higher priority pull operations are likely to be executed before
other pull actions.
Examples
--------
>>> # push a single key-value pair
>>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a)
>>> print a.asnumpy()
[[ 8. 8. 8.]
[ 8. 8. 8.]]
>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pushpull('3', b, out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
>>> # push a list of keys.
>>> # single device
>>> keys = ['4', '5', '6']
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys), out=b)
>>> print b[1].asnumpy()
[[ 1. 1. 1.]
[ 1. 1. 1.]]
>>> # multiple devices:
>>> keys = ['7', '8', '9']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pushpull(keys, b)
>>> print b[1][1].asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""

cvkeys, cvals, use_str_keys = _ctype_key_value(key, value)
if out is not None:
cokeys, couts, _ = _ctype_key_value(key, out)
else:
cokeys = cvkeys
couts = cvals

if use_str_keys:
check_call(_LIB.MXKVStorePushPullEx(
self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys,
cvals, couts, ctypes.c_int(priority)))
else:
check_call(_LIB.MXKVStorePushPull(
self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys,
cvals, couts, ctypes.c_int(priority)))

def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \
Expand Down
52 changes: 52 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,58 @@ int MXKVStorePullEx(KVStoreHandle handle,
API_END();
}

int MXKVStorePushPull(KVStoreHandle handle,
mx_uint vnum,
const int* vkeys,
mx_uint onum,
const int* okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority) {
API_BEGIN();
std::vector<int> v_vkeys(vnum);
std::vector<int> v_okeys(onum);
std::vector<NDArray> v_vals(vnum);
std::vector<NDArray*> v_outs(onum);
for (mx_uint i = 0; i < vnum; ++i) {
v_vkeys[i] = vkeys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
for (mx_uint i = 0; i < onum; ++i) {
v_okeys[i] = okeys[i];
v_outs[i] = static_cast<NDArray*>(outs[i]);
}
static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs,
priority);
API_END();
}

int MXKVStorePushPullEx(KVStoreHandle handle,
mx_uint vnum,
const char** vkeys,
mx_uint onum,
const char** okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority) {
API_BEGIN();
std::vector<std::string> v_vkeys(vnum);
std::vector<std::string> v_okeys(onum);
std::vector<NDArray> v_vals(vnum);
std::vector<NDArray*> v_outs(onum);
for (mx_uint i = 0; i < vnum; ++i) {
v_vkeys[i] = vkeys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
for (mx_uint i = 0; i < onum; ++i) {
v_okeys[i] = okeys[i];
v_outs[i] = static_cast<NDArray*>(outs[i]);
}
static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs,
priority);
API_END();
}

int MXKVStorePullWithSparse(KVStoreHandle handle,
uint32_t num,
const int* keys,
Expand Down
75 changes: 75 additions & 0 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,81 @@ class KVStoreDist : public KVStoreLocal {
}
}

void PushPullImpl(const std::vector<int>& vkeys,
const std::vector<int>& okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outputs,
int priority) override {
std::vector<int> uniq_vkeys;
std::vector<int> uniq_okeys;
std::vector<std::vector<NDArray>> grouped_vals;
std::vector<std::vector<NDArray*>> grouped_outs;

GroupKVPairsPush(vkeys, values, &uniq_vkeys, &grouped_vals, false);
GroupKVPairsPull(okeys, outputs, &uniq_okeys, &grouped_outs, true);
CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size())
<< "List of push and pull keys are different";

for (size_t i = 0; i < uniq_vkeys.size(); ++i) {
CHECK_EQ(uniq_vkeys[i], uniq_okeys[i])
<< "Mismatch in push and pull key";
int key = uniq_vkeys[i];
const auto& vals = grouped_vals[i];
const auto& outs = grouped_outs[i];

NDArray merged = comm_->Reduce(key, vals, priority);

const auto push_stype = merged.storage_type();
const auto pull_stype = outs[0]->storage_type();
CHECK_EQ(push_stype, kDefaultStorage)
<< "Expected push_stype of value to be kDefaultStorage";
CHECK_EQ(pull_stype, kDefaultStorage)
<< "Expected pull_stype of value to be kDefaultStorage";

const int push_dtype = merged.dtype();
const int pull_dtype = outs[0]->dtype();
CHECK_EQ(push_dtype, pull_dtype) << "Output buffer dtype is different";

auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
comm_buf = merged; // avoid memory copy
} else {
if (comm_buf.is_none()) {
comm_buf = NDArray(outs[0]->shape(), pinned_ctx_, true, pull_dtype);
}
CopyFromTo(merged, &comm_buf);
}

CHECK(gradient_compression_->get_type() == CompressionType::kNone)
<< "Compression not supported with PushPull";
auto pushpull = [this, key, comm_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
size_t size = comm_buf.shape().Size();
const int dtype = comm_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);

PSKV& pskv = EncodeDefaultKey(key, size, num_bytes);
char* data = static_cast<char*>(comm_buf.data().dptr_);
auto vals = new ps::SArray<char>(data, size * num_bytes, false);

CHECK_NOTNULL(ps_worker_)->ZPushPull(
pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
};

CHECK_NOTNULL(Engine::Get())->PushAsync(
pushpull,
pinned_ctx_,
{},
{comm_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePushPull");

comm_->Broadcast(key, comm_buf, outs, priority);
}
}

void PushImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
Expand Down
22 changes: 16 additions & 6 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ class KVStoreDistServer {
}

inline void ApplyUpdates(const DataHandleType type, const int key,
UpdateBuf *update_buf, ps::KVServer<char>* server) {
const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
ps::KVServer<char>* server) {
if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
// let the main thread to execute updater_, which is necessary for python
auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key];
Expand All @@ -364,7 +365,16 @@ class KVStoreDistServer {
LOG(INFO) << "sent response to " << update_buf->request.size() << " workers";
}
for (const auto& req : update_buf->request) {
server->Response(req);
/**
* Request can be for either push, pull or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
} else {
server->Response(req);
}
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
Expand Down Expand Up @@ -532,7 +542,7 @@ class KVStoreDistServer {
true, merged_dtype);
} // else nothing to aggregate
updates.request.push_back(req_meta);
ApplyUpdates(type, master_key, &updates, server);
ApplyUpdates(type, master_key, req_data, &updates, server);
} else {
server->Response(req_meta);
}
Expand Down Expand Up @@ -570,7 +580,7 @@ class KVStoreDistServer {
AccumulateRowSparseGrads(type, recved, &updates);
}
updates.request.push_back(req_meta);
ApplyUpdates(type, master_key, &updates, server);
ApplyUpdates(type, master_key, req_data, &updates, server);
}
}
} else {
Expand Down Expand Up @@ -649,7 +659,7 @@ class KVStoreDistServer {
merged.merged += decomp_buf;
}
merged.request.push_back(req_meta);
ApplyUpdates(type, key, &merged, server);
ApplyUpdates(type, key, req_data, &merged, server);
} else {
// async push
gradient_compression_->Dequantize(recved, &decomp_buf, 0);
Expand Down Expand Up @@ -732,7 +742,7 @@ class KVStoreDistServer {
}
}
updates.request.push_back(req_meta);
ApplyUpdates(type, key, &updates, server);
ApplyUpdates(type, key, req_data, &updates, server);
}
} else {
DefaultStorageResponse(type, key, req_meta, req_data, server);
Expand Down
Loading

0 comments on commit 1c67928

Please sign in to comment.