diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1b1c10e79fea..765af9144e77 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2389,6 +2389,52 @@ MXNET_DLL int MXKVStorePullRowSparseEx(KVStoreHandle handle, const NDArrayHandle* row_ids, int priority); +/*! + * \brief push and pull a list of (key, value) pairs from the kvstore + * \param handle handle to the kvstore + * \param vnum the number of key-value pairs corresponding to vkeys + * \param vkeys the list of keys for the values to be pushed + * \param onum the number of key-value pairs corresponding to okeys + * \param okeys the list of keys for the values to be pulled + * \param vals the list of values + * \param outs the list of outputs + * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle, + mx_uint vnum, + const int* vkeys, + mx_uint onum, + const int* okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority, + bool ignore_sparse); +/*! + * \brief push and pull a list of (key, value) pairs from the kvstore, + * where each key is a string + * \param handle handle to the kvstore + * \param vnum the number of key-value pairs corresponding to vkeys + * \param vkeys the list of keys for the values to be pushed + * \param onum the number of key-value pairs corresponding to okeys + * \param okeys the list of keys for the values to be pulled + * \param vals the list of values + * \param outs the list of outputs + * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle, + mx_uint vnum, + const char** vkeys, + mx_uint onum, + const char** okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority, + bool ignore_sparse); + /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index a73d96356132..88da5d610836 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -198,6 +198,23 @@ class KVStore { const std::vector& values, int priority = 0, bool ignore_sparse = true) = 0; + /** + * TODO: Comment + */ + virtual void PushPull(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority = 0, bool ignore_sparse = true) = 0; + + /** + * TODO: Comment + */ + virtual void PushPull(const std::vector& str_vkeys, + const std::vector& str_okeys, + const std::vector& values, + const std::vector& outs, + int priority = 0, bool ignore_sparse = true) = 0; /*! * \brief pull a list of key-value pairs from the store. * The NDArray pulled back will be in row_sparse storage with only the diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index a54817501391..2bc4fa79ec1b 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -311,6 +311,92 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True): cvals, ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) + def pushpull(self, key, value, out=None, priority=0, ignore_sparse=True): + """ Performs push and pull a single value or a sequence of values from the store. + + This function is coalesced form of push and pull operations. This function returns + immediately after adding an operator to the engine. Subsequent attempts to read + from the `out` variable will be blocked until the pull operation completes. + + `value` is pushed to the kvstore server for the specified keys and the updated + values are pulled from the server to `out`. If `out` is not specified the pulled + values are written to `value`. The returned values are guaranteed to be the latest + values in the store. + + pushpull with `RowSparseNDArray` is not supported for dist kvstore. + + Parameters + ---------- + key : str, int, or sequence of str or int + Keys. + + value : NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray, + or list of list of NDArray or RowSparseNDArray + Values corresponding to the keys. + + out: NDArray or list of NDArray or list of list of NDArray + Values corresponding to the keys. + + priority : int, optional + The priority of the pull operation. + Higher priority pull operations are likely to be executed before + other pull actions. + + ignore_sparse: bool, optional, default True + Whether to ignore sparse arrays in the request. + + Examples + -------- + >>> # push a single key-value pair + >>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a) + >>> print a.asnumpy() + [[ 8. 8. 8.] + [ 8. 8. 8.]] + + >>> # aggregate the value and the push + >>> gpus = [mx.gpu(i) for i in range(4)] + >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] + >>> kv.pushpull('3', b, out=a) + >>> print a.asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + + >>> # push a list of keys. + >>> # single device + >>> keys = ['4', '5', '6'] + >>> b = [mx.nd.zeros(shape)]*len(keys) + >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys), out=b) + >>> print b[1].asnumpy() + [[ 1. 1. 1.] + [ 1. 1. 1.]] + + >>> # multiple devices: + >>> keys = ['7', '8', '9'] + >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) + >>> kv.pushpull(keys, b) + >>> print b[1][1].asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + """ + + cvkeys, cvals, use_str_keys = _ctype_key_value(key, value) + if out is not None: + cokeys, couts, _ = _ctype_key_value(key, out) + else: + cokeys = cvkeys + couts = cvals + + if use_str_keys: + check_call(_LIB.MXKVStorePushPullEx( + self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys, + cvals, couts, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) + else: + check_call(_LIB.MXKVStorePushPull( + self.handle, mx_uint(len(cvkeys)), cvkeys, mx_uint(len(cokeys)), cokeys, + cvals, couts, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) + def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f5d72d53d2b7..dc7516b25d73 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -977,6 +977,60 @@ int MXKVStorePullEx(KVStoreHandle handle, API_END(); } +int MXKVStorePushPull(KVStoreHandle handle, + mx_uint vnum, + const int* vkeys, + mx_uint onum, + const int* okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority, + bool ignore_sparse) { + API_BEGIN(); + std::vector v_vkeys(vnum); + std::vector v_okeys(onum); + std::vector v_vals(vnum); + std::vector v_outs(onum); + for (mx_uint i = 0; i < vnum; ++i) { + v_vkeys[i] = vkeys[i]; + v_vals[i] = *static_cast(vals[i]); + } + for (mx_uint i = 0; i < onum; ++i) { + v_okeys[i] = okeys[i]; + v_outs[i] = static_cast(outs[i]); + } + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, + priority, ignore_sparse); + API_END(); +} + +int MXKVStorePushPullEx(KVStoreHandle handle, + mx_uint vnum, + const char** vkeys, + mx_uint onum, + const char** okeys, + NDArrayHandle* vals, + NDArrayHandle* outs, + int priority, + bool ignore_sparse) { + API_BEGIN(); + std::vector v_vkeys(vnum); + std::vector v_okeys(onum); + std::vector v_vals(vnum); + std::vector v_outs(onum); + for (mx_uint i = 0; i < vnum; ++i) { + v_vkeys[i] = vkeys[i]; + v_vals[i] = *static_cast(vals[i]); + } + for (mx_uint i = 0; i < onum; ++i) { + v_okeys[i] = okeys[i]; + v_outs[i] = static_cast(outs[i]); + } + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, + priority, ignore_sparse); + API_END(); +} + int MXKVStorePullWithSparse(KVStoreHandle handle, mx_uint num, const int* keys, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 9fe41c51a2c6..8ecb34b3ca8b 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -208,6 +208,84 @@ class KVStoreDist : public KVStoreLocal { } } + void PushPullImpl(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outputs, + int priority, + bool ignore_sparse) override { + CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False"; + + std::vector uniq_vkeys; + std::vector uniq_okeys; + std::vector> grouped_vals; + std::vector> grouped_outs; + + GroupKVPairsPush(vkeys, values, &uniq_vkeys, &grouped_vals, false); + GroupKVPairsPull(okeys, outputs, &uniq_okeys, &grouped_outs, true); + CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size()) + << "List of push and pull keys are different"; + + for (size_t i = 0; i < uniq_vkeys.size(); ++i) { + CHECK_EQ(uniq_vkeys[i], uniq_okeys[i]) + << "Mismatch in push and pull key"; + int key = uniq_vkeys[i]; + const auto& vals = grouped_vals[i]; + const auto& outs = grouped_outs[i]; + + NDArray merged = comm_->Reduce(key, vals, priority); + + const auto push_stype = merged.storage_type(); + const auto pull_stype = outs[0]->storage_type(); + CHECK_EQ(push_stype, kDefaultStorage) + << "Expected push_stype of value to be kDefaultStorage"; + CHECK_EQ(pull_stype, kDefaultStorage) + << "Expected pull_stype of value to be kDefaultStorage"; + + const int push_dtype = merged.dtype(); + const int pull_dtype = outs[0]->dtype(); + CHECK_EQ(push_dtype, pull_dtype) << "Output buffer dtype is different"; + + auto &comm_buf = comm_buf_[key]; + if (merged.ctx().dev_mask() == cpu::kDevMask) { + comm_buf = merged; // avoid memory copy + } else { + if (comm_buf.is_none()) { + comm_buf = NDArray(outs[0]->shape(), pinned_ctx_, true, pull_dtype); + } + CopyFromTo(merged, &comm_buf); + } + + CHECK(gradient_compression_->get_type() == CompressionType::kNone) + << "Compression not supported with PushPull"; + auto pushpull = [this, key, comm_buf]( + RunContext rctx, Engine::CallbackOnComplete cb) { + size_t size = comm_buf.shape().Size(); + const int dtype = comm_buf.dtype(); + const int num_bytes = mshadow::mshadow_sizeof(dtype); + const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + + PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); + char* data = static_cast(comm_buf.data().dptr_); + auto vals = new ps::SArray(data, size * num_bytes, false); + + CHECK_NOTNULL(ps_worker_)->ZPushPull( + pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); + }; + + CHECK_NOTNULL(Engine::Get())->PushAsync( + pushpull, + pinned_ctx_, + {}, + {comm_buf.var()}, + FnProperty::kNormal, + priority, + "KVStoreDistDefaultStoragePushPull"); + + comm_->Broadcast(key, comm_buf, outs, priority); + } + } + void PushImpl(const std::vector& keys, const std::vector& values, int priority) override { diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 0cb1a11e3fcc..f2c7e9824386 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -344,7 +344,8 @@ class KVStoreDistServer { } inline void ApplyUpdates(const DataHandleType type, const int key, - UpdateBuf *update_buf, ps::KVServer* server) { + const ps::KVPairs& req_data, UpdateBuf *update_buf, + ps::KVServer* server) { if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) { // let the main thread to execute updater_, which is necessary for python auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key]; @@ -364,7 +365,11 @@ class KVStoreDistServer { LOG(INFO) << "sent response to " << update_buf->request.size() << " workers"; } for (const auto& req : update_buf->request) { - server->Response(req); + if (req.pull) { + DefaultStorageResponse(type, key, req, req_data, server); + } else { + server->Response(req); + } } update_buf->request.clear(); if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); @@ -532,7 +537,7 @@ class KVStoreDistServer { true, merged_dtype); } // else nothing to aggregate updates.request.push_back(req_meta); - ApplyUpdates(type, master_key, &updates, server); + ApplyUpdates(type, master_key, req_data, &updates, server); } else { server->Response(req_meta); } @@ -570,7 +575,7 @@ class KVStoreDistServer { AccumulateRowSparseGrads(type, recved, &updates); } updates.request.push_back(req_meta); - ApplyUpdates(type, master_key, &updates, server); + ApplyUpdates(type, master_key, req_data, &updates, server); } } } else { @@ -588,14 +593,12 @@ class KVStoreDistServer { const NDArray& stored = store_[key]; CHECK(!stored.is_none()) << "init " << key << " first"; - // as server returns when store_realt is ready in this case - if (has_multi_precision_copy(type)) stored.WaitToRead(); - auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype()); response.keys = req_data.keys; response.lens = {len}; - // TODO(mli) try to remove this CopyFrom - response.vals.CopyFrom(static_cast(stored.data().dptr_), len); + stored.WaitToRead(); + auto data = static_cast(stored.data().dptr_); + response.vals.reset(data, len, [](char* data){}); server->Response(req_meta, response); } @@ -649,7 +652,7 @@ class KVStoreDistServer { merged.merged += decomp_buf; } merged.request.push_back(req_meta); - ApplyUpdates(type, key, &merged, server); + ApplyUpdates(type, key, req_data, &merged, server); } else { // async push gradient_compression_->Dequantize(recved, &decomp_buf, 0); @@ -732,7 +735,7 @@ class KVStoreDistServer { } } updates.request.push_back(req_meta); - ApplyUpdates(type, key, &updates, server); + ApplyUpdates(type, key, req_data, &updates, server); } } else { DefaultStorageResponse(type, key, req_meta, req_data, server); diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 4e004a3a3008..868b15191b01 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -129,6 +129,16 @@ class KVStoreLocal : public KVStore { PullImpl(keys, values, priority, ignore_sparse); } + void PushPull(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority, + bool ignore_sparse) override { + SetKeyType(kIntKey); + PushPullImpl(vkeys, okeys, values, outs, priority, ignore_sparse); + } + void PullRowSparse(const std::vector& keys, const std::vector>& val_rowids, int priority = 0) override { @@ -155,6 +165,20 @@ class KVStoreLocal : public KVStore { PullImpl(keys, values, priority, ignore_sparse); } + void PushPull(const std::vector& str_vkeys, + const std::vector& str_okeys, + const std::vector& values, + const std::vector& outs, + int priority, + bool ignore_sparse) override { + SetKeyType(kStringKey); + std::vector vkeys(str_vkeys.size()); + std::vector okeys(str_okeys.size()); + LookupKeys(str_vkeys, &vkeys); + LookupKeys(str_okeys, &okeys); + PushPullImpl(vkeys, okeys, values, outs, priority, ignore_sparse); + } + void PullRowSparse(const std::vector& str_keys, const std::vector>& val_rowids, int priority = 0) override { @@ -269,6 +293,16 @@ class KVStoreLocal : public KVStore { CHECK_EQ(key_type_, key_type) << "Mixed key types are not allowed"; } + virtual void PushPullImpl(const std::vector& vkeys, + const std::vector& okeys, + const std::vector& values, + const std::vector& outs, + int priority, + bool ignore_sparse) { + PushImpl(vkeys, values, priority); + PullImpl(okeys, outs, priority, ignore_sparse); + } + /** * \brief group values on keys for push */