From df2cb199582c0ed237219a62ac8ab5288380a126 Mon Sep 17 00:00:00 2001 From: Anand J Date: Sun, 8 Sep 2019 18:32:49 -0400 Subject: [PATCH] [MXNET-1294] Add KVSTORE PushPull API (#15559) * Switch to latest ps-lite * Add PushPull API * Add PushPull test cases --- 3rdparty/ps-lite | 2 +- include/mxnet/c_api.h | 42 ++++++++++++ include/mxnet/kvstore.h | 27 ++++++++ python/mxnet/kvstore.py | 81 +++++++++++++++++++++++ src/c_api/c_api.cc | 52 +++++++++++++++ src/kvstore/kvstore_dist.h | 75 +++++++++++++++++++++ src/kvstore/kvstore_dist_server.h | 22 ++++-- src/kvstore/kvstore_local.h | 31 +++++++++ tests/nightly/dist_device_sync_kvstore.py | 30 ++++++--- 9 files changed, 345 insertions(+), 17 deletions(-) diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite index 2c8ed25384fc..60b826e4422f 160000 --- a/3rdparty/ps-lite +++ b/3rdparty/ps-lite @@ -1 +1 @@ -Subproject commit 2c8ed25384fc0ea00a70f081f8308170146d8f25 +Subproject commit 60b826e4422fee0df00b892c66ffffea11e5da3f diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e5719cb36447..2fa2af5ebcf2 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2649,6 +2649,48 @@ MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle, const NDArrayHandle* row_ids, int priority); +/*! + * \brief push and pull a list of (key, value) pairs from the kvstore + * \param handle handle to the kvstore + * \param vnum the number of key-value pairs corresponding to vkeys + * \param vkeys the list of keys for the values to be pushed + * \param onum the number of key-value pairs corresponding to okeys + * \param okeys the list of keys for the values to be pulled + * \param vals the list of values + * \param outs the list of outputs + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle, + mx_uint vnum, + const int* vkeys, + mx_uint onum, + const int* okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority); +/*! + * \brief push and pull a list of (key, value) pairs from the kvstore, + * where each key is a string + * \param handle handle to the kvstore + * \param vnum the number of key-value pairs corresponding to vkeys + * \param vkeys the list of keys for the values to be pushed + * \param onum the number of key-value pairs corresponding to okeys + * \param okeys the list of keys for the values to be pulled + * \param vals the list of values + * \param outs the list of outputs + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle, + mx_uint vnum, + const char** vkeys, + mx_uint onum, + const char** okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority); + /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index a73d96356132..c5e8a8219989 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -198,6 +198,33 @@ class KVStore { const std::vector& values, int priority = 0, bool ignore_sparse = true) = 0; + /*! + * \brief push and pull a list of key-value pairs from the store + * \param vkeys the list of keys to be pushed + * \param okeys the list of keys to be pulled. Should be the same set of keys in vkeys. + * \param values the list of values to be pushed + * \param outs the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. + */ + virtual void PushPull(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority = 0) = 0; + + /*! + * \brief push and pull a list of key-value pairs from the store + * \param vkeys the list of keys to be pushed in string format + * \param okeys the list of keys to be pulled in string format. Should be the same set of keys in vkeys. + * \param values the list of values to be pushed + * \param outs the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. + */ + virtual void PushPull(const std::vector& str_vkeys, + const std::vector& str_okeys, + const std::vector& values, + const std::vector& outs, + int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store. * The NDArray pulled back will be in row_sparse storage with only the diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index a54817501391..5d332ff45ecb 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -311,6 +311,87 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True): cvals, ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) + def pushpull(self, key, value, out=None, priority=0): + """ Performs push and pull a single value or a sequence of values from the store. + + This function is coalesced form of push and pull operations. This function returns + immediately after adding an operator to the engine. Subsequent attempts to read + from the `out` variable will be blocked until the pull operation completes. + + `value` is pushed to the kvstore server for the specified keys and the updated + values are pulled from the server to `out`. If `out` is not specified the pulled + values are written to `value`. The returned values are guaranteed to be the latest + values in the store. + + pushpull with `RowSparseNDArray` is not supported for dist kvstore. + + Parameters + ---------- + key : str, int, or sequence of str or int + Keys. + + value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray, + or list of list of NDArray or RowSparseNDArray + Values corresponding to the keys. + + out: NDArray or list of NDArray or list of list of NDArray + Values corresponding to the keys. + + priority : int, optional + The priority of the pull operation. + Higher priority pull operations are likely to be executed before + other pull actions. + + Examples + -------- + >>> # push a single key-value pair + >>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a) + >>> print a.asnumpy() + [[ 8. 8. 8.] + [ 8. 8. 8.]] + + >>> # aggregate the value and the push + >>> gpus = [mx.gpu(i) for i in range(4)] + >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] + >>> kv.pushpull('3', b, out=a) + >>> print a.asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + + >>> # push a list of keys. + >>> # single device + >>> keys = ['4', '5', '6'] + >>> b = [mx.nd.zeros(shape)]*len(keys) + >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys), out=b) + >>> print b[1].asnumpy() + [[ 1. 1. 1.] + [ 1. 1. 1.]] + + >>> # multiple devices: + >>> keys = ['7', '8', '9'] + >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) + >>> kv.pushpull(keys, b) + >>> print b[1][1].asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + """ + + cvkeys, cvals, use_str_keys = _ctype_key_value(key, value) + if out is not None: + cokeys, couts, _ = _ctype_key_value(key, out) + else: + cokeys = cvkeys + couts = cvals + + if use_str_keys: + check_call(_LIB.MXKVStorePushPullEx( + self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys, + cvals, couts, ctypes.c_int(priority))) + else: + check_call(_LIB.MXKVStorePushPull( + self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys, + cvals, couts, ctypes.c_int(priority))) + def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f3e0ba8f5c26..70ce869b0ef7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1064,6 +1064,58 @@ int MXKVStorePullEx(KVStoreHandle handle, API_END(); } +int MXKVStorePushPull(KVStoreHandle handle, + mx_uint vnum, + const int* vkeys, + mx_uint onum, + const int* okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority) { + API_BEGIN(); + std::vector v_vkeys(vnum); + std::vector v_okeys(onum); + std::vector v_vals(vnum); + std::vector v_outs(onum); + for (mx_uint i = 0; i < vnum; ++i) { + v_vkeys[i] = vkeys[i]; + v_vals[i] = *static_cast(vals[i]); + } + for (mx_uint i = 0; i < onum; ++i) { + v_okeys[i] = okeys[i]; + v_outs[i] = static_cast(outs[i]); + } + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, + priority); + API_END(); +} + +int MXKVStorePushPullEx(KVStoreHandle handle, + mx_uint vnum, + const char** vkeys, + mx_uint onum, + const char** okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority) { + API_BEGIN(); + std::vector v_vkeys(vnum); + std::vector v_okeys(onum); + std::vector v_vals(vnum); + std::vector v_outs(onum); + for (mx_uint i = 0; i < vnum; ++i) { + v_vkeys[i] = vkeys[i]; + v_vals[i] = *static_cast(vals[i]); + } + for (mx_uint i = 0; i < onum; ++i) { + v_okeys[i] = okeys[i]; + v_outs[i] = static_cast(outs[i]); + } + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, + priority); + API_END(); +} + int MXKVStorePullWithSparse(KVStoreHandle handle, uint32_t num, const int* keys, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5ac28cb20c8d..f2df582db3b1 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -206,6 +206,81 @@ class KVStoreDist : public KVStoreLocal { } } + void PushPullImpl(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outputs, + int priority) override { + std::vector uniq_vkeys; + std::vector uniq_okeys; + std::vector> grouped_vals; + std::vector> grouped_outs; + + GroupKVPairsPush(vkeys, values, &uniq_vkeys, &grouped_vals, false); + GroupKVPairsPull(okeys, outputs, &uniq_okeys, &grouped_outs, true); + CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size()) + << "List of push and pull keys are different"; + + for (size_t i = 0; i < uniq_vkeys.size(); ++i) { + CHECK_EQ(uniq_vkeys[i], uniq_okeys[i]) + << "Mismatch in push and pull key"; + int key = uniq_vkeys[i]; + const auto& vals = grouped_vals[i]; + const auto& outs = grouped_outs[i]; + + NDArray merged = comm_->Reduce(key, vals, priority); + + const auto push_stype = merged.storage_type(); + const auto pull_stype = outs[0]->storage_type(); + CHECK_EQ(push_stype, kDefaultStorage) + << "Expected push_stype of value to be kDefaultStorage"; + CHECK_EQ(pull_stype, kDefaultStorage) + << "Expected pull_stype of value to be kDefaultStorage"; + + const int push_dtype = merged.dtype(); + const int pull_dtype = outs[0]->dtype(); + CHECK_EQ(push_dtype, pull_dtype) << "Output buffer dtype is different"; + + auto &comm_buf = comm_buf_[key]; + if (merged.ctx().dev_mask() == cpu::kDevMask) { + comm_buf = merged; // avoid memory copy + } else { + if (comm_buf.is_none()) { + comm_buf = NDArray(outs[0]->shape(), pinned_ctx_, true, pull_dtype); + } + CopyFromTo(merged, &comm_buf); + } + + CHECK(gradient_compression_->get_type() == CompressionType::kNone) + << "Compression not supported with PushPull"; + auto pushpull = [this, key, comm_buf]( + RunContext rctx, Engine::CallbackOnComplete cb) { + size_t size = comm_buf.shape().Size(); + const int dtype = comm_buf.dtype(); + const int num_bytes = mshadow::mshadow_sizeof(dtype); + const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + + PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); + char* data = static_cast(comm_buf.data().dptr_); + auto vals = new ps::SArray(data, size * num_bytes, false); + + CHECK_NOTNULL(ps_worker_)->ZPushPull( + pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); + }; + + CHECK_NOTNULL(Engine::Get())->PushAsync( + pushpull, + pinned_ctx_, + {}, + {comm_buf.var()}, + FnProperty::kNormal, + priority, + "KVStoreDistDefaultStoragePushPull"); + + comm_->Broadcast(key, comm_buf, outs, priority); + } + } + void PushImpl(const std::vector& keys, const std::vector& values, int priority) override { diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 0cb1a11e3fcc..65ded79743e4 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -344,7 +344,8 @@ class KVStoreDistServer { } inline void ApplyUpdates(const DataHandleType type, const int key, - UpdateBuf *update_buf, ps::KVServer* server) { + const ps::KVPairs& req_data, UpdateBuf *update_buf, + ps::KVServer* server) { if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) { // let the main thread to execute updater_, which is necessary for python auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key]; @@ -364,7 +365,16 @@ class KVStoreDistServer { LOG(INFO) << "sent response to " << update_buf->request.size() << " workers"; } for (const auto& req : update_buf->request) { - server->Response(req); + /** + * Request can be for either push, pull or pushpull + * If pull flag is set, respond immediately with the updated values + * Otherwise, only send the notification + */ + if (req.pull) { + DefaultStorageResponse(type, key, req, req_data, server); + } else { + server->Response(req); + } } update_buf->request.clear(); if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); @@ -532,7 +542,7 @@ class KVStoreDistServer { true, merged_dtype); } // else nothing to aggregate updates.request.push_back(req_meta); - ApplyUpdates(type, master_key, &updates, server); + ApplyUpdates(type, master_key, req_data, &updates, server); } else { server->Response(req_meta); } @@ -570,7 +580,7 @@ class KVStoreDistServer { AccumulateRowSparseGrads(type, recved, &updates); } updates.request.push_back(req_meta); - ApplyUpdates(type, master_key, &updates, server); + ApplyUpdates(type, master_key, req_data, &updates, server); } } } else { @@ -649,7 +659,7 @@ class KVStoreDistServer { merged.merged += decomp_buf; } merged.request.push_back(req_meta); - ApplyUpdates(type, key, &merged, server); + ApplyUpdates(type, key, req_data, &merged, server); } else { // async push gradient_compression_->Dequantize(recved, &decomp_buf, 0); @@ -732,7 +742,7 @@ class KVStoreDistServer { } } updates.request.push_back(req_meta); - ApplyUpdates(type, key, &updates, server); + ApplyUpdates(type, key, req_data, &updates, server); } } else { DefaultStorageResponse(type, key, req_meta, req_data, server); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 5c3c80838a36..ad70bc15ea0a 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -129,6 +129,15 @@ class KVStoreLocal : public KVStore { PullImpl(keys, values, priority, ignore_sparse); } + void PushPull(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority) override { + SetKeyType(kIntKey); + PushPullImpl(vkeys, okeys, values, outs, priority); + } + void PullRowSparse(const std::vector& keys, const std::vector>& val_rowids, int priority = 0) override { @@ -155,6 +164,19 @@ class KVStoreLocal : public KVStore { PullImpl(keys, values, priority, ignore_sparse); } + void PushPull(const std::vector& str_vkeys, + const std::vector& str_okeys, + const std::vector& values, + const std::vector& outs, + int priority) override { + SetKeyType(kStringKey); + std::vector vkeys(str_vkeys.size()); + std::vector okeys(str_okeys.size()); + LookupKeys(str_vkeys, &vkeys); + LookupKeys(str_okeys, &okeys); + PushPullImpl(vkeys, okeys, values, outs, priority); + } + void PullRowSparse(const std::vector& str_keys, const std::vector>& val_rowids, int priority = 0) override { @@ -269,6 +291,15 @@ class KVStoreLocal : public KVStore { CHECK_EQ(key_type_, key_type) << "Mixed key types are not allowed"; } + virtual void PushPullImpl(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority) { + PushImpl(vkeys, values, priority); + PullImpl(okeys, outs, priority, true); + } + /** * \brief group values on keys for push */ diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py index 7fd0333aea79..dc2c7bc35747 100644 --- a/tests/nightly/dist_device_sync_kvstore.py +++ b/tests/nightly/dist_device_sync_kvstore.py @@ -55,23 +55,33 @@ def init_kv(): def test_sync_push_pull(): kv, my_rank, nworker = init_kv() num_gpus = 2 - def check_default_keys(kv, my_rank, nworker): - nrepeat = 3 + def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False): # checks pull after push in loop, because behavior during # consecutive pushes doesn't offer any guarantees - for i in range(nrepeat): + for i in range(offset, nrepeat): scale = my_rank + 1 - kv.push('3', [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]) - kv.push('99', [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]) num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1 + + arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] val = mx.nd.zeros(shape) - kv.pull('3', out=val) + if use_pushpull: + kv.pushpull('3', arr, out=val) + else: + kv.push('3', arr) + kv.pull('3', out=val) check_diff_to_scalar(val, num) - val2 = mx.nd.zeros(big_shape) - kv.pull('99', out=val2) - check_diff_to_scalar(val2, num) - check_default_keys(kv, my_rank, nworker) + big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] + big_val = mx.nd.zeros(big_shape) + if use_pushpull: + kv.pushpull('99', big_arr, out=big_val) + else: + kv.push('99', big_arr) + kv.pull('99', out=big_val) + check_diff_to_scalar(big_val, num) + + check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False) + check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True) print('worker ' + str(my_rank) + ' is done') def test_sync_init():