Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add PushPull API
Browse files Browse the repository at this point in the history
  • Loading branch information
Anand J committed Jul 21, 2019
1 parent 7b7d0b9 commit 3507312
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 11 deletions.
46 changes: 46 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,52 @@ MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle,
const NDArrayHandle* row_ids,
int priority);

/*!
* \brief push and pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
* \param vnum the number of key-value pairs corresponding to vkeys
* \param vkeys the list of keys for the values to be pushed
* \param onum the number of key-value pairs corresponding to okeys
* \param okeys the list of keys for the values to be pulled
* \param vals the list of values
* \param outs the list of outputs
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle,
mx_uint vnum,
const int* vkeys,
mx_uint onum,
const int* okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority,
bool ignore_sparse);
/*!
* \brief push and pull a list of (key, value) pairs from the kvstore,
* where each key is a string
* \param handle handle to the kvstore
* \param vnum the number of key-value pairs corresponding to vkeys
* \param vkeys the list of keys for the values to be pushed
* \param onum the number of key-value pairs corresponding to okeys
* \param okeys the list of keys for the values to be pulled
* \param vals the list of values
* \param outs the list of outputs
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle,
mx_uint vnum,
const char** vkeys,
mx_uint onum,
const char** okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority,
bool ignore_sparse);

/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
17 changes: 17 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,23 @@ class KVStore {
const std::vector<NDArray*>& values,
int priority = 0, bool ignore_sparse = true) = 0;

/**
* TODO: Comment
*/
virtual void PushPull(const std::vector<int>& vkeys,
const std::vector<int>& okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outs,
int priority = 0, bool ignore_sparse = true) = 0;

/**
* TODO: Comment
*/
virtual void PushPull(const std::vector<std::string>& str_vkeys,
const std::vector<std::string>& str_okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outs,
int priority = 0, bool ignore_sparse = true) = 0;
/*!
* \brief pull a list of key-value pairs from the store.
* The NDArray pulled back will be in row_sparse storage with only the
Expand Down
86 changes: 86 additions & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,92 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True):
cvals, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))

def pushpull(self, key, value, out=None, priority=0, ignore_sparse=True):
""" Performs push and pull a single value or a sequence of values from the store.
This function is coalesced form of push and pull operations. This function returns
immediately after adding an operator to the engine. Subsequent attempts to read
from the `out` variable will be blocked until the pull operation completes.
`value` is pushed to the kvstore server for the specified keys and the updated
values are pulled from the server to `out`. If `out` is not specified the pulled
values are written to `value`. The returned values are guaranteed to be the latest
values in the store.
pushpull with `RowSparseNDArray` is not supported for dist kvstore.
Parameters
----------
key : str, int, or sequence of str or int
Keys.
value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray,
or list of list of NDArray or RowSparseNDArray
Values corresponding to the keys.
out: NDArray or list of NDArray or list of list of NDArray
Values corresponding to the keys.
priority : int, optional
The priority of the pull operation.
Higher priority pull operations are likely to be executed before
other pull actions.
ignore_sparse: bool, optional, default True
Whether to ignore sparse arrays in the request.
Examples
--------
>>> # push a single key-value pair
>>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a)
>>> print a.asnumpy()
[[ 8. 8. 8.]
[ 8. 8. 8.]]
>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pushpull('3', b, out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
>>> # push a list of keys.
>>> # single device
>>> keys = ['4', '5', '6']
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys), out=b)
>>> print b[1].asnumpy()
[[ 1. 1. 1.]
[ 1. 1. 1.]]
>>> # multiple devices:
>>> keys = ['7', '8', '9']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pushpull(keys, b)
>>> print b[1][1].asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""

cvkeys, cvals, use_str_keys = _ctype_key_value(key, value)
if out is not None:
cokeys, couts, _ = _ctype_key_value(key, out)
else:
cokeys = cvkeys
couts = cvals

if use_str_keys:
check_call(_LIB.MXKVStorePushPullEx(
self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys,
cvals, couts, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))
else:
check_call(_LIB.MXKVStorePushPull(
self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys,
cvals, couts, ctypes.c_int(priority),
ctypes.c_bool(ignore_sparse)))

def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \
Expand Down
54 changes: 54 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,60 @@ int MXKVStorePullEx(KVStoreHandle handle,
API_END();
}

int MXKVStorePushPull(KVStoreHandle handle,
mx_uint vnum,
const int* vkeys,
mx_uint onum,
const int* okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<int> v_vkeys(vnum);
std::vector<int> v_okeys(onum);
std::vector<NDArray> v_vals(vnum);
std::vector<NDArray*> v_outs(onum);
for (mx_uint i = 0; i < vnum; ++i) {
v_vkeys[i] = vkeys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
for (mx_uint i = 0; i < onum; ++i) {
v_okeys[i] = okeys[i];
v_outs[i] = static_cast<NDArray*>(outs[i]);
}
static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs,
priority, ignore_sparse);
API_END();
}

int MXKVStorePushPullEx(KVStoreHandle handle,
mx_uint vnum,
const char** vkeys,
mx_uint onum,
const char** okeys,
NDArrayHandle* vals,
NDArrayHandle* outs,
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<std::string> v_vkeys(vnum);
std::vector<std::string> v_okeys(onum);
std::vector<NDArray> v_vals(vnum);
std::vector<NDArray*> v_outs(onum);
for (mx_uint i = 0; i < vnum; ++i) {
v_vkeys[i] = vkeys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
for (mx_uint i = 0; i < onum; ++i) {
v_okeys[i] = okeys[i];
v_outs[i] = static_cast<NDArray*>(outs[i]);
}
static_cast<KVStore*>(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs,
priority, ignore_sparse);
API_END();
}

int MXKVStorePullWithSparse(KVStoreHandle handle,
mx_uint num,
const int* keys,
Expand Down
78 changes: 78 additions & 0 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,84 @@ class KVStoreDist : public KVStoreLocal {
}
}

void PushPullImpl(const std::vector<int>& vkeys,
const std::vector<int>& okeys,
const std::vector<NDArray>& values,
const std::vector<NDArray*>& outputs,
int priority,
bool ignore_sparse) override {
CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";

std::vector<int> uniq_vkeys;
std::vector<int> uniq_okeys;
std::vector<std::vector<NDArray>> grouped_vals;
std::vector<std::vector<NDArray*>> grouped_outs;

GroupKVPairsPush(vkeys, values, &uniq_vkeys, &grouped_vals, false);
GroupKVPairsPull(okeys, outputs, &uniq_okeys, &grouped_outs, true);
CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size())
<< "List of push and pull keys are different";

for (size_t i = 0; i < uniq_vkeys.size(); ++i) {
CHECK_EQ(uniq_vkeys[i], uniq_okeys[i])
<< "Mismatch in push and pull key";
int key = uniq_vkeys[i];
const auto& vals = grouped_vals[i];
const auto& outs = grouped_outs[i];

NDArray merged = comm_->Reduce(key, vals, priority);

const auto push_stype = merged.storage_type();
const auto pull_stype = outs[0]->storage_type();
CHECK_EQ(push_stype, kDefaultStorage)
<< "Expected push_stype of value to be kDefaultStorage";
CHECK_EQ(pull_stype, kDefaultStorage)
<< "Expected pull_stype of value to be kDefaultStorage";

const int push_dtype = merged.dtype();
const int pull_dtype = outs[0]->dtype();
CHECK_EQ(push_dtype, pull_dtype) << "Output buffer dtype is different";

auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
comm_buf = merged; // avoid memory copy
} else {
if (comm_buf.is_none()) {
comm_buf = NDArray(outs[0]->shape(), pinned_ctx_, true, pull_dtype);
}
CopyFromTo(merged, &comm_buf);
}

CHECK(gradient_compression_->get_type() == CompressionType::kNone)
<< "Compression not supported with PushPull";
auto pushpull = [this, key, comm_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
size_t size = comm_buf.shape().Size();
const int dtype = comm_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);

PSKV& pskv = EncodeDefaultKey(key, size, num_bytes);
char* data = static_cast<char*>(comm_buf.data().dptr_);
auto vals = new ps::SArray<char>(data, size * num_bytes, false);

CHECK_NOTNULL(ps_worker_)->ZPushPull(
pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
};

CHECK_NOTNULL(Engine::Get())->PushAsync(
pushpull,
pinned_ctx_,
{},
{comm_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePushPull");

comm_->Broadcast(key, comm_buf, outs, priority);
}
}

void PushImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
Expand Down
25 changes: 14 additions & 11 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ class KVStoreDistServer {
}

inline void ApplyUpdates(const DataHandleType type, const int key,
UpdateBuf *update_buf, ps::KVServer<char>* server) {
const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
ps::KVServer<char>* server) {
if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
// let the main thread to execute updater_, which is necessary for python
auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key];
Expand All @@ -364,7 +365,11 @@ class KVStoreDistServer {
LOG(INFO) << "sent response to " << update_buf->request.size() << " workers";
}
for (const auto& req : update_buf->request) {
server->Response(req);
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
} else {
server->Response(req);
}
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
Expand Down Expand Up @@ -532,7 +537,7 @@ class KVStoreDistServer {
true, merged_dtype);
} // else nothing to aggregate
updates.request.push_back(req_meta);
ApplyUpdates(type, master_key, &updates, server);
ApplyUpdates(type, master_key, req_data, &updates, server);
} else {
server->Response(req_meta);
}
Expand Down Expand Up @@ -570,7 +575,7 @@ class KVStoreDistServer {
AccumulateRowSparseGrads(type, recved, &updates);
}
updates.request.push_back(req_meta);
ApplyUpdates(type, master_key, &updates, server);
ApplyUpdates(type, master_key, req_data, &updates, server);
}
}
} else {
Expand All @@ -588,14 +593,12 @@ class KVStoreDistServer {
const NDArray& stored = store_[key];
CHECK(!stored.is_none()) << "init " << key << " first";

// as server returns when store_realt is ready in this case
if (has_multi_precision_copy(type)) stored.WaitToRead();

auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
response.keys = req_data.keys;
response.lens = {len};
// TODO(mli) try to remove this CopyFrom
response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
stored.WaitToRead();
auto data = static_cast<char*>(stored.data().dptr_);
response.vals.reset(data, len, [](char* data){});
server->Response(req_meta, response);
}

Expand Down Expand Up @@ -649,7 +652,7 @@ class KVStoreDistServer {
merged.merged += decomp_buf;
}
merged.request.push_back(req_meta);
ApplyUpdates(type, key, &merged, server);
ApplyUpdates(type, key, req_data, &merged, server);
} else {
// async push
gradient_compression_->Dequantize(recved, &decomp_buf, 0);
Expand Down Expand Up @@ -732,7 +735,7 @@ class KVStoreDistServer {
}
}
updates.request.push_back(req_meta);
ApplyUpdates(type, key, &updates, server);
ApplyUpdates(type, key, req_data, &updates, server);
}
} else {
DefaultStorageResponse(type, key, req_meta, req_data, server);
Expand Down
Loading

0 comments on commit 3507312

Please sign in to comment.