From d8f5a80722570cde705d8726e3398a085b0465c4 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 11 Dec 2019 16:37:05 +0800 Subject: [PATCH 01/79] add new abstraction for transport --- src/rdma_van.h | 89 +++++++++++++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index 9c63c195..f6c5aac8 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -489,6 +489,60 @@ struct AsyncCopy { bool shutdown; }; +class Transport { + Transport(); + ~Transport(); + + void SendPushResponse(Message &msg) { + + } + + void SendPullRequest(Message &msg) { + + } + + void Recv(Message &msg); + void SendPushRequest(); + void SendPullResponse(); +}; + + +class RDMATransport : public Transport { + RDMATransport() { + + } + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + + } + + void Recv(Message &msg) override { + + } +}; + +class IPCTransport : public Transport { + IPCTransport() { + + } + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + + } + + void Recv(Message &msg) override { + + } +}; + class RDMAVan : public Van { public: RDMAVan() { @@ -507,11 +561,6 @@ class RDMAVan : public Van { if (is_server_) LOG(INFO) << "This is server"; else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); - val = Environment::Get()->find("ENABLE_RDMA_LOG"); - enable_rdma_log_ = val? atoi(val) : false; - if (enable_rdma_log_) LOG(INFO) << "Enable RDMA logging"; - else LOG(INFO) << "You can enable RDMA logging with ENABLE_RDMA_LOG=1"; - val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); disable_ipc_ = val ? !atoi(val) : true; if (disable_ipc_) LOG(INFO) << "Shared memory IPC has been disabled"; @@ -836,6 +885,7 @@ class RDMAVan : public Van { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); + // register RDMA memory for (auto& sa : msg.data) { if (sa.size()) { std::lock_guard lock(map_mu_); @@ -868,14 +918,6 @@ class RDMAVan : public Van { msg.meta.addr = reinterpret_cast(vals.data()); // vals address msg.meta.val_len = vals.size(); msg.meta.option = memory_mr_map[vals.data()]->rkey; - - if (enable_rdma_log_) { - LOG(INFO) << "send push key=" << key - << ", val_len=" << msg.meta.val_len - << ", recver=" << msg.meta.recver - << ", val_addr=" << msg.meta.addr - << ", rkey=" << msg.meta.option; - } } } if (!msg.meta.push && !msg.meta.request) { // server, pull response @@ -894,13 +936,6 @@ class RDMAVan : public Van { msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - - if (enable_rdma_log_) { - LOG(INFO) << "send pull response key=" << key - << ", val_len=" << msg.meta.val_len - << ", val_addr=" << msg.meta.addr - << ", rkey=" << msg.meta.option; - } } } @@ -1151,25 +1186,11 @@ class RDMAVan : public Van { std::lock_guard lock(map_mu_); if (key_meta_map_.find(key) == key_meta_map_.end() || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { - if (enable_rdma_log_) { - LOG(INFO) << "(init) recv key=" << key - << ", len=" << len - << ", sender=" << msg->meta.sender - << ", val_addr=" << addr - << ", rkey=" << rkey; - } key_meta_map_[key][sender] = std::make_tuple(len, addr, rkey); } else { CHECK_EQ(len, std::get<0>(key_meta_map_[key][sender])); CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); - - if (enable_rdma_log_) { - LOG(INFO) << "recv push key=" << key - << ", len=" << len - << ", val_addr=" << addr - << ", rkey=" << rkey; - } } } From e125ad7a495a1905e01250edc89d40d08a8acde1 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Wed, 11 Dec 2019 22:03:24 +0800 Subject: [PATCH 02/79] remove redundant PackMetaPB --- include/ps/internal/van.h | 5 ----- src/rdma_van.h | 11 ++++++----- src/van.cc | 35 ----------------------------------- 3 files changed, 6 insertions(+), 45 deletions(-) diff --git a/include/ps/internal/van.h b/include/ps/internal/van.h index e33bf4f5..45455845 100644 --- a/include/ps/internal/van.h +++ b/include/ps/internal/van.h @@ -111,11 +111,6 @@ class Van { */ void PackMeta(const Meta &meta, char **meta_buf, int *buf_size); - /** - * \brief pack meta into protobuf - */ - void PackMetaPB(const Meta &meta, PBMeta *pb); - /** * \brief unpack meta from a string */ diff --git a/src/rdma_van.h b/src/rdma_van.h index 9c63c195..1c2ba75a 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -904,13 +904,14 @@ class RDMAVan : public Van { } } - PBMeta meta; - PackMetaPB(msg.meta, &meta); + int meta_len; + char* meta_buf = nullptr; + PackMeta(msg.meta, &meta_buf, &meta_size); + CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); - size_t meta_len = meta.ByteSize(); size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); @@ -919,7 +920,7 @@ class RDMAVan : public Van { if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); - meta.SerializeToArray(msg_buf->inline_buf, meta_len); + memcpy(msg_buf->inline_buf, meta_buf, meta_len); char *cur = msg_buf->inline_buf + meta_len; for (auto &sa : msg.data) { size_t seg_len = sa.size(); @@ -930,7 +931,7 @@ class RDMAVan : public Van { msg_buf->inline_len = meta_len; msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; - meta.SerializeToArray(msg_buf->inline_buf, meta_len); + memcpy(msg_buf->inline_buf, meta_buf, meta_len); if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers for (auto &sa : msg_buf->data) { if (sa.size()) { diff --git a/src/van.cc b/src/van.cc index 64f004ed..483ac636 100644 --- a/src/van.cc +++ b/src/van.cc @@ -494,41 +494,6 @@ void Van::Receiving() { } } -void Van::PackMetaPB(const Meta &meta, PBMeta *pb) { - pb->set_head(meta.head); - if (meta.app_id != Meta::kEmpty) pb->set_app_id(meta.app_id); - if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); - if (meta.body.size()) pb->set_body(meta.body); - pb->set_push(meta.push); - pb->set_request(meta.request); - pb->set_simple_app(meta.simple_app); - pb->set_customer_id(meta.customer_id); - for (auto d : meta.data_type) pb->add_data_type(d); - if (!meta.control.empty()) { - auto ctrl = pb->mutable_control(); - ctrl->set_cmd(meta.control.cmd); - if (meta.control.cmd == Control::BARRIER) { - ctrl->set_barrier_group(meta.control.barrier_group); - } else if (meta.control.cmd == Control::ACK) { - ctrl->set_msg_sig(meta.control.msg_sig); - } - for (const auto &n : meta.control.node) { - auto p = ctrl->add_node(); - p->set_id(n.id); - p->set_role(n.role); - p->set_port(n.port); - p->set_hostname(n.hostname); - p->set_is_recovery(n.is_recovery); - p->set_customer_id(n.customer_id); - } - } - pb->set_data_size(meta.data_size); - pb->set_key(meta.key); - pb->set_addr(meta.addr); - pb->set_val_len(meta.val_len); - pb->set_option(meta.option); -} - void Van::PackMeta(const Meta &meta, char **meta_buf, int *buf_size) { // convert into protobuf PBMeta pb; From 2d145c920d1eb1ea470ed8b7c61ab2f6f08b681a Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 11 Dec 2019 22:04:05 +0800 Subject: [PATCH 03/79] wip: refactoring --- src/rdma_van.h | 189 ++++++++++++++++++++++++++----------------------- 1 file changed, 101 insertions(+), 88 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index f6c5aac8..76818e9d 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -490,9 +490,9 @@ struct AsyncCopy { }; class Transport { + public: Transport(); ~Transport(); - void SendPushResponse(Message &msg) { } @@ -501,17 +501,24 @@ class Transport { } - void Recv(Message &msg); + void SendControlMessage(Message &msg) { + + } + + void Recv(Message *msg); void SendPushRequest(); void SendPullResponse(); }; class RDMATransport : public Transport { + public: RDMATransport() { } + ~RDMATransport(); + void SendPushRequest(Message &msg) override { } @@ -520,16 +527,44 @@ class RDMATransport : public Transport { } - void Recv(Message &msg) override { + // get remote address using SEND/RECV + void GetRemoteAddr() { + + } + + // register RDMA memory + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (!sa.size()) continue; + CHECK(sa.data()); + std::lock_guard lock(map_mu_); + if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + memory_mr_map_[sa.data()] = temp_mr; + } + } + } + + void Recv(Message *msg) override { } + + private: + }; class IPCTransport : public Transport { + public: IPCTransport() { } + ~IPCTransport(); + void SendPushRequest(Message &msg) override { } @@ -538,7 +573,7 @@ class IPCTransport : public Transport { } - void Recv(Message &msg) override { + void Recv(Message *msg) override { } }; @@ -620,11 +655,7 @@ class RDMAVan : public Van { PS_VLOG(1) << "Clearing mempool."; mempool_.reset(); - auto map_iter = memory_mr_map.begin(); - while (map_iter != memory_mr_map.end()) { - ibv_dereg_mr(map_iter->second); - map_iter++; - } + for (auto& it : memory_mr_map_) ibv_dereg_mr(it.second); PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); @@ -860,7 +891,7 @@ class RDMAVan : public Van { if (m.shutdown) break; if (m.len == 0) continue; - // TODO: use parallel copy + // TODO: use parallel copy CHECK(m.dst); CHECK(m.src); memcpy(m.dst, m.src, m.len); @@ -885,82 +916,22 @@ class RDMAVan : public Van { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); - // register RDMA memory - for (auto& sa : msg.data) { - if (sa.size()) { - std::lock_guard lock(map_mu_); - auto search_map_iterator = memory_mr_map.find(sa.data()); - if (search_map_iterator == memory_mr_map.end()) { - struct ibv_mr *temp_mr; - CHECK(sa.data()) << "address empty"; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " - << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map[sa.data()] = temp_mr; - } - } - } - - // init for inplace push_pull - if (IsValidPushpull(msg)) { - if (!is_server_) { // worker - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; - - if (msg.meta.push && msg.meta.request) { // push request - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map.find(msg.data[1].data()), memory_mr_map.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map[vals.data()]->rkey; - } - } - if (!msg.meta.push && !msg.meta.request) { // server, pull response - CHECK(is_server_); - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - } - } - PBMeta meta; PackMetaPB(msg.meta, &meta); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); - + size_t meta_len = meta.ByteSize(); size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); // prepare memory - if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message + if (!IsValidPushpull(msg)) { // simple_app or control message msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); meta.SerializeToArray(msg_buf->inline_buf, meta_len); - char *cur = msg_buf->inline_buf + meta_len; - for (auto &sa : msg.data) { - size_t seg_len = sa.size(); - memcpy(cur, sa.data(), seg_len); - cur += seg_len; - } } else { // data message msg_buf->inline_len = meta_len; msg_buf->inline_buf = mempool_->Alloc(meta_len); @@ -968,20 +939,58 @@ class RDMAVan : public Van { meta.SerializeToArray(msg_buf->inline_buf, meta_len); if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers for (auto &sa : msg_buf->data) { - if (sa.size()) { - auto search_map_iterator = memory_mr_map.find(sa.data()); - CHECK_NE(search_map_iterator, memory_mr_map.end()) << "not registered memory region"; - MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); - } + if (!sa.size()) continue; + auto search_map_iterator = memory_mr_map_.find(sa.data()); + CHECK_NE(search_map_iterator, memory_mr_map_.end()) << "not registered memory region"; + MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } } } - // server send pull response (vals) with RDMA-write / IPC - if (is_server_ && IsValidPushpull(msg) && - !msg.meta.push && !msg.meta.request) { + auto trans = reinterpret_cast(is_local_[remote_id] ? ipc_trans_ : rdma_trans_); + if (!IsValidPushpull(msg)) { // control message + msg_buf->inline_len = total_len; + msg_buf->inline_buf = mempool_->Alloc(total_len); + meta.SerializeToArray(msg_buf->inline_buf, meta_len); + trans->SendControlMessage(); + } else if (msg.meta.push && msg.meta.request) { // worker, push request + std::lock_guard lock(map_mu_); + uint64_t key = DecodeKey(msg.data[0]); + msg.meta.key = key; + + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = memory_mr_map_[vals.data()]->rkey; + + trans->SendPushRequest(); + } else if (msg.meta.push && !msg.meta.request) { // server, push response + trans->SendPushResponse(); + } else if (!msg.meta.push && msg.meta.request) { // worker, pull request + trans->SendPullRequest(); + } else if (!msg.meta.push && !msg.meta.request) { // server, pull response + CHECK(is_server_); + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + + std::lock_guard lock(map_mu_); + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) + << "key=" << key << " not inited in key_meta_map"; + CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + + msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); + msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); + msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + + // RDMA write or IPC std::lock_guard lock(map_mu_); auto key = msg.meta.key; auto recver = msg.meta.recver; @@ -1002,7 +1011,6 @@ class RDMAVan : public Van { auto addr = (void*) msg_buf->data[1].data(); CHECK(addr); void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; auto cnt = cpy_counter_.fetch_add(1); @@ -1013,8 +1021,8 @@ class RDMAVan : public Van { auto raddr = std::get<1>(key_meta_map_[key][recver]); auto rkey = std::get<2>(key_meta_map_[key][recver]); - auto temp_mr = memory_mr_map.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map.end()); + auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, memory_mr_map_.end()); struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->data[1].data()); @@ -1037,6 +1045,10 @@ class RDMAVan : public Van { CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) << "ibv_post_send failed."; } + + trans->SendPullResponse(); + } else { + CHECK(0) << "unexpected message type"; } WRContext *context = nullptr, *reserved = nullptr; @@ -1055,8 +1067,6 @@ class RDMAVan : public Van { if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { // local IPC with shared memory req->data_num = 0; - auto key = DecodeKey(msg.data[0]); - CHECK_EQ(key, msg.meta.key); } else { // normal RDMA req->data_num = msg.data.size(); @@ -1577,6 +1587,9 @@ class RDMAVan : public Van { AddressPool addr_pool_; std::unique_ptr mempool_; + std::unique_ptr rdma_trans_; + std::unique_ptr ipc_trans_; + struct rdma_cm_id *listener_ = nullptr; std::atomic should_stop_; @@ -1586,7 +1599,7 @@ class RDMAVan : public Van { struct rdma_event_channel *event_channel_ = nullptr; struct ibv_context *context_ = nullptr; - std::unordered_map memory_mr_map; + std::unordered_map memory_mr_map_; // ibverbs protection domain struct ibv_pd *pd_ = nullptr; From a60d46d94e75086fd2a420bc04a518dddd6a0c4f Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Wed, 11 Dec 2019 22:21:18 +0800 Subject: [PATCH 04/79] Fix makefile and compilation error --- Makefile | 4 ++-- src/rdma_van.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 3dc9b566..e83f671d 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,10 @@ endif INCPATH = -I./src -I./include -I$(DEPS_PATH)/include CFLAGS = -std=c++14 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -LIBS = -pthread +LIBS = -pthread -lrt ifeq ($(USE_RDMA), 1) -LIBS += -lrdmacm -libverbs -lrt +LIBS += -lrdmacm -libverbs CFLAGS += -DDMLC_USE_RDMA endif diff --git a/src/rdma_van.h b/src/rdma_van.h index 1c2ba75a..9e4d9d36 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -906,7 +906,7 @@ class RDMAVan : public Van { int meta_len; char* meta_buf = nullptr; - PackMeta(msg.meta, &meta_buf, &meta_size); + PackMeta(msg.meta, &meta_buf, &meta_len); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); From c06df604d7cd1258538b1b4f84aa1ff48cb524f0 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 12 Dec 2019 12:30:09 +0800 Subject: [PATCH 05/79] wip: auto transport type --- src/rdma_van.h | 128 ++++------------------------------------------- src/transport.h | 130 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 117 deletions(-) create mode 100644 src/transport.h diff --git a/src/rdma_van.h b/src/rdma_van.h index 76818e9d..329fdfdc 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -9,33 +9,7 @@ #ifdef DMLC_USE_RDMA -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ps/internal/threadsafe_queue.h" -#include "ps/internal/van.h" +#include "transport.h" namespace ps { @@ -350,6 +324,7 @@ struct Endpoint { std::condition_variable cv; std::mutex connect_mu; struct rdma_cm_id *cm_id; + std::unique_ptr tran; WRContext rx_ctx[kRxDepth]; @@ -397,10 +372,15 @@ struct Endpoint { CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); } + void SetTransport(std::unique_ptr t) { tran = t; } + + std::unique_ptr GetTransport() { return tran; } + void Disconnect() { std::unique_lock lk(connect_mu); CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); cv.wait(lk, [this] { return status == IDLE; }); + tran.reset(); } void SetNodeID(int id) { node_id = id; } @@ -489,95 +469,6 @@ struct AsyncCopy { bool shutdown; }; -class Transport { - public: - Transport(); - ~Transport(); - void SendPushResponse(Message &msg) { - - } - - void SendPullRequest(Message &msg) { - - } - - void SendControlMessage(Message &msg) { - - } - - void Recv(Message *msg); - void SendPushRequest(); - void SendPullResponse(); -}; - - -class RDMATransport : public Transport { - public: - RDMATransport() { - - } - - ~RDMATransport(); - - void SendPushRequest(Message &msg) override { - - } - - void SendPullResponse(Message &msg) override { - - } - - // get remote address using SEND/RECV - void GetRemoteAddr() { - - } - - // register RDMA memory - void RegisterMemory(Message &msg) { - for (auto& sa : msg.data) { - if (!sa.size()) continue; - CHECK(sa.data()); - std::lock_guard lock(map_mu_); - if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { - struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map_[sa.data()] = temp_mr; - } - } - } - - void Recv(Message *msg) override { - - } - - private: - -}; - -class IPCTransport : public Transport { - public: - IPCTransport() { - - } - - ~IPCTransport(); - - void SendPushRequest(Message &msg) override { - - } - - void SendPullResponse(Message &msg) override { - - } - - void Recv(Message *msg) override { - - } -}; - class RDMAVan : public Van { public: RDMAVan() { @@ -748,6 +639,9 @@ class RDMAVan : public Van { endpoint->SetNodeID(node.id); + Transport* tran = is_local_[node.id] ? std::make_unique() : std::make_unique(); + endpoint->SetTransport(tran); + struct addrinfo *remote_addr; CHECK_EQ( getaddrinfo(node.hostname.c_str(), std::to_string(node.port).c_str(), @@ -949,7 +843,7 @@ class RDMAVan : public Van { } } - auto trans = reinterpret_cast(is_local_[remote_id] ? ipc_trans_ : rdma_trans_); + auto trans = endpoint.GetTransport(); if (!IsValidPushpull(msg)) { // control message msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); diff --git a/src/transport.h b/src/transport.h new file mode 100644 index 00000000..ca81faef --- /dev/null +++ b/src/transport.h @@ -0,0 +1,130 @@ +#ifndef PS_RDMA_VAN_H_ +#define PS_RDMA_VAN_H_ + +#ifdef DMLC_USE_RDMA + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/internal/threadsafe_queue.h" +#include "ps/internal/van.h" + +namespace ps { + +class Transport { + public: + Transport(); + ~Transport(); + void SendPushResponse(Message &msg) { + + } + + void SendPullRequest(Message &msg) { + + } + + void SendControlMessage(Message &msg) { + + } + + void Recv(Message *msg); + void SendPushRequest(); + void SendPullResponse(); +}; // class Transport + + +class RDMATransport : public Transport { + public: + RDMATransport() { + + } + + ~RDMATransport(); + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + + } + + // get remote address using SEND/RECV + void GetRemoteAddr() { + + } + + // register RDMA memory + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (!sa.size()) continue; + CHECK(sa.data()); + std::lock_guard lock(map_mu_); + if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + memory_mr_map_[sa.data()] = temp_mr; + } + } + } + + void Recv(Message *msg) override { + + } + + private: + +}; // class RDMATransport + +class IPCTransport : public Transport { + public: + IPCTransport() { + + } + + ~IPCTransport(); + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + + } + + void Recv(Message *msg) override { + + } +}; // class IPCTransport + + + +} // namespace ps + +#endif // DMLC_USE_RDMA +#endif // PS_RDMA_VAN_H_ \ No newline at end of file From ee533df11b88f7fab21113e6ecc54dd8cced33e5 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 12 Dec 2019 17:34:39 +0800 Subject: [PATCH 06/79] wip: simplify SendMsg --- src/rdma_van.h | 183 +++++++++--------------------------------------- src/transport.h | 126 +++++++++++++++++++++++++++++---- 2 files changed, 147 insertions(+), 162 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index 329fdfdc..e3780acb 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -639,8 +639,8 @@ class RDMAVan : public Van { endpoint->SetNodeID(node.id); - Transport* tran = is_local_[node.id] ? std::make_unique() : std::make_unique(); - endpoint->SetTransport(tran); + Transport* t = is_local_[node.id] ? std::make_unique() : std::make_unique(); + endpoint->SetTransport(t); struct addrinfo *remote_addr; CHECK_EQ( @@ -809,166 +809,53 @@ class RDMAVan : public Van { int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); + CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); + Endpoint *endpoint = endpoints_[remote_id].get(); PBMeta meta; PackMetaPB(msg.meta, &meta); - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); - Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); size_t meta_len = meta.ByteSize(); - size_t data_len = msg.meta.data_size; - size_t total_len = meta_len + data_len; - CHECK(meta_len); - - // prepare memory - if (!IsValidPushpull(msg)) { // simple_app or control message - msg_buf->inline_len = total_len; - msg_buf->inline_buf = mempool_->Alloc(total_len); - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - } else { // data message - msg_buf->inline_len = meta_len; - msg_buf->inline_buf = mempool_->Alloc(meta_len); - msg_buf->data = msg.data; - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers - for (auto &sa : msg_buf->data) { - if (!sa.size()) continue; - auto search_map_iterator = memory_mr_map_.find(sa.data()); - CHECK_NE(search_map_iterator, memory_mr_map_.end()) << "not registered memory region"; - MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); - } + size_t total_len = meta_len + msg.meta.data_size; + + msg_buf->inline_len = meta_len; + msg_buf->inline_buf = mempool_->Alloc(meta_len); + meta.SerializeToArray(msg_buf->inline_buf, meta_len); + msg_buf->data = msg.data; + + // prepare memory + if (!is_server_ && !is_local_[remote_id]) { + for (auto &sa : msg_buf->data) { + if (!sa.size()) continue; + auto it = memory_mr_map_.find(sa.data()); + CHECK_NE(it, memory_mr_map_.end()) << "not registered memory region"; + MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } } auto trans = endpoint.GetTransport(); - if (!IsValidPushpull(msg)) { // control message - msg_buf->inline_len = total_len; - msg_buf->inline_buf = mempool_->Alloc(total_len); - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - trans->SendControlMessage(); - } else if (msg.meta.push && msg.meta.request) { // worker, push request - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; - - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map_[vals.data()]->rkey; - - trans->SendPushRequest(); - } else if (msg.meta.push && !msg.meta.request) { // server, push response - trans->SendPushResponse(); - } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - trans->SendPullRequest(); - } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - CHECK(is_server_); - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - - // RDMA write or IPC - std::lock_guard lock(map_mu_); - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - - CHECK_EQ(msg_buf->data.size(), 3) << "Actual msg_buf size is " << msg_buf->data.size(); - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not initiated"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key - << ", recver=" << recver - << " not initiated"; - CHECK_EQ(msg_buf->data[1].size(), (unsigned int) len) - << msg_buf->data[1].size() << ", " << len; - - if (is_local_[remote_id]) { - // IPC - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy - AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); - return total_len; - } else { - // RDMA write - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); - - auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map_.end()); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(raddr); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - wr.wr.rdma.remote_addr = raddr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } - - trans->SendPullResponse(); + if (!IsValidPushpull(msg)) { + // control message + trans->SendControlMessage(endpoint, msg_buf); + } else if (msg.meta.push && msg.meta.request) { + // worker, push request + trans->SendPushRequest(endpoint, msg_buf); + } else if (msg.meta.push && !msg.meta.request) { + // server, push response + trans->SendPushResponse(endpoint, msg_buf); + } else if (!msg.meta.push && msg.meta.request) { + // worker, pull request + trans->SendPullRequest(endpoint, msg_buf); + } else if (!msg.meta.push && !msg.meta.request) { + // server, pull response + trans->SendPullResponse(endpoint, msg_buf); } else { CHECK(0) << "unexpected message type"; } - WRContext *context = nullptr, *reserved = nullptr; - endpoint->free_write_ctx.WaitAndPop(&reserved); - endpoint->free_start_ctx.WaitAndPop(&context); - - msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = meta_len; - req->origin_addr = reinterpret_cast(msg_buf); - - auto addr = reinterpret_cast(req); - - // rendezvous message, not data message - if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { - // local IPC with shared memory - req->data_num = 0; - } else { - // normal RDMA - req->data_num = msg.data.size(); - for (size_t i = 0; i < req->data_num; ++i) { - req->data_len[i] = msg.data[i].size(); - } - } - SendRendezvousBegin(endpoint, addr, context, kRendezvousStart); return total_len; } diff --git a/src/transport.h b/src/transport.h index ca81faef..77f14849 100644 --- a/src/transport.h +++ b/src/transport.h @@ -35,8 +35,12 @@ namespace ps { class Transport { public: - Transport(); + explicit Transport(Endpoint *endpoint) { + endpoint_ = endpoint; + }; + ~Transport(); + void SendPushResponse(Message &msg) { } @@ -46,29 +50,110 @@ class Transport { } void SendControlMessage(Message &msg) { + PBMeta meta; + PackMetaPB(msg.meta, &meta); + MessageBuffer *msg_buf = new MessageBuffer(); + + size_t meta_len = meta.ByteSize(); + size_t total_len = meta_len + msg.meta.data_size; + CHECK(meta_len); + + msg_buf->inline_len = total_len; + msg_buf->inline_buf = mempool_->Alloc(total_len); + meta.SerializeToArray(msg_buf->inline_buf, meta_len); + WRContext *context = nullptr, *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + endpoint_->free_start_ctx.WaitAndPop(&context); + + msg_buf->reserved_context = reserved; + RendezvousStart *req = + reinterpret_cast(context->buffer->addr); + req->meta_len = meta_len; + req->origin_addr = reinterpret_cast(msg_buf); + + auto addr = reinterpret_cast(req); + + // rendezvous message, not data message + if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { + // local IPC with shared memory + req->data_num = 0; + } else { + // normal RDMA + req->data_num = msg.data.size(); + for (size_t i = 0; i < req->data_num; ++i) { + req->data_len[i] = msg.data[i].size(); + } + } + SendRendezvousBegin(endpoint_, addr, context, kRendezvousStart); } void Recv(Message *msg); void SendPushRequest(); void SendPullResponse(); + + Endpoint* endpoint_; }; // class Transport + class RDMATransport : public Transport { public: - RDMATransport() { - - } - - ~RDMATransport(); - void SendPushRequest(Message &msg) override { + std::lock_guard lock(map_mu_); + uint64_t key = DecodeKey(msg.data[0]); + msg.meta.key = key; + + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = memory_mr_map_[vals.data()]->rkey; } void SendPullResponse(Message &msg) override { - + std::lock_guard lock(map_mu_); + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) + << "key=" << key << " not inited in key_meta_map"; + CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + + msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); + msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); + msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + + // RDMA write + auto raddr = std::get<1>(key_meta_map_[key][recver]); + auto rkey = std::get<2>(key_meta_map_[key][recver]); + + auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, memory_mr_map_.end()); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->data[1].data()); + sge.length = msg_buf->data[1].size(); + sge.lkey = temp_mr->second->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(raddr); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + // wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + wr.wr.rdma.remote_addr = raddr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; } // get remote address using SEND/RECV @@ -101,20 +186,33 @@ class RDMATransport : public Transport { }; // class RDMATransport -class IPCTransport : public Transport { - public: - IPCTransport() { - } - ~IPCTransport(); + + + + +class IPCTransport : public Transport { + public: void SendPushRequest(Message &msg) override { } void SendPullResponse(Message &msg) override { - + std::lock_guard lock(map_mu_); + auto key = msg.meta.key; + auto recver = msg.meta.recver; + auto len = std::get<0>(key_meta_map_[key][recver]); + + // IPC + auto addr = (void*) msg_buf->data[1].data(); + CHECK(addr); + void* shm_addr = GetSharedMemory(kShmPrefix, key); + // async copy + AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); } void Recv(Message *msg) override { From f300a2ed3759352747fbb48496d8e0a87a0f9205 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Thu, 12 Dec 2019 18:19:15 +0800 Subject: [PATCH 07/79] Use raw serializer instead of pb --- Makefile | 8 +- include/ps/internal/van.h | 1 - make/deps.mk | 11 --- make/ps.mk | 4 +- src/meta.h | 76 ++++++++++++++++++ src/meta.proto | 64 --------------- src/van.cc | 163 +++++++++++++++++++++----------------- 7 files changed, 170 insertions(+), 157 deletions(-) create mode 100644 src/meta.h delete mode 100644 src/meta.proto diff --git a/Makefile b/Makefile index e83f671d..9853008b 100644 --- a/Makefile +++ b/Makefile @@ -38,25 +38,21 @@ include make/deps.mk clean: rm -rf build $(TEST) tests/*.d tests/*.dSYM - find src -name "*.pb.[ch]*" -delete lint: python tests/lint.py ps all include/ps src ps: build/libps.a -OBJS = $(addprefix build/, customer.o postoffice.o van.o meta.pb.o) +OBJS = $(addprefix build/, customer.o postoffice.o van.o) build/libps.a: $(OBJS) ar crv $@ $(filter %.o, $?) -build/%.o: src/%.cc ${ZMQ} src/meta.pb.h +build/%.o: src/%.cc ${ZMQ} @mkdir -p $(@D) $(CXX) $(INCPATH) -std=c++0x -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@ -src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF} - $(PROTOC) --cpp_out=./src --proto_path=./src $< - -include build/*.d -include build/*/*.d diff --git a/include/ps/internal/van.h b/include/ps/internal/van.h index 45455845..67f55bc9 100644 --- a/include/ps/internal/van.h +++ b/include/ps/internal/van.h @@ -16,7 +16,6 @@ #include "ps/internal/message.h" namespace ps { class Resender; -class PBMeta; /** * \brief Van sends messages to remote nodes * diff --git a/make/deps.mk b/make/deps.mk index 8463ecca..a8ec57be 100644 --- a/make/deps.mk +++ b/make/deps.mk @@ -1,21 +1,10 @@ # Install dependencies URL1=https://raw.githubusercontent.com/mli/deps/master/build -URL2=https://github.com/google/protobuf/releases/download/v3.5.1 ifndef WGET WGET = wget endif -# protobuf -PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h -${PROTOBUF}: - $(eval FILE=protobuf-cpp-3.5.1.tar.gz) - $(eval DIR=protobuf-3.5.1) - rm -rf $(FILE) $(DIR) - $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE) - cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install - rm -rf $(FILE) $(DIR) - # zmq ZMQ = ${DEPS_PATH}/include/zmq.h diff --git a/make/ps.mk b/make/ps.mk index 0b0f678d..866caf32 100644 --- a/make/ps.mk +++ b/make/ps.mk @@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1) ADD_CFLAGS += -DUSE_KEY32=1 endif -PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq -PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a) +PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lzmq +PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libzmq.a) diff --git a/src/meta.h b/src/meta.h new file mode 100644 index 00000000..98dff57b --- /dev/null +++ b/src/meta.h @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2018-2019 Bytedance Inc. + * Author: zhuyibo@bytedance.com (Yibo Zhu) +*/ +#ifndef PS_LITE_META_H_ +#define PS_LITE_META_H_ + +#include + +namespace ps { + +struct RawNode { + // the node role + int role; + // node id + int id; + // hostname or ip + char hostname[64]; + // the port this node is binding + int port; + // whether this node is created by failover + bool is_recovery; + // the locally unique id of an customer + int customer_id; +}; + +// system control info +struct RawControl { + int cmd; + int node_size; + int barrier_group; + uint64_t msg_sig; +}; + +// mete information about a message +struct RawMeta { + // message.head + int head; + // message.body + int body_size; + // if set, then it is system control task. otherwise, it is for app + RawControl control; + // true: a request task + // false: the response task to the request task with the same *time* + bool request; + // the unique id of an application + int app_id; + // the timestamp of this message + int timestamp; + // data type of message.data[i] + int data_type_size; + // the locally unique id of an customer + int customer_id; + // whether or not a push message + bool push; + // whether or not it's for SimpleApp + bool simple_app; + // message.data_size + int data_size; + // message.key + uint64_t key; + // message.addr + uint64_t addr; + // the length of the message's value + int val_len; + // the option field + int option; + + // body + // data_type + // node +}; + +} // namespace + +#endif diff --git a/src/meta.proto b/src/meta.proto deleted file mode 100644 index 9a146b8a..00000000 --- a/src/meta.proto +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright (c) 2015 by Contributors - */ -syntax = "proto2"; -package ps; -option optimize_for = LITE_RUNTIME; - -message PBNode { - // the node role - required int32 role = 1; - // node id - optional int32 id = 2; - // hostname or ip - optional string hostname = 3; - // the port this node is binding - optional int32 port = 4; - // whether this node is created by failover - optional bool is_recovery = 5; - // the locally unique id of an customer - optional int32 customer_id = 10; -} - -// system control info -message PBControl { - required int32 cmd = 1; - repeated PBNode node = 2; - optional int32 barrier_group = 3; - optional uint64 msg_sig = 4; -} - -// mete information about a message -message PBMeta { - // message.head - optional int32 head = 1; - // message.body - optional bytes body = 2; - // if set, then it is system control task. otherwise, it is for app - optional PBControl control = 3; - // true: a request task - // false: the response task to the request task with the same *time* - optional bool request = 4 [default = false]; - // the unique id of an application - optional int32 app_id = 7; - // the timestamp of this message - optional int32 timestamp = 8; - // data type of message.data[i] - repeated int32 data_type = 9 [packed=true]; - // the locally unique id of an customer - optional int32 customer_id = 10; - // whether or not a push message - optional bool push = 5; - // whether or not it's for SimpleApp - optional bool simple_app = 6 [default = false]; - // message.data_size - optional int32 data_size = 11; - // message.key - optional uint64 key = 12; - // message.addr - optional uint64 addr = 13 [default = 0]; - // the length of the message's value - optional int32 val_len = 14; - // the option field - optional int32 option = 15; -} diff --git a/src/van.cc b/src/van.cc index 483ac636..f98b02b2 100644 --- a/src/van.cc +++ b/src/van.cc @@ -14,7 +14,7 @@ #include "ps/internal/van.h" #include "ps/sarray.h" -#include "./meta.pb.h" +#include "./meta.h" #include "./network_utils.h" #include "./rdma_van.h" #include "./resender.h" @@ -495,93 +495,110 @@ void Van::Receiving() { } void Van::PackMeta(const Meta &meta, char **meta_buf, int *buf_size) { - // convert into protobuf - PBMeta pb; - pb.set_head(meta.head); - if (meta.app_id != Meta::kEmpty) pb.set_app_id(meta.app_id); - if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp); - if (meta.body.size()) pb.set_body(meta.body); - pb.set_push(meta.push); - pb.set_request(meta.request); - pb.set_simple_app(meta.simple_app); - pb.set_customer_id(meta.customer_id); - for (auto d : meta.data_type) pb.add_data_type(d); + *buf_size = sizeof(RawMeta) + meta.body.size() + + meta.data_type.size() * sizeof(int) + + meta.control.node.size() * sizeof(RawNode); + // allocate buffer only when needed + if (*meta_buf == nullptr) { + *meta_buf = new char[*buf_size + 1]; + } + + RawMeta *raw = (RawMeta*)*meta_buf; + char *raw_body = *meta_buf + sizeof(RawMeta); + int *raw_data_type = (int*)(raw_body + meta.body.size()); + RawNode *raw_node = (RawNode*)(raw_data_type + meta.data_type.size()); + + // convert into raw buffer + raw->head = meta.head; + raw->app_id = meta.app_id; + raw->timestamp = meta.timestamp; + if (meta.body.size()) { + memcpy(raw_body, meta.body.c_str(), meta.body.size()); + raw->body_size = meta.body.size(); + } + raw->push = meta.push; + raw->request = meta.request; + raw->simple_app = meta.simple_app; + raw->customer_id = meta.customer_id; + int data_type_count = 0; + for (auto d : meta.data_type) { + raw_data_type[data_type_count] = d; + data_type_count++; + } + raw->data_type_size = meta.data_type.size(); + auto ctrl = &(raw->control); if (!meta.control.empty()) { - auto ctrl = pb.mutable_control(); - ctrl->set_cmd(meta.control.cmd); + ctrl->cmd = meta.control.cmd; if (meta.control.cmd == Control::BARRIER) { - ctrl->set_barrier_group(meta.control.barrier_group); + ctrl->barrier_group = meta.control.barrier_group; } else if (meta.control.cmd == Control::ACK) { - ctrl->set_msg_sig(meta.control.msg_sig); + ctrl->msg_sig = meta.control.msg_sig; } + ctrl->node_size = meta.control.node.size(); + int node_count = 0; for (const auto &n : meta.control.node) { - auto p = ctrl->add_node(); - p->set_id(n.id); - p->set_role(n.role); - p->set_port(n.port); - p->set_hostname(n.hostname); - p->set_is_recovery(n.is_recovery); - p->set_customer_id(n.customer_id); + raw_node[node_count].id = n.id; + raw_node[node_count].role = n.role; + raw_node[node_count].port = n.port; + bzero(raw_node[node_count].hostname, sizeof(raw_node[node_count].hostname)); + memcpy(raw_node[node_count].hostname, n.hostname.c_str(), n.hostname.size()); + raw_node[node_count].is_recovery = n.is_recovery; + raw_node[node_count].customer_id = n.customer_id; + node_count++; } } - pb.set_data_size(meta.data_size); - pb.set_key(meta.key); - pb.set_addr(meta.addr); - pb.set_val_len(meta.val_len); - pb.set_option(meta.option); - - // to string - *buf_size = pb.ByteSize(); - // allocate buffer only when needed - if (*meta_buf == nullptr) { - *meta_buf = new char[*buf_size + 1]; + else { + ctrl->cmd = Control::EMPTY; } - CHECK(pb.SerializeToArray(*meta_buf, *buf_size)) << "failed to serialize protbuf"; + raw->data_size = meta.data_size; + raw->key = meta.key; + raw->addr = meta.addr; + raw->val_len = meta.val_len; + raw->option = meta.option; } void Van::UnpackMeta(const char *meta_buf, int buf_size, Meta *meta) { - // to protobuf - PBMeta pb; - CHECK(pb.ParseFromArray(meta_buf, buf_size)) - << "failed to parse string into protobuf, buf_size=" << buf_size; + + RawMeta *raw = (RawMeta*)meta_buf; + const char *raw_body = meta_buf + sizeof(RawMeta); + const int *raw_data_type = (const int*)(raw_body + raw->body_size); + const RawNode *raw_node = (RawNode*)(raw_data_type + raw->data_type_size); // to meta - meta->head = pb.head(); - meta->app_id = pb.has_app_id() ? pb.app_id() : Meta::kEmpty; - meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty; - meta->request = pb.request(); - meta->push = pb.push(); - meta->simple_app = pb.simple_app(); - meta->body = pb.body(); - meta->customer_id = pb.customer_id(); - meta->data_type.resize(pb.data_type_size()); - for (int i = 0; i < pb.data_type_size(); ++i) { - meta->data_type[i] = static_cast(pb.data_type(i)); + meta->head = raw->head; + meta->app_id = raw->app_id; + meta->timestamp = raw->timestamp; + meta->request = raw->request; + meta->push = raw->push; + meta->simple_app = raw->simple_app; + meta->body = std::string(raw_body, raw->body_size); + meta->customer_id = raw->customer_id; + meta->data_type.resize(raw->data_type_size); + for (int i = 0; i < raw->data_type_size; ++i) { + meta->data_type[i] = static_cast(raw_data_type[i]); } - if (pb.has_control()) { - const auto &ctrl = pb.control(); - meta->control.cmd = static_cast(ctrl.cmd()); - meta->control.barrier_group = ctrl.barrier_group(); - meta->control.msg_sig = ctrl.msg_sig(); - for (int i = 0; i < ctrl.node_size(); ++i) { - const auto &p = ctrl.node(i); - Node n; - n.role = static_cast(p.role()); - n.port = p.port(); - n.hostname = p.hostname(); - n.id = p.has_id() ? p.id() : Node::kEmpty; - n.is_recovery = p.is_recovery(); - n.customer_id = p.customer_id(); - meta->control.node.push_back(n); - } - } else { - meta->control.cmd = Control::EMPTY; + + auto ctrl = &(raw->control); + meta->control.cmd = static_cast(ctrl->cmd); + meta->control.barrier_group = ctrl->barrier_group; + meta->control.msg_sig = ctrl->msg_sig; + for (int i = 0; i < ctrl->node_size; ++i) { + const auto &p = raw_node[i]; + Node n; + n.role = static_cast(p.role); + n.port = p.port; + n.hostname = p.hostname; + n.id = p.id; + n.is_recovery = p.is_recovery; + n.customer_id = p.customer_id; + meta->control.node.push_back(n); } - meta->data_size = pb.data_size(); - meta->key = pb.key(); - meta->addr = pb.addr(); - meta->val_len = pb.val_len(); - meta->option = pb.option(); + + meta->data_size = raw->data_size; + meta->key = raw->key; + meta->addr = raw->addr; + meta->val_len = raw->val_len; + meta->option = raw->option; } void Van::Heartbeat() { From 44e38fae785696b19da8042d19ee07991390dd1e Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Thu, 12 Dec 2019 19:44:02 +0800 Subject: [PATCH 08/79] bzero after init serialization buffer --- src/van.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/van.cc b/src/van.cc index f98b02b2..7102d2ac 100644 --- a/src/van.cc +++ b/src/van.cc @@ -504,6 +504,7 @@ void Van::PackMeta(const Meta &meta, char **meta_buf, int *buf_size) { } RawMeta *raw = (RawMeta*)*meta_buf; + bzero(raw, sizeof(RawMeta)); char *raw_body = *meta_buf + sizeof(RawMeta); int *raw_data_type = (int*)(raw_body + meta.body.size()); RawNode *raw_node = (RawNode*)(raw_data_type + meta.data_type.size()); From 98de771daf8085496ef43746fd681114c034e9a0 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Thu, 12 Dec 2019 20:07:11 +0800 Subject: [PATCH 09/79] reduce copies for PackMeta --- include/ps/internal/van.h | 5 +++++ src/rdma_van.h | 9 +++------ src/van.cc | 10 +++++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/include/ps/internal/van.h b/include/ps/internal/van.h index 67f55bc9..a943975e 100644 --- a/include/ps/internal/van.h +++ b/include/ps/internal/van.h @@ -105,6 +105,11 @@ class Van { */ virtual int SendMsg(Message &msg) = 0; + /** + * \brief get the length of pack meta + */ + int GetPackMetaLen(const Meta &meta); + /** * \brief pack meta into a string */ diff --git a/src/rdma_van.h b/src/rdma_van.h index 9e4d9d36..df8e4132 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -904,14 +904,11 @@ class RDMAVan : public Van { } } - int meta_len; - char* meta_buf = nullptr; - PackMeta(msg.meta, &meta_buf, &meta_len); - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); + int meta_len = GetPackMetaLen(msg.meta); size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); @@ -920,7 +917,7 @@ class RDMAVan : public Van { if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); - memcpy(msg_buf->inline_buf, meta_buf, meta_len); + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); char *cur = msg_buf->inline_buf + meta_len; for (auto &sa : msg.data) { size_t seg_len = sa.size(); @@ -931,7 +928,7 @@ class RDMAVan : public Van { msg_buf->inline_len = meta_len; msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; - memcpy(msg_buf->inline_buf, meta_buf, meta_len); + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers for (auto &sa : msg_buf->data) { if (sa.size()) { diff --git a/src/van.cc b/src/van.cc index 7102d2ac..aa40880c 100644 --- a/src/van.cc +++ b/src/van.cc @@ -494,10 +494,14 @@ void Van::Receiving() { } } +int Van::GetPackMetaLen(const Meta &meta) { + return sizeof(RawMeta) + meta.body.size() + + meta.data_type.size() * sizeof(int) + + meta.control.node.size() * sizeof(RawNode); +} + void Van::PackMeta(const Meta &meta, char **meta_buf, int *buf_size) { - *buf_size = sizeof(RawMeta) + meta.body.size() + - meta.data_type.size() * sizeof(int) + - meta.control.node.size() * sizeof(RawNode); + *buf_size = GetPackMetaLen(meta); // allocate buffer only when needed if (*meta_buf == nullptr) { *meta_buf = new char[*buf_size + 1]; From dbf15e164db8ac2bdc19d7e6feb4a6c5b6d7d3a7 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 12 Dec 2019 20:47:31 +0800 Subject: [PATCH 10/79] wip: remove transport.h --- src/rdma_van.h | 243 ++++++++++++++++++++++++++++++++++++++++++++++-- src/transport.h | 228 --------------------------------------------- 2 files changed, 235 insertions(+), 236 deletions(-) delete mode 100644 src/transport.h diff --git a/src/rdma_van.h b/src/rdma_van.h index e3780acb..933e7d7e 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -9,7 +9,34 @@ #ifdef DMLC_USE_RDMA -#include "transport.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/internal/threadsafe_queue.h" +#include "ps/internal/van.h" + namespace ps { @@ -469,6 +496,203 @@ struct AsyncCopy { bool shutdown; }; + +class Transport { + public: + explicit Transport(Endpoint *endpoint) { + endpoint_ = endpoint; + }; + + ~Transport(); + + void SendPushResponse(MessageBuffer *msg_buf) { + + } + + void SendPullRequest(MessageBuffer *msg_buf) { + + } + + void SendControlMessage(MessageBuffer *msg_buf) { + if (no remote address) { + Rendezvous and get address; + } + RDMAWriteWithImm(msg_buf); + } + + void RDMAWriteWithImm(MessageBuffer *msg_buf) { + struct ibv_sge sge[1 + msg_buf->mrs.size()]; + sge[0].addr = reinterpret_cast(msg_buf->inline_buf); + sge[0].length = msg_buf->inline_len; + sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); + + size_t num_sge = 1; + for (auto &pair : msg_buf->mrs) { + size_t length = pair.second; + CHECK(length); + sge[num_sge].addr = + reinterpret_cast(pair.first->addr); + sge[num_sge].length = length; + sge[num_sge].lkey = pair.first->lkey; + ++num_sge; + } + if (is_server_) CHECK_EQ(num_sge, 1) << num_sge; + + WRContext *write_ctx = msg_buf->reserved_context; + + MessageBuffer **tmp = + reinterpret_cast(write_ctx->buffer->addr); + *tmp = msg_buf; // write the addr of msg_buf into the mr buffer + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(write_ctx); + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.next = nullptr; + + wr.imm_data = idx; + + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = sge; + wr.num_sge = num_sge; + + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + void Recv(Message *msg); + void SendPushRequest(MessageBuffer *msg_buf); + void SendPullResponse(MessageBuffer *msg_buf); + + Endpoint *endpoint_; +}; // class Transport + + + +class RDMATransport : public Transport { + public: + void SendPushRequest(MessageBuffer *msg_buf) override { + std::lock_guard lock(map_mu_); + uint64_t key = DecodeKey(msg.data[0]); + msg.meta.key = key; + + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = memory_mr_map_[vals.data()]->rkey; + } + + void SendPullResponse(MessageBuffer *msg_buf) override { + std::lock_guard lock(map_mu_); + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) + << "key=" << key << " not inited in key_meta_map"; + CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + + msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); + msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); + msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + + // RDMA write + auto raddr = std::get<1>(key_meta_map_[key][recver]); + auto rkey = std::get<2>(key_meta_map_[key][recver]); + + auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, memory_mr_map_.end()); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->data[1].data()); + sge.length = msg_buf->data[1].size(); + sge.lkey = temp_mr->second->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(raddr); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + // wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + wr.wr.rdma.remote_addr = raddr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + // get remote address using SEND/RECV + void GetRemoteAddr() { + + } + + // register RDMA memory + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (!sa.size()) continue; + CHECK(sa.data()); + std::lock_guard lock(map_mu_); + if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + memory_mr_map_[sa.data()] = temp_mr; + } + } + } + + void Recv(Message *msg) override { + + } + + private: + +}; // class RDMATransport + + +class IPCTransport : public Transport { + public: + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + std::lock_guard lock(map_mu_); + auto key = msg.meta.key; + auto recver = msg.meta.recver; + auto len = std::get<0>(key_meta_map_[key][recver]); + + // IPC + auto addr = (void*) msg_buf->data[1].data(); + CHECK(addr); + void* shm_addr = GetSharedMemory(kShmPrefix, key); + // async copy + AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + } + + void Recv(Message *msg) override { + + } +}; // class IPCTransport + + + class RDMAVan : public Van { public: RDMAVan() { @@ -639,7 +863,8 @@ class RDMAVan : public Van { endpoint->SetNodeID(node.id); - Transport* t = is_local_[node.id] ? std::make_unique() : std::make_unique(); + Transport *t = is_local_[node.id] ? + std::make_unique(endpoint) : std::make_unique(endpoint); endpoint->SetTransport(t); struct addrinfo *remote_addr; @@ -839,19 +1064,19 @@ class RDMAVan : public Van { auto trans = endpoint.GetTransport(); if (!IsValidPushpull(msg)) { // control message - trans->SendControlMessage(endpoint, msg_buf); + trans->SendControlMessage(msg_buf); } else if (msg.meta.push && msg.meta.request) { // worker, push request - trans->SendPushRequest(endpoint, msg_buf); + trans->SendPushRequest(msg_buf); } else if (msg.meta.push && !msg.meta.request) { // server, push response - trans->SendPushResponse(endpoint, msg_buf); + trans->SendPushResponse(msg_buf); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - trans->SendPullRequest(endpoint, msg_buf); + trans->SendPullRequest(msg_buf); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - trans->SendPullResponse(endpoint, msg_buf); + trans->SendPullResponse(msg_buf); } else { CHECK(0) << "unexpected message type"; } @@ -1425,7 +1650,9 @@ class RDMAVan : public Van { std::vector ipc_copy_thread_list_; std::vector*> async_copy_queue_; std::atomic cpy_counter_{0}; -}; // namespace ps + +}; // class RDMAVan + }; // namespace ps #endif // DMLC_USE_RDMA diff --git a/src/transport.h b/src/transport.h deleted file mode 100644 index 77f14849..00000000 --- a/src/transport.h +++ /dev/null @@ -1,228 +0,0 @@ -#ifndef PS_RDMA_VAN_H_ -#define PS_RDMA_VAN_H_ - -#ifdef DMLC_USE_RDMA - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ps/internal/threadsafe_queue.h" -#include "ps/internal/van.h" - -namespace ps { - -class Transport { - public: - explicit Transport(Endpoint *endpoint) { - endpoint_ = endpoint; - }; - - ~Transport(); - - void SendPushResponse(Message &msg) { - - } - - void SendPullRequest(Message &msg) { - - } - - void SendControlMessage(Message &msg) { - PBMeta meta; - PackMetaPB(msg.meta, &meta); - MessageBuffer *msg_buf = new MessageBuffer(); - - size_t meta_len = meta.ByteSize(); - size_t total_len = meta_len + msg.meta.data_size; - CHECK(meta_len); - - msg_buf->inline_len = total_len; - msg_buf->inline_buf = mempool_->Alloc(total_len); - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - - WRContext *context = nullptr, *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - endpoint_->free_start_ctx.WaitAndPop(&context); - - msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = meta_len; - req->origin_addr = reinterpret_cast(msg_buf); - - auto addr = reinterpret_cast(req); - - // rendezvous message, not data message - if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { - // local IPC with shared memory - req->data_num = 0; - } else { - // normal RDMA - req->data_num = msg.data.size(); - for (size_t i = 0; i < req->data_num; ++i) { - req->data_len[i] = msg.data[i].size(); - } - } - SendRendezvousBegin(endpoint_, addr, context, kRendezvousStart); - } - - void Recv(Message *msg); - void SendPushRequest(); - void SendPullResponse(); - - Endpoint* endpoint_; -}; // class Transport - - - -class RDMATransport : public Transport { - public: - void SendPushRequest(Message &msg) override { - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; - - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map_[vals.data()]->rkey; - } - - void SendPullResponse(Message &msg) override { - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - - // RDMA write - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); - - auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map_.end()); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(raddr); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - wr.wr.rdma.remote_addr = raddr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } - - // get remote address using SEND/RECV - void GetRemoteAddr() { - - } - - // register RDMA memory - void RegisterMemory(Message &msg) { - for (auto& sa : msg.data) { - if (!sa.size()) continue; - CHECK(sa.data()); - std::lock_guard lock(map_mu_); - if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { - struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map_[sa.data()] = temp_mr; - } - } - } - - void Recv(Message *msg) override { - - } - - private: - -}; // class RDMATransport - - - - - - - -class IPCTransport : public Transport { - public: - - void SendPushRequest(Message &msg) override { - - } - - void SendPullResponse(Message &msg) override { - std::lock_guard lock(map_mu_); - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - - // IPC - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy - AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); - } - - void Recv(Message *msg) override { - - } -}; // class IPCTransport - - - -} // namespace ps - -#endif // DMLC_USE_RDMA -#endif // PS_RDMA_VAN_H_ \ No newline at end of file From 1ffb37ef1cf373af5ff14a731ddd0d0ba1a59fcd Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 13 Dec 2019 11:16:47 +0800 Subject: [PATCH 11/79] wip: split the files --- src/rdma_utils.h | 515 +++++++++++++++++++++++++++++++++++ src/rdma_van.h | 682 +---------------------------------------------- src/transport.h | 187 +++++++++++++ 3 files changed, 703 insertions(+), 681 deletions(-) create mode 100644 src/rdma_utils.h create mode 100644 src/transport.h diff --git a/src/rdma_utils.h b/src/rdma_utils.h new file mode 100644 index 00000000..5d1043b9 --- /dev/null +++ b/src/rdma_utils.h @@ -0,0 +1,515 @@ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef PS_RDMA_VAN_H_ +#define PS_RDMA_VAN_H_ + +#ifdef DMLC_USE_RDMA + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/internal/threadsafe_queue.h" +#include "ps/internal/van.h" + + +namespace ps { + + +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) + +static const int kStartDepth = 128; +static const int kWriteDepth = kStartDepth; + +static const int kRxDepth = kStartDepth + kWriteDepth; +static const int kReplyDepth = kRxDepth; + +static const int kSGEntry = 4; +static const int kTimeoutms = 1000; +static const int kRdmaListenBacklog = 128; +static const int kMaxConcurrentWorkRequest = + kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; +static const int kMaxHostnameLength = 16; +static const int kMaxDataFields = 4; +static const size_t kAlignment = 8; + +static const int kMaxResolveRetry = 50000; +static const int kBasePort = 9010; + +// should have the same prefix with BytePS shared memory +static const std::string kShmPrefix("BytePS_ShM_"); + +template +static inline T align_floor(T v, T align) { + return v - (v % align); +} + +template +static inline T align_ceil(T v, T align) { + return align_floor(v + align - 1, align); +} + +static inline void ib_malloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 0, size); + *ptr = p; +} + +class SimpleMempool { + public: + explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) { + std::lock_guard lk(mu_); + pd_ = pd; + struct ibv_mr *mr; + + // set init mempool size + auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); + size = byteps_rdma_mempool_size ? atoi(byteps_rdma_mempool_size) : size; + auto byteps_rdma_mempool_num = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_NUM"); + size_t mempool_num = byteps_rdma_mempool_num ? atoi(byteps_rdma_mempool_num) : 1; + PS_VLOG(1) << "RDMA initial mempool size set to " << size + << ", mempool num set to " << mempool_num; + + for (size_t i = 0; i < mempool_num; ++i) { + char *p; + ib_malloc((void**) &p, size); + total_allocated_size += size; + CHECK(p); + CHECK(mr = ibv_reg_mr(pd, p, size, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); + mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size] + free_list.emplace(size, p); + } + } + + ~SimpleMempool() { + std::lock_guard lk(mu_); + for(auto it = mr_list.begin(); it != mr_list.end(); it++){ + CHECK_EQ(ibv_dereg_mr(it->second), 0); + free(it->second->addr); + } + } + + char *Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + std::lock_guard lk(mu_); + + size_t proper_size = align_ceil(size, kAlignment); + + auto it = free_list.lower_bound(proper_size); + + if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory + size_t new_mem_size = total_allocated_size; + while (proper_size > new_mem_size) { + new_mem_size *= 2; + } + char *p; + ib_malloc((void**) &p, new_mem_size); + CHECK(p); + struct ibv_mr *mr; + CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); + mr_list.emplace(p+new_mem_size, mr); + free_list.emplace(new_mem_size, p); + it = free_list.lower_bound(proper_size); + PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; + total_allocated_size += new_mem_size; + } + + CHECK_NE(free_list.end(), it) << "Not enough memory"; + CHECK_GE(it->first, proper_size); + + char *addr = it->second; + size_t space_left = it->first - proper_size; + + free_list.erase(it); + CHECK_EQ(used_list.find(addr), used_list.end()) + << "Address is already allocated"; + + used_list.emplace(addr, proper_size); + + if (space_left) { + free_list.emplace(space_left, addr + proper_size); + } + + return addr; + } + + void Free(char *addr) { + if (!addr) { + return; + } + + std::lock_guard lk(mu_); + + auto it = used_list.find(addr); + CHECK_NE(used_list.end(), it) + << "Cannot find info about address: " << (uintptr_t)addr; + + size_t size = it->second; + used_list.erase(it); + free_list.emplace(size, addr); + } + + uint32_t LocalKey(char *addr) { + struct ibv_mr *mr = Addr2MR(addr); + return mr->lkey; + } + uint32_t RemoteKey(char *addr) { + struct ibv_mr *mr = Addr2MR(addr); + return mr->rkey; + } + + private: + std::mutex mu_; + std::multimap free_list; + std::unordered_map used_list; + struct ibv_pd *pd_; + size_t total_allocated_size = 0; + + // first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) + std::map mr_list; + + // convert the memory address to its associated RDMA memory region + inline struct ibv_mr* Addr2MR(char *addr) { + std::lock_guard lk(mu_); + auto it = mr_list.lower_bound(addr); + CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; + return it->second; + } + +}; + +class Block { + public: + explicit Block(SimpleMempool *pool, char *addr, int count) + : pool(pool), addr(addr), counter(count) {} + + ~Block() { + CHECK_EQ(counter, 0); + pool->Free(addr); + } + + void Release() { + int v = counter.fetch_sub(1); + if (v == 1) { + delete this; + } + } + + private: + SimpleMempool *pool; + char *addr; + std::atomic counter; +}; + +enum MessageTypes : uint32_t { + kRendezvousStart, + kRendezvousReply, +}; + +struct RendezvousStart { + uint64_t meta_len; + uint64_t data_num; + uint64_t data_len[kMaxDataFields]; + uint64_t origin_addr; +}; + +struct RendezvousReply { + uint64_t addr; + uint64_t origin_addr; + uint32_t rkey; + uint32_t idx; +}; + +enum WRContextType { + kRendezvousStartContext, + kRendezvousReplyContext, + kWriteContext, + kReceiveContext +}; + +struct WRContext { + WRContextType type; + struct ibv_mr *buffer; + void *private_data; +}; + +struct BufferContext { + char *buffer; + size_t meta_len; + size_t data_num; + size_t data_len[kMaxDataFields]; +}; + +typedef std::unique_ptr> + MRPtr; + +struct MessageBuffer { + size_t inline_len; + char *inline_buf; + WRContext *reserved_context; + std::vector> data; + std::vector> mrs; +}; + +struct RequestContext { + uint32_t node; + uint16_t port; + char hostname[kMaxHostnameLength]; +}; + +static_assert(std::is_pod::value, + "RendezvousStart must be a POD type."); +static_assert(std::is_pod::value, + "RendezvousReply must be a POD type."); +static_assert(std::is_pod::value, + "RequestContext must be a POD type."); + +static const size_t kMempoolChunkSize = + std::max({sizeof(RendezvousStart), sizeof(RendezvousReply)}); + +template +class AddressPool { + public: + AddressPool() { + std::lock_guard lk(mu_); + // init the queue + for (int i = 0; i < kMaxEntries; i++) { + indices_.push(i); + table_[i] = nullptr; + } + } + + T *GetAddressAndRelease(uint32_t index) { + std::lock_guard lk(mu_); + T *ptr = table_[index]; + CHECK(ptr); + indices_.push(index); + table_[index] = nullptr; + return ptr; + } + + uint32_t StoreAddress(T *ptr) { + std::lock_guard lk(mu_); + CHECK(ptr); + CHECK(!indices_.empty()) + << "Address pool size is too small, " + << "consider increasing kMaxEntries"; + uint32_t idx = indices_.front(); + indices_.pop(); + CHECK_EQ(table_[idx], nullptr) << idx; + table_[idx] = ptr; + return idx; + } + + private: + static const int kMaxEntries = 5120; + + std::mutex mu_; + std::queue indices_; + T *table_[kMaxEntries]; +}; + +struct Endpoint { + enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; + + ConnectionStatus status; + int node_id; + std::condition_variable cv; + std::mutex connect_mu; + struct rdma_cm_id *cm_id; + std::unique_ptr tran; + + WRContext rx_ctx[kRxDepth]; + + WRContext start_ctx[kStartDepth]; + WRContext reply_ctx[kReplyDepth]; + WRContext write_ctx[kWriteDepth]; + + ThreadsafeQueue free_start_ctx; + ThreadsafeQueue free_reply_ctx; + ThreadsafeQueue free_write_ctx; + + Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} + + ~Endpoint() { + for (int i = 0; i < kRxDepth; ++i) { + if (!(rx_ctx[i].buffer)) { + continue; + } + free(rx_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); + } + + for (int i = 0; i < kStartDepth; ++i) { + if (start_ctx[i].buffer) { + free(start_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); + } + } + + for (int i = 0; i < kReplyDepth; ++i) { + if (reply_ctx[i].buffer) { + free(reply_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); + } + } + + for (int i = 0; i < kWriteDepth; ++i) { + if (write_ctx[i].buffer) { + free(write_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); + } + } + + rdma_destroy_qp(cm_id); + CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); + } + + void SetTransport(std::unique_ptr t) { tran = t; } + + std::unique_ptr GetTransport() { return tran; } + + void Disconnect() { + std::unique_lock lk(connect_mu); + CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); + cv.wait(lk, [this] { return status == IDLE; }); + tran.reset(); + } + + void SetNodeID(int id) { node_id = id; } + + void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, + ThreadsafeQueue *queue, size_t num, + WRContextType type) { + for (size_t i = 0; i < num; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); + CHECK(mr); + + ctx[i].type = type; + ctx[i].buffer = mr; + ctx[i].private_data = this; + queue->Push(&ctx[i]); + } + } + + void Init(struct ibv_cq *cq, struct ibv_pd *pd) { + struct ibv_qp_init_attr attr; + memset(&attr, 0, sizeof(ibv_qp_init_attr)); + attr.send_cq = cq; + attr.recv_cq = cq; + attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; + attr.cap.max_recv_wr = kRxDepth; + attr.cap.max_send_sge = kSGEntry; + attr.cap.max_recv_sge = kSGEntry; + attr.qp_type = IBV_QPT_RC; + attr.sq_sig_all = 0; + + CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) + << "Create RDMA queue pair failed"; + + InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, + kRendezvousStartContext); + InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, + kRendezvousReplyContext); + InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, + kWriteContext); + + for (size_t i = 0; i < kRxDepth; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = + ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); + CHECK(mr); + + rx_ctx[i].type = kReceiveContext; + rx_ctx[i].buffer = mr; + rx_ctx[i].private_data = this; + + PostRecv(&rx_ctx[i]); + } + } + + void PostRecv(WRContext *ctx) { + struct ibv_recv_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(ctx->buffer->addr); + sge.length = kMempoolChunkSize; + sge.lkey = ctx->buffer->lkey; + + wr.wr_id = reinterpret_cast(ctx); + wr.next = nullptr; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_recv failed."; + } +}; + +struct AsyncCopy { + Endpoint* endpoint; + MessageBuffer* msg_buf; + void* dst; + void* src; + int len; + uint64_t meta_len; + bool shutdown; +}; + + +}; // namespace ps + +#endif // DMLC_USE_RDMA +#endif // PS_RDMA_VAN_H_ + diff --git a/src/rdma_van.h b/src/rdma_van.h index 933e7d7e..3f916ed3 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -9,690 +9,10 @@ #ifdef DMLC_USE_RDMA -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ps/internal/threadsafe_queue.h" -#include "ps/internal/van.h" - +#include "transport.h" namespace ps { -#define DIVUP(x, y) (((x)+(y)-1)/(y)) -#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) - -static const int kStartDepth = 128; -static const int kWriteDepth = kStartDepth; - -static const int kRxDepth = kStartDepth + kWriteDepth; -static const int kReplyDepth = kRxDepth; - -static const int kSGEntry = 4; -static const int kTimeoutms = 1000; -static const int kRdmaListenBacklog = 128; -static const int kMaxConcurrentWorkRequest = - kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; -static const int kMaxHostnameLength = 16; -static const int kMaxDataFields = 4; -static const size_t kAlignment = 8; - -static const int kMaxResolveRetry = 50000; -static const int kBasePort = 9010; - -// should have the same prefix with BytePS shared memory -static const std::string kShmPrefix("BytePS_ShM_"); - -template -static inline T align_floor(T v, T align) { - return v - (v % align); -} - -template -static inline T align_ceil(T v, T align) { - return align_floor(v + align - 1, align); -} - -static inline void ib_malloc(void** ptr, size_t size) { - size_t page_size = sysconf(_SC_PAGESIZE); - void* p; - int size_aligned = ROUNDUP(size, page_size); - int ret = posix_memalign(&p, page_size, size_aligned); - CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); - CHECK(p); - memset(p, 0, size); - *ptr = p; -} - -class SimpleMempool { - public: - explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) { - std::lock_guard lk(mu_); - pd_ = pd; - struct ibv_mr *mr; - - // set init mempool size - auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); - size = byteps_rdma_mempool_size ? atoi(byteps_rdma_mempool_size) : size; - auto byteps_rdma_mempool_num = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_NUM"); - size_t mempool_num = byteps_rdma_mempool_num ? atoi(byteps_rdma_mempool_num) : 1; - PS_VLOG(1) << "RDMA initial mempool size set to " << size - << ", mempool num set to " << mempool_num; - - for (size_t i = 0; i < mempool_num; ++i) { - char *p; - ib_malloc((void**) &p, size); - total_allocated_size += size; - CHECK(p); - CHECK(mr = ibv_reg_mr(pd, p, size, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size] - free_list.emplace(size, p); - } - } - - ~SimpleMempool() { - std::lock_guard lk(mu_); - for(auto it = mr_list.begin(); it != mr_list.end(); it++){ - CHECK_EQ(ibv_dereg_mr(it->second), 0); - free(it->second->addr); - } - } - - char *Alloc(size_t size) { - if (size == 0) { - return nullptr; - } - - std::lock_guard lk(mu_); - - size_t proper_size = align_ceil(size, kAlignment); - - auto it = free_list.lower_bound(proper_size); - - if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory - size_t new_mem_size = total_allocated_size; - while (proper_size > new_mem_size) { - new_mem_size *= 2; - } - char *p; - ib_malloc((void**) &p, new_mem_size); - CHECK(p); - struct ibv_mr *mr; - CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+new_mem_size, mr); - free_list.emplace(new_mem_size, p); - it = free_list.lower_bound(proper_size); - PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; - total_allocated_size += new_mem_size; - } - - CHECK_NE(free_list.end(), it) << "Not enough memory"; - CHECK_GE(it->first, proper_size); - - char *addr = it->second; - size_t space_left = it->first - proper_size; - - free_list.erase(it); - CHECK_EQ(used_list.find(addr), used_list.end()) - << "Address is already allocated"; - - used_list.emplace(addr, proper_size); - - if (space_left) { - free_list.emplace(space_left, addr + proper_size); - } - - return addr; - } - - void Free(char *addr) { - if (!addr) { - return; - } - - std::lock_guard lk(mu_); - - auto it = used_list.find(addr); - CHECK_NE(used_list.end(), it) - << "Cannot find info about address: " << (uintptr_t)addr; - - size_t size = it->second; - used_list.erase(it); - free_list.emplace(size, addr); - } - - uint32_t LocalKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->lkey; - } - uint32_t RemoteKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->rkey; - } - - private: - std::mutex mu_; - std::multimap free_list; - std::unordered_map used_list; - struct ibv_pd *pd_; - size_t total_allocated_size = 0; - - // first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) - std::map mr_list; - - // convert the memory address to its associated RDMA memory region - inline struct ibv_mr* Addr2MR(char *addr) { - std::lock_guard lk(mu_); - auto it = mr_list.lower_bound(addr); - CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; - return it->second; - } - -}; - -class Block { - public: - explicit Block(SimpleMempool *pool, char *addr, int count) - : pool(pool), addr(addr), counter(count) {} - - ~Block() { - CHECK_EQ(counter, 0); - pool->Free(addr); - } - - void Release() { - int v = counter.fetch_sub(1); - if (v == 1) { - delete this; - } - } - - private: - SimpleMempool *pool; - char *addr; - std::atomic counter; -}; - -enum MessageTypes : uint32_t { - kRendezvousStart, - kRendezvousReply, -}; - -struct RendezvousStart { - uint64_t meta_len; - uint64_t data_num; - uint64_t data_len[kMaxDataFields]; - uint64_t origin_addr; -}; - -struct RendezvousReply { - uint64_t addr; - uint64_t origin_addr; - uint32_t rkey; - uint32_t idx; -}; - -enum WRContextType { - kRendezvousStartContext, - kRendezvousReplyContext, - kWriteContext, - kReceiveContext -}; - -struct WRContext { - WRContextType type; - struct ibv_mr *buffer; - void *private_data; -}; - -struct BufferContext { - char *buffer; - size_t meta_len; - size_t data_num; - size_t data_len[kMaxDataFields]; -}; - -typedef std::unique_ptr> - MRPtr; - -struct MessageBuffer { - size_t inline_len; - char *inline_buf; - WRContext *reserved_context; - std::vector> data; - std::vector> mrs; -}; - -struct RequestContext { - uint32_t node; - uint16_t port; - char hostname[kMaxHostnameLength]; -}; - -static_assert(std::is_pod::value, - "RendezvousStart must be a POD type."); -static_assert(std::is_pod::value, - "RendezvousReply must be a POD type."); -static_assert(std::is_pod::value, - "RequestContext must be a POD type."); - -static const size_t kMempoolChunkSize = - std::max({sizeof(RendezvousStart), sizeof(RendezvousReply)}); - -template -class AddressPool { - public: - AddressPool() { - std::lock_guard lk(mu_); - // init the queue - for (int i = 0; i < kMaxEntries; i++) { - indices_.push(i); - table_[i] = nullptr; - } - } - - T *GetAddressAndRelease(uint32_t index) { - std::lock_guard lk(mu_); - T *ptr = table_[index]; - CHECK(ptr); - indices_.push(index); - table_[index] = nullptr; - return ptr; - } - - uint32_t StoreAddress(T *ptr) { - std::lock_guard lk(mu_); - CHECK(ptr); - CHECK(!indices_.empty()) - << "Address pool size is too small, " - << "consider increasing kMaxEntries"; - uint32_t idx = indices_.front(); - indices_.pop(); - CHECK_EQ(table_[idx], nullptr) << idx; - table_[idx] = ptr; - return idx; - } - - private: - static const int kMaxEntries = 5120; - - std::mutex mu_; - std::queue indices_; - T *table_[kMaxEntries]; -}; - -struct Endpoint { - enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; - - ConnectionStatus status; - int node_id; - std::condition_variable cv; - std::mutex connect_mu; - struct rdma_cm_id *cm_id; - std::unique_ptr tran; - - WRContext rx_ctx[kRxDepth]; - - WRContext start_ctx[kStartDepth]; - WRContext reply_ctx[kReplyDepth]; - WRContext write_ctx[kWriteDepth]; - - ThreadsafeQueue free_start_ctx; - ThreadsafeQueue free_reply_ctx; - ThreadsafeQueue free_write_ctx; - - Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} - - ~Endpoint() { - for (int i = 0; i < kRxDepth; ++i) { - if (!(rx_ctx[i].buffer)) { - continue; - } - free(rx_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); - } - - for (int i = 0; i < kStartDepth; ++i) { - if (start_ctx[i].buffer) { - free(start_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kReplyDepth; ++i) { - if (reply_ctx[i].buffer) { - free(reply_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kWriteDepth; ++i) { - if (write_ctx[i].buffer) { - free(write_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); - } - } - - rdma_destroy_qp(cm_id); - CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); - } - - void SetTransport(std::unique_ptr t) { tran = t; } - - std::unique_ptr GetTransport() { return tran; } - - void Disconnect() { - std::unique_lock lk(connect_mu); - CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); - cv.wait(lk, [this] { return status == IDLE; }); - tran.reset(); - } - - void SetNodeID(int id) { node_id = id; } - - void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, - ThreadsafeQueue *queue, size_t num, - WRContextType type) { - for (size_t i = 0; i < num; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); - CHECK(mr); - - ctx[i].type = type; - ctx[i].buffer = mr; - ctx[i].private_data = this; - queue->Push(&ctx[i]); - } - } - - void Init(struct ibv_cq *cq, struct ibv_pd *pd) { - struct ibv_qp_init_attr attr; - memset(&attr, 0, sizeof(ibv_qp_init_attr)); - attr.send_cq = cq; - attr.recv_cq = cq; - attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; - attr.cap.max_recv_wr = kRxDepth; - attr.cap.max_send_sge = kSGEntry; - attr.cap.max_recv_sge = kSGEntry; - attr.qp_type = IBV_QPT_RC; - attr.sq_sig_all = 0; - - CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) - << "Create RDMA queue pair failed"; - - InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, - kRendezvousStartContext); - InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, - kRendezvousReplyContext); - InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, - kWriteContext); - - for (size_t i = 0; i < kRxDepth; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = - ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); - CHECK(mr); - - rx_ctx[i].type = kReceiveContext; - rx_ctx[i].buffer = mr; - rx_ctx[i].private_data = this; - - PostRecv(&rx_ctx[i]); - } - } - - void PostRecv(WRContext *ctx) { - struct ibv_recv_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(ctx->buffer->addr); - sge.length = kMempoolChunkSize; - sge.lkey = ctx->buffer->lkey; - - wr.wr_id = reinterpret_cast(ctx); - wr.next = nullptr; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_recv failed."; - } -}; - -struct AsyncCopy { - Endpoint* endpoint; - MessageBuffer* msg_buf; - void* dst; - void* src; - int len; - uint64_t meta_len; - bool shutdown; -}; - - -class Transport { - public: - explicit Transport(Endpoint *endpoint) { - endpoint_ = endpoint; - }; - - ~Transport(); - - void SendPushResponse(MessageBuffer *msg_buf) { - - } - - void SendPullRequest(MessageBuffer *msg_buf) { - - } - - void SendControlMessage(MessageBuffer *msg_buf) { - if (no remote address) { - Rendezvous and get address; - } - RDMAWriteWithImm(msg_buf); - } - - void RDMAWriteWithImm(MessageBuffer *msg_buf) { - struct ibv_sge sge[1 + msg_buf->mrs.size()]; - sge[0].addr = reinterpret_cast(msg_buf->inline_buf); - sge[0].length = msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - - size_t num_sge = 1; - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; - ++num_sge; - } - if (is_server_) CHECK_EQ(num_sge, 1) << num_sge; - - WRContext *write_ctx = msg_buf->reserved_context; - - MessageBuffer **tmp = - reinterpret_cast(write_ctx->buffer->addr); - *tmp = msg_buf; // write the addr of msg_buf into the mr buffer - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(write_ctx); - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = idx; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = sge; - wr.num_sge = num_sge; - - wr.wr.rdma.remote_addr = remote_addr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } - - void Recv(Message *msg); - void SendPushRequest(MessageBuffer *msg_buf); - void SendPullResponse(MessageBuffer *msg_buf); - - Endpoint *endpoint_; -}; // class Transport - - - -class RDMATransport : public Transport { - public: - void SendPushRequest(MessageBuffer *msg_buf) override { - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; - - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map_[vals.data()]->rkey; - } - - void SendPullResponse(MessageBuffer *msg_buf) override { - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - - // RDMA write - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); - - auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map_.end()); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(raddr); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - wr.wr.rdma.remote_addr = raddr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } - - // get remote address using SEND/RECV - void GetRemoteAddr() { - - } - - // register RDMA memory - void RegisterMemory(Message &msg) { - for (auto& sa : msg.data) { - if (!sa.size()) continue; - CHECK(sa.data()); - std::lock_guard lock(map_mu_); - if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { - struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map_[sa.data()] = temp_mr; - } - } - } - - void Recv(Message *msg) override { - - } - - private: - -}; // class RDMATransport - - -class IPCTransport : public Transport { - public: - - void SendPushRequest(Message &msg) override { - - } - - void SendPullResponse(Message &msg) override { - std::lock_guard lock(map_mu_); - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - - // IPC - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy - AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); - } - - void Recv(Message *msg) override { - - } -}; // class IPCTransport - - - class RDMAVan : public Van { public: RDMAVan() { diff --git a/src/transport.h b/src/transport.h new file mode 100644 index 00000000..b98af056 --- /dev/null +++ b/src/transport.h @@ -0,0 +1,187 @@ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef PS_RDMA_VAN_H_ +#define PS_RDMA_VAN_H_ + +#ifdef DMLC_USE_RDMA + +#include "rdma_utils.h" + +namespace ps { + + +class Transport { + public: + explicit Transport(Endpoint *endpoint) { + endpoint_ = endpoint; + }; + + ~Transport(); + + void SendPushResponse(MessageBuffer *msg_buf) { + + } + + void SendPullRequest(MessageBuffer *msg_buf) { + + } + + void SendControlMessage(MessageBuffer *msg_buf) { + if (no remote address) { + Rendezvous and get address; + } + RDMAWriteWithImm(msg_buf); + } + + void RDMAWriteWithImm(MessageBuffer *msg_buf) { + + } + + void Recv(Message *msg); + void SendPushRequest(MessageBuffer *msg_buf); + void SendPullResponse(MessageBuffer *msg_buf); + + Endpoint *endpoint_; +}; // class Transport + + + +class RDMATransport : public Transport { + public: + void SendPushRequest(MessageBuffer *msg_buf) override { + std::lock_guard lock(map_mu_); + uint64_t key = DecodeKey(msg.data[0]); + msg.meta.key = key; + + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = memory_mr_map_[vals.data()]->rkey; + } + + void SendPullResponse(MessageBuffer *msg_buf) override { + std::lock_guard lock(map_mu_); + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) + << "key=" << key << " not inited in key_meta_map"; + CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + + msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); + msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); + msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + + // RDMA write + auto raddr = std::get<1>(key_meta_map_[key][recver]); + auto rkey = std::get<2>(key_meta_map_[key][recver]); + + auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, memory_mr_map_.end()); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->data[1].data()); + sge.length = msg_buf->data[1].size(); + sge.lkey = temp_mr->second->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(raddr); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + // wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + wr.wr.rdma.remote_addr = raddr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + // get remote address using SEND/RECV + void GetRemoteAddr() { + + } + + // register RDMA memory + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (!sa.size()) continue; + CHECK(sa.data()); + std::lock_guard lock(map_mu_); + if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + memory_mr_map_[sa.data()] = temp_mr; + } + } + } + + void Recv(Message *msg) override { + + } + + private: + +}; // class RDMATransport + + +class IPCTransport : public Transport { + public: + + void SendPushRequest(Message &msg) override { + + } + + void SendPullResponse(Message &msg) override { + std::lock_guard lock(map_mu_); + auto key = msg.meta.key; + auto recver = msg.meta.recver; + auto len = std::get<0>(key_meta_map_[key][recver]); + + // IPC + auto addr = (void*) msg_buf->data[1].data(); + CHECK(addr); + void* shm_addr = GetSharedMemory(kShmPrefix, key); + // async copy + AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + } + + void Recv(Message *msg) override { + + } +}; // class IPCTransport + + + + +}; // namespace ps + +#endif // DMLC_USE_RDMA +#endif // PS_RDMA_VAN_H_ + From 38a245fce67305c56d63255b3cab726585792462 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 13 Dec 2019 11:17:48 +0800 Subject: [PATCH 12/79] nit --- src/rdma_van.h | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index 3f916ed3..0f138a77 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -1,9 +1,18 @@ -/** - * Copyright (c) 2018-2019 Bytedance Inc. - * Author: lanchang@bytedance.com (Chang Lan) - * jiangyimin@bytedance.com (Yimin Jiang) - * chenjingrong@bytedance.com (Jingrong Chen) -*/ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + #ifndef PS_RDMA_VAN_H_ #define PS_RDMA_VAN_H_ From c97b9eb521c633b3c44b22997cfa09ea1e89fd70 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 13 Dec 2019 18:17:16 +0800 Subject: [PATCH 13/79] wip: finish Transport base class --- src/rdma_utils.h | 4 +- src/rdma_van.h | 183 +++++-------------------------------------- src/transport.h | 200 ++++++++++++++++++++++++++++++++++------------- 3 files changed, 169 insertions(+), 218 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 5d1043b9..5f28e9a9 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -54,8 +54,8 @@ namespace ps { #define DIVUP(x, y) (((x)+(y)-1)/(y)) #define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) -static const int kStartDepth = 128; -static const int kWriteDepth = kStartDepth; +static const int kStartDepth = 1024; +static const int kWriteDepth = kStartDepth * 2; static const int kRxDepth = kStartDepth + kWriteDepth; static const int kReplyDepth = kRxDepth; diff --git a/src/rdma_van.h b/src/rdma_van.h index 0f138a77..94391813 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -378,23 +378,21 @@ class RDMAVan : public Van { meta.SerializeToArray(msg_buf->inline_buf, meta_len); msg_buf->data = msg.data; - // prepare memory - if (!is_server_ && !is_local_[remote_id]) { - for (auto &sa : msg_buf->data) { - if (!sa.size()) continue; - auto it = memory_mr_map_.find(sa.data()); - CHECK_NE(it, memory_mr_map_.end()) << "not registered memory region"; - MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); - } - } - auto trans = endpoint.GetTransport(); if (!IsValidPushpull(msg)) { - // control message - trans->SendControlMessage(msg_buf); - } else if (msg.meta.push && msg.meta.request) { + trans->SendRendezvousBegin(msg_buf); + return total_len; + } else { + auto is_push = msg.meta.push; + auto key = DecodeKey(msg.data[0]); + if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { + trans->SendRendezvousBegin(msg_buf); + return total_len; + } + } + + // already know remote address, directly use RDMA-write + if (msg.meta.push && msg.meta.request) { // worker, push request trans->SendPushRequest(msg_buf); } else if (msg.meta.push && !msg.meta.request) { @@ -433,111 +431,8 @@ class RDMAVan : public Van { uint64_t data_num = buffer_ctx->data_num; cur += buffer_ctx->meta_len; - bool is_released = false; - if (is_server_ && IsValidPushpull(*msg) && is_local_[msg->meta.sender]) { - // get data message from local shared memory - auto key = msg->meta.key; - - std::lock_guard lock(map_mu_); - if (key_addr_map_.find(key) == key_addr_map_.end()) { - key_addr_map_[key] = key; - } - - SArray keys; - SArray vals; - SArray lens; - keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); - msg->data.push_back(keys); - total_len += keys.size(); - - if (msg->meta.push && msg->meta.request) { // push request - auto len = msg->meta.val_len; - if (key_len_map_.find(key) == key_len_map_.end()) { - key_len_map_[key] = len; - } - CHECK_EQ(len, key_len_map_[key]) << "key=" << key - << ": " << len << ", " << key_len_map_[key]; - - void* addr = GetSharedMemory(kShmPrefix, key); - vals.reset(reinterpret_cast(addr), len, [](void *){}); - lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); - msg->data.push_back(vals); - msg->data.push_back(lens); - } else { // pull request - msg->data.push_back(vals); - } - total_len += vals.size() + lens.size(); - - mempool_->Free(buffer_ctx->buffer); - is_released = true; - } - - if (IsValidPushpull(*msg) && !msg->meta.push && !msg->meta.request) { - // worker, get message directly from inplace tensor - std::lock_guard lock(map_mu_); - auto key = msg->meta.key; - CHECK(!is_server_); - if (key_len_map_.find(key) == key_len_map_.end()) { - key_addr_map_[key] = (ps::Key) key; - key_len_map_[key] = (int) msg->meta.val_len; - } - CHECK_NE(key_len_map_.find(key), key_len_map_.end()) << key; - CHECK_NE(key_addr_map_.find(key), key_addr_map_.end()) << key; - - auto addr = msg->meta.addr; - - CHECK_NE(key_len_map_[key], 0) << msg->DebugString(); - - SArray keys; - SArray vals; - SArray lens; - - keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); - vals.reset(reinterpret_cast(addr), key_len_map_[key], [](void *){}); - lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); - - msg->data.push_back(keys); - msg->data.push_back(vals); - msg->data.push_back(lens); - total_len += keys.size() + vals.size() + lens.size(); - - mempool_->Free(buffer_ctx->buffer); - } else if (data_num > 0) { - Block *mem_block = - new Block(mempool_.get(), buffer_ctx->buffer, data_num); - - for (size_t i = 0; i < data_num; i++) { - uint32_t len = buffer_ctx->data_len[i]; - SArray data; - data.reset(cur, len, [mem_block](void *) { - mem_block->Release(); - }); // Defer the deletion of block_ref - msg->data.push_back(data); - cur += len; - total_len += len; - } - } else { - if (!is_released) mempool_->Free(buffer_ctx->buffer); - } - - if (msg->meta.push && msg->meta.request) { // server - CHECK(is_server_); - auto key = msg->meta.key; - auto len = msg->meta.val_len; - auto addr = msg->meta.addr; - auto rkey = msg->meta.option; - auto sender = msg->meta.sender; - - std::lock_guard lock(map_mu_); - if (key_meta_map_.find(key) == key_meta_map_.end() - || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { - key_meta_map_[key][sender] = std::make_tuple(len, addr, rkey); - } else { - CHECK_EQ(len, std::get<0>(key_meta_map_[key][sender])); - CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); - CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); - } - } + auto trans = endpoint.GetTransport(); + trans->Recv(msg); delete buffer_ctx; return total_len; @@ -679,6 +574,7 @@ class RDMAVan : public Van { SendRendezvousReply(endpoint, buf_ctx, req->origin_addr); } else if (imm == kRendezvousReply) { + auto trans = endpoint.GetTransport(); // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousReply"; RendezvousReply *resp = reinterpret_cast(mr->addr); @@ -690,49 +586,10 @@ class RDMAVan : public Van { MessageBuffer *msg_buf = reinterpret_cast(origin_addr); - struct ibv_sge sge[1 + msg_buf->mrs.size()]; - - sge[0].addr = reinterpret_cast(msg_buf->inline_buf); - sge[0].length = msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - - size_t num_sge = 1; - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; - ++num_sge; - } - if (is_server_) CHECK_EQ(num_sge, 1) << num_sge; - - WRContext *write_ctx = msg_buf->reserved_context; - - MessageBuffer **tmp = - reinterpret_cast(write_ctx->buffer->addr); - *tmp = msg_buf; // write the addr of msg_buf into the mr buffer - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(write_ctx); - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = idx; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = sge; - wr.num_sge = num_sge; - - wr.wr.rdma.remote_addr = remote_addr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - + // Before RDMA write, store the remote info so that + // subsequent write does not need repeated rendezvous + trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } else { CHECK(0); } diff --git a/src/transport.h b/src/transport.h index b98af056..cd032f8c 100644 --- a/src/transport.h +++ b/src/transport.h @@ -22,43 +22,166 @@ namespace ps { - class Transport { public: - explicit Transport(Endpoint *endpoint) { + explicit Transport(Endpoint *endpoint, SimpleMempool *mempool) { endpoint_ = endpoint; + mempool_ = mempool; }; ~Transport(); + void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + // prepare memory + for (auto& sa : msg_buf->data) { + if (!sa.size()) continue; + CHECK(sa.data()); + std::lock_guard lock(mr_mu_); + if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + mem_mr_map_[sa.data()] = temp_mr; + } + auto it = mem_mr_map_.find(sa.data()); + MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); + } + + // prepare RDMA write sge list + struct ibv_sge sge[1 + msg_buf->mrs.size()]; + sge[0].addr = reinterpret_cast(msg_buf->inline_buf); + sge[0].length = msg_buf->inline_len; + sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); + + size_t num_sge = 1; + for (auto &pair : msg_buf->mrs) { + size_t length = pair.second; + CHECK(length); + sge[num_sge].addr = + reinterpret_cast(pair.first->addr); + sge[num_sge].length = length; + sge[num_sge].lkey = pair.first->lkey; + ++num_sge; + } + + WRContext *write_ctx = msg_buf->reserved_context; + MessageBuffer **tmp = + reinterpret_cast(write_ctx->buffer->addr); + *tmp = msg_buf; // write the addr of msg_buf into the mr buffer + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(write_ctx); + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.next = nullptr; + wr.imm_data = idx; + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = sge; + wr.num_sge = num_sge; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + void SendPushResponse(MessageBuffer *msg_buf) { - + std::lock_guard lk(mu_); + auto key = DecodeKey(msg_buf->data[0]); + auto remote_addr = std::get<0>(push_addr_[key]); + auto rkey = std::get<1>(push_addr_[key]); + auto idx = std::get<2>(push_addr_[key]); + RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } void SendPullRequest(MessageBuffer *msg_buf) { + std::lock_guard lk(mu_); + auto key = DecodeKey(msg_buf->data[0]); + auto remote_addr = std::get<0>(pull_addr_[key]); + auto rkey = std::get<1>(pull_addr_[key]); + auto idx = std::get<2>(pull_addr_[key]); + RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } + bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) { + std::lock_guard lk(mu_); + if ( is_push && (push_addr_.find(key) != push_addr_.end())) return true; + if (!is_push && (pull_addr_.find(key) != pull_addr_.end())) return true; + // no remote info, store the msg_buf address and push/pull flag for RendezvousReply + msgbuf_cache_.emplace(reinterpret_cast(msg_buf), is_push); + return false; } - void SendControlMessage(MessageBuffer *msg_buf) { - if (no remote address) { - Rendezvous and get address; + void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + if (msg_buf->data.size() == 0) return; + auto key = DecodeKey(msg_buf->data[0]); + auto buf = reinterpret_cast(msg_buf); + + std::lock_guard lk(mu_); + auto is_push = msgbuf_cache_[buf]; + if (is_push) { + push_addr_.emplace(key, std::make_tuple(remote_addr, rkey, idx)); + } else { + pull_addr_.emplace(key, std::make_tuple(remote_addr, rkey, idx)); } - RDMAWriteWithImm(msg_buf); + msgbuf_cache_.erase(buf); } - void RDMAWriteWithImm(MessageBuffer *msg_buf) { + void SendRendezvousBegin(MessageBuffer *msg_buf) { + WRContext *context = nullptr, *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + endpoint_->free_start_ctx.WaitAndPop(&context); + + msg_buf->reserved_context = reserved; + RendezvousStart *req = + reinterpret_cast(context->buffer->addr); + req->meta_len = msg_buf->inline_len; + req->origin_addr = reinterpret_cast(msg_buf); + req->data_num = msg_buf->data.size(); + for (size_t i = 0; i < req->data_num; ++i) { + req->data_len[i] = msg->data[i].size(); + } + struct ibv_sge sge; + sge.addr = reinterpret_cast(req); + sge.lkey = context->buffer->lkey; + sge.length = sizeof(RendezvousStart); + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(context); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + wr.imm_data = kRendezvousStart; + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << strerror(errno); } void Recv(Message *msg); void SendPushRequest(MessageBuffer *msg_buf); void SendPullResponse(MessageBuffer *msg_buf); - + + private: Endpoint *endpoint_; + SimpleMempool *mempool_; + std::unordered_map > push_addr_; // key, + std::unordered_map > pull_addr_; // key, + std::unordered_map msgbuf_cache_; // msg_buf, is_push + std::mutex mu_; + + std::unordered_map mem_mr_map_; + std::mutex mr_mu_; }; // class Transport - class RDMATransport : public Transport { public: void SendPushRequest(MessageBuffer *msg_buf) override { @@ -118,34 +241,10 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - // get remote address using SEND/RECV - void GetRemoteAddr() { - - } - - // register RDMA memory - void RegisterMemory(Message &msg) { - for (auto& sa : msg.data) { - if (!sa.size()) continue; - CHECK(sa.data()); - std::lock_guard lock(map_mu_); - if (memory_mr_map_.find(sa.data()) == memory_mr_map_.end()) { - struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map_[sa.data()] = temp_mr; - } - } - } - void Recv(Message *msg) override { } - private: - }; // class RDMATransport @@ -153,33 +252,28 @@ class IPCTransport : public Transport { public: void SendPushRequest(Message &msg) override { - + // get from shared memory } void SendPullResponse(Message &msg) override { - std::lock_guard lock(map_mu_); - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - - // IPC - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy - AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + // std::lock_guard lock(map_mu_); + // auto key = msg.meta.key; + // auto recver = msg.meta.recver; + // auto len = std::get<0>(key_meta_map_[key][recver]); + + // // IPC + // auto addr = (void*) msg_buf->data[1].data(); + // CHECK(addr); + // void* shm_addr = GetSharedMemory(kShmPrefix, key); + // // async copy + // AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; + // auto cnt = cpy_counter_.fetch_add(1); + // async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); } - void Recv(Message *msg) override { - - } }; // class IPCTransport - - }; // namespace ps #endif // DMLC_USE_RDMA From b0fccf4d35d46586643215422ea89d3c4d12e695 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Fri, 13 Dec 2019 20:12:13 +0800 Subject: [PATCH 14/79] refactor transport to pure virtual class --- src/rdma_utils.h | 4 +-- src/rdma_van.h | 2 +- src/transport.h | 63 +++++++++++++++++++++++++++--------------------- 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 5f28e9a9..000f8d9d 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef PS_RDMA_VAN_H_ -#define PS_RDMA_VAN_H_ +#ifndef PS_RDMA_UTILS_H_ +#define PS_RDMA_UTILS_H_ #ifdef DMLC_USE_RDMA diff --git a/src/rdma_van.h b/src/rdma_van.h index 94391813..f3314ecc 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -431,7 +431,7 @@ class RDMAVan : public Van { uint64_t data_num = buffer_ctx->data_num; cur += buffer_ctx->meta_len; - auto trans = endpoint.GetTransport(); + auto trans = endpoint->GetTransport(); trans->Recv(msg); delete buffer_ctx; diff --git a/src/transport.h b/src/transport.h index cd032f8c..0aba822a 100644 --- a/src/transport.h +++ b/src/transport.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef PS_RDMA_VAN_H_ -#define PS_RDMA_VAN_H_ +#ifndef PS_RDMA_TRANSPORT_H_ +#define PS_RDMA_TRANSPORT_H_ #ifdef DMLC_USE_RDMA @@ -24,12 +24,30 @@ namespace ps { class Transport { public: - explicit Transport(Endpoint *endpoint, SimpleMempool *mempool) { + + virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; + virtual void Recv(Message *msg) = 0; + + virtual void SendPullRequest(MessageBuffer *msg_buf) = 0; + virtual void SendPushRequest(MessageBuffer *msg_buf) = 0; + virtual void SendPushResponse(MessageBuffer *msg_buf) = 0; + virtual void SendPullResponse(MessageBuffer *msg_buf) = 0; + + virtual bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) = 0; + virtual void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; + virtual void SendRendezvousBegin(MessageBuffer *msg_buf) = 0; + +}; // class Transport + + +class RDMATransport : public Transport { + public: + explicit RDMATransport(Endpoint *endpoint, SimpleMempool *mempool) { endpoint_ = endpoint; mempool_ = mempool; }; - ~Transport(); + ~RDMATransport(); void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { // prepare memory @@ -165,25 +183,6 @@ class Transport { << strerror(errno); } - void Recv(Message *msg); - void SendPushRequest(MessageBuffer *msg_buf); - void SendPullResponse(MessageBuffer *msg_buf); - - private: - Endpoint *endpoint_; - SimpleMempool *mempool_; - std::unordered_map > push_addr_; // key, - std::unordered_map > pull_addr_; // key, - std::unordered_map msgbuf_cache_; // msg_buf, is_push - std::mutex mu_; - - std::unordered_map mem_mr_map_; - std::mutex mr_mu_; -}; // class Transport - - -class RDMATransport : public Transport { - public: void SendPushRequest(MessageBuffer *msg_buf) override { std::lock_guard lock(map_mu_); uint64_t key = DecodeKey(msg.data[0]); @@ -241,14 +240,24 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - void Recv(Message *msg) override { - + void Recv(Message *msg) { + } + + protected: + Endpoint *endpoint_; + SimpleMempool *mempool_; + std::unordered_map > push_addr_; // key, + std::unordered_map > pull_addr_; // key, + std::unordered_map msgbuf_cache_; // msg_buf, is_push + std::mutex mu_; -}; // class RDMATransport + std::unordered_map mem_mr_map_; + std::mutex mr_mu_; +}; // class Transport -class IPCTransport : public Transport { +class IPCTransport : public RDMATransport { public: void SendPushRequest(Message &msg) override { From d871b30eec7a16e091be0ae7494449f463c7f9c5 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Fri, 13 Dec 2019 21:51:42 +0800 Subject: [PATCH 15/79] refactored most parts of IPCTransport --- src/rdma_utils.h | 33 ++++++- src/rdma_van.h | 237 +++-------------------------------------------- src/transport.h | 224 +++++++++++++++++++++++++++++++++++++++----- 3 files changed, 245 insertions(+), 249 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 000f8d9d..1313db16 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -203,6 +203,10 @@ class SimpleMempool { return mr->rkey; } + struct ibv_pd* GetPD() { + return pd_; + } + private: std::mutex mu_; std::multimap free_list; @@ -362,7 +366,7 @@ struct Endpoint { std::condition_variable cv; std::mutex connect_mu; struct rdma_cm_id *cm_id; - std::unique_ptr tran; + std::shared_ptr tran; WRContext rx_ctx[kRxDepth]; @@ -410,9 +414,9 @@ struct Endpoint { CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); } - void SetTransport(std::unique_ptr t) { tran = t; } + void SetTransport(std::shared_ptr t) { tran = t; } - std::unique_ptr GetTransport() { return tran; } + std::shared_ptr GetTransport() { return tran; } void Disconnect() { std::unique_lock lk(connect_mu); @@ -508,6 +512,29 @@ struct AsyncCopy { }; +bool IsValidPushpull(const Message &msg) { + if (!msg.meta.control.empty()) return false; + if (msg.meta.simple_app) return false; + return true; +} + +uint64_t DecodeKey(SArray keys) { // just a translation, the decoded key might not be readable when we have multiple servers + ps::Key key = 0; + uint64_t coef = 1; + for (unsigned int i = 0; i < keys.size(); ++i) { + key += coef * (uint8_t) keys.data()[i]; + coef *= 256; // 256=2^8 (uint8_t) + } + return key; +} + +uint64_t DecodeWorkerKey(uint64_t key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::Postoffice::Get()->my_rank()]; + return key - kr.begin(); +} + +int AlignTo(int input, int alignment) { return input / alignment * alignment; } + }; // namespace ps #endif // DMLC_USE_RDMA diff --git a/src/rdma_van.h b/src/rdma_van.h index 7a04fad5..d1973f39 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -34,41 +34,10 @@ class RDMAVan : public Van { start_mu_.lock(); should_stop_ = false; - auto val = Environment::Get()->find("DMLC_ROLE"); - std::string role(val); - is_server_ = role=="server"; - if (is_server_) LOG(INFO) << "This is server"; - else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); - - val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); + auto val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); disable_ipc_ = val ? !atoi(val) : true; if (disable_ipc_) LOG(INFO) << "Shared memory IPC has been disabled"; - val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); - byteps_partition_bytes_ = val ? atoi(val) : 4096000; - - val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); - auto byteps_local_size = val ? atoi(val) : 1; - byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); - if (!disable_ipc_) { - CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; - LOG(INFO) << "partition bytes set to " << byteps_partition_bytes_ << ", should be identical with byteps core"; - } - - val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); - ipc_copy_nthreads_ = val ? atoi(val) : 4; - if (!disable_ipc_) { - LOG(INFO) << "IPC async copy nthreads set to " << ipc_copy_nthreads_; - for (int i = 0; i < ipc_copy_nthreads_; ++i) { - auto q = new ThreadsafeQueue; - async_copy_queue_.push_back(q); - } - for (int i = 0; i < ipc_copy_nthreads_; ++i) { - auto t = new std::thread(&RDMAVan::AsyncCopyThread, this, i); - ipc_copy_thread_list_.push_back(t); - } - } - if (event_channel_ == nullptr) { event_channel_ = rdma_create_event_channel(); CHECK(event_channel_) << "Create RDMA event channel failed"; @@ -99,8 +68,6 @@ class RDMAVan : public Van { PS_VLOG(1) << "Clearing mempool."; mempool_.reset(); - for (auto& it : memory_mr_map_) ibv_dereg_mr(it.second); - PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); endpoints_.clear(); @@ -109,13 +76,6 @@ class RDMAVan : public Van { CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; CHECK(!ibv_destroy_comp_channel(comp_event_channel_)) << "Failed to destroy channel"; - - for (size_t i = 0; i < ipc_copy_thread_list_.size(); ++i) { - AsyncCopy m; - m.shutdown = true; - async_copy_queue_[i]->Push(m); - ipc_copy_thread_list_[i]->join(); - } // TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this // CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << @@ -192,8 +152,9 @@ class RDMAVan : public Van { endpoint->SetNodeID(node.id); - Transport *t = is_local_[node.id] ? - std::make_unique(endpoint) : std::make_unique(endpoint); + std::shared_ptr t = is_local_[node.id] ? + std::make_shared(endpoint, mempool_.get()) : + std::make_shared(endpoint, mempool_.get()); endpoint->SetTransport(t); struct addrinfo *remote_addr; @@ -259,115 +220,12 @@ class RDMAVan : public Van { } } - bool IsValidPushpull(const Message &msg) { - if (!msg.meta.control.empty()) return false; - if (msg.meta.simple_app) return false; - return true; - } - - uint64_t DecodeKey(SArray keys) { // just a translation, the decoded key might not be readable when we have multiple servers - ps::Key key = 0; - uint64_t coef = 1; - for (unsigned int i = 0; i < keys.size(); ++i) { - key += coef * (uint8_t) keys.data()[i]; - coef *= 256; // 256=2^8 (uint8_t) - } - return key; - } - - uint64_t DecodeWorkerKey(uint64_t key) { - auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::Postoffice::Get()->my_rank()]; - return key - kr.begin(); - } - - void* GetSharedMemory(const std::string& prefix, uint64_t key) { - std::lock_guard lock(shm_mu_); - auto worker_key = DecodeWorkerKey(key); - auto seq_num = worker_key % (1 << 16); - auto base_key = worker_key - seq_num; - uint64_t offset = byteps_partition_bytes_ * seq_num; - if (key_shm_addr_.find(base_key) != key_shm_addr_.end()) { - return key_shm_addr_[base_key] + offset; - } - std::string shm_name(prefix); - shm_name += std::to_string(base_key); - int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; - - struct stat sb; - CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); - auto total_shm_size = sb.st_size; - - void* base_ptr = mmap(0, total_shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); - CHECK_NE(base_ptr, (void*) -1) << strerror(errno); - key_shm_addr_[base_key] = base_ptr; - - LOG(INFO) << "open Shared Memory: " << shm_name - << ", offset=" << offset - << ", (in bytes) size=" << total_shm_size; - return key_shm_addr_[base_key] + offset; - } - - void SendRendezvousBegin(Endpoint* endpoint, - uint64_t origin_addr, WRContext *context, MessageTypes msg_type) { - struct ibv_sge sge; - sge.addr = origin_addr; - sge.lkey = context->buffer->lkey; - sge.length = sizeof(RendezvousStart); - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(context); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = msg_type; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << strerror(errno); - } - - void AsyncCopyThread(int i) { - auto& q = async_copy_queue_[i]; - while (true) { - AsyncCopy m; - q->WaitAndPop(&m); - if (m.shutdown) break; - if (m.len == 0) continue; - - // TODO: use parallel copy - CHECK(m.dst); - CHECK(m.src); - memcpy(m.dst, m.src, m.len); - - WRContext *context = nullptr, *reserved = nullptr; - m.endpoint->free_write_ctx.WaitAndPop(&reserved); - m.endpoint->free_start_ctx.WaitAndPop(&context); - - m.msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = m.meta_len; - req->origin_addr = reinterpret_cast(m.msg_buf); - - auto addr = reinterpret_cast(req); - req->data_num = 0; - SendRendezvousBegin(m.endpoint, addr, context, kRendezvousStart); - } - } - int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); - Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); int meta_len = GetPackMetaLen(msg.meta); @@ -380,15 +238,15 @@ class RDMAVan : public Van { PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); msg_buf->data = msg.data; - auto trans = endpoint.GetTransport(); + auto trans = endpoint->GetTransport(); if (!IsValidPushpull(msg)) { - trans->SendRendezvousBegin(msg_buf); + trans->SendRendezvousBegin(msg, msg_buf); return total_len; } else { auto is_push = msg.meta.push; auto key = DecodeKey(msg.data[0]); if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { - trans->SendRendezvousBegin(msg_buf); + trans->SendRendezvousBegin(msg, msg_buf); return total_len; } } @@ -396,16 +254,16 @@ class RDMAVan : public Van { // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - trans->SendPushRequest(msg_buf); + trans->SendPushRequest(msg, msg_buf); } else if (msg.meta.push && !msg.meta.request) { // server, push response - trans->SendPushResponse(msg_buf); + trans->SendPushResponse(msg, msg_buf); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - trans->SendPullRequest(msg_buf); + trans->SendPullRequest(msg, msg_buf); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - trans->SendPullResponse(msg_buf); + trans->SendPullResponse(msg, msg_buf); } else { CHECK(0) << "unexpected message type"; } @@ -478,40 +336,6 @@ class RDMAVan : public Van { CHECK(0); } } - - void SendRendezvousReply(Endpoint* endpoint, BufferContext *buf_ctx, uint64_t origin_addr) { - WRContext *reply_ctx = nullptr; - endpoint->free_reply_ctx.WaitAndPop(&reply_ctx); - RendezvousReply *resp = - reinterpret_cast(reply_ctx->buffer->addr); - - char* buffer = buf_ctx->buffer; - resp->addr = reinterpret_cast(buffer); - resp->rkey = mempool_->RemoteKey(buffer); - resp->origin_addr = origin_addr; - resp->idx = addr_pool_.StoreAddress(buf_ctx); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(resp); - sge.length = sizeof(RendezvousReply); - sge.lkey = reply_ctx->buffer->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(reply_ctx); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = kRendezvousReply; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } void PollCQ() { // Pre-allocated work completions array used for polling @@ -561,22 +385,9 @@ class RDMAVan : public Van { // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousStart"; RendezvousStart *req = reinterpret_cast(mr->addr); - - BufferContext *buf_ctx = new BufferContext(); - uint64_t len = req->meta_len; - buf_ctx->meta_len = req->meta_len; - buf_ctx->data_num = req->data_num; - for (size_t i = 0; i < req->data_num; ++i) { - buf_ctx->data_len[i] = req->data_len[i]; - len += req->data_len[i]; - } - char *buffer = mempool_->Alloc(is_server_ ? len : req->meta_len); - CHECK(buffer) << len; - buf_ctx->buffer = buffer; - - SendRendezvousReply(endpoint, buf_ctx, req->origin_addr); + endpoint->GetTransport()->SendRendezvousReply(req, addr_pool_); } else if (imm == kRendezvousReply) { - auto trans = endpoint.GetTransport(); + auto trans = endpoint->GetTransport(); // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousReply"; RendezvousReply *resp = reinterpret_cast(mr->addr); @@ -776,8 +587,6 @@ class RDMAVan : public Van { LOG(INFO) << "OnDisconnected from Node " << endpoint->node_id; } - int AlignTo(int input, int alignment) { return input / alignment * alignment; } - AddressPool addr_pool_; std::unique_ptr mempool_; @@ -793,8 +602,6 @@ class RDMAVan : public Van { struct rdma_event_channel *event_channel_ = nullptr; struct ibv_context *context_ = nullptr; - std::unordered_map memory_mr_map_; - // ibverbs protection domain struct ibv_pd *pd_ = nullptr; // Completion event channel, to wait for work completions @@ -808,17 +615,9 @@ class RDMAVan : public Van { // Recv buffer queue ThreadsafeQueue> recv_buffers_; - // role is server or worker - bool is_server_; // RDMA logging info bool enable_rdma_log_; - std::mutex map_mu_; - // macros for key_meta_map - using MetaInfo = std::tuple; // len, addr, rkey - using SenderMeta = std::unordered_map; // sender as the key - // (key, sender) --> MetaInfo - std::unordered_map key_meta_map_; // a static address for the key std::unordered_map key_addr_map_; // a static address for the length @@ -829,16 +628,6 @@ class RDMAVan : public Van { std::mutex local_mu_; std::unordered_map is_local_; - std::mutex shm_mu_; - std::unordered_map key_shm_addr_; - - int byteps_partition_bytes_ = 4096000; - - int ipc_copy_nthreads_; - std::vector ipc_copy_thread_list_; - std::vector*> async_copy_queue_; - std::atomic cpy_counter_{0}; - }; // class RDMAVan }; // namespace ps diff --git a/src/transport.h b/src/transport.h index 0aba822a..5c0e4f2c 100644 --- a/src/transport.h +++ b/src/transport.h @@ -28,14 +28,15 @@ class Transport { virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; virtual void Recv(Message *msg) = 0; - virtual void SendPullRequest(MessageBuffer *msg_buf) = 0; - virtual void SendPushRequest(MessageBuffer *msg_buf) = 0; - virtual void SendPushResponse(MessageBuffer *msg_buf) = 0; - virtual void SendPullResponse(MessageBuffer *msg_buf) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; virtual bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) = 0; virtual void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual void SendRendezvousBegin(MessageBuffer *msg_buf) = 0; }; // class Transport @@ -45,9 +46,16 @@ class RDMATransport : public Transport { explicit RDMATransport(Endpoint *endpoint, SimpleMempool *mempool) { endpoint_ = endpoint; mempool_ = mempool; + auto val = Environment::Get()->find("DMLC_ROLE"); + std::string role(val); + is_server_ = role=="server"; + if (is_server_) LOG(INFO) << "This is server"; + else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); }; - ~RDMATransport(); + ~RDMATransport() { + for (auto& it : mem_mr_map_) ibv_dereg_mr(it.second); + }; void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { // prepare memory @@ -57,7 +65,7 @@ class RDMATransport : public Transport { std::lock_guard lock(mr_mu_); if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), + CHECK (temp_mr = ibv_reg_mr(mempool_->GetPD(), sa.data(), sa.size(), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) << "Failed to register the memory region: " << strerror(errno) << ", sa.size()=" << sa.size(); @@ -107,7 +115,7 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - void SendPushResponse(MessageBuffer *msg_buf) { + void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { std::lock_guard lk(mu_); auto key = DecodeKey(msg_buf->data[0]); auto remote_addr = std::get<0>(push_addr_[key]); @@ -116,7 +124,7 @@ class RDMATransport : public Transport { RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } - void SendPullRequest(MessageBuffer *msg_buf) { + void SendPullRequest(Message &msg, MessageBuffer *msg_buf) { std::lock_guard lk(mu_); auto key = DecodeKey(msg_buf->data[0]); auto remote_addr = std::get<0>(pull_addr_[key]); @@ -149,7 +157,7 @@ class RDMATransport : public Transport { msgbuf_cache_.erase(buf); } - void SendRendezvousBegin(MessageBuffer *msg_buf) { + void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) { WRContext *context = nullptr, *reserved = nullptr; endpoint_->free_write_ctx.WaitAndPop(&reserved); endpoint_->free_start_ctx.WaitAndPop(&context); @@ -161,7 +169,7 @@ class RDMATransport : public Transport { req->origin_addr = reinterpret_cast(msg_buf); req->data_num = msg_buf->data.size(); for (size_t i = 0; i < req->data_num; ++i) { - req->data_len[i] = msg->data[i].size(); + req->data_len[i] = msg.data[i].size(); } struct ibv_sge sge; @@ -182,22 +190,68 @@ class RDMATransport : public Transport { CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) << strerror(errno); } - - void SendPushRequest(MessageBuffer *msg_buf) override { + + void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) { + + BufferContext *buf_ctx = new BufferContext(); + uint64_t len = req->meta_len; + buf_ctx->meta_len = req->meta_len; + buf_ctx->data_num = req->data_num; + for (size_t i = 0; i < req->data_num; ++i) { + buf_ctx->data_len[i] = req->data_len[i]; + len += req->data_len[i]; + } + char *buffer = mempool_->Alloc(is_server_ ? len : req->meta_len); + CHECK(buffer) << len; + buf_ctx->buffer = buffer; + WRContext *reply_ctx = nullptr; + endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx); + RendezvousReply *resp = + reinterpret_cast(reply_ctx->buffer->addr); + + char* buffer = buf_ctx->buffer; + resp->addr = reinterpret_cast(buffer); + resp->rkey = mempool_->RemoteKey(buffer); + resp->origin_addr = req->origin_addr; + resp->idx = pool.StoreAddress(buf_ctx); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(resp); + sge.length = sizeof(RendezvousReply); + sge.lkey = reply_ctx->buffer->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(reply_ctx); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + + wr.imm_data = kRendezvousReply; + + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { std::lock_guard lock(map_mu_); uint64_t key = DecodeKey(msg.data[0]); msg.meta.key = key; CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map_.find(msg.data[1].data()), memory_mr_map_.end()); + CHECK_NE(mem_mr_map_.find(msg.data[1].data()), mem_mr_map_.end()); auto& vals = msg.data[1]; msg.meta.addr = reinterpret_cast(vals.data()); // vals address msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map_[vals.data()]->rkey; + msg.meta.option = mem_mr_map_[vals.data()]->rkey; } - void SendPullResponse(MessageBuffer *msg_buf) override { + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { std::lock_guard lock(map_mu_); uint64_t key = msg.meta.key; auto recver = msg.meta.recver; @@ -215,8 +269,8 @@ class RDMATransport : public Transport { auto raddr = std::get<1>(key_meta_map_[key][recver]); auto rkey = std::get<2>(key_meta_map_[key][recver]); - auto temp_mr = memory_mr_map_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map_.end()); + auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_map_.end()); struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->data[1].data()); @@ -236,35 +290,75 @@ class RDMATransport : public Transport { wr.wr.rdma.remote_addr = raddr; wr.wr.rdma.rkey = rkey; - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) << "ibv_post_send failed."; } void Recv(Message *msg) { - + } protected: Endpoint *endpoint_; SimpleMempool *mempool_; + // role is server or worker + bool is_server_; std::unordered_map > push_addr_; // key, std::unordered_map > pull_addr_; // key, std::unordered_map msgbuf_cache_; // msg_buf, is_push std::mutex mu_; std::unordered_map mem_mr_map_; + std::mutex map_mu_; + // macros for key_meta_map + using MetaInfo = std::tuple; // len, addr, rkey + using SenderMeta = std::unordered_map; // sender as the key + // (key, sender) --> MetaInfo + std::unordered_map key_meta_map_; std::mutex mr_mu_; }; // class Transport + class IPCTransport : public RDMATransport { public: - void SendPushRequest(Message &msg) override { + explicit IPCTransport(Endpoint *endpoint, SimpleMempool *mempool) : RDMATransport(endpoint, mempool) { + auto val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); + ipc_copy_nthreads_ = val ? atoi(val) : 4; + LOG(INFO) << "IPC async copy nthreads set to " << ipc_copy_nthreads_; + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto q = new ThreadsafeQueue; + async_copy_queue_.push_back(q); + } + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto t = new std::thread(&IPCTransport::AsyncCopyThread, this, i); + ipc_copy_thread_list_.push_back(t); + } + val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); + byteps_partition_bytes_ = val ? atoi(val) : 4096000; + + val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); + auto byteps_local_size = val ? atoi(val) : 1; + byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); + CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; + LOG(INFO) << "partition bytes set to " << byteps_partition_bytes_ << ", should be identical with byteps core"; + }; + + ~IPCTransport() { + for (size_t i = 0; i < ipc_copy_thread_list_.size(); ++i) { + AsyncCopy m; + m.shutdown = true; + async_copy_queue_[i]->Push(m); + ipc_copy_thread_list_[i]->join(); + } + } + + void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { // get from shared memory } - void SendPullResponse(Message &msg) override { + void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { // std::lock_guard lock(map_mu_); // auto key = msg.meta.key; // auto recver = msg.meta.recver; @@ -280,6 +374,92 @@ class IPCTransport : public RDMATransport { // async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); } + void AsyncCopyThread(int i) { + auto& q = async_copy_queue_[i]; + while (true) { + AsyncCopy m; + q->WaitAndPop(&m); + if (m.shutdown) break; + if (m.len == 0) continue; + + // TODO: use parallel copy + CHECK(m.dst); + CHECK(m.src); + memcpy(m.dst, m.src, m.len); + + WRContext *context = nullptr, *reserved = nullptr; + m.endpoint->free_write_ctx.WaitAndPop(&reserved); + m.endpoint->free_start_ctx.WaitAndPop(&context); + + m.msg_buf->reserved_context = reserved; + RendezvousStart *req = + reinterpret_cast(context->buffer->addr); + req->meta_len = m.meta_len; + req->origin_addr = reinterpret_cast(m.msg_buf); + + auto addr = reinterpret_cast(req); + req->data_num = 0; + + struct ibv_sge sge; + sge.addr = reinterpret_cast(req); + sge.lkey = context->buffer->lkey; + sge.length = sizeof(RendezvousStart); + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(context); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + wr.imm_data = kRendezvousStart; + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << strerror(errno); + } + } + + private: + + void* GetSharedMemory(const std::string& prefix, uint64_t key) { + std::lock_guard lock(shm_mu_); + auto worker_key = DecodeWorkerKey(key); + auto seq_num = worker_key % (1 << 16); + auto base_key = worker_key - seq_num; + uint64_t offset = byteps_partition_bytes_ * seq_num; + if (key_shm_addr_.find(base_key) != key_shm_addr_.end()) { + return key_shm_addr_[base_key] + offset; + } + std::string shm_name(prefix); + shm_name += std::to_string(base_key); + int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + + struct stat sb; + CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); + auto total_shm_size = sb.st_size; + + void* base_ptr = mmap(0, total_shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(base_ptr, (void*) -1) << strerror(errno); + key_shm_addr_[base_key] = base_ptr; + + LOG(INFO) << "open Shared Memory: " << shm_name + << ", offset=" << offset + << ", (in bytes) size=" << total_shm_size; + return key_shm_addr_[base_key] + offset; + } + + int ipc_copy_nthreads_; + std::vector ipc_copy_thread_list_; + std::vector*> async_copy_queue_; + std::atomic cpy_counter_{0}; + + int byteps_partition_bytes_ = 4096000; + + std::mutex shm_mu_; + std::unordered_map key_shm_addr_; + }; // class IPCTransport From 0d8704c4db6510cbcc9b87a8e5a5a03a64c3d2f9 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Fri, 13 Dec 2019 21:57:22 +0800 Subject: [PATCH 16/79] rename transport.h to rdma_transport.h --- src/{transport.h => rdma_transport.h} | 0 src/rdma_van.h | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) rename src/{transport.h => rdma_transport.h} (100%) diff --git a/src/transport.h b/src/rdma_transport.h similarity index 100% rename from src/transport.h rename to src/rdma_transport.h diff --git a/src/rdma_van.h b/src/rdma_van.h index d1973f39..0181c00b 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -18,7 +18,8 @@ #ifdef DMLC_USE_RDMA -#include "transport.h" +#include "rdma_utils.h" +#include "rdma_transport.h" namespace ps { From 60895b986ec68f0be13acce52246cd91852e6b80 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 14:09:08 +0800 Subject: [PATCH 17/79] add RecvPushResponse and RecvPullRequest --- src/rdma_transport.h | 47 ++++++++++++++++++++++++++++++++++++++++++-- src/rdma_utils.h | 2 +- src/rdma_van.h | 30 +++++++++++++++++++++++----- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 5c0e4f2c..43411e97 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -27,6 +27,10 @@ class Transport { virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; virtual void Recv(Message *msg) = 0; + virtual void RecvPushRequest(Message *msg, BufferContext *buffer_ctx) = 0; + virtual void RecvPullRequest(Message *msg, BufferContext *buffer_ctx) = 0; + virtual void RecvPushResponse(Message *msg, BufferContext *buffer_ctx) = 0; + virtual void RecvPullResponse(Message *msg, BufferContext *buffer_ctx) = 0; virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; @@ -48,7 +52,7 @@ class RDMATransport : public Transport { mempool_ = mempool; auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); - is_server_ = role=="server"; + is_server_ = (role=="server"); if (is_server_) LOG(INFO) << "This is server"; else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); }; @@ -294,8 +298,47 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - void Recv(Message *msg) { + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx) { + return Recv(msg, buffer_ctx); + } + + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx) { + return Recv(msg, buffer_ctx); + } + + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) { + + } + + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) { + + } + + private: + int Recv(Message *msg, BufferContext *buffer_ctx) { + uint64_t data_num = buffer_ctx->data_num; + if (data_num == 0) { + mempool_->Free(buffer_ctx->buffer); + delete buffer_ctx; + return 0; + } + + int total_data_len = 0; + char *cur = buffer_ctx->buffer + buffer_ctx->meta_len; // offset + + Block *mem_block = new Block(mempool_.get(), buffer_ctx->buffer, data_num); + for (size_t i = 0; i < data_num; i++) { + uint32_t len = buffer_ctx->data_len[i]; + SArray data; + data.reset(cur, len, [mem_block](void *) { + mem_block->Release(); + }); // Defer the deletion of block_ref + msg->data.push_back(data); + cur += len; + total_data_len += len; + } + return total_data_len; } protected: diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 1313db16..5fc691fb 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -125,7 +125,7 @@ class SimpleMempool { ~SimpleMempool() { std::lock_guard lk(mu_); - for(auto it = mr_list.begin(); it != mr_list.end(); it++){ + for(auto it = mr_list.begin(); it != mr_list.end(); it++) { CHECK_EQ(ibv_dereg_mr(it->second), 0); free(it->second->addr); } diff --git a/src/rdma_van.h b/src/rdma_van.h index 0181c00b..1a9b29e1 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -280,20 +280,40 @@ class RDMAVan : public Van { Endpoint *endpoint = std::get(notification); BufferContext *buffer_ctx = std::get(notification); - int total_len = 0; - msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; char *cur = buffer_ctx->buffer; UnpackMeta(cur, buffer_ctx->meta_len, &msg->meta); + + int total_len = 0; total_len += buffer_ctx->meta_len; - uint64_t data_num = buffer_ctx->data_num; - cur += buffer_ctx->meta_len; auto trans = endpoint->GetTransport(); - trans->Recv(msg); + + if (!IsValidPushpull(*msg)) { + mempool_->Free(buffer_ctx->buffer); + delete buffer_ctx; + return total_len; + } + + // valid data message + if (msg->meta.push && msg->meta.request) { + // push request + total_len += trans->RecvPushRequest(msg, buffer_ctx); + } else if (!msg->meta.push && msg->meta.request) { + // pull request + total_len += trans->RecvPullRequest(msg, buffer_ctx); + } else if (msg->meta.push && !msg->meta.request) { + // push response + total_len += trans->RecvPushResponse(msg, buffer_ctx); + } else if (!msg->meta.push && !msg->meta.request) { + // pull response + total_len += trans->RecvPullResponse(msg, buffer_ctx); + } else { + CHECK(0) << "unknown msg type"; + } delete buffer_ctx; return total_len; From 9c0fa316381419ce92bae2af2ed9fb11caac4daa Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 14:20:01 +0800 Subject: [PATCH 18/79] add RecvPushRequest and RecvPullResponse --- src/rdma_transport.h | 66 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 43411e97..be25bff6 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -66,7 +66,7 @@ class RDMATransport : public Transport { for (auto& sa : msg_buf->data) { if (!sa.size()) continue; CHECK(sa.data()); - std::lock_guard lock(mr_mu_); + std::lock_guard lock(map_mu_); if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { struct ibv_mr *temp_mr; CHECK (temp_mr = ibv_reg_mr(mempool_->GetPD(), sa.data(), sa.size(), @@ -307,11 +307,53 @@ class RDMATransport : public Transport { } virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) { - + int total_data_len = 0; + std::lock_guard lock(map_mu_); + auto key = msg->meta.key; + if (key_len_map_.find(key) == key_len_map_.end()) { + key_addr_map_[key] = (ps::Key) key; + key_len_map_[key] = (int) msg->meta.val_len; + } + CHECK_NE(key_len_map_.find(key), key_len_map_.end()) << key; + CHECK_NE(key_addr_map_.find(key), key_addr_map_.end()) << key; + + auto addr = msg->meta.addr; + + CHECK_NE(key_len_map_[key], 0) << msg->DebugString(); + + SArray keys; + SArray vals; + SArray lens; + + keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); + vals.reset(reinterpret_cast(addr), key_len_map_[key], [](void *){}); + lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + total_data_len += keys.size() + vals.size() + lens.size(); + + mempool_->Free(buffer_ctx->buffer); + return total_data_len; } virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) { - + auto key = msg->meta.key; + auto len = msg->meta.val_len; + auto addr = msg->meta.addr; + auto rkey = msg->meta.option; + auto sender = msg->meta.sender; + + std::lock_guard lock(map_mu_); + if (key_meta_map_.find(key) == key_meta_map_.end() + || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { + key_meta_map_[key][sender] = std::make_tuple(len, addr, rkey); + } else { + CHECK_EQ(len, std::get<0>(key_meta_map_[key][sender])); + CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); + CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); + } } private: @@ -346,19 +388,25 @@ class RDMATransport : public Transport { SimpleMempool *mempool_; // role is server or worker bool is_server_; + std::mutex mu_; std::unordered_map > push_addr_; // key, std::unordered_map > pull_addr_; // key, std::unordered_map msgbuf_cache_; // msg_buf, is_push - std::mutex mu_; - std::unordered_map mem_mr_map_; + // manage the following map std::mutex map_mu_; - // macros for key_meta_map + + // (memory, ibv_mr) + std::unordered_map mem_mr_map_; + + // store the static address for keys and lens + std::unordered_map key_addr_map_; + std::unordered_map key_len_map_; + using MetaInfo = std::tuple; // len, addr, rkey using SenderMeta = std::unordered_map; // sender as the key - // (key, sender) --> MetaInfo - std::unordered_map key_meta_map_; - std::mutex mr_mu_; + std::unordered_map key_meta_map_; // (key, sender) --> MetaInfo + }; // class Transport From 2c6e090b577e170d8465369b009de19e5ebbdace Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 15:48:51 +0800 Subject: [PATCH 19/79] can compile --- src/rdma_transport.h | 280 ++++++++++++++++++++++++++++++++++--------- src/rdma_utils.h | 154 ------------------------ src/rdma_van.h | 8 +- 3 files changed, 232 insertions(+), 210 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index be25bff6..df627b7e 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -22,15 +22,173 @@ namespace ps { +class Transport; + +struct Endpoint { + enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; + + ConnectionStatus status; + int node_id; + std::condition_variable cv; + std::mutex connect_mu; + struct rdma_cm_id *cm_id; + std::shared_ptr tran; + + WRContext rx_ctx[kRxDepth]; + + WRContext start_ctx[kStartDepth]; + WRContext reply_ctx[kReplyDepth]; + WRContext write_ctx[kWriteDepth]; + + ThreadsafeQueue free_start_ctx; + ThreadsafeQueue free_reply_ctx; + ThreadsafeQueue free_write_ctx; + + Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} + + ~Endpoint() { + for (int i = 0; i < kRxDepth; ++i) { + if (!(rx_ctx[i].buffer)) { + continue; + } + free(rx_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); + } + + for (int i = 0; i < kStartDepth; ++i) { + if (start_ctx[i].buffer) { + free(start_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); + } + } + + for (int i = 0; i < kReplyDepth; ++i) { + if (reply_ctx[i].buffer) { + free(reply_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); + } + } + + for (int i = 0; i < kWriteDepth; ++i) { + if (write_ctx[i].buffer) { + free(write_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); + } + } + + rdma_destroy_qp(cm_id); + CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); + } + + void SetTransport(std::shared_ptr t) { tran = t; } + + std::shared_ptr GetTransport() { return tran; } + + void Disconnect() { + std::unique_lock lk(connect_mu); + CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); + cv.wait(lk, [this] { return status == IDLE; }); + tran.reset(); + } + + void SetNodeID(int id) { node_id = id; } + + void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, + ThreadsafeQueue *queue, size_t num, + WRContextType type) { + for (size_t i = 0; i < num; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); + CHECK(mr); + + ctx[i].type = type; + ctx[i].buffer = mr; + ctx[i].private_data = this; + queue->Push(&ctx[i]); + } + } + + void Init(struct ibv_cq *cq, struct ibv_pd *pd) { + struct ibv_qp_init_attr attr; + memset(&attr, 0, sizeof(ibv_qp_init_attr)); + attr.send_cq = cq; + attr.recv_cq = cq; + attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; + attr.cap.max_recv_wr = kRxDepth; + attr.cap.max_send_sge = kSGEntry; + attr.cap.max_recv_sge = kSGEntry; + attr.qp_type = IBV_QPT_RC; + attr.sq_sig_all = 0; + + CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) + << "Create RDMA queue pair failed"; + + InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, + kRendezvousStartContext); + InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, + kRendezvousReplyContext); + InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, + kWriteContext); + + for (size_t i = 0; i < kRxDepth; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = + ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); + CHECK(mr); + + rx_ctx[i].type = kReceiveContext; + rx_ctx[i].buffer = mr; + rx_ctx[i].private_data = this; + + PostRecv(&rx_ctx[i]); + } + } + + void PostRecv(WRContext *ctx) { + struct ibv_recv_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(ctx->buffer->addr); + sge.length = kMempoolChunkSize; + sge.lkey = ctx->buffer->lkey; + + wr.wr_id = reinterpret_cast(ctx); + wr.next = nullptr; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_recv failed."; + } +}; + +struct AsyncCopy { + Endpoint* endpoint; + MessageBuffer* msg_buf; + void* dst; + void* src; + int len; + uint64_t meta_len; + bool shutdown; +}; + + class Transport { public: virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual void Recv(Message *msg) = 0; - virtual void RecvPushRequest(Message *msg, BufferContext *buffer_ctx) = 0; - virtual void RecvPullRequest(Message *msg, BufferContext *buffer_ctx) = 0; - virtual void RecvPushResponse(Message *msg, BufferContext *buffer_ctx) = 0; - virtual void RecvPullResponse(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int Recv(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) = 0; + + virtual void AddMeta(Message &msg) = 0; virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; @@ -119,26 +277,8 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { - std::lock_guard lk(mu_); - auto key = DecodeKey(msg_buf->data[0]); - auto remote_addr = std::get<0>(push_addr_[key]); - auto rkey = std::get<1>(push_addr_[key]); - auto idx = std::get<2>(push_addr_[key]); - RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); - } - - void SendPullRequest(Message &msg, MessageBuffer *msg_buf) { - std::lock_guard lk(mu_); - auto key = DecodeKey(msg_buf->data[0]); - auto remote_addr = std::get<0>(pull_addr_[key]); - auto rkey = std::get<1>(pull_addr_[key]); - auto idx = std::get<2>(pull_addr_[key]); - RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); - } - bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) { - std::lock_guard lk(mu_); + std::lock_guard lk(addr_mu_); if ( is_push && (push_addr_.find(key) != push_addr_.end())) return true; if (!is_push && (pull_addr_.find(key) != pull_addr_.end())) return true; // no remote info, store the msg_buf address and push/pull flag for RendezvousReply @@ -151,7 +291,7 @@ class RDMATransport : public Transport { auto key = DecodeKey(msg_buf->data[0]); auto buf = reinterpret_cast(msg_buf); - std::lock_guard lk(mu_); + std::lock_guard lk(addr_mu_); auto is_push = msgbuf_cache_[buf]; if (is_push) { push_addr_.emplace(key, std::make_tuple(remote_addr, rkey, idx)); @@ -195,8 +335,7 @@ class RDMATransport : public Transport { << strerror(errno); } - void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) { - + void SendRendezvousReply(RendezvousStart *req, AddressPool &addrpool) { BufferContext *buf_ctx = new BufferContext(); uint64_t len = req->meta_len; buf_ctx->meta_len = req->meta_len; @@ -210,14 +349,14 @@ class RDMATransport : public Transport { buf_ctx->buffer = buffer; WRContext *reply_ctx = nullptr; endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx); + RendezvousReply *resp = reinterpret_cast(reply_ctx->buffer->addr); - char* buffer = buf_ctx->buffer; resp->addr = reinterpret_cast(buffer); resp->rkey = mempool_->RemoteKey(buffer); resp->origin_addr = req->origin_addr; - resp->idx = pool.StoreAddress(buf_ctx); + resp->idx = addrpool.StoreAddress(buf_ctx); struct ibv_sge sge; sge.addr = reinterpret_cast(resp); @@ -241,35 +380,65 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; + void AddMeta(Message &msg) { + // should only be invoked when send + if (msg.meta.push && msg.meta.request) { + // push request + uint64_t key = DecodeKey(msg.data[0]); + msg.meta.key = key; + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(mem_mr_map_.find(msg.data[1].data()), mem_mr_map_.end()); + std::lock_guard lock(map_mu_); + CHECK_NE(mem_mr_map_.find(msg.data[1].data()), mem_mr_map_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = mem_mr_map_[vals.data()]->rkey; + } else if (!msg.meta.push && !msg.meta.request) { + // pull response + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = mem_mr_map_[vals.data()]->rkey; + std::lock_guard lock(map_mu_); + CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) + << "key=" << key << " not inited in key_meta_map"; + CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + + msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); + msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); + msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + } } - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; + void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { + auto key = DecodeKey(msg_buf->data[0]); + std::lock_guard lk(addr_mu_); + auto remote_addr = std::get<0>(push_addr_[key]); + auto rkey = std::get<1>(push_addr_[key]); + auto idx = std::get<2>(push_addr_[key]); + RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; + void SendPullRequest(Message &msg, MessageBuffer *msg_buf) { + auto key = DecodeKey(msg_buf->data[0]); + std::lock_guard lk(addr_mu_); + auto remote_addr = std::get<0>(pull_addr_[key]); + auto rkey = std::get<1>(pull_addr_[key]); + auto idx = std::get<2>(pull_addr_[key]); + RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { + // RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } - // RDMA write + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { + std::lock_guard lock(map_mu_); + auto key = msg.meta.key; + auto recver = msg.meta.recver; + auto len = std::get<0>(key_meta_map_[key][recver]); auto raddr = std::get<1>(key_meta_map_[key][recver]); auto rkey = std::get<2>(key_meta_map_[key][recver]); @@ -339,6 +508,8 @@ class RDMATransport : public Transport { } virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) { + int total_data_len = Recv(msg, buffer_ctx); + auto key = msg->meta.key; auto len = msg->meta.val_len; auto addr = msg->meta.addr; @@ -354,10 +525,11 @@ class RDMATransport : public Transport { CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); } + return total_data_len; } private: - int Recv(Message *msg, BufferContext *buffer_ctx) { + virtual int Recv(Message *msg, BufferContext *buffer_ctx) { uint64_t data_num = buffer_ctx->data_num; if (data_num == 0) { mempool_->Free(buffer_ctx->buffer); @@ -368,7 +540,7 @@ class RDMATransport : public Transport { int total_data_len = 0; char *cur = buffer_ctx->buffer + buffer_ctx->meta_len; // offset - Block *mem_block = new Block(mempool_.get(), buffer_ctx->buffer, data_num); + Block *mem_block = new Block(mempool_, buffer_ctx->buffer, data_num); for (size_t i = 0; i < data_num; i++) { uint32_t len = buffer_ctx->data_len[i]; SArray data; @@ -388,7 +560,7 @@ class RDMATransport : public Transport { SimpleMempool *mempool_; // role is server or worker bool is_server_; - std::mutex mu_; + std::mutex addr_mu_; std::unordered_map > push_addr_; // key, std::unordered_map > pull_addr_; // key, std::unordered_map msgbuf_cache_; // msg_buf, is_push diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 5fc691fb..03f4f08b 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -358,160 +358,6 @@ class AddressPool { T *table_[kMaxEntries]; }; -struct Endpoint { - enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; - - ConnectionStatus status; - int node_id; - std::condition_variable cv; - std::mutex connect_mu; - struct rdma_cm_id *cm_id; - std::shared_ptr tran; - - WRContext rx_ctx[kRxDepth]; - - WRContext start_ctx[kStartDepth]; - WRContext reply_ctx[kReplyDepth]; - WRContext write_ctx[kWriteDepth]; - - ThreadsafeQueue free_start_ctx; - ThreadsafeQueue free_reply_ctx; - ThreadsafeQueue free_write_ctx; - - Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} - - ~Endpoint() { - for (int i = 0; i < kRxDepth; ++i) { - if (!(rx_ctx[i].buffer)) { - continue; - } - free(rx_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); - } - - for (int i = 0; i < kStartDepth; ++i) { - if (start_ctx[i].buffer) { - free(start_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kReplyDepth; ++i) { - if (reply_ctx[i].buffer) { - free(reply_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kWriteDepth; ++i) { - if (write_ctx[i].buffer) { - free(write_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); - } - } - - rdma_destroy_qp(cm_id); - CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); - } - - void SetTransport(std::shared_ptr t) { tran = t; } - - std::shared_ptr GetTransport() { return tran; } - - void Disconnect() { - std::unique_lock lk(connect_mu); - CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); - cv.wait(lk, [this] { return status == IDLE; }); - tran.reset(); - } - - void SetNodeID(int id) { node_id = id; } - - void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, - ThreadsafeQueue *queue, size_t num, - WRContextType type) { - for (size_t i = 0; i < num; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); - CHECK(mr); - - ctx[i].type = type; - ctx[i].buffer = mr; - ctx[i].private_data = this; - queue->Push(&ctx[i]); - } - } - - void Init(struct ibv_cq *cq, struct ibv_pd *pd) { - struct ibv_qp_init_attr attr; - memset(&attr, 0, sizeof(ibv_qp_init_attr)); - attr.send_cq = cq; - attr.recv_cq = cq; - attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; - attr.cap.max_recv_wr = kRxDepth; - attr.cap.max_send_sge = kSGEntry; - attr.cap.max_recv_sge = kSGEntry; - attr.qp_type = IBV_QPT_RC; - attr.sq_sig_all = 0; - - CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) - << "Create RDMA queue pair failed"; - - InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, - kRendezvousStartContext); - InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, - kRendezvousReplyContext); - InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, - kWriteContext); - - for (size_t i = 0; i < kRxDepth; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = - ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); - CHECK(mr); - - rx_ctx[i].type = kReceiveContext; - rx_ctx[i].buffer = mr; - rx_ctx[i].private_data = this; - - PostRecv(&rx_ctx[i]); - } - } - - void PostRecv(WRContext *ctx) { - struct ibv_recv_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(ctx->buffer->addr); - sge.length = kMempoolChunkSize; - sge.lkey = ctx->buffer->lkey; - - wr.wr_id = reinterpret_cast(ctx); - wr.next = nullptr; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_recv failed."; - } -}; - -struct AsyncCopy { - Endpoint* endpoint; - MessageBuffer* msg_buf; - void* dst; - void* src; - int len; - uint64_t meta_len; - bool shutdown; -}; - - bool IsValidPushpull(const Message &msg) { if (!msg.meta.control.empty()) return false; if (msg.meta.simple_app) return false; diff --git a/src/rdma_van.h b/src/rdma_van.h index 1a9b29e1..e3c1e328 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -227,6 +227,8 @@ class RDMAVan : public Van { CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); + auto trans = endpoint->GetTransport(); + MessageBuffer *msg_buf = new MessageBuffer(); int meta_len = GetPackMetaLen(msg.meta); @@ -236,10 +238,12 @@ class RDMAVan : public Van { msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); - PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); msg_buf->data = msg.data; - auto trans = endpoint->GetTransport(); + if (IsValidPushpull(msg)) trans->AddMeta(msg); + + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + if (!IsValidPushpull(msg)) { trans->SendRendezvousBegin(msg, msg_buf); return total_len; From 834c9036876cc21e7a1d00b3e61f84b6a0425cbe Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 17:15:04 +0800 Subject: [PATCH 20/79] wip: improve compile --- src/rdma_transport.h | 29 ++++++++++++++++++++--------- src/rdma_van.h | 16 ++++++++++------ 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index df627b7e..dbd1b7bb 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -32,7 +32,7 @@ struct Endpoint { std::condition_variable cv; std::mutex connect_mu; struct rdma_cm_id *cm_id; - std::shared_ptr tran; + std::shared_ptr trans; WRContext rx_ctx[kRxDepth]; @@ -80,15 +80,15 @@ struct Endpoint { CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); } - void SetTransport(std::shared_ptr t) { tran = t; } + void SetTransport(std::shared_ptr t) { trans = t; } - std::shared_ptr GetTransport() { return tran; } + std::shared_ptr GetTransport() { return trans; } void Disconnect() { std::unique_lock lk(connect_mu); CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); cv.wait(lk, [this] { return status == IDLE; }); - tran.reset(); + trans.reset(); } void SetNodeID(int id) { node_id = id; } @@ -189,6 +189,9 @@ class Transport { virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) = 0; virtual void AddMeta(Message &msg) = 0; + virtual void RegisterMemory(Message &msg) = 0; + virtual void PrepareData(MessageBuffer *msg_buf) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; @@ -219,11 +222,9 @@ class RDMATransport : public Transport { for (auto& it : mem_mr_map_) ibv_dereg_mr(it.second); }; - void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - // prepare memory - for (auto& sa : msg_buf->data) { - if (!sa.size()) continue; - CHECK(sa.data()); + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { struct ibv_mr *temp_mr; @@ -233,11 +234,21 @@ class RDMATransport : public Transport { << ", sa.size()=" << sa.size(); mem_mr_map_[sa.data()] = temp_mr; } + } + } + + void PrepareData(MessageBuffer *msg_buf) { + for (auto &sa : msg_buf->data) { + std::lock_guard lock(map_mu_); + if (sa.size() == 0) continue; auto it = mem_mr_map_.find(sa.data()); MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); CHECK(ptr.get()) << strerror(errno); msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } + } + + void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; diff --git a/src/rdma_van.h b/src/rdma_van.h index e3c1e328..a6e13140 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -227,7 +227,8 @@ class RDMAVan : public Van { CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); - auto trans = endpoint->GetTransport(); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + trans->RegisterMemory(msg); MessageBuffer *msg_buf = new MessageBuffer(); @@ -244,6 +245,8 @@ class RDMAVan : public Van { PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + trans->PrepareData(msg_buf); + if (!IsValidPushpull(msg)) { trans->SendRendezvousBegin(msg, msg_buf); return total_len; @@ -294,7 +297,7 @@ class RDMAVan : public Van { int total_len = 0; total_len += buffer_ctx->meta_len; - auto trans = endpoint->GetTransport(); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); if (!IsValidPushpull(*msg)) { mempool_->Free(buffer_ctx->buffer); @@ -407,13 +410,13 @@ class RDMAVan : public Van { struct ibv_mr *mr = context->buffer; if (imm == kRendezvousStart) { - // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousStart"; RendezvousStart *req = reinterpret_cast(mr->addr); - endpoint->GetTransport()->SendRendezvousReply(req, addr_pool_); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + trans->SendRendezvousReply(req, addr_pool_); + } else if (imm == kRendezvousReply) { - auto trans = endpoint->GetTransport(); - // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousReply"; + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); RendezvousReply *resp = reinterpret_cast(mr->addr); uint64_t remote_addr = resp->addr; @@ -428,6 +431,7 @@ class RDMAVan : public Van { // subsequent write does not need repeated rendezvous trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } else { CHECK(0); } From 174d59aaeccbb316f6956566843736797afd7241 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 18:07:29 +0800 Subject: [PATCH 21/79] fix GetTransport nullptr problem and add check --- src/rdma_transport.h | 4 ++-- src/rdma_van.h | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index dbd1b7bb..89e177ac 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -209,8 +209,8 @@ class Transport { class RDMATransport : public Transport { public: explicit RDMATransport(Endpoint *endpoint, SimpleMempool *mempool) { - endpoint_ = endpoint; - mempool_ = mempool; + endpoint_ = CHECK_NOTNULL(endpoint); + mempool_ = CHECK_NOTNULL(mempool); auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); is_server_ = (role=="server"); diff --git a/src/rdma_van.h b/src/rdma_van.h index a6e13140..35db2b08 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -153,11 +153,6 @@ class RDMAVan : public Van { endpoint->SetNodeID(node.id); - std::shared_ptr t = is_local_[node.id] ? - std::make_shared(endpoint, mempool_.get()) : - std::make_shared(endpoint, mempool_.get()); - endpoint->SetTransport(t); - struct addrinfo *remote_addr; CHECK_EQ( getaddrinfo(node.hostname.c_str(), std::to_string(node.port).c_str(), @@ -217,6 +212,13 @@ class RDMAVan : public Van { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } + if (node.id != my_node_.id) { + std::shared_ptr t = is_local_[node.id] ? + std::make_shared(endpoint, mempool_.get()) : + std::make_shared(endpoint, mempool_.get()); + endpoint->SetTransport(t); + } + freeaddrinfo(remote_addr); } } @@ -335,6 +337,7 @@ class RDMAVan : public Van { CHECK(pd_) << "Failed to allocate protection domain"; mempool_.reset(new SimpleMempool(pd_)); + LOG(INFO) << "mempool_ inited =========="; comp_event_channel_ = ibv_create_comp_channel(context_); @@ -534,6 +537,11 @@ class RDMAVan : public Van { endpoint->Init(cq_, pd_); + std::shared_ptr t = is_local_[remote_ctx->node] ? + std::make_shared(endpoint, mempool_.get()) : + std::make_shared(endpoint, mempool_.get()); + endpoint->SetTransport(t); + RequestContext ctx; ctx.node = static_cast(my_node_.id); ctx.port = static_cast(my_node_.port); @@ -561,7 +569,7 @@ class RDMAVan : public Van { void OnRouteResolved(struct rdma_cm_event *event) { struct rdma_cm_id *id = event->id; Endpoint *endpoint = reinterpret_cast(id->context); - + if (context_ == nullptr) { InitContext(id->verbs); } From 84ee377792f587fb470d42ace8885823d6aead26 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 18:46:57 +0800 Subject: [PATCH 22/79] fix connection & attempt to fallback original functionalities --- src/rdma_transport.h | 7 +------ src/rdma_van.h | 37 ++++++++++++++++++------------------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 89e177ac..b7f54e6b 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -180,7 +180,6 @@ struct AsyncCopy { class Transport { public: - virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; virtual int Recv(Message *msg, BufferContext *buffer_ctx) = 0; virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) = 0; @@ -192,7 +191,6 @@ class Transport { virtual void RegisterMemory(Message &msg) = 0; virtual void PrepareData(MessageBuffer *msg_buf) = 0; - virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf) = 0; @@ -214,8 +212,6 @@ class RDMATransport : public Transport { auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); is_server_ = (role=="server"); - if (is_server_) LOG(INFO) << "This is server"; - else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); }; ~RDMATransport() { @@ -239,8 +235,8 @@ class RDMATransport : public Transport { void PrepareData(MessageBuffer *msg_buf) { for (auto &sa : msg_buf->data) { - std::lock_guard lock(map_mu_); if (sa.size() == 0) continue; + std::lock_guard lock(map_mu_); auto it = mem_mr_map_.find(sa.data()); MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); CHECK(ptr.get()) << strerror(errno); @@ -249,7 +245,6 @@ class RDMATransport : public Transport { } void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); diff --git a/src/rdma_van.h b/src/rdma_van.h index 35db2b08..912bcf2e 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -212,12 +212,10 @@ class RDMAVan : public Van { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } - if (node.id != my_node_.id) { - std::shared_ptr t = is_local_[node.id] ? - std::make_shared(endpoint, mempool_.get()) : - std::make_shared(endpoint, mempool_.get()); - endpoint->SetTransport(t); - } + std::shared_ptr t = is_local_[node.id] ? + std::make_shared(endpoint, mempool_.get()) : + std::make_shared(endpoint, mempool_.get()); + endpoint->SetTransport(t); freeaddrinfo(remote_addr); } @@ -249,17 +247,19 @@ class RDMAVan : public Van { trans->PrepareData(msg_buf); - if (!IsValidPushpull(msg)) { - trans->SendRendezvousBegin(msg, msg_buf); - return total_len; - } else { - auto is_push = msg.meta.push; - auto key = DecodeKey(msg.data[0]); - if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { - trans->SendRendezvousBegin(msg, msg_buf); - return total_len; - } - } + // if (!IsValidPushpull(msg)) { + // trans->SendRendezvousBegin(msg, msg_buf); + // return total_len; + // } else { + // auto is_push = msg.meta.push; + // auto key = DecodeKey(msg.data[0]); + // if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { + // trans->SendRendezvousBegin(msg, msg_buf); + // return total_len; + // } + // } + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { @@ -337,7 +337,6 @@ class RDMAVan : public Van { CHECK(pd_) << "Failed to allocate protection domain"; mempool_.reset(new SimpleMempool(pd_)); - LOG(INFO) << "mempool_ inited =========="; comp_event_channel_ = ibv_create_comp_channel(context_); @@ -432,7 +431,7 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous - trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + // trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } else { From ed1f4ec1f508ee72b82d8dd017aa618e83631fad Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 15 Dec 2019 22:01:36 +0800 Subject: [PATCH 23/79] wip: do not free memory in mempool --- src/rdma_transport.h | 6 +++--- src/rdma_utils.h | 16 +++++++++++++--- src/rdma_van.h | 31 ++++++++++++++----------------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index b7f54e6b..ba948ad3 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -387,11 +387,11 @@ class RDMATransport : public Transport { } void AddMeta(Message &msg) { - // should only be invoked when send + if (msg.meta.request) { + msg.meta.key = DecodeKey(msg.data[0]); + } if (msg.meta.push && msg.meta.request) { // push request - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; CHECK_EQ(msg.data.size(), 3) << msg.data.size(); std::lock_guard lock(map_mu_); diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 03f4f08b..ecaf3e3d 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -320,7 +320,10 @@ template class AddressPool { public: AddressPool() { + auto addrpool_size = Environment::Get()->find("BYTEPS_ADDRESS_POOL_SIZE"); + kMaxEntries = addrpool_size ? atoi(addrpool_size) : kMaxEntries; std::lock_guard lk(mu_); + table_ = new T*[kMaxEntries]; // init the queue for (int i = 0; i < kMaxEntries; i++) { indices_.push(i); @@ -336,13 +339,20 @@ class AddressPool { table_[index] = nullptr; return ptr; } + + // TODO: make the address pool size dynamic + T *GetAddress(uint32_t index) { + std::lock_guard lk(mu_); + return CHECK_NOTNULL(table_[index]); + } uint32_t StoreAddress(T *ptr) { std::lock_guard lk(mu_); CHECK(ptr); CHECK(!indices_.empty()) << "Address pool size is too small, " - << "consider increasing kMaxEntries"; + << "current size is " << kMaxEntries + << ", consider increasing BYTEPS_ADDRESS_POOL_SIZE"; uint32_t idx = indices_.front(); indices_.pop(); CHECK_EQ(table_[idx], nullptr) << idx; @@ -351,11 +361,11 @@ class AddressPool { } private: - static const int kMaxEntries = 5120; + int kMaxEntries = 10240; std::mutex mu_; std::queue indices_; - T *table_[kMaxEntries]; + T **table_; }; bool IsValidPushpull(const Message &msg) { diff --git a/src/rdma_van.h b/src/rdma_van.h index 912bcf2e..9a23acb5 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -247,19 +247,17 @@ class RDMAVan : public Van { trans->PrepareData(msg_buf); - // if (!IsValidPushpull(msg)) { - // trans->SendRendezvousBegin(msg, msg_buf); - // return total_len; - // } else { - // auto is_push = msg.meta.push; - // auto key = DecodeKey(msg.data[0]); - // if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { - // trans->SendRendezvousBegin(msg, msg_buf); - // return total_len; - // } - // } - trans->SendRendezvousBegin(msg, msg_buf); - return total_len; + if (!IsValidPushpull(msg)) { + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; + } else { + auto is_push = msg.meta.push; + auto key = msg.meta.key; + if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; + } + } // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { @@ -400,11 +398,10 @@ class RDMAVan : public Van { ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV_RDMA_WITH_IMM: { - // LOG(INFO) << "opcode: IBV_WC_RECV_RDMA_WITH_IMM"; uint32_t addr_idx = wc[i].imm_data; - BufferContext *buf_ctx = addr_pool_.GetAddressAndRelease(addr_idx); + BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); recv_buffers_.Push(std::make_tuple(endpoint, buf_ctx)); - ReleaseWorkRequestContext(context, endpoint); + // ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV: { CHECK(wc[i].wc_flags & IBV_WC_WITH_IMM); @@ -431,7 +428,7 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous - // trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } else { From a6d5cda8353bb5c6910e3cbba22ad76e42b58d76 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 16 Dec 2019 15:44:52 +0800 Subject: [PATCH 24/79] add a max bound for meta size --- src/rdma_transport.h | 13 ++++++++----- src/rdma_utils.h | 3 +++ src/rdma_van.h | 12 ++++++++---- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index ba948ad3..c0ccb33c 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -343,15 +343,18 @@ class RDMATransport : public Transport { void SendRendezvousReply(RendezvousStart *req, AddressPool &addrpool) { BufferContext *buf_ctx = new BufferContext(); - uint64_t len = req->meta_len; buf_ctx->meta_len = req->meta_len; buf_ctx->data_num = req->data_num; + + uint64_t data_len = 0; for (size_t i = 0; i < req->data_num; ++i) { buf_ctx->data_len[i] = req->data_len[i]; - len += req->data_len[i]; + data_len += req->data_len[i]; } - char *buffer = mempool_->Alloc(is_server_ ? len : req->meta_len); - CHECK(buffer) << len; + + // worker only needs a buffer for receving meta + char *buffer = mempool_->Alloc(is_server_ ? (kMaxMetaBound + data_len) : kMaxMetaBound); + CHECK(buffer); buf_ctx->buffer = buffer; WRContext *reply_ctx = nullptr; endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx); @@ -544,7 +547,7 @@ class RDMATransport : public Transport { } int total_data_len = 0; - char *cur = buffer_ctx->buffer + buffer_ctx->meta_len; // offset + char *cur = buffer_ctx->buffer + kMaxMetaBound; // offset Block *mem_block = new Block(mempool_, buffer_ctx->buffer, data_num); for (size_t i = 0; i < data_num; i++) { diff --git a/src/rdma_utils.h b/src/rdma_utils.h index ecaf3e3d..3cb57c1f 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -72,6 +72,9 @@ static const size_t kAlignment = 8; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; +// allocate 4KB more for meta with potentially variable length +static const int kMaxMetaBound = 4096000; + // should have the same prefix with BytePS shared memory static const std::string kShmPrefix("BytePS_ShM_"); diff --git a/src/rdma_van.h b/src/rdma_van.h index 9a23acb5..27974671 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -233,6 +233,8 @@ class RDMAVan : public Van { MessageBuffer *msg_buf = new MessageBuffer(); int meta_len = GetPackMetaLen(msg.meta); + CHECK_LE(meta_len, kMaxMetaBound) << meta_len; + size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); @@ -290,12 +292,14 @@ class RDMAVan : public Van { msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; - char *cur = buffer_ctx->buffer; - - UnpackMeta(cur, buffer_ctx->meta_len, &msg->meta); + // the second argument is actually deprecated, + // we keep it as is in order to be compatible + UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); + int real_meta_len = GetPackMetaLen(msg->meta); + CHECK_LE(real_meta_len, kMaxMetaBound) << real_meta_len; int total_len = 0; - total_len += buffer_ctx->meta_len; + total_len += real_meta_len; auto trans = CHECK_NOTNULL(endpoint->GetTransport()); From 5e78f301ac09d3049ec57da507c4575fafde9613 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 16 Dec 2019 17:30:19 +0800 Subject: [PATCH 25/79] basically finish RDMATransport --- src/rdma_transport.h | 64 +++++++++++++++++----------------- src/rdma_van.h | 33 ++++++++++-------- tests/test_kv_app_benchmark.cc | 5 +-- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index c0ccb33c..77bfeda9 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -181,16 +181,17 @@ struct AsyncCopy { class Transport { public: virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual int Recv(Message *msg, BufferContext *buffer_ctx) = 0; - virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) = 0; - virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx) = 0; - virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx) = 0; - virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) = 0; + virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual void AddMeta(Message &msg) = 0; virtual void RegisterMemory(Message &msg) = 0; virtual void PrepareData(MessageBuffer *msg_buf) = 0; + virtual void Send(Message &msg, MessageBuffer *msg_buf, bool is_push) = 0; virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf) = 0; @@ -346,14 +347,14 @@ class RDMATransport : public Transport { buf_ctx->meta_len = req->meta_len; buf_ctx->data_num = req->data_num; - uint64_t data_len = 0; + uint64_t len = req->meta_len; for (size_t i = 0; i < req->data_num; ++i) { buf_ctx->data_len[i] = req->data_len[i]; - data_len += req->data_len[i]; + len += req->data_len[i]; } // worker only needs a buffer for receving meta - char *buffer = mempool_->Alloc(is_server_ ? (kMaxMetaBound + data_len) : kMaxMetaBound); + char *buffer = mempool_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); CHECK(buffer); buf_ctx->buffer = buffer; WRContext *reply_ctx = nullptr; @@ -421,26 +422,25 @@ class RDMATransport : public Transport { } } - void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { + void Send(Message &msg, MessageBuffer *msg_buf, bool is_push) { auto key = DecodeKey(msg_buf->data[0]); std::lock_guard lk(addr_mu_); - auto remote_addr = std::get<0>(push_addr_[key]); - auto rkey = std::get<1>(push_addr_[key]); - auto idx = std::get<2>(push_addr_[key]); + auto remote_addr = is_push ? std::get<0>(push_addr_[key]) : std::get<0>(pull_addr_[key]); + auto rkey = is_push ? std::get<1>(push_addr_[key]) : std::get<1>(pull_addr_[key]); + auto idx = is_push ? std::get<2>(push_addr_[key]) : std::get<2>(pull_addr_[key]); RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } + void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { + Send(msg, msg_buf, true); + } + void SendPullRequest(Message &msg, MessageBuffer *msg_buf) { - auto key = DecodeKey(msg_buf->data[0]); - std::lock_guard lk(addr_mu_); - auto remote_addr = std::get<0>(pull_addr_[key]); - auto rkey = std::get<1>(pull_addr_[key]); - auto idx = std::get<2>(pull_addr_[key]); - RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + Send(msg, msg_buf, false); } virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { - // RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + Send(msg, msg_buf, true); } virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { @@ -476,15 +476,15 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx) { - return Recv(msg, buffer_ctx); + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { + return Recv(msg, buffer_ctx, meta_len); } - virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx) { - return Recv(msg, buffer_ctx); + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + return Recv(msg, buffer_ctx, meta_len); } - virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx) { + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { int total_data_len = 0; std::lock_guard lock(map_mu_); auto key = msg->meta.key; @@ -516,8 +516,8 @@ class RDMATransport : public Transport { return total_data_len; } - virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx) { - int total_data_len = Recv(msg, buffer_ctx); + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + int total_data_len = Recv(msg, buffer_ctx, meta_len); auto key = msg->meta.key; auto len = msg->meta.val_len; @@ -525,6 +525,8 @@ class RDMATransport : public Transport { auto rkey = msg->meta.option; auto sender = msg->meta.sender; + LOG(INFO) << "key=" << key << " len=" << len << " sender=" << sender; + std::lock_guard lock(map_mu_); if (key_meta_map_.find(key) == key_meta_map_.end() || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { @@ -538,7 +540,7 @@ class RDMATransport : public Transport { } private: - virtual int Recv(Message *msg, BufferContext *buffer_ctx) { + virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) { uint64_t data_num = buffer_ctx->data_num; if (data_num == 0) { mempool_->Free(buffer_ctx->buffer); @@ -547,15 +549,13 @@ class RDMATransport : public Transport { } int total_data_len = 0; - char *cur = buffer_ctx->buffer + kMaxMetaBound; // offset + char *cur = buffer_ctx->buffer + meta_len; // offset - Block *mem_block = new Block(mempool_, buffer_ctx->buffer, data_num); for (size_t i = 0; i < data_num; i++) { uint32_t len = buffer_ctx->data_len[i]; + LOG(INFO) << "=====Recv: data_len=" << len; SArray data; - data.reset(cur, len, [mem_block](void *) { - mem_block->Release(); - }); // Defer the deletion of block_ref + data.reset(cur, len, [](void *) {}); // no need for delete msg->data.push_back(data); cur += len; total_data_len += len; diff --git a/src/rdma_van.h b/src/rdma_van.h index 27974671..7a505952 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -233,7 +233,6 @@ class RDMAVan : public Van { MessageBuffer *msg_buf = new MessageBuffer(); int meta_len = GetPackMetaLen(msg.meta); - CHECK_LE(meta_len, kMaxMetaBound) << meta_len; size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; @@ -295,11 +294,10 @@ class RDMAVan : public Van { // the second argument is actually deprecated, // we keep it as is in order to be compatible UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); - int real_meta_len = GetPackMetaLen(msg->meta); - CHECK_LE(real_meta_len, kMaxMetaBound) << real_meta_len; + int meta_len = GetPackMetaLen(msg->meta); int total_len = 0; - total_len += real_meta_len; + total_len += meta_len; auto trans = CHECK_NOTNULL(endpoint->GetTransport()); @@ -312,16 +310,20 @@ class RDMAVan : public Van { // valid data message if (msg->meta.push && msg->meta.request) { // push request - total_len += trans->RecvPushRequest(msg, buffer_ctx); + LOG(INFO) << "recv push request"; + total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); } else if (!msg->meta.push && msg->meta.request) { // pull request - total_len += trans->RecvPullRequest(msg, buffer_ctx); + LOG(INFO) << "recv pull request"; + total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); } else if (msg->meta.push && !msg->meta.request) { // push response - total_len += trans->RecvPushResponse(msg, buffer_ctx); + LOG(INFO) << "recv push response"; + total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); } else if (!msg->meta.push && !msg->meta.request) { // pull response - total_len += trans->RecvPullResponse(msg, buffer_ctx); + LOG(INFO) << "recv push response"; + total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); } else { CHECK(0) << "unknown msg type"; } @@ -393,13 +395,14 @@ class RDMAVan : public Van { ReleaseWorkRequestContext(context, endpoint); break; case IBV_WC_RDMA_WRITE: { - // LOG(INFO) << "opcode: IBV_WC_RDMA_WRITE"; - // Note: This is not a struct ibv_mr* - MessageBuffer *msg_buf = - *reinterpret_cast(context->buffer->addr); - mempool_->Free(msg_buf->inline_buf); - delete msg_buf; - ReleaseWorkRequestContext(context, endpoint); + // do nothing + LOG(INFO) << "opcode: IBV_WC_RDMA_WRITE"; + + // MessageBuffer *msg_buf = + // *reinterpret_cast(context->buffer->addr); + // mempool_->Free(msg_buf->inline_buf); + // delete msg_buf; + // ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV_RDMA_WITH_IMM: { uint32_t addr_idx = wc[i].imm_data; diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 3bb6383e..499483c3 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -17,7 +17,8 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer uint64_t key = req_data.keys[0]; if (req_meta.push) { CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) + << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; if (mem_map.find(key) == mem_map.end()) { PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; @@ -143,7 +144,7 @@ void RunWorker(int argc, char *argv[]) { auto start = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now(); auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; + unsigned int threshold = val ? atoi(val) : 1; val = Environment::Get()->find("LOG_DURATION"); unsigned int log_duration = val ? atoi(val) : 50; int cnt = 0; From 3aca8d47891501579ce604609188b3ca0921744e Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 16 Dec 2019 18:42:33 +0800 Subject: [PATCH 26/79] fix msg_buf size --- src/rdma_transport.h | 2 +- src/rdma_van.h | 23 ++++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 77bfeda9..049654bc 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -550,10 +550,10 @@ class RDMATransport : public Transport { int total_data_len = 0; char *cur = buffer_ctx->buffer + meta_len; // offset + LOG(INFO) << "meta_len=" << meta_len; for (size_t i = 0; i < data_num; i++) { uint32_t len = buffer_ctx->data_len[i]; - LOG(INFO) << "=====Recv: data_len=" << len; SArray data; data.reset(cur, len, [](void *) {}); // no need for delete msg->data.push_back(data); diff --git a/src/rdma_van.h b/src/rdma_van.h index 7a505952..d79d233b 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -226,6 +226,11 @@ class RDMAVan : public Van { CHECK_NE(remote_id, Meta::kEmpty); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); + + if (msg.meta.push && msg.meta.request) { + SArray lens(msg.data[2]); + LOG(INFO) << "sendmsg, len=" << lens[0]; + } auto trans = CHECK_NOTNULL(endpoint->GetTransport()); trans->RegisterMemory(msg); @@ -238,10 +243,13 @@ class RDMAVan : public Van { size_t total_len = meta_len + data_len; CHECK(meta_len); - msg_buf->inline_len = total_len; - msg_buf->inline_buf = mempool_->Alloc(total_len); + msg_buf->inline_len = meta_len; + msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; + LOG(INFO) << "meta_len=" << meta_len + << ", total_len=" << total_len; + if (IsValidPushpull(msg)) trans->AddMeta(msg); PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); @@ -395,13 +403,10 @@ class RDMAVan : public Van { ReleaseWorkRequestContext(context, endpoint); break; case IBV_WC_RDMA_WRITE: { - // do nothing - LOG(INFO) << "opcode: IBV_WC_RDMA_WRITE"; - - // MessageBuffer *msg_buf = - // *reinterpret_cast(context->buffer->addr); - // mempool_->Free(msg_buf->inline_buf); - // delete msg_buf; + MessageBuffer *msg_buf = + *reinterpret_cast(context->buffer->addr); + mempool_->Free(msg_buf->inline_buf); + delete msg_buf; // ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV_RDMA_WITH_IMM: { From 4ac2137cb9d2e7f418b32a0542a2b03dcae7215c Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 16 Dec 2019 19:09:47 +0800 Subject: [PATCH 27/79] fix Rendezvous recv --- src/rdma_transport.h | 4 +--- src/rdma_van.h | 16 ++++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 049654bc..f344e751 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -264,6 +264,7 @@ class RDMATransport : public Transport { } WRContext *write_ctx = msg_buf->reserved_context; + CHECK(write_ctx); MessageBuffer **tmp = reinterpret_cast(write_ctx->buffer->addr); *tmp = msg_buf; // write the addr of msg_buf into the mr buffer @@ -525,8 +526,6 @@ class RDMATransport : public Transport { auto rkey = msg->meta.option; auto sender = msg->meta.sender; - LOG(INFO) << "key=" << key << " len=" << len << " sender=" << sender; - std::lock_guard lock(map_mu_); if (key_meta_map_.find(key) == key_meta_map_.end() || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { @@ -544,7 +543,6 @@ class RDMATransport : public Transport { uint64_t data_num = buffer_ctx->data_num; if (data_num == 0) { mempool_->Free(buffer_ctx->buffer); - delete buffer_ctx; return 0; } diff --git a/src/rdma_van.h b/src/rdma_van.h index d79d233b..49870f9e 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -226,11 +226,6 @@ class RDMAVan : public Van { CHECK_NE(remote_id, Meta::kEmpty); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); - - if (msg.meta.push && msg.meta.request) { - SArray lens(msg.data[2]); - LOG(INFO) << "sendmsg, len=" << lens[0]; - } auto trans = CHECK_NOTNULL(endpoint->GetTransport()); trans->RegisterMemory(msg); @@ -247,7 +242,7 @@ class RDMAVan : public Van { msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; - LOG(INFO) << "meta_len=" << meta_len + LOG(INFO) << "Send a message with meta_len=" << meta_len << ", total_len=" << total_len; if (IsValidPushpull(msg)) trans->AddMeta(msg); @@ -263,6 +258,9 @@ class RDMAVan : public Van { auto is_push = msg.meta.push; auto key = msg.meta.key; if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { + LOG(INFO) << "Call SendRendezvousBegin" + << ", " << (is_push?"push":"pull") + << " " << (msg.meta.request?"request":"response"); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } @@ -271,15 +269,19 @@ class RDMAVan : public Van { // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request + LOG(INFO) << "PUSH REQUEST"; trans->SendPushRequest(msg, msg_buf); } else if (msg.meta.push && !msg.meta.request) { // server, push response + LOG(INFO) << "PUSH RESPONSE"; trans->SendPushResponse(msg, msg_buf); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request + LOG(INFO) << "PULL REQUEST"; trans->SendPullRequest(msg, msg_buf); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response + LOG(INFO) << "PULL RESPONSE"; trans->SendPullResponse(msg, msg_buf); } else { CHECK(0) << "unexpected message type"; @@ -304,6 +306,8 @@ class RDMAVan : public Van { UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); int meta_len = GetPackMetaLen(msg->meta); + LOG(INFO) << "receive a message with meta_len=" << meta_len; + int total_len = 0; total_len += meta_len; From 55132768ad9e6a526464c3cf1918365439c3fbe9 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 10:09:51 +0800 Subject: [PATCH 28/79] release write context --- src/rdma_transport.h | 10 ++++++++-- src/rdma_van.h | 10 ++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index f344e751..4cddda48 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -246,6 +246,8 @@ class RDMATransport : public Transport { } void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + LOG(INFO) << "RDMAWriteWithImm: remote_addr=" << remote_addr + << ", idx=" << idx; // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); @@ -424,11 +426,16 @@ class RDMATransport : public Transport { } void Send(Message &msg, MessageBuffer *msg_buf, bool is_push) { + WRContext *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + msg_buf->reserved_context = reserved; auto key = DecodeKey(msg_buf->data[0]); + std::lock_guard lk(addr_mu_); auto remote_addr = is_push ? std::get<0>(push_addr_[key]) : std::get<0>(pull_addr_[key]); auto rkey = is_push ? std::get<1>(push_addr_[key]) : std::get<1>(pull_addr_[key]); auto idx = is_push ? std::get<2>(push_addr_[key]) : std::get<2>(pull_addr_[key]); + RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } @@ -513,7 +520,6 @@ class RDMATransport : public Transport { msg->data.push_back(lens); total_data_len += keys.size() + vals.size() + lens.size(); - mempool_->Free(buffer_ctx->buffer); return total_data_len; } @@ -544,7 +550,7 @@ class RDMATransport : public Transport { if (data_num == 0) { mempool_->Free(buffer_ctx->buffer); return 0; - } + } int total_data_len = 0; char *cur = buffer_ctx->buffer + meta_len; // offset diff --git a/src/rdma_van.h b/src/rdma_van.h index 49870f9e..8f987ebd 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -265,7 +265,7 @@ class RDMAVan : public Van { return total_len; } } - + // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request @@ -301,8 +301,10 @@ class RDMAVan : public Van { msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; + LOG(INFO) << "RecvMsg meta_len=" << buffer_ctx->meta_len; + // the second argument is actually deprecated, - // we keep it as is in order to be compatible + // we keep it as is in order to be compatible UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); int meta_len = GetPackMetaLen(msg->meta); @@ -411,13 +413,13 @@ class RDMAVan : public Van { *reinterpret_cast(context->buffer->addr); mempool_->Free(msg_buf->inline_buf); delete msg_buf; - // ReleaseWorkRequestContext(context, endpoint); + ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV_RDMA_WITH_IMM: { uint32_t addr_idx = wc[i].imm_data; BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); recv_buffers_.Push(std::make_tuple(endpoint, buf_ctx)); - // ReleaseWorkRequestContext(context, endpoint); + ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV: { CHECK(wc[i].wc_flags & IBV_WC_WITH_IMM); From 696c7faa996cb6a0d7ca2d9485bc9f734f6374fa Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 11:41:11 +0800 Subject: [PATCH 29/79] server: fix receiving push request --- src/rdma_transport.h | 14 ++++++++++---- src/rdma_van.h | 7 +++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 4cddda48..511e1be7 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -246,8 +246,6 @@ class RDMATransport : public Transport { } void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - LOG(INFO) << "RDMAWriteWithImm: remote_addr=" << remote_addr - << ", idx=" << idx; // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); @@ -255,15 +253,23 @@ class RDMATransport : public Transport { sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); size_t num_sge = 1; + + uint64_t data_len = 0; for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; + size_t length = pair.second; CHECK(length); sge[num_sge].addr = reinterpret_cast(pair.first->addr); sge[num_sge].length = length; sge[num_sge].lkey = pair.first->lkey; ++num_sge; - } + + data_len += length; + } + + LOG(INFO) << "RDMAWriteWithImm: remote_addr=" << remote_addr + << ", idx=" << idx + << ", len=" << data_len; WRContext *write_ctx = msg_buf->reserved_context; CHECK(write_ctx); diff --git a/src/rdma_van.h b/src/rdma_van.h index 8f987ebd..15a46a05 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -301,7 +301,8 @@ class RDMAVan : public Van { msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; - LOG(INFO) << "RecvMsg meta_len=" << buffer_ctx->meta_len; + LOG(INFO) << "RecvMsg meta_len=" << buffer_ctx->meta_len + << ", buffer addr=" << reinterpret_cast(buffer_ctx->buffer); // the second argument is actually deprecated, // we keep it as is in order to be compatible @@ -342,7 +343,6 @@ class RDMAVan : public Van { CHECK(0) << "unknown msg type"; } - delete buffer_ctx; return total_len; } @@ -418,6 +418,9 @@ class RDMAVan : public Van { case IBV_WC_RECV_RDMA_WITH_IMM: { uint32_t addr_idx = wc[i].imm_data; BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); + LOG(INFO) << "------- receving RDMA_WITH_IMM with idx=" << addr_idx + << " bufctx=" << reinterpret_cast(buf_ctx) + << " buffer=" << reinterpret_cast(buf_ctx->buffer); recv_buffers_.Push(std::make_tuple(endpoint, buf_ctx)); ReleaseWorkRequestContext(context, endpoint); } break; From 59bdad2aa1a7bf8ae63f755505b193b9128ea035 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 14:13:45 +0800 Subject: [PATCH 30/79] fix storing worker tensor address --- src/rdma_transport.h | 48 +++--------------------------------- src/rdma_van.h | 58 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 511e1be7..987c52ac 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -267,10 +267,6 @@ class RDMATransport : public Transport { data_len += length; } - LOG(INFO) << "RDMAWriteWithImm: remote_addr=" << remote_addr - << ", idx=" << idx - << ", len=" << data_len; - WRContext *write_ctx = msg_buf->reserved_context; CHECK(write_ctx); MessageBuffer **tmp = @@ -414,20 +410,6 @@ class RDMATransport : public Transport { msg.meta.addr = reinterpret_cast(vals.data()); // vals address msg.meta.val_len = vals.size(); msg.meta.option = mem_mr_map_[vals.data()]->rkey; - } else if (!msg.meta.push && !msg.meta.request) { - // pull response - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - std::lock_guard lock(map_mu_); - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); } } @@ -461,9 +443,9 @@ class RDMATransport : public Transport { std::lock_guard lock(map_mu_); auto key = msg.meta.key; auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); + auto len = msg.meta.val_len; + auto raddr = msg.meta.addr; + auto rkey = msg.meta.option; auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); CHECK_NE(temp_mr, mem_mr_map_.end()); @@ -530,24 +512,7 @@ class RDMATransport : public Transport { } virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { - int total_data_len = Recv(msg, buffer_ctx, meta_len); - - auto key = msg->meta.key; - auto len = msg->meta.val_len; - auto addr = msg->meta.addr; - auto rkey = msg->meta.option; - auto sender = msg->meta.sender; - - std::lock_guard lock(map_mu_); - if (key_meta_map_.find(key) == key_meta_map_.end() - || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { - key_meta_map_[key][sender] = std::make_tuple(len, addr, rkey); - } else { - CHECK_EQ(len, std::get<0>(key_meta_map_[key][sender])); - CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); - CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); - } - return total_data_len; + return Recv(msg, buffer_ctx, meta_len); } private: @@ -560,7 +525,6 @@ class RDMATransport : public Transport { int total_data_len = 0; char *cur = buffer_ctx->buffer + meta_len; // offset - LOG(INFO) << "meta_len=" << meta_len; for (size_t i = 0; i < data_num; i++) { uint32_t len = buffer_ctx->data_len[i]; @@ -594,10 +558,6 @@ class RDMATransport : public Transport { std::unordered_map key_addr_map_; std::unordered_map key_len_map_; - using MetaInfo = std::tuple; // len, addr, rkey - using SenderMeta = std::unordered_map; // sender as the key - std::unordered_map key_meta_map_; // (key, sender) --> MetaInfo - }; // class Transport diff --git a/src/rdma_van.h b/src/rdma_van.h index 15a46a05..8e15435c 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -221,6 +221,41 @@ class RDMAVan : public Van { } } + void PackWorkerTensorAddress(Message &msg) { + // must be pull response + if (msg.meta.push || msg.meta.request) return; + + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + std::lock_guard lock(info_mu_); + CHECK_NE(tensor_info_map_.find(key), tensor_info_map_.end()) + << "key=" << key << " not inited in tensor_info_map_"; + CHECK_NE(tensor_info_map_[key].find(recver), tensor_info_map_[key].end()) + << "key=" << key << ", recver=" << recver << " not inited in tensor_info_map_[key]"; + msg.meta.val_len = std::get<0>(tensor_info_map_[key][recver]); + msg.meta.addr = std::get<1>(tensor_info_map_[key][recver]); + msg.meta.option = std::get<2>(tensor_info_map_[key][recver]); + } + + void StoreWorkerTensorAddress(Message *msg) { + auto key = msg->meta.key; + auto len = msg->meta.val_len; + auto addr = msg->meta.addr; + auto rkey = msg->meta.option; + auto sender = msg->meta.sender; + + std::lock_guard lock(info_mu_); + if (tensor_info_map_.find(key) == tensor_info_map_.end() + || tensor_info_map_[key].find(sender) == tensor_info_map_[key].end()) { + tensor_info_map_[key][sender] = std::make_tuple(len, addr, rkey); + } else { + CHECK_EQ(len, std::get<0>(tensor_info_map_[key][sender])); + CHECK_EQ(addr, std::get<1>(tensor_info_map_[key][sender])); + CHECK_EQ(rkey, std::get<2>(tensor_info_map_[key][sender])); + } + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); @@ -242,10 +277,10 @@ class RDMAVan : public Van { msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; - LOG(INFO) << "Send a message with meta_len=" << meta_len - << ", total_len=" << total_len; - - if (IsValidPushpull(msg)) trans->AddMeta(msg); + if (IsValidPushpull(msg)) { + trans->AddMeta(msg); + PackWorkerTensorAddress(msg); + } PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); @@ -301,16 +336,11 @@ class RDMAVan : public Van { msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; - LOG(INFO) << "RecvMsg meta_len=" << buffer_ctx->meta_len - << ", buffer addr=" << reinterpret_cast(buffer_ctx->buffer); - // the second argument is actually deprecated, // we keep it as is in order to be compatible UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); int meta_len = GetPackMetaLen(msg->meta); - LOG(INFO) << "receive a message with meta_len=" << meta_len; - int total_len = 0; total_len += meta_len; @@ -327,6 +357,7 @@ class RDMAVan : public Van { // push request LOG(INFO) << "recv push request"; total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); + StoreWorkerTensorAddress(msg); } else if (!msg->meta.push && msg->meta.request) { // pull request LOG(INFO) << "recv pull request"; @@ -418,9 +449,6 @@ class RDMAVan : public Van { case IBV_WC_RECV_RDMA_WITH_IMM: { uint32_t addr_idx = wc[i].imm_data; BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); - LOG(INFO) << "------- receving RDMA_WITH_IMM with idx=" << addr_idx - << " bufctx=" << reinterpret_cast(buf_ctx) - << " buffer=" << reinterpret_cast(buf_ctx->buffer); recv_buffers_.Push(std::make_tuple(endpoint, buf_ctx)); ReleaseWorkRequestContext(context, endpoint); } break; @@ -682,6 +710,12 @@ class RDMAVan : public Van { std::mutex local_mu_; std::unordered_map is_local_; + // to store worker's tensor address + std::mutex info_mu_; + using RemoteInfo = std::tuple; // len, addr, rkey + using SenderMeta = std::unordered_map; // sender as the key + std::unordered_map tensor_info_map_; // (key, sender) --> RemoteInfo + }; // class RDMAVan }; // namespace ps From 3457b0b7df851ee53f5d27814f7494177fa2ffd5 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 16:33:45 +0800 Subject: [PATCH 31/79] fix push response redundant rendez --- src/rdma_transport.h | 79 ++++++++++------------------------ src/rdma_utils.h | 2 + src/rdma_van.h | 79 ++++++++++++++++++++++++---------- tests/test_kv_app_benchmark.cc | 2 +- 4 files changed, 83 insertions(+), 79 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 987c52ac..ef69f950 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -189,19 +189,16 @@ class Transport { virtual void AddMeta(Message &msg) = 0; virtual void RegisterMemory(Message &msg) = 0; - virtual void PrepareData(MessageBuffer *msg_buf) = 0; + virtual void PrepareData(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void Send(Message &msg, MessageBuffer *msg_buf, bool is_push) = 0; - virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; - virtual bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) = 0; - virtual void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - }; // class Transport @@ -234,7 +231,10 @@ class RDMATransport : public Transport { } } - void PrepareData(MessageBuffer *msg_buf) { + void PrepareData(Message &msg, MessageBuffer *msg_buf) { + // pull response send with rdma write (no imm) + if (!msg.meta.push && !msg.meta.request) return; + for (auto &sa : msg_buf->data) { if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); @@ -289,30 +289,6 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) { - std::lock_guard lk(addr_mu_); - if ( is_push && (push_addr_.find(key) != push_addr_.end())) return true; - if (!is_push && (pull_addr_.find(key) != pull_addr_.end())) return true; - // no remote info, store the msg_buf address and push/pull flag for RendezvousReply - msgbuf_cache_.emplace(reinterpret_cast(msg_buf), is_push); - return false; - } - - void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - if (msg_buf->data.size() == 0) return; - auto key = DecodeKey(msg_buf->data[0]); - auto buf = reinterpret_cast(msg_buf); - - std::lock_guard lk(addr_mu_); - auto is_push = msgbuf_cache_[buf]; - if (is_push) { - push_addr_.emplace(key, std::make_tuple(remote_addr, rkey, idx)); - } else { - pull_addr_.emplace(key, std::make_tuple(remote_addr, rkey, idx)); - } - msgbuf_cache_.erase(buf); - } - void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) { WRContext *context = nullptr, *reserved = nullptr; endpoint_->free_write_ctx.WaitAndPop(&reserved); @@ -413,40 +389,38 @@ class RDMATransport : public Transport { } } - void Send(Message &msg, MessageBuffer *msg_buf, bool is_push) { + void Send(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { WRContext *reserved = nullptr; endpoint_->free_write_ctx.WaitAndPop(&reserved); msg_buf->reserved_context = reserved; - auto key = DecodeKey(msg_buf->data[0]); + auto key = msg.meta.key; - std::lock_guard lk(addr_mu_); - auto remote_addr = is_push ? std::get<0>(push_addr_[key]) : std::get<0>(pull_addr_[key]); - auto rkey = is_push ? std::get<1>(push_addr_[key]) : std::get<1>(pull_addr_[key]); - auto idx = is_push ? std::get<2>(push_addr_[key]) : std::get<2>(pull_addr_[key]); + auto raddr = std::get<0>(remote_addr); + auto rkey = std::get<1>(remote_addr); + auto idx = std::get<2>(remote_addr); - RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } - void SendPushResponse(Message &msg, MessageBuffer *msg_buf) { - Send(msg, msg_buf, true); + void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + Send(msg, msg_buf, remote_addr); } - void SendPullRequest(Message &msg, MessageBuffer *msg_buf) { - Send(msg, msg_buf, false); + void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + Send(msg, msg_buf, remote_addr); } - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { - Send(msg, msg_buf, true); + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + Send(msg, msg_buf, remote_addr); } - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { std::lock_guard lock(map_mu_); auto key = msg.meta.key; auto recver = msg.meta.recver; auto len = msg.meta.val_len; auto raddr = msg.meta.addr; auto rkey = msg.meta.option; - auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); CHECK_NE(temp_mr, mem_mr_map_.end()); @@ -457,14 +431,12 @@ class RDMATransport : public Transport { struct ibv_send_wr wr, *bad_wr = nullptr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = reinterpret_cast(raddr); wr.opcode = IBV_WR_RDMA_WRITE; wr.next = nullptr; // wr.send_flags = IBV_SEND_SIGNALED; wr.sg_list = &sge; wr.num_sge = 1; - wr.wr.rdma.remote_addr = raddr; wr.wr.rdma.rkey = rkey; @@ -519,7 +491,6 @@ class RDMATransport : public Transport { virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) { uint64_t data_num = buffer_ctx->data_num; if (data_num == 0) { - mempool_->Free(buffer_ctx->buffer); return 0; } @@ -543,10 +514,6 @@ class RDMATransport : public Transport { SimpleMempool *mempool_; // role is server or worker bool is_server_; - std::mutex addr_mu_; - std::unordered_map > push_addr_; // key, - std::unordered_map > pull_addr_; // key, - std::unordered_map msgbuf_cache_; // msg_buf, is_push // manage the following map std::mutex map_mu_; diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 3cb57c1f..2bbfd657 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -295,6 +295,8 @@ struct BufferContext { typedef std::unique_ptr> MRPtr; +typedef std::tuple RemoteAddress; // + struct MessageBuffer { size_t inline_len; char *inline_buf; diff --git a/src/rdma_van.h b/src/rdma_van.h index 8e15435c..2dd2c68b 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -256,6 +256,30 @@ class RDMAVan : public Van { } } + bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) { + std::lock_guard lk(addr_mu_); + if ( is_push && (push_addr_.find(key) != push_addr_.end())) return true; + if (!is_push && (pull_addr_.find(key) != pull_addr_.end())) return true; + // no remote info, store the msg_buf address and push/pull flag for RendezvousReply + msgbuf_cache_.emplace(reinterpret_cast(msg_buf), std::make_pair(key, is_push)); + return false; + } + + void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + auto buf = reinterpret_cast(msg_buf); + if (msgbuf_cache_.find(buf) == msgbuf_cache_.end()) return; // control message + std::lock_guard lk(addr_mu_); + auto key = std::get<0>(msgbuf_cache_[buf]); + auto is_push = std::get<1>(msgbuf_cache_[buf]); + if (is_push) { + push_addr_[key] = std::make_tuple(remote_addr, rkey, idx); + } else { + pull_addr_[key] = std::make_tuple(remote_addr, rkey, idx); + } + CHECK_NE(msgbuf_cache_.find(buf), msgbuf_cache_.end()); + msgbuf_cache_.erase(buf); + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); @@ -284,40 +308,45 @@ class RDMAVan : public Van { PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); - trans->PrepareData(msg_buf); - if (!IsValidPushpull(msg)) { trans->SendRendezvousBegin(msg, msg_buf); return total_len; } else { + trans->PrepareData(msg, msg_buf); auto is_push = msg.meta.push; auto key = msg.meta.key; - if (!trans->HasRemoteInfo(msg_buf, key, is_push)) { + if (!HasRemoteInfo(msg_buf, key, is_push)) { LOG(INFO) << "Call SendRendezvousBegin" + << ", key=" << key << ", " << (is_push?"push":"pull") - << " " << (msg.meta.request?"request":"response"); + << " " << (msg.meta.request?"request":"response") + << ", push_addr.size=" << push_addr_.size(); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } } + std::lock_guard lk(addr_mu_); + auto key = msg.meta.key; + auto remote_addr_tuple = (msg.meta.push ? push_addr_[key] : pull_addr_[key]); + // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - LOG(INFO) << "PUSH REQUEST"; - trans->SendPushRequest(msg, msg_buf); + LOG(INFO) << "SEND PUSH REQUEST, key=" << key; + trans->SendPushRequest(msg, msg_buf, remote_addr_tuple); } else if (msg.meta.push && !msg.meta.request) { // server, push response - LOG(INFO) << "PUSH RESPONSE"; - trans->SendPushResponse(msg, msg_buf); + LOG(INFO) << "SEND PUSH RESPONSE, key=" << key; + trans->SendPushResponse(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - LOG(INFO) << "PULL REQUEST"; - trans->SendPullRequest(msg, msg_buf); + LOG(INFO) << "SEND PULL REQUEST, key=" << key; + trans->SendPullRequest(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - LOG(INFO) << "PULL RESPONSE"; - trans->SendPullResponse(msg, msg_buf); + LOG(INFO) << "SEND PULL RESPONSE, key=" << key; + trans->SendPullResponse(msg, msg_buf, remote_addr_tuple); } else { CHECK(0) << "unexpected message type"; } @@ -355,20 +384,20 @@ class RDMAVan : public Van { // valid data message if (msg->meta.push && msg->meta.request) { // push request - LOG(INFO) << "recv push request"; + LOG(INFO) << "RECV PUSH REQUEST, key=" << msg->meta.key; total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); StoreWorkerTensorAddress(msg); } else if (!msg->meta.push && msg->meta.request) { // pull request - LOG(INFO) << "recv pull request"; + LOG(INFO) << "RECV PULL REQUEST, key=" << msg->meta.key; total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); } else if (msg->meta.push && !msg->meta.request) { // push response - LOG(INFO) << "recv push response"; + LOG(INFO) << "RECV PUSH RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); } else if (!msg->meta.push && !msg->meta.request) { // pull response - LOG(INFO) << "recv push response"; + LOG(INFO) << "RECV PULL RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); } else { CHECK(0) << "unknown msg type"; @@ -477,7 +506,8 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous - trans->StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + LOG(INFO) << "kRendezvousReply: push_addr_.size=" << push_addr_.size(); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } else { @@ -710,12 +740,17 @@ class RDMAVan : public Van { std::mutex local_mu_; std::unordered_map is_local_; - // to store worker's tensor address + // worker's tensor address std::mutex info_mu_; - using RemoteInfo = std::tuple; // len, addr, rkey - using SenderMeta = std::unordered_map; // sender as the key - std::unordered_map tensor_info_map_; // (key, sender) --> RemoteInfo - + using TensorInfo = std::tuple; // len, addr, rkey + using RemoteTensorMeta = std::unordered_map; // sender as the key + std::unordered_map tensor_info_map_; // (key, sender) --> TensorInfo + + // store rendezvous address + std::mutex addr_mu_; + std::unordered_map push_addr_; // key, + std::unordered_map pull_addr_; // key, + std::unordered_map > msgbuf_cache_; // msg_buf, }; // class RDMAVan }; // namespace ps diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 499483c3..89592829 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -182,7 +182,7 @@ void RunWorker(int argc, char *argv[]) { auto start = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now(); auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; + unsigned int threshold = val ? atoi(val) : 1; val = Environment::Get()->find("LOG_DURATION"); unsigned int log_duration = val ? atoi(val) : 50; int cnt = 0; From b4b4487d4ca1633e9461301bc364fd7e31cf65fb Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 16:41:33 +0800 Subject: [PATCH 32/79] can run 1v1 --- src/rdma_transport.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index ef69f950..f7a67434 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -232,9 +232,6 @@ class RDMATransport : public Transport { } void PrepareData(Message &msg, MessageBuffer *msg_buf) { - // pull response send with rdma write (no imm) - if (!msg.meta.push && !msg.meta.request) return; - for (auto &sa : msg_buf->data) { if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); @@ -442,6 +439,11 @@ class RDMATransport : public Transport { CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) << "ibv_post_send failed."; + + // after write keys/vals/lens (no imm), write the meta (with imm) + // TODO: consolidate this into one RDMA_WRITE_WITH_IMM + msg_buf->mrs.clear(); + Send(msg, msg_buf, remote_addr); } virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { From 59b85af215ed0bf9d42708c7f8b4ecf071c3b8df Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 17 Dec 2019 21:28:31 +0800 Subject: [PATCH 33/79] add push/pull tests --- src/rdma_van.h | 26 ++-- tests/test_kv_app_benchmark.cc | 246 ++++++++++++++++----------------- 2 files changed, 133 insertions(+), 139 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index 2dd2c68b..fdc90b39 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -316,11 +316,9 @@ class RDMAVan : public Van { auto is_push = msg.meta.push; auto key = msg.meta.key; if (!HasRemoteInfo(msg_buf, key, is_push)) { - LOG(INFO) << "Call SendRendezvousBegin" - << ", key=" << key - << ", " << (is_push?"push":"pull") - << " " << (msg.meta.request?"request":"response") - << ", push_addr.size=" << push_addr_.size(); + // LOG(INFO) << "Call SendRendezvousBegin" << ", key=" << key + // << ", " << (is_push?"push":"pull") << " " << (msg.meta.request?"request":"response") + // << ", push_addr.size=" << push_addr_.size(); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } @@ -333,19 +331,19 @@ class RDMAVan : public Van { // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - LOG(INFO) << "SEND PUSH REQUEST, key=" << key; + // LOG(INFO) << "SEND PUSH REQUEST, key=" << key; trans->SendPushRequest(msg, msg_buf, remote_addr_tuple); } else if (msg.meta.push && !msg.meta.request) { // server, push response - LOG(INFO) << "SEND PUSH RESPONSE, key=" << key; + // LOG(INFO) << "SEND PUSH RESPONSE, key=" << key; trans->SendPushResponse(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - LOG(INFO) << "SEND PULL REQUEST, key=" << key; + // LOG(INFO) << "SEND PULL REQUEST, key=" << key; trans->SendPullRequest(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - LOG(INFO) << "SEND PULL RESPONSE, key=" << key; + // LOG(INFO) << "SEND PULL RESPONSE, key=" << key; trans->SendPullResponse(msg, msg_buf, remote_addr_tuple); } else { CHECK(0) << "unexpected message type"; @@ -384,20 +382,20 @@ class RDMAVan : public Van { // valid data message if (msg->meta.push && msg->meta.request) { // push request - LOG(INFO) << "RECV PUSH REQUEST, key=" << msg->meta.key; + // LOG(INFO) << "RECV PUSH REQUEST, key=" << msg->meta.key; total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); StoreWorkerTensorAddress(msg); } else if (!msg->meta.push && msg->meta.request) { // pull request - LOG(INFO) << "RECV PULL REQUEST, key=" << msg->meta.key; + // LOG(INFO) << "RECV PULL REQUEST, key=" << msg->meta.key; total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); } else if (msg->meta.push && !msg->meta.request) { // push response - LOG(INFO) << "RECV PUSH RESPONSE, key=" << msg->meta.key; + // LOG(INFO) << "RECV PUSH RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); } else if (!msg->meta.push && !msg->meta.request) { // pull response - LOG(INFO) << "RECV PULL RESPONSE, key=" << msg->meta.key; + // LOG(INFO) << "RECV PULL RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); } else { CHECK(0) << "unknown msg type"; @@ -507,9 +505,7 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); - LOG(INFO) << "kRendezvousReply: push_addr_.size=" << push_addr_.size(); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); - } else { CHECK(0); } diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 89592829..35c4fb96 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -7,8 +7,9 @@ using namespace ps; enum MODE { PUSH_THEN_PULL = 0, - PUSH_PULL_MIX_ENDLESS = 1, - PUSH_ONLY = 2 + PUSH_PULL = 1, + PUSH_ONLY = 2, + PULL_ONLY = 3 }; std::unordered_map > mem_map; @@ -52,6 +53,70 @@ struct PSKV { }; std::unordered_map ps_kv_; +void push_pull(KVWorker &kv, std::vector > &server_vals, + int len, int num_servers, int total_key_num, int how_many_key_per_server, MODE mode) { + CHECK_GT(mode, 0); + switch (mode) { + case PUSH_PULL: + LOG(INFO) << "========= PUSH_PULL mode ========="; + break; + case PUSH_ONLY: + LOG(INFO) << "========= PUSH_ONLY mode ========="; + break; + case PULL_ONLY: + LOG(INFO) << "========= PULL_ONLY mode ========="; + break; + default: CHECK(0); + } + + std::vector timestamp_list; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + + auto val = Environment::Get()->find("LOG_DURATION"); + unsigned int log_duration = val ? atoi(val) : 10; + + int cnt = 0; + while (1) { + for (int key = 0; key < total_key_num; key++) { + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[key]; + + switch (mode) { + case PUSH_PULL: { + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } break; + case PUSH_ONLY: { + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + } break; + case PULL_ONLY: { + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } break; + default: { + CHECK(0); + break; + } + } + } + + for (auto& ts : timestamp_list) { kv.Wait(ts); } + timestamp_list.clear(); + + cnt++; + if (cnt % log_duration != 0) continue; + + end = std::chrono::high_resolution_clock::now(); + LL << "Application goodput: " + << 8.0 * len * sizeof(float) * total_key_num * cnt / (end - start).count() + << " Gbps"; + cnt = 0; + start = std::chrono::high_resolution_clock::now(); + } +} + void RunWorker(int argc, char *argv[]) { if (!IsWorker()) return; CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; @@ -65,21 +130,27 @@ void RunWorker(int argc, char *argv[]) { // init int len = atoi(argv[1]); int repeat = atoi(argv[2]); - MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL_MIX_ENDLESS; + MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; + + auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); + const int how_many_key_per_server = v ? atoi(v) : 10; + const int total_key_num = num_servers * how_many_key_per_server; std::vector > server_vals; - for (int server = 0; server < num_servers; server++) { + for (int key = 0; key < total_key_num; key++) { std::vector vec(len); SArray vals(vec); server_vals.push_back(vals); } - // init push, do not count this into time cos - for (int server = 0; server < num_servers; server++) { - int key = server; // could be other value - auto vals = server_vals[server]; + // init push, do not count this into time cost + for (int key = 0; key < total_key_num; key++) { + auto vals = server_vals[key]; PSKV& pskv = ps_kv_[key]; SArray keys; + + int server = key % num_servers; + LOG(INFO) << "key=" << key << " assigned to server " << server; ps::Key ps_key = krs[server].begin() + key; keys.push_back(ps_key); SArray lens; @@ -92,128 +163,55 @@ void RunWorker(int argc, char *argv[]) { switch(mode) { case PUSH_THEN_PULL: { - LOG(INFO) << "PUSH_THEN_PULL mode"; - // push - uint64_t accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - LL << "push " << len * sizeof(float) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - - // pull - accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPull(keys, &vals, &lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns + LOG(INFO) << "PUSH_THEN_PULL mode"; + // push + uint64_t accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { + auto start = std::chrono::high_resolution_clock::now(); + for (int server = 0; server < num_servers; server++) { + int key = server; + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPush(keys, vals, lens)); } - - LL << "pull " << len * sizeof(float) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; + auto end = std::chrono::high_resolution_clock::now(); + accumulated_ms += (end - start).count(); // ns } - break; - - case PUSH_PULL_MIX_ENDLESS: { - LOG(INFO) << "PUSH_PULL_MIX_ENDLESS mode, should exit by Ctrl+C"; - std::vector timestamp_list; + LL << "push " << len * sizeof(float) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + + // pull + accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 1; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); - } - if (timestamp_list.size()/2/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(float) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } + for (int server = 0; server < num_servers; server++) { + int key = server; + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPull(keys, &vals, &lens)); } - } break; - case PUSH_ONLY: { - LOG(INFO) << "PUSH_ONLY mode, should exit by Ctrl+C"; - std::vector timestamp_list; - auto start = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 1; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - } - if (timestamp_list.size()/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(float) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } - } - } break; + accumulated_ms += (end - start).count(); // ns + } + LL << "pull " << len * sizeof(float) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + } break; + case PUSH_PULL: + case PUSH_ONLY: + case PULL_ONLY: + push_pull(kv, server_vals, len, num_servers, total_key_num, how_many_key_per_server, mode); + break; default: CHECK(0) << "unknown mode " << mode; } From b1ec799daa82489c542f2b66073688caaaaad5cc Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 10:25:49 +0800 Subject: [PATCH 34/79] use two separate mempools --- src/rdma_van.h | 24 +++++++++++++----------- tests/test_kv_app_benchmark.cc | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index fdc90b39..2275b6da 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -67,7 +67,8 @@ class RDMAVan : public Van { cm_event_polling_thread_.reset(); PS_VLOG(1) << "Clearing mempool."; - mempool_.reset(); + send_mempool_.reset(); + recv_mempool_.reset(); PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); @@ -213,8 +214,8 @@ class RDMAVan : public Van { } std::shared_ptr t = is_local_[node.id] ? - std::make_shared(endpoint, mempool_.get()) : - std::make_shared(endpoint, mempool_.get()); + std::make_shared(endpoint, send_mempool_.get()) : + std::make_shared(endpoint, send_mempool_.get()); endpoint->SetTransport(t); freeaddrinfo(remote_addr); @@ -298,7 +299,7 @@ class RDMAVan : public Van { CHECK(meta_len); msg_buf->inline_len = meta_len; - msg_buf->inline_buf = mempool_->Alloc(meta_len); + msg_buf->inline_buf = send_mempool_->Alloc(meta_len); msg_buf->data = msg.data; if (IsValidPushpull(msg)) { @@ -374,7 +375,7 @@ class RDMAVan : public Van { auto trans = CHECK_NOTNULL(endpoint->GetTransport()); if (!IsValidPushpull(*msg)) { - mempool_->Free(buffer_ctx->buffer); + recv_mempool_->Free(buffer_ctx->buffer); delete buffer_ctx; return total_len; } @@ -412,7 +413,8 @@ class RDMAVan : public Van { pd_ = ibv_alloc_pd(context_); CHECK(pd_) << "Failed to allocate protection domain"; - mempool_.reset(new SimpleMempool(pd_)); + send_mempool_.reset(new SimpleMempool(pd_)); + recv_mempool_.reset(new SimpleMempool(pd_)); comp_event_channel_ = ibv_create_comp_channel(context_); @@ -463,13 +465,12 @@ class RDMAVan : public Van { switch (wc[i].opcode) { case IBV_WC_SEND: - // LOG(INFO) << "opcode: IBV_WC_SEND"; ReleaseWorkRequestContext(context, endpoint); break; case IBV_WC_RDMA_WRITE: { MessageBuffer *msg_buf = *reinterpret_cast(context->buffer->addr); - mempool_->Free(msg_buf->inline_buf); + send_mempool_->Free(msg_buf->inline_buf); delete msg_buf; ReleaseWorkRequestContext(context, endpoint); } break; @@ -609,8 +610,8 @@ class RDMAVan : public Van { endpoint->Init(cq_, pd_); std::shared_ptr t = is_local_[remote_ctx->node] ? - std::make_shared(endpoint, mempool_.get()) : - std::make_shared(endpoint, mempool_.get()); + std::make_shared(endpoint, recv_mempool_.get()) : + std::make_shared(endpoint, recv_mempool_.get()); endpoint->SetTransport(t); RequestContext ctx; @@ -696,7 +697,8 @@ class RDMAVan : public Van { } AddressPool addr_pool_; - std::unique_ptr mempool_; + std::unique_ptr recv_mempool_; + std::unique_ptr send_mempool_; std::unique_ptr rdma_trans_; std::unique_ptr ipc_trans_; diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 35c4fb96..437ec84f 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -22,7 +22,7 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; if (mem_map.find(key) == mem_map.end()) { - PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; + PS_VLOG(1) << "receive key-" << key << " from worker-" << req_meta.sender; size_t len = (size_t) req_data.vals.size(); mem_map[key].keys.push_back(key); mem_map[key].vals.CopyFrom(req_data.vals); @@ -150,7 +150,7 @@ void RunWorker(int argc, char *argv[]) { SArray keys; int server = key % num_servers; - LOG(INFO) << "key=" << key << " assigned to server " << server; + PS_VLOG(1) << "key=" << key << " assigned to server " << server; ps::Key ps_key = krs[server].begin() + key; keys.push_back(ps_key); SArray lens; From c7e9bb1212c8c4cb3a161ed16792a5685c41716e Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 15:52:47 +0800 Subject: [PATCH 35/79] nit: a little improvement --- src/rdma_transport.h | 2 +- src/rdma_van.h | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index f7a67434..8222ef69 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -426,12 +426,12 @@ class RDMATransport : public Transport { sge.length = msg_buf->data[1].size(); sge.lkey = temp_mr->second->lkey; + // this rdma-write will not trigger any signal both remotely and locally struct ibv_send_wr wr, *bad_wr = nullptr; memset(&wr, 0, sizeof(wr)); wr.wr_id = reinterpret_cast(raddr); wr.opcode = IBV_WR_RDMA_WRITE; wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; wr.sg_list = &sge; wr.num_sge = 1; wr.wr.rdma.remote_addr = raddr; diff --git a/src/rdma_van.h b/src/rdma_van.h index 2275b6da..13bfa24e 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -281,6 +281,11 @@ class RDMAVan : public Van { msgbuf_cache_.erase(buf); } + RemoteAddress GetRemoteInfo(uint64_t key, bool is_push) { + std::lock_guard lk(addr_mu_); + return (is_push ? push_addr_[key] : pull_addr_[key]); + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); @@ -325,9 +330,7 @@ class RDMAVan : public Van { } } - std::lock_guard lk(addr_mu_); - auto key = msg.meta.key; - auto remote_addr_tuple = (msg.meta.push ? push_addr_[key] : pull_addr_[key]); + auto remote_addr_tuple = GetRemoteInfo(msg.meta.key, msg.meta.push); // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { From c1b45b902aa8a56eb580c6c8668dfa6ce8467924 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 16:42:25 +0800 Subject: [PATCH 36/79] tests: use page aligned memory --- tests/test_kv_app_benchmark.cc | 47 ++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 437ec84f..768d0f64 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -1,8 +1,12 @@ #include #include #include +#include #include "ps/ps.h" +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) + using namespace ps; enum MODE { @@ -11,7 +15,18 @@ enum MODE { PUSH_ONLY = 2, PULL_ONLY = 3 }; -std::unordered_map > mem_map; +std::unordered_map > mem_map; + +void aligned_memory_alloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 0, size); + *ptr = p; +} template void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { @@ -25,12 +40,15 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer PS_VLOG(1) << "receive key-" << key << " from worker-" << req_meta.sender; size_t len = (size_t) req_data.vals.size(); mem_map[key].keys.push_back(key); - mem_map[key].vals.CopyFrom(req_data.vals); mem_map[key].lens.push_back(len); + + void* ptr; + aligned_memory_alloc(&ptr, len); + mem_map[key].vals.reset((char*)ptr, len, [](void *){ }); } // send push response (empty) - KVPairs res; + KVPairs res; server->Response(req_meta, res); } else { @@ -42,8 +60,8 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer void StartServer() { if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); + auto server = new KVServer(0); + server->set_request_handle(EmptyHandler); RegisterExitCallback([server]() { delete server; }); } @@ -53,7 +71,7 @@ struct PSKV { }; std::unordered_map ps_kv_; -void push_pull(KVWorker &kv, std::vector > &server_vals, +void push_pull(KVWorker &kv, std::vector > &server_vals, int len, int num_servers, int total_key_num, int how_many_key_per_server, MODE mode) { CHECK_GT(mode, 0); switch (mode) { @@ -110,7 +128,7 @@ void push_pull(KVWorker &kv, std::vector > &server_vals, end = std::chrono::high_resolution_clock::now(); LL << "Application goodput: " - << 8.0 * len * sizeof(float) * total_key_num * cnt / (end - start).count() + << 8.0 * len * sizeof(char) * total_key_num * cnt / (end - start).count() << " Gbps"; cnt = 0; start = std::chrono::high_resolution_clock::now(); @@ -120,7 +138,7 @@ void push_pull(KVWorker &kv, std::vector > &server_vals, void RunWorker(int argc, char *argv[]) { if (!IsWorker()) return; CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; - KVWorker kv(0, 0); + KVWorker kv(0, 0); auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); const int num_servers = krs.size(); @@ -136,10 +154,13 @@ void RunWorker(int argc, char *argv[]) { const int how_many_key_per_server = v ? atoi(v) : 10; const int total_key_num = num_servers * how_many_key_per_server; - std::vector > server_vals; + std::vector > server_vals; for (int key = 0; key < total_key_num; key++) { - std::vector vec(len); - SArray vals(vec); + std::vector vec(len); + void* ptr; + aligned_memory_alloc(&ptr, len); + SArray vals; + vals.reset((char*) ptr, len * sizeof(char), [](void *){}); server_vals.push_back(vals); } @@ -180,7 +201,7 @@ void RunWorker(int argc, char *argv[]) { auto end = std::chrono::high_resolution_clock::now(); accumulated_ms += (end - start).count(); // ns } - LL << "push " << len * sizeof(float) + LL << "push " << len * sizeof(char) << " bytes to each server, repeat=" << repeat << ", total_time=" << accumulated_ms / 1e6 << "ms"; @@ -202,7 +223,7 @@ void RunWorker(int argc, char *argv[]) { accumulated_ms += (end - start).count(); // ns } - LL << "pull " << len * sizeof(float) + LL << "pull " << len * sizeof(char) << " bytes to each server, repeat=" << repeat << ", total_time=" << accumulated_ms / 1e6 << "ms"; From 1b5b01c437c1a08165c37a95c5b5a46bc7fd21f6 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 16:47:54 +0800 Subject: [PATCH 37/79] nit: some improvement --- tests/test_kv_app_benchmark.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 768d0f64..6b1ce9e5 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -77,12 +77,15 @@ void push_pull(KVWorker &kv, std::vector > &server_vals, switch (mode) { case PUSH_PULL: LOG(INFO) << "========= PUSH_PULL mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; break; case PUSH_ONLY: LOG(INFO) << "========= PUSH_ONLY mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; break; case PULL_ONLY: LOG(INFO) << "========= PULL_ONLY mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; break; default: CHECK(0); } @@ -156,7 +159,6 @@ void RunWorker(int argc, char *argv[]) { std::vector > server_vals; for (int key = 0; key < total_key_num; key++) { - std::vector vec(len); void* ptr; aligned_memory_alloc(&ptr, len); SArray vals; From 3c858862263180895ca93878469679b00f38bbf4 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 17:45:50 +0800 Subject: [PATCH 38/79] mempool->Alloc returns page aligned memory --- src/rdma_utils.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 2bbfd657..5b9e4ee0 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -67,7 +67,6 @@ static const int kMaxConcurrentWorkRequest = kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; static const int kMaxHostnameLength = 16; static const int kMaxDataFields = 4; -static const size_t kAlignment = 8; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; @@ -105,6 +104,8 @@ class SimpleMempool { std::lock_guard lk(mu_); pd_ = pd; struct ibv_mr *mr; + + pagesize_ = sysconf(_SC_PAGESIZE); // set init mempool size auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); @@ -141,11 +142,13 @@ class SimpleMempool { std::lock_guard lk(mu_); - size_t proper_size = align_ceil(size, kAlignment); + // use page aligned memory + size_t proper_size = align_ceil(size, pagesize_); auto it = free_list.lower_bound(proper_size); - if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory + // if there is no space left, need to allocate and register new memory + if (it == free_list.end()) { size_t new_mem_size = total_allocated_size; while (proper_size > new_mem_size) { new_mem_size *= 2; @@ -228,6 +231,8 @@ class SimpleMempool { return it->second; } + size_t pagesize_; + }; class Block { From 137d7c8d798aa4c7617f9d5101d9ee910748f4b5 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 18:37:09 +0800 Subject: [PATCH 39/79] tests: all keys/lens use page aligned mem --- include/ps/kv_app.h | 2 +- tests/test_kv_app_benchmark.cc | 76 ++++++++++++++++++++-------------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/include/ps/kv_app.h b/include/ps/kv_app.h index fcfd34ff..cf37f329 100644 --- a/include/ps/kv_app.h +++ b/include/ps/kv_app.h @@ -582,7 +582,7 @@ void KVWorker::DefaultSlicer( k = send.vals.size() / send.keys.size(); CHECK_EQ(k * send.keys.size(), send.vals.size()); } else { - CHECK_EQ(send.keys.size(), send.lens.size()); + CHECK_EQ(send.keys.size(), send.lens.size()) << send.keys.size() << ", " << send.lens.size(); } // slice diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 6b1ce9e5..b971b242 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -39,12 +39,20 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer if (mem_map.find(key) == mem_map.end()) { PS_VLOG(1) << "receive key-" << key << " from worker-" << req_meta.sender; size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].lens.push_back(len); - void* ptr; - aligned_memory_alloc(&ptr, len); - mem_map[key].vals.reset((char*)ptr, len, [](void *){ }); + void* ptr_val; + aligned_memory_alloc(&ptr_val, len); + mem_map[key].vals.reset((char*)ptr_val, len, [](void *){ }); + + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + mem_map[key].keys.reset((Key*)ptr_key, 1, [](void *){ }); + memcpy(ptr_key, &key, sizeof(Key)); + + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + mem_map[key].lens.reset((int*)ptr_len, 1, [](void *){ }); + memcpy(ptr_len, &len, sizeof(int)); } // send push response (empty) @@ -65,14 +73,12 @@ void StartServer() { RegisterExitCallback([server]() { delete server; }); } -struct PSKV { - SArray keys; // n keys - SArray lens; // the length of the i-th value -}; -std::unordered_map ps_kv_; - -void push_pull(KVWorker &kv, std::vector > &server_vals, - int len, int num_servers, int total_key_num, int how_many_key_per_server, MODE mode) { +void push_pull(KVWorker &kv, + std::vector > &server_keys, + std::vector > &server_vals, + std::vector > &server_lens, + int len, int num_servers, int total_key_num, + int how_many_key_per_server, MODE mode) { CHECK_GT(mode, 0); switch (mode) { case PUSH_PULL: @@ -100,9 +106,8 @@ void push_pull(KVWorker &kv, std::vector > &server_vals, int cnt = 0; while (1) { for (int key = 0; key < total_key_num; key++) { - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; + auto keys = server_keys[key]; + auto lens = server_lens[key]; auto vals = server_vals[key]; switch (mode) { @@ -158,6 +163,8 @@ void RunWorker(int argc, char *argv[]) { const int total_key_num = num_servers * how_many_key_per_server; std::vector > server_vals; + std::vector > server_keys; + std::vector > server_lens; for (int key = 0; key < total_key_num; key++) { void* ptr; aligned_memory_alloc(&ptr, len); @@ -168,18 +175,27 @@ void RunWorker(int argc, char *argv[]) { // init push, do not count this into time cost for (int key = 0; key < total_key_num; key++) { - auto vals = server_vals[key]; - PSKV& pskv = ps_kv_[key]; - SArray keys; - int server = key % num_servers; PS_VLOG(1) << "key=" << key << " assigned to server " << server; + + auto vals = server_vals[key]; + + // page aligned keys + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + SArray keys; + keys.reset((Key*) ptr_key, 1, [](void *){}); ps::Key ps_key = krs[server].begin() + key; - keys.push_back(ps_key); + memcpy(ptr_key, &ps_key, sizeof(Key)); + server_keys.push_back(keys); + + // page aligned vals + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); SArray lens; - lens.push_back(len); - pskv.keys.push_back(ps_key); - pskv.lens.push_back(len); + lens.reset((int*) ptr_len, 1, [](void *){}); + memcpy(ptr_len, &len, sizeof(len)); + server_lens.push_back(lens); kv.Wait(kv.ZPush(keys, vals, lens)); } @@ -193,9 +209,8 @@ void RunWorker(int argc, char *argv[]) { auto start = std::chrono::high_resolution_clock::now(); for (int server = 0; server < num_servers; server++) { int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; + auto keys = server_keys[server]; + auto lens = server_lens[server]; auto vals = server_vals[server]; kv.Wait(kv.ZPush(keys, vals, lens)); @@ -214,9 +229,8 @@ void RunWorker(int argc, char *argv[]) { auto start = std::chrono::high_resolution_clock::now(); for (int server = 0; server < num_servers; server++) { int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; + auto keys = server_keys[server]; + auto lens = server_lens[server]; auto vals = server_vals[server]; kv.Wait(kv.ZPull(keys, &vals, &lens)); @@ -233,7 +247,7 @@ void RunWorker(int argc, char *argv[]) { case PUSH_PULL: case PUSH_ONLY: case PULL_ONLY: - push_pull(kv, server_vals, len, num_servers, total_key_num, how_many_key_per_server, mode); + push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); break; default: CHECK(0) << "unknown mode " << mode; From a14acfe0440acba6c26972331949315e6590e4d8 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 21:47:27 +0800 Subject: [PATCH 40/79] push request: only write one sge --- include/ps/kv_app.h | 2 +- src/rdma_transport.h | 55 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/include/ps/kv_app.h b/include/ps/kv_app.h index cf37f329..fcfd34ff 100644 --- a/include/ps/kv_app.h +++ b/include/ps/kv_app.h @@ -582,7 +582,7 @@ void KVWorker::DefaultSlicer( k = send.vals.size() / send.keys.size(); CHECK_EQ(k * send.keys.size(), send.vals.size()); } else { - CHECK_EQ(send.keys.size(), send.lens.size()) << send.keys.size() << ", " << send.lens.size(); + CHECK_EQ(send.keys.size(), send.lens.size()); } // slice diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 8222ef69..7a0b247a 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -252,16 +252,25 @@ class RDMATransport : public Transport { size_t num_sge = 1; uint64_t data_len = 0; - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; + if (msg_buf->mrs.size() == 3) { + // push request, only write vals + sge[1].addr = reinterpret_cast(msg_buf->mrs[1].first->addr); + sge[1].length = msg_buf->mrs[1].second; + sge[1].lkey = msg_buf->mrs[1].first->lkey; ++num_sge; - - data_len += length; + data_len += sge[1].length; + } else { + for (auto &pair : msg_buf->mrs) { + size_t length = pair.second; + CHECK(length); + sge[num_sge].addr = + reinterpret_cast(pair.first->addr); + sge[num_sge].length = length; + sge[num_sge].lkey = pair.first->lkey; + ++num_sge; + + data_len += length; + } } WRContext *write_ctx = msg_buf->reserved_context; @@ -459,6 +468,7 @@ class RDMATransport : public Transport { std::lock_guard lock(map_mu_); auto key = msg->meta.key; if (key_len_map_.find(key) == key_len_map_.end()) { + // need a static address for keys/lens key_addr_map_[key] = (ps::Key) key; key_len_map_[key] = (int) msg->meta.val_len; } @@ -496,9 +506,33 @@ class RDMATransport : public Transport { return 0; } - int total_data_len = 0; char *cur = buffer_ctx->buffer + meta_len; // offset + if (msg->meta.push && msg->meta.request) { // push request + CHECK_EQ(data_num, 3); + uint32_t len = buffer_ctx->data_len[1]; + + SArray keys; + void *p = malloc(sizeof(Key)); + memcpy(p, &msg->meta.key, sizeof(Key)); + keys.reset((char *) p, sizeof(Key), [p](void *) { free(p); }); + + SArray vals; + vals.reset(cur, len, [](void *) {}); // no need to delete + + SArray lens; + void *q = malloc(sizeof(int)); + memcpy(q, &len, sizeof(int)); + lens.reset((char *) q, sizeof(int), [q](void *) { free(q); }); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return sizeof(Key) + len + sizeof(int); + } + + int total_data_len = 0; for (size_t i = 0; i < data_num; i++) { uint32_t len = buffer_ctx->data_len[i]; SArray data; @@ -507,7 +541,6 @@ class RDMATransport : public Transport { cur += len; total_data_len += len; } - return total_data_len; } From 3a9f237c4dfa2e3e0e46e6f2ef2f397b4c4e81d3 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Wed, 18 Dec 2019 22:21:25 +0800 Subject: [PATCH 41/79] Split one write into two writes --- src/rdma_transport.h | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 7a0b247a..31db406d 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -252,13 +252,26 @@ class RDMATransport : public Transport { size_t num_sge = 1; uint64_t data_len = 0; - if (msg_buf->mrs.size() == 3) { - // push request, only write vals - sge[1].addr = reinterpret_cast(msg_buf->mrs[1].first->addr); - sge[1].length = msg_buf->mrs[1].second; - sge[1].lkey = msg_buf->mrs[1].first->lkey; - ++num_sge; - data_len += sge[1].length; + if (msg_buf->mrs.size() == 3) { + struct ibv_sge my_sge; + my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); + my_sge.length = msg_buf->mrs[1].second; + my_sge.lkey = msg_buf->mrs[1].first->lkey; + + // this rdma-write will not trigger any signal both remotely and locally + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + wr.sg_list = &my_sge; + wr.num_sge = 1; + wr.wr.rdma.remote_addr = remote_addr + msg_buf->inline_len; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } else { for (auto &pair : msg_buf->mrs) { size_t length = pair.second; From 20b3782a14b9708837bb47a920ca8de2c9da5970 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Wed, 18 Dec 2019 23:03:33 +0800 Subject: [PATCH 42/79] server receives push data with aligned memory --- src/rdma_transport.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 31db406d..2ce06adc 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -210,6 +210,7 @@ class RDMATransport : public Transport { auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); is_server_ = (role=="server"); + pagesize_ = sysconf(_SC_PAGESIZE); }; ~RDMATransport() { @@ -252,7 +253,8 @@ class RDMATransport : public Transport { size_t num_sge = 1; uint64_t data_len = 0; - if (msg_buf->mrs.size() == 3) { + if (msg_buf->mrs.size() == 3) { + // push request, split the meta and data into two writes struct ibv_sge my_sge; my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); my_sge.length = msg_buf->mrs[1].second; @@ -266,7 +268,7 @@ class RDMATransport : public Transport { wr.next = nullptr; wr.sg_list = &my_sge; wr.num_sge = 1; - wr.wr.rdma.remote_addr = remote_addr + msg_buf->inline_len; + wr.wr.rdma.remote_addr = remote_addr + align_ceil(msg_buf->inline_len, pagesize_); wr.wr.rdma.rkey = rkey; CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) @@ -524,6 +526,8 @@ class RDMATransport : public Transport { if (msg->meta.push && msg->meta.request) { // push request CHECK_EQ(data_num, 3); uint32_t len = buffer_ctx->data_len[1]; + + cur = buffer_ctx->buffer + align_ceil((size_t)meta_len, pagesize_); SArray keys; void *p = malloc(sizeof(Key)); @@ -558,6 +562,7 @@ class RDMATransport : public Transport { } protected: + size_t pagesize_ = 4096; Endpoint *endpoint_; SimpleMempool *mempool_; // role is server or worker From 50c1bf49e08f6dbe229dc2749c0dee1a922573d6 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 12:32:47 +0800 Subject: [PATCH 43/79] tests: clean testcase and set default value --- tests/test_connection.cc | 8 --- tests/test_kv_app.cc | 65 ----------------- tests/test_kv_app_benchmark.cc | 6 +- tests/test_kv_app_multi_servers.cc | 108 ----------------------------- tests/test_kv_app_multi_workers.cc | 71 ------------------- tests/test_simple_app.cc | 35 ---------- 6 files changed, 3 insertions(+), 290 deletions(-) delete mode 100644 tests/test_connection.cc delete mode 100644 tests/test_kv_app.cc delete mode 100644 tests/test_kv_app_multi_servers.cc delete mode 100644 tests/test_kv_app_multi_workers.cc delete mode 100644 tests/test_simple_app.cc diff --git a/tests/test_connection.cc b/tests/test_connection.cc deleted file mode 100644 index eeea7618..00000000 --- a/tests/test_connection.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "ps/ps.h" - -int main(int argc, char *argv[]) { - ps::Start(0); - // do nothing - ps::Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app.cc b/tests/test_kv_app.cc deleted file mode 100644 index aa143304..00000000 --- a/tests/test_kv_app.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include "ps/ps.h" -#include - -using namespace ps; - -void StartServer() { - if (!IsServer()) { - return; - } - auto server = new KVServer(0); - server->set_request_handle(KVServerDefaultHandle()); - RegisterExitCallback([server](){ delete server; }); -} - -void RunWorker() { - if (!IsWorker()) return; - KVWorker kv(0, 0); - - // init - int num = 10000; - std::vector keys(num); - std::vector vals(num); - - int rank = MyRank(); - srand(rank + 7); - for (int i = 0; i < num; ++i) { - keys[i] = kMaxKey / num * i + rank; - vals[i] = (rand() % 1000); - } - - // push - int repeat = 50; - std::vector ts; - for (int i = 0; i < repeat; ++i) { - ts.push_back(kv.Push(keys, vals)); - - // to avoid too frequency push, which leads huge memory usage - if (i > 10) kv.Wait(ts[ts.size()-10]); - } - for (int t : ts) kv.Wait(t); - - // pull - std::vector rets; - kv.Wait(kv.Pull(keys, &rets)); - - float res = 0; - for (int i = 0; i < num; ++i) { - res += std::fabs(rets[i] - vals[i] * repeat); - } - CHECK_LT(res / repeat, 1e-5); - LL << "error: " << res / repeat; -} - -int main(int argc, char *argv[]) { - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(); - // stop system - Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index b971b242..611e1b7d 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -154,12 +154,12 @@ void RunWorker(int argc, char *argv[]) { CHECK_GT(num_servers, 0); // init - int len = atoi(argv[1]); - int repeat = atoi(argv[2]); + int len = (argc > 1) ? atoi(argv[1]) : 1024000; + int repeat = (argc > 2) ? atoi(argv[2]) : 10; MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); - const int how_many_key_per_server = v ? atoi(v) : 10; + const int how_many_key_per_server = v ? atoi(v) : 40; const int total_key_num = num_servers * how_many_key_per_server; std::vector > server_vals; diff --git a/tests/test_kv_app_multi_servers.cc b/tests/test_kv_app_multi_servers.cc deleted file mode 100644 index d4cc44c9..00000000 --- a/tests/test_kv_app_multi_servers.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* -This code only works for 1 worker VS N server -*/ -#include -#include -#include -#include "ps/ps.h" - -using namespace ps; - -std::unordered_map > mem_map; - -template -void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { - uint64_t key = req_data.keys[0]; - if (req_meta.push) { - CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); - - if (mem_map.find(key) == mem_map.end()) { - size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].vals.CopyFrom(req_data.vals); - mem_map[key].lens.push_back(len); - } - - // send push response (empty) - KVPairs res; - server->Response(req_meta, res); - } - else { - auto iter = mem_map.find(key); - CHECK_NE(iter, mem_map.end()); - server->Response(req_meta, iter->second); - } -} - -void StartServer() { - if (!IsServer()) return; - int myrank = Postoffice::Get()->my_rank(); - LOG(INFO) << "This is server " << myrank; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); - RegisterExitCallback([server]() { delete server; }); -} - -void RunWorker(int argc, char *argv[]) { - if (!IsWorker()) return; - CHECK_EQ(argc, 2) << "input argument should be: [SCRIPT, LEN]"; - KVWorker kv(0, 0); - - auto krs = Postoffice::Get()->GetServerKeyRanges(); - const int num_servers = krs.size(); - CHECK_GT(num_servers, 0); - - // init - int len = atoi(argv[1]); - - std::vector vec(len); - - std::vector > keys(num_servers); - std::vector > lens(num_servers); - std::vector > vals; - - for (int i = 0; i < num_servers ; ++i) { - int key = i; - int server = (key * 9973) % num_servers; - ps::Key ps_key = krs[server].begin() + key; - CHECK_LT(ps_key, krs[server].end()); - - SArray tmp_vals(vec); - - keys[i].push_back(ps_key); - lens[i].push_back(len); - vals.push_back(tmp_vals); - - // init push, to register memory, better not count this into time cost - kv.Wait(kv.ZPush(keys[i], vals[i], lens[i])); - } - - std::vector timestamp_list; - while (1) { - for (int j = 0; j < num_servers; ++j) { - timestamp_list.push_back(kv.ZPush(keys[j], vals[j], lens[j])); - timestamp_list.push_back(kv.ZPull(keys[j], &vals[j], &lens[j])); - } - if (timestamp_list.size() >= 30) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - } - } -} - -int main(int argc, char *argv[]) { - // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(argc, argv); - // stop system - Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app_multi_workers.cc b/tests/test_kv_app_multi_workers.cc deleted file mode 100644 index 636d494b..00000000 --- a/tests/test_kv_app_multi_workers.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include -#include "ps/ps.h" -using namespace ps; - -void StartServer() { - if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(KVServerDefaultHandle()); - RegisterExitCallback([server](){ delete server; }); -} - -void RunWorker(int customer_id) { - Start(customer_id); - if (!IsWorker()) { - return; - } - KVWorker kv(0, customer_id); - // init - int num = 10000; - std::vector keys(num); - std::vector vals(num); - - int rank = MyRank(); - srand(rank + 7); - for (int i = 0; i < num; ++i) { - keys[i] = kMaxKey / num * i + customer_id; - vals[i] = (rand() % 1000); - } - // push - int repeat = 50; - std::vector ts; - for (int i = 0; i < repeat; ++i) { - ts.push_back(kv.Push(keys, vals)); - - // to avoid too frequency push, which leads huge memory usage - if (i > 10) kv.Wait(ts[ts.size()-10]); - } - for (int t : ts) kv.Wait(t); - - // pull - std::vector rets; - kv.Wait(kv.Pull(keys, &rets)); - - float res = 0; - for (int i = 0; i < num; ++i) { - res += fabs(rets[i] - vals[i] * repeat); - } - CHECK_LT(res / repeat, 1e-5); - LL << "error: " << res / repeat; - // stop system - Finalize(customer_id, true); -} - -int main(int argc, char *argv[]) { - // start system - bool isWorker = (strcmp(argv[1], "worker") == 0); - if (!isWorker) { - Start(0); - // setup server nodes - StartServer(); - Finalize(0, true); - return 0; - } - // run worker nodes - std::thread t0(RunWorker, 0); - std::thread t1(RunWorker, 1); - - t0.join(); - t1.join(); - return 0; -} diff --git a/tests/test_simple_app.cc b/tests/test_simple_app.cc deleted file mode 100644 index cecf5de7..00000000 --- a/tests/test_simple_app.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "ps/ps.h" -using namespace ps; - -int num = 0; - -void ReqHandle(const SimpleData& req, SimpleApp* app) { - CHECK_EQ(req.head, 1); - CHECK_EQ(req.body, "test"); - app->Response(req); - ++ num; -} - -int main(int argc, char *argv[]) { - int n = 100; - Start(0); - SimpleApp app(0, 0); - app.set_request_handle(ReqHandle); - - if (IsScheduler()) { - std::vector ts; - for (int i = 0; i < n; ++i) { - int recver = kScheduler + kServerGroup + kWorkerGroup; - ts.push_back(app.Request(1, "test", recver)); - } - - for (int t : ts) { - app.Wait(t); - } - } - - Finalize(0, true); - - CHECK_EQ(num, n); - return 0; -} From 3f8ee6ff2a84170ad0f43d88ca000978fd41c875 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 12:34:17 +0800 Subject: [PATCH 44/79] quick fix --- tests/test_kv_app_benchmark.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 611e1b7d..1dbfb2e4 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -145,7 +145,6 @@ void push_pull(KVWorker &kv, void RunWorker(int argc, char *argv[]) { if (!IsWorker()) return; - CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; KVWorker kv(0, 0); auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); From fd391bdddf9c3071cf606853761ac4d89e4d7669 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 16:00:07 +0800 Subject: [PATCH 45/79] remove key_addr_map and key_len_map --- src/rdma_transport.h | 50 +++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 2ce06adc..276f5bab 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -207,10 +207,12 @@ class RDMATransport : public Transport { explicit RDMATransport(Endpoint *endpoint, SimpleMempool *mempool) { endpoint_ = CHECK_NOTNULL(endpoint); mempool_ = CHECK_NOTNULL(mempool); + pagesize_ = sysconf(_SC_PAGESIZE); + PS_VLOG(1) << "System page size is " << pagesize_; + auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); is_server_ = (role=="server"); - pagesize_ = sysconf(_SC_PAGESIZE); }; ~RDMATransport() { @@ -255,6 +257,7 @@ class RDMATransport : public Transport { uint64_t data_len = 0; if (msg_buf->mrs.size() == 3) { // push request, split the meta and data into two writes + // further, it does not send keys and lens since these meta already carries these info struct ibv_sge my_sge; my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); my_sge.length = msg_buf->mrs[1].second; @@ -268,9 +271,11 @@ class RDMATransport : public Transport { wr.next = nullptr; wr.sg_list = &my_sge; wr.num_sge = 1; - wr.wr.rdma.remote_addr = remote_addr + align_ceil(msg_buf->inline_len, pagesize_); wr.wr.rdma.rkey = rkey; + // write to the next page-aligned address (remote_addr should already be aligned) + wr.wr.rdma.remote_addr = remote_addr + align_ceil(msg_buf->inline_len, pagesize_); + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) << "ibv_post_send failed."; @@ -479,35 +484,27 @@ class RDMATransport : public Transport { } virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { - int total_data_len = 0; std::lock_guard lock(map_mu_); auto key = msg->meta.key; - if (key_len_map_.find(key) == key_len_map_.end()) { - // need a static address for keys/lens - key_addr_map_[key] = (ps::Key) key; - key_len_map_[key] = (int) msg->meta.val_len; - } - CHECK_NE(key_len_map_.find(key), key_len_map_.end()) << key; - CHECK_NE(key_addr_map_.find(key), key_addr_map_.end()) << key; - auto addr = msg->meta.addr; - CHECK_NE(key_len_map_[key], 0) << msg->DebugString(); - SArray keys; + void *p = malloc(sizeof(Key)); + memcpy(p, &msg->meta.key, sizeof(Key)); + keys.reset((char *) p, sizeof(Key), [p](void *) { free(p); }); + SArray vals; - SArray lens; + vals.reset(reinterpret_cast(addr), msg->meta.val_len, [](void *){}); - keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); - vals.reset(reinterpret_cast(addr), key_len_map_[key], [](void *){}); - lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); + SArray lens; + void *q = malloc(sizeof(int)); + memcpy(q, &msg->meta.val_len, sizeof(int)); + lens.reset((char *) q, sizeof(int), [q](void *) { free(q); }); msg->data.push_back(keys); msg->data.push_back(vals); msg->data.push_back(lens); - total_data_len += keys.size() + vals.size() + lens.size(); - - return total_data_len; + return keys.size() + vals.size() + lens.size(); } virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { @@ -527,7 +524,7 @@ class RDMATransport : public Transport { CHECK_EQ(data_num, 3); uint32_t len = buffer_ctx->data_len[1]; - cur = buffer_ctx->buffer + align_ceil((size_t)meta_len, pagesize_); + cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); SArray keys; void *p = malloc(sizeof(Key)); @@ -565,18 +562,9 @@ class RDMATransport : public Transport { size_t pagesize_ = 4096; Endpoint *endpoint_; SimpleMempool *mempool_; - // role is server or worker bool is_server_; - - // manage the following map std::mutex map_mu_; - - // (memory, ibv_mr) - std::unordered_map mem_mr_map_; - - // store the static address for keys and lens - std::unordered_map key_addr_map_; - std::unordered_map key_len_map_; + std::unordered_map mem_mr_map_; // (memory, ibv_mr) }; // class Transport From be13a5dc7b7f7109be7950d13bbc63a8bcca6567 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 16:30:22 +0800 Subject: [PATCH 46/79] some cleaning --- src/rdma_van.h | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/rdma_van.h b/src/rdma_van.h index 13bfa24e..7d638c63 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -322,9 +322,6 @@ class RDMAVan : public Van { auto is_push = msg.meta.push; auto key = msg.meta.key; if (!HasRemoteInfo(msg_buf, key, is_push)) { - // LOG(INFO) << "Call SendRendezvousBegin" << ", key=" << key - // << ", " << (is_push?"push":"pull") << " " << (msg.meta.request?"request":"response") - // << ", push_addr.size=" << push_addr_.size(); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } @@ -335,19 +332,15 @@ class RDMAVan : public Van { // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - // LOG(INFO) << "SEND PUSH REQUEST, key=" << key; trans->SendPushRequest(msg, msg_buf, remote_addr_tuple); } else if (msg.meta.push && !msg.meta.request) { // server, push response - // LOG(INFO) << "SEND PUSH RESPONSE, key=" << key; trans->SendPushResponse(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - // LOG(INFO) << "SEND PULL REQUEST, key=" << key; trans->SendPullRequest(msg, msg_buf, remote_addr_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - // LOG(INFO) << "SEND PULL RESPONSE, key=" << key; trans->SendPullResponse(msg, msg_buf, remote_addr_tuple); } else { CHECK(0) << "unexpected message type"; @@ -386,20 +379,16 @@ class RDMAVan : public Van { // valid data message if (msg->meta.push && msg->meta.request) { // push request - // LOG(INFO) << "RECV PUSH REQUEST, key=" << msg->meta.key; total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); StoreWorkerTensorAddress(msg); } else if (!msg->meta.push && msg->meta.request) { // pull request - // LOG(INFO) << "RECV PULL REQUEST, key=" << msg->meta.key; total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); } else if (msg->meta.push && !msg->meta.request) { // push response - // LOG(INFO) << "RECV PUSH RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); } else if (!msg->meta.push && !msg->meta.request) { // pull response - // LOG(INFO) << "RECV PULL RESPONSE, key=" << msg->meta.key; total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); } else { CHECK(0) << "unknown msg type"; @@ -728,14 +717,6 @@ class RDMAVan : public Van { // Recv buffer queue ThreadsafeQueue> recv_buffers_; - // RDMA logging info - bool enable_rdma_log_; - - // a static address for the key - std::unordered_map key_addr_map_; - // a static address for the length - std::unordered_map key_len_map_; - // local IPC related bool disable_ipc_ = false; std::mutex local_mu_; From b0830ca062522639dad1a7d9d638713a247f2fb2 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 17:02:05 +0800 Subject: [PATCH 47/79] can run 2v2 --- src/rdma_transport.h | 7 ++++--- src/rdma_utils.h | 4 +++- src/rdma_van.h | 27 +++++++++++++++++---------- tests/test_kv_app_benchmark.cc | 1 - 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 276f5bab..b9d76a2f 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -420,10 +420,11 @@ class RDMATransport : public Transport { endpoint_->free_write_ctx.WaitAndPop(&reserved); msg_buf->reserved_context = reserved; auto key = msg.meta.key; + auto recver = msg.meta.recver; - auto raddr = std::get<0>(remote_addr); - auto rkey = std::get<1>(remote_addr); - auto idx = std::get<2>(remote_addr); + auto raddr = std::get<0>(remote_addr[recver]); + auto rkey = std::get<1>(remote_addr[recver]); + auto idx = std::get<2>(remote_addr[recver]); RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 5b9e4ee0..fcde3120 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -300,7 +300,9 @@ struct BufferContext { typedef std::unique_ptr> MRPtr; -typedef std::tuple RemoteAddress; // +// recver, +typedef std::unordered_map > + RemoteAddress; struct MessageBuffer { size_t inline_len; diff --git a/src/rdma_van.h b/src/rdma_van.h index 7d638c63..11c58373 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -257,12 +257,18 @@ class RDMAVan : public Van { } } - bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push) { + bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push, int recver) { std::lock_guard lk(addr_mu_); - if ( is_push && (push_addr_.find(key) != push_addr_.end())) return true; - if (!is_push && (pull_addr_.find(key) != pull_addr_.end())) return true; + if (is_push && (push_addr_.find(key) != push_addr_.end()) + && (push_addr_[key].find(recver) != push_addr_[key].end())) { + return true; + } + if (!is_push && (pull_addr_.find(key) != pull_addr_.end()) + && (pull_addr_[key].find(recver) != pull_addr_[key].end())) { + return true; + } // no remote info, store the msg_buf address and push/pull flag for RendezvousReply - msgbuf_cache_.emplace(reinterpret_cast(msg_buf), std::make_pair(key, is_push)); + msgbuf_cache_.emplace(reinterpret_cast(msg_buf), std::make_tuple(key, is_push, recver)); return false; } @@ -272,10 +278,11 @@ class RDMAVan : public Van { std::lock_guard lk(addr_mu_); auto key = std::get<0>(msgbuf_cache_[buf]); auto is_push = std::get<1>(msgbuf_cache_[buf]); + auto recver = std::get<2>(msgbuf_cache_[buf]); if (is_push) { - push_addr_[key] = std::make_tuple(remote_addr, rkey, idx); + push_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); } else { - pull_addr_[key] = std::make_tuple(remote_addr, rkey, idx); + pull_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); } CHECK_NE(msgbuf_cache_.find(buf), msgbuf_cache_.end()); msgbuf_cache_.erase(buf); @@ -321,7 +328,7 @@ class RDMAVan : public Van { trans->PrepareData(msg, msg_buf); auto is_push = msg.meta.push; auto key = msg.meta.key; - if (!HasRemoteInfo(msg_buf, key, is_push)) { + if (!HasRemoteInfo(msg_buf, key, is_push, remote_id)) { trans->SendRendezvousBegin(msg, msg_buf); return total_len; } @@ -730,9 +737,9 @@ class RDMAVan : public Van { // store rendezvous address std::mutex addr_mu_; - std::unordered_map push_addr_; // key, - std::unordered_map pull_addr_; // key, - std::unordered_map > msgbuf_cache_; // msg_buf, + std::unordered_map push_addr_; // , + std::unordered_map pull_addr_; // , + std::unordered_map > msgbuf_cache_; // msg_buf, }; // class RDMAVan }; // namespace ps diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc index 1dbfb2e4..209ed223 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_kv_app_benchmark.cc @@ -37,7 +37,6 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; if (mem_map.find(key) == mem_map.end()) { - PS_VLOG(1) << "receive key-" << key << " from worker-" << req_meta.sender; size_t len = (size_t) req_data.vals.size(); void* ptr_val; From ab198cd108253037320c0a965dbe22d99b4ef3ab Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 19:14:08 +0800 Subject: [PATCH 48/79] finish IPCTransport --- src/rdma_transport.h | 178 +++++++++++++++++++++++-------------------- 1 file changed, 95 insertions(+), 83 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index b9d76a2f..d2112510 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -167,20 +167,11 @@ struct Endpoint { } }; -struct AsyncCopy { - Endpoint* endpoint; - MessageBuffer* msg_buf; - void* dst; - void* src; - int len; - uint64_t meta_len; - bool shutdown; -}; - - class Transport { public: virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; + virtual void PostRDMAWriteWithImm(MessageBuffer *msg_buf, struct ibv_sge *sge, size_t num_sge, + uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; @@ -199,6 +190,8 @@ class Transport { virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; + virtual SArray CreateFunctionalSarray(size_t size, void *value) = 0; + }; // class Transport @@ -208,7 +201,6 @@ class RDMATransport : public Transport { endpoint_ = CHECK_NOTNULL(endpoint); mempool_ = CHECK_NOTNULL(mempool); pagesize_ = sysconf(_SC_PAGESIZE); - PS_VLOG(1) << "System page size is " << pagesize_; auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); @@ -245,7 +237,7 @@ class RDMATransport : public Transport { } } - void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); @@ -254,7 +246,6 @@ class RDMATransport : public Transport { size_t num_sge = 1; - uint64_t data_len = 0; if (msg_buf->mrs.size() == 3) { // push request, split the meta and data into two writes // further, it does not send keys and lens since these meta already carries these info @@ -288,11 +279,14 @@ class RDMATransport : public Transport { sge[num_sge].length = length; sge[num_sge].lkey = pair.first->lkey; ++num_sge; - - data_len += length; } } + PostRDMAWriteWithImm(msg_buf, sge, num_sge, remote_addr, rkey, idx); + } + void PostRDMAWriteWithImm(MessageBuffer *msg_buf, struct ibv_sge *sge, + size_t num_sge, uint64_t remote_addr, + uint32_t rkey, uint32_t idx) { WRContext *write_ctx = msg_buf->reserved_context; CHECK(write_ctx); MessageBuffer **tmp = @@ -421,7 +415,6 @@ class RDMATransport : public Transport { msg_buf->reserved_context = reserved; auto key = msg.meta.key; auto recver = msg.meta.recver; - auto raddr = std::get<0>(remote_addr[recver]); auto rkey = std::get<1>(remote_addr[recver]); auto idx = std::get<2>(remote_addr[recver]); @@ -489,18 +482,12 @@ class RDMATransport : public Transport { auto key = msg->meta.key; auto addr = msg->meta.addr; - SArray keys; - void *p = malloc(sizeof(Key)); - memcpy(p, &msg->meta.key, sizeof(Key)); - keys.reset((char *) p, sizeof(Key), [p](void *) { free(p); }); + SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); SArray vals; vals.reset(reinterpret_cast(addr), msg->meta.val_len, [](void *){}); - SArray lens; - void *q = malloc(sizeof(int)); - memcpy(q, &msg->meta.val_len, sizeof(int)); - lens.reset((char *) q, sizeof(int), [q](void *) { free(q); }); + SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); msg->data.push_back(keys); msg->data.push_back(vals); @@ -512,6 +499,14 @@ class RDMATransport : public Transport { return Recv(msg, buffer_ctx, meta_len); } + SArray CreateFunctionalSarray(size_t size, void *value) { + SArray sarr; + void *p = malloc(size); + memcpy(p, value, size); + sarr.reset((char *) p, size, [p](void *) { free(p); }); + return sarr; + } + private: virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) { uint64_t data_num = buffer_ctx->data_num; @@ -527,18 +522,12 @@ class RDMATransport : public Transport { cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); - SArray keys; - void *p = malloc(sizeof(Key)); - memcpy(p, &msg->meta.key, sizeof(Key)); - keys.reset((char *) p, sizeof(Key), [p](void *) { free(p); }); + SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); SArray vals; vals.reset(cur, len, [](void *) {}); // no need to delete - SArray lens; - void *q = malloc(sizeof(int)); - memcpy(q, &len, sizeof(int)); - lens.reset((char *) q, sizeof(int), [q](void *) { free(q); }); + SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); msg->data.push_back(keys); msg->data.push_back(vals); @@ -569,8 +558,6 @@ class RDMATransport : public Transport { }; // class Transport - - class IPCTransport : public RDMATransport { public: @@ -605,26 +592,74 @@ class IPCTransport : public RDMATransport { } } - void SendPushRequest(Message &msg, MessageBuffer *msg_buf) { - // get from shared memory + void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + // prepare RDMA write sge list + struct ibv_sge sge[1 + msg_buf->mrs.size()]; + sge[0].addr = reinterpret_cast(msg_buf->inline_buf); + sge[0].length = msg_buf->inline_len; + sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); + + size_t num_sge = 1; + if (msg_buf->mrs.size() != 3) { + // not push request + for (auto &pair : msg_buf->mrs) { + size_t length = pair.second; + CHECK(length); + sge[num_sge].addr = + reinterpret_cast(pair.first->addr); + sge[num_sge].length = length; + sge[num_sge].lkey = pair.first->lkey; + ++num_sge; + } + } + PostRDMAWriteWithImm(msg_buf, sge, num_sge, remote_addr, rkey, idx); } - void SendPullResponse(Message &msg, MessageBuffer *msg_buf) { - // std::lock_guard lock(map_mu_); - // auto key = msg.meta.key; - // auto recver = msg.meta.recver; - // auto len = std::get<0>(key_meta_map_[key][recver]); - - // // IPC - // auto addr = (void*) msg_buf->data[1].data(); - // CHECK(addr); - // void* shm_addr = GetSharedMemory(kShmPrefix, key); - // // async copy - // AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - // auto cnt = cpy_counter_.fetch_add(1); - // async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + auto key = msg.meta.key; + auto recver = msg.meta.recver; + auto len = msg.meta.val_len; + auto addr = (void*) msg_buf->data[1].data(); + CHECK(addr); + void* shm_addr = GetSharedMemory(kShmPrefix, key); + // async copy with a simple load-balancing strategy + AsyncCopy m = {msg_buf, shm_addr, addr, len, remote_addr, recver, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); } + int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + CHECK(msg->meta.push && msg->meta.request); + // get data message from local shared memory + auto key = msg->meta.key; + auto len = msg->meta.val_len; + + SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); + + SArray vals; + void* addr = GetSharedMemory(kShmPrefix, key); + vals.reset(reinterpret_cast(addr), len, [](void *){}); + + SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return keys.size() + vals.size() + lens.size(); + } + + private: + struct AsyncCopy { + MessageBuffer* msg_buf; + void* dst; + void* src; + int len; + RemoteAddress remote_addr; + int recver; + bool shutdown; + }; + void AsyncCopyThread(int i) { auto& q = async_copy_queue_[i]; while (true) { @@ -638,41 +673,18 @@ class IPCTransport : public RDMATransport { CHECK(m.src); memcpy(m.dst, m.src, m.len); - WRContext *context = nullptr, *reserved = nullptr; - m.endpoint->free_write_ctx.WaitAndPop(&reserved); - m.endpoint->free_start_ctx.WaitAndPop(&context); - - m.msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = m.meta_len; - req->origin_addr = reinterpret_cast(m.msg_buf); - - auto addr = reinterpret_cast(req); - req->data_num = 0; - - struct ibv_sge sge; - sge.addr = reinterpret_cast(req); - sge.lkey = context->buffer->lkey; - sge.length = sizeof(RendezvousStart); - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = reinterpret_cast(context); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - wr.imm_data = kRendezvousStart; - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) - << strerror(errno); + struct ibv_sge sge[1]; + sge[0].addr = reinterpret_cast(m.msg_buf->inline_buf); + sge[0].length = m.msg_buf->inline_len; + sge[0].lkey = mempool_->LocalKey(m.msg_buf->inline_buf); + + auto raddr = std::get<0>(m.remote_addr[m.recver]); + auto rkey = std::get<1>(m.remote_addr[m.recver]); + auto idx = std::get<2>(m.remote_addr[m.recver]); + PostRDMAWriteWithImm(m.msg_buf, sge, 1, raddr, rkey, idx); } } - private: - void* GetSharedMemory(const std::string& prefix, uint64_t key) { std::lock_guard lock(shm_mu_); auto worker_key = DecodeWorkerKey(key); From 30bf1a3caed177f59306c8bcc79835289660349a Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 20:13:46 +0800 Subject: [PATCH 49/79] tests: add ipc testcase --- src/rdma_transport.h | 21 +- ..._kv_app_benchmark.cc => test_benchmark.cc} | 68 ++++- tests/test_kv_app_ipc_benchmark.cc | 246 ------------------ 3 files changed, 73 insertions(+), 262 deletions(-) rename tests/{test_kv_app_benchmark.cc => test_benchmark.cc} (78%) delete mode 100644 tests/test_kv_app_ipc_benchmark.cc diff --git a/src/rdma_transport.h b/src/rdma_transport.h index d2112510..9480cb3e 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -190,7 +190,7 @@ class Transport { virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; - virtual SArray CreateFunctionalSarray(size_t size, void *value) = 0; + virtual SArray CreateFunctionalSarray(void *value, size_t size) = 0; }; // class Transport @@ -355,7 +355,8 @@ class RDMATransport : public Transport { } // worker only needs a buffer for receving meta - char *buffer = mempool_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); + char *buffer = + mempool_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); CHECK(buffer); buf_ctx->buffer = buffer; WRContext *reply_ctx = nullptr; @@ -482,12 +483,12 @@ class RDMATransport : public Transport { auto key = msg->meta.key; auto addr = msg->meta.addr; - SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); SArray vals; vals.reset(reinterpret_cast(addr), msg->meta.val_len, [](void *){}); - SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); msg->data.push_back(keys); msg->data.push_back(vals); @@ -499,7 +500,7 @@ class RDMATransport : public Transport { return Recv(msg, buffer_ctx, meta_len); } - SArray CreateFunctionalSarray(size_t size, void *value) { + SArray CreateFunctionalSarray(void *value, size_t size) { SArray sarr; void *p = malloc(size); memcpy(p, value, size); @@ -522,12 +523,12 @@ class RDMATransport : public Transport { cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); - SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); SArray vals; vals.reset(cur, len, [](void *) {}); // no need to delete - SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); msg->data.push_back(keys); msg->data.push_back(vals); @@ -564,7 +565,6 @@ class IPCTransport : public RDMATransport { explicit IPCTransport(Endpoint *endpoint, SimpleMempool *mempool) : RDMATransport(endpoint, mempool) { auto val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); ipc_copy_nthreads_ = val ? atoi(val) : 4; - LOG(INFO) << "IPC async copy nthreads set to " << ipc_copy_nthreads_; for (int i = 0; i < ipc_copy_nthreads_; ++i) { auto q = new ThreadsafeQueue; async_copy_queue_.push_back(q); @@ -580,7 +580,6 @@ class IPCTransport : public RDMATransport { auto byteps_local_size = val ? atoi(val) : 1; byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; - LOG(INFO) << "partition bytes set to " << byteps_partition_bytes_ << ", should be identical with byteps core"; }; ~IPCTransport() { @@ -634,13 +633,13 @@ class IPCTransport : public RDMATransport { auto key = msg->meta.key; auto len = msg->meta.val_len; - SArray keys = CreateFunctionalSarray(sizeof(Key), &msg->meta.key); + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); SArray vals; void* addr = GetSharedMemory(kShmPrefix, key); vals.reset(reinterpret_cast(addr), len, [](void *){}); - SArray lens = CreateFunctionalSarray(sizeof(int), &msg->meta.val_len); + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); msg->data.push_back(keys); msg->data.push_back(vals); diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_benchmark.cc similarity index 78% rename from tests/test_kv_app_benchmark.cc rename to tests/test_benchmark.cc index 209ed223..ab8872dd 100644 --- a/tests/test_kv_app_benchmark.cc +++ b/tests/test_benchmark.cc @@ -2,6 +2,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include "ps/ps.h" #define DIVUP(x, y) (((x)+(y)-1)/(y)) @@ -13,9 +19,37 @@ enum MODE { PUSH_THEN_PULL = 0, PUSH_PULL = 1, PUSH_ONLY = 2, - PULL_ONLY = 3 + PULL_ONLY = 3, + IPC = 4 }; -std::unordered_map > mem_map; + +std::unordered_map _key_shm_addr; +std::unordered_map _key_shm_size; +std::unordered_map store_; +std::mutex mu_; + +void* OpenSharedMemory(const std::string& prefix, + uint64_t key, size_t size) { + std::string shm_name(prefix); + shm_name += std::to_string(key); + int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); + + void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(ptr, (void*)-1) << strerror(errno); + + LOG(INFO) << "initialized share memory size=" << size + << " for key=" << key << ", name=" << shm_name; + _key_shm_addr[shm_name] = ptr; + _key_shm_size[shm_name] = size; + return ptr; +} + +uint64_t DecodeKey(ps::Key key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; + return key - kr.begin(); +} void aligned_memory_alloc(void** ptr, size_t size) { size_t page_size = sysconf(_SC_PAGESIZE); @@ -28,6 +62,7 @@ void aligned_memory_alloc(void** ptr, size_t size) { *ptr = p; } +std::unordered_map > mem_map; template void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { uint64_t key = req_data.keys[0]; @@ -92,6 +127,10 @@ void push_pull(KVWorker &kv, LOG(INFO) << "========= PULL_ONLY mode ========="; LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; break; + case IPC: + LOG(INFO) << "========= IPC mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; + break; default: CHECK(0); } @@ -110,6 +149,7 @@ void push_pull(KVWorker &kv, auto vals = server_vals[key]; switch (mode) { + case IPC: case PUSH_PULL: { timestamp_list.push_back(kv.ZPush(keys, vals, lens)); timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); @@ -156,8 +196,14 @@ void RunWorker(int argc, char *argv[]) { int repeat = (argc > 2) ? atoi(argv[2]) : 10; MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; + size_t partition_bytes = Environment::Get()->find("BYTEPS_PARTITION_BYTES") ? + atoi(Environment::Get()->find("BYTEPS_PARTITION_BYTES")) : 4096000; + CHECK_GE(partition_bytes, len) + << "tensor partition is not supported in this benchmark" + << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; + auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); - const int how_many_key_per_server = v ? atoi(v) : 40; + const int how_many_key_per_server = v ? atoi(v) : 20; const int total_key_num = num_servers * how_many_key_per_server; std::vector > server_vals; @@ -165,8 +211,9 @@ void RunWorker(int argc, char *argv[]) { std::vector > server_lens; for (int key = 0; key < total_key_num; key++) { void* ptr; - aligned_memory_alloc(&ptr, len); + // aligned_memory_alloc(&ptr, len); SArray vals; + auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); vals.reset((char*) ptr, len * sizeof(char), [](void *){}); server_vals.push_back(vals); } @@ -245,6 +292,7 @@ void RunWorker(int argc, char *argv[]) { case PUSH_PULL: case PUSH_ONLY: case PULL_ONLY: + case IPC: push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); break; default: @@ -256,7 +304,12 @@ void RunWorker(int argc, char *argv[]) { int main(int argc, char *argv[]) { // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); + setenv("BYTEPS_LOCAL_SIZE", "1", 1); + MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; + if (mode == IPC) { + setenv("BYTEPS_ENABLE_IPC", "1", 1); + LOG(INFO) << "IPC mode on"; + } // start system Start(0); // setup server nodes @@ -265,5 +318,10 @@ int main(int argc, char *argv[]) { RunWorker(argc, argv); // stop system Finalize(0, true); + // release shm + for (auto &it : _key_shm_addr) { + munmap(it.second, _key_shm_size[it.first]); + shm_unlink(it.first.c_str()); + } return 0; } diff --git a/tests/test_kv_app_ipc_benchmark.cc b/tests/test_kv_app_ipc_benchmark.cc deleted file mode 100644 index 2ecf9a95..00000000 --- a/tests/test_kv_app_ipc_benchmark.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ps/ps.h" -#define DATA_TYPE char -using namespace ps; - -enum MODE { - PUSH_THEN_PULL = 0, - PUSH_PULL_MIX_ENDLESS = 1 -}; -std::unordered_map > mem_map; -std::unordered_map _key_shm_addr; -std::unordered_map _key_shm_size; -std::unordered_map store_; -std::mutex mu_; - -void* OpenSharedMemory(const std::string& prefix, - uint64_t key, size_t size) { - std::string shm_name(prefix); - shm_name += std::to_string(key); - int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; - CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); - - void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); - CHECK_NE(ptr, (void*)-1) << strerror(errno); - - LOG(INFO) << "initialized share memory size=" << size - << " for key=" << key << ", name=" << shm_name; - _key_shm_addr[shm_name] = ptr; - _key_shm_size[shm_name] = size; - return ptr; -} - -uint64_t DecodeKey(ps::Key key) { - auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; - return key - kr.begin(); -} - -template -void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { - std::lock_guard lk(mu_); - uint64_t key = DecodeKey(req_data.keys[0]); - if (req_meta.push) { - CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); - - if (mem_map.find(key) == mem_map.end()) { - PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; - size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].lens.push_back(len); - - store_[key] = (char*) malloc(len); - mem_map[key].vals = ps::SArray(store_[key], len, false); - } - - // send push response (empty) - KVPairs res; - server->Response(req_meta, res); - } else { // pull request - auto iter = mem_map.find(key); - CHECK_NE(iter, mem_map.end()); - server->Response(req_meta, iter->second); - } -} - -void StartServer() { - if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); - RegisterExitCallback([server]() { delete server; }); -} - -struct PSKV { - SArray keys; // n keys - SArray lens; // the length of the i-th value -}; -std::unordered_map ps_kv_; - -uint64_t EncodeKey(uint64_t seed) { - return seed << 16; -} - -void RunWorker(int argc, char *argv[]) { - if (!IsWorker()) return; - CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; - KVWorker kv(0, 0); - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); - - const int num_servers = krs.size(); - LOG(INFO) << num_servers << " servers in total"; - CHECK_GT(num_servers, 0); - - // init - auto val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); - unsigned int partition_bytes = val ? atoi(val) : 4096000; - int len = atoi(argv[1]); - CHECK_GE(partition_bytes, len) - << "tensor partition is not supported in this benchmark" - << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; - int repeat = atoi(argv[2]); - MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL_MIX_ENDLESS; - - std::vector > server_vals; - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); - SArray vals(addr, len, false); - server_vals.push_back(vals); - } - - // init broadcast - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - auto vals = server_vals[server]; - PSKV& pskv = ps_kv_[key]; - SArray keys; - ps::Key ps_key = krs[server].begin() + key; - keys.push_back(ps_key); - SArray lens; - lens.push_back(len); - pskv.keys.push_back(ps_key); - pskv.lens.push_back(len); - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - - switch(mode) { - case PUSH_THEN_PULL: { - LOG(INFO) << "PUSH_THEN_PULL mode"; - // push - uint64_t accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - LL << "push " << len * sizeof(DATA_TYPE) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - - // pull - accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPull(keys, &vals, &lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - - LL << "pull " << len * sizeof(DATA_TYPE) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - } - break; - - case PUSH_PULL_MIX_ENDLESS: { - LOG(INFO) << "PUSH_PULL_MIX_ENDLESS mode, should exit by Ctrl+C"; - std::vector timestamp_list; - auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); - } - if (timestamp_list.size()/2/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(DATA_TYPE) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } - } - } break; - default: - CHECK(0) << "unknown mode " << mode; - } -} - -int main(int argc, char *argv[]) { - // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); - setenv("BYTEPS_LOCAL_SIZE", "1", 1); - setenv("BYTEPS_ENABLE_IPC", "1", 0); - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(argc, argv); - // stop system - Finalize(0, true); - // release shm - for (auto &it : _key_shm_addr) { - munmap(it.second, _key_shm_size[it.first]); - shm_unlink(it.first.c_str()); - } - return 0; -} From dacde24a81bc42abf03df39f4519fb91d805e0c6 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 20:56:28 +0800 Subject: [PATCH 50/79] split testcases into two files --- tests/test_benchmark.cc | 70 +---------- tests/test_ipc_benchmark.cc | 232 ++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 64 deletions(-) create mode 100644 tests/test_ipc_benchmark.cc diff --git a/tests/test_benchmark.cc b/tests/test_benchmark.cc index ab8872dd..c08a5013 100644 --- a/tests/test_benchmark.cc +++ b/tests/test_benchmark.cc @@ -2,12 +2,6 @@ #include #include #include -#include -#include -#include -#include -#include -#include #include "ps/ps.h" #define DIVUP(x, y) (((x)+(y)-1)/(y)) @@ -19,37 +13,9 @@ enum MODE { PUSH_THEN_PULL = 0, PUSH_PULL = 1, PUSH_ONLY = 2, - PULL_ONLY = 3, - IPC = 4 + PULL_ONLY = 3 }; - -std::unordered_map _key_shm_addr; -std::unordered_map _key_shm_size; -std::unordered_map store_; -std::mutex mu_; - -void* OpenSharedMemory(const std::string& prefix, - uint64_t key, size_t size) { - std::string shm_name(prefix); - shm_name += std::to_string(key); - int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; - CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); - - void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); - CHECK_NE(ptr, (void*)-1) << strerror(errno); - - LOG(INFO) << "initialized share memory size=" << size - << " for key=" << key << ", name=" << shm_name; - _key_shm_addr[shm_name] = ptr; - _key_shm_size[shm_name] = size; - return ptr; -} - -uint64_t DecodeKey(ps::Key key) { - auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; - return key - kr.begin(); -} +std::unordered_map > mem_map; void aligned_memory_alloc(void** ptr, size_t size) { size_t page_size = sysconf(_SC_PAGESIZE); @@ -62,7 +28,6 @@ void aligned_memory_alloc(void** ptr, size_t size) { *ptr = p; } -std::unordered_map > mem_map; template void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { uint64_t key = req_data.keys[0]; @@ -127,10 +92,6 @@ void push_pull(KVWorker &kv, LOG(INFO) << "========= PULL_ONLY mode ========="; LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; break; - case IPC: - LOG(INFO) << "========= IPC mode ========="; - LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; - break; default: CHECK(0); } @@ -149,7 +110,6 @@ void push_pull(KVWorker &kv, auto vals = server_vals[key]; switch (mode) { - case IPC: case PUSH_PULL: { timestamp_list.push_back(kv.ZPush(keys, vals, lens)); timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); @@ -196,14 +156,8 @@ void RunWorker(int argc, char *argv[]) { int repeat = (argc > 2) ? atoi(argv[2]) : 10; MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; - size_t partition_bytes = Environment::Get()->find("BYTEPS_PARTITION_BYTES") ? - atoi(Environment::Get()->find("BYTEPS_PARTITION_BYTES")) : 4096000; - CHECK_GE(partition_bytes, len) - << "tensor partition is not supported in this benchmark" - << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; - auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); - const int how_many_key_per_server = v ? atoi(v) : 20; + const int how_many_key_per_server = v ? atoi(v) : 40; const int total_key_num = num_servers * how_many_key_per_server; std::vector > server_vals; @@ -211,9 +165,8 @@ void RunWorker(int argc, char *argv[]) { std::vector > server_lens; for (int key = 0; key < total_key_num; key++) { void* ptr; - // aligned_memory_alloc(&ptr, len); + aligned_memory_alloc(&ptr, len); SArray vals; - auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); vals.reset((char*) ptr, len * sizeof(char), [](void *){}); server_vals.push_back(vals); } @@ -292,7 +245,6 @@ void RunWorker(int argc, char *argv[]) { case PUSH_PULL: case PUSH_ONLY: case PULL_ONLY: - case IPC: push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); break; default: @@ -304,12 +256,7 @@ void RunWorker(int argc, char *argv[]) { int main(int argc, char *argv[]) { // disable multi-threaded processing first - setenv("BYTEPS_LOCAL_SIZE", "1", 1); - MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; - if (mode == IPC) { - setenv("BYTEPS_ENABLE_IPC", "1", 1); - LOG(INFO) << "IPC mode on"; - } + setenv("ENABLE_SERVER_MULTIPULL", "0", 1); // start system Start(0); // setup server nodes @@ -318,10 +265,5 @@ int main(int argc, char *argv[]) { RunWorker(argc, argv); // stop system Finalize(0, true); - // release shm - for (auto &it : _key_shm_addr) { - munmap(it.second, _key_shm_size[it.first]); - shm_unlink(it.first.c_str()); - } return 0; -} +} \ No newline at end of file diff --git a/tests/test_ipc_benchmark.cc b/tests/test_ipc_benchmark.cc new file mode 100644 index 00000000..3952f9c7 --- /dev/null +++ b/tests/test_ipc_benchmark.cc @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ps/ps.h" + +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) + +using namespace ps; + +enum MODE { IPC }; + +std::unordered_map _key_shm_addr; +std::unordered_map _key_shm_size; +std::unordered_map store_; +std::mutex mu_; + +void* OpenSharedMemory(const std::string& prefix, uint64_t key, size_t size) { + std::string shm_name(prefix); + shm_name += std::to_string(key); + int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); + + void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(ptr, (void*)-1) << strerror(errno); + + LOG(INFO) << "initialized share memory size=" << size + << " for key=" << key << ", name=" << shm_name; + _key_shm_addr[shm_name] = ptr; + _key_shm_size[shm_name] = size; + return ptr; +} + +uint64_t EncodeKey(uint64_t seed) { + return seed << 16; +} + +uint64_t DecodeKey(ps::Key key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; + return key - kr.begin(); +} + +void aligned_memory_alloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 0, size); + *ptr = p; +} + +std::unordered_map > mem_map; +template +void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { + uint64_t key = DecodeKey(req_data.keys[0]); + if (req_meta.push) { + CHECK(req_data.lens.size()); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) + << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; + + if (mem_map.find(key) == mem_map.end()) { + size_t len = (size_t) req_data.vals.size(); + + void* ptr_val; + aligned_memory_alloc(&ptr_val, len); + mem_map[key].vals.reset((char*)ptr_val, len, [](void *){ }); + + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + mem_map[key].keys.reset((Key*)ptr_key, 1, [](void *){ }); + memcpy(ptr_key, &key, sizeof(Key)); + + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + mem_map[key].lens.reset((int*)ptr_len, 1, [](void *){ }); + memcpy(ptr_len, &len, sizeof(int)); + } + + // send push response (empty) + KVPairs res; + server->Response(req_meta, res); + } + else { + auto iter = mem_map.find(key); + CHECK_NE(iter, mem_map.end()); + server->Response(req_meta, iter->second); + } +} + +void StartServer() { + if (!IsServer()) return; + auto server = new KVServer(0); + server->set_request_handle(EmptyHandler); + RegisterExitCallback([server]() { delete server; }); +} + +void push_pull(KVWorker &kv, + std::vector > &server_keys, + std::vector > &server_vals, + std::vector > &server_lens, + int len, int num_servers, int total_key_num, + int how_many_key_per_server, MODE mode) { + std::vector timestamp_list; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto val = Environment::Get()->find("LOG_DURATION"); + unsigned int log_duration = val ? atoi(val) : 10; + + int cnt = 0; + while (1) { + for (int i = 0; i < total_key_num; i++) { + auto key = EncodeKey(i); + auto keys = server_keys[i]; + auto lens = server_lens[i]; + auto vals = server_vals[i]; + + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } + + for (auto& ts : timestamp_list) { kv.Wait(ts); } + timestamp_list.clear(); + + cnt++; + if (cnt % log_duration != 0) continue; + + end = std::chrono::high_resolution_clock::now(); + LL << "Application goodput: " + << 8.0 * len * sizeof(char) * total_key_num * cnt / (end - start).count() + << " Gbps"; + cnt = 0; + start = std::chrono::high_resolution_clock::now(); + } +} + +void RunWorker(int argc, char *argv[]) { + if (!IsWorker()) return; + KVWorker kv(0, 0); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + + const int num_servers = krs.size(); + LOG(INFO) << num_servers << " servers in total"; + CHECK_GT(num_servers, 0); + + // init + int len = (argc > 1) ? atoi(argv[1]) : 1024000; + MODE mode = IPC; + + size_t partition_bytes = Environment::Get()->find("BYTEPS_PARTITION_BYTES") ? + atoi(Environment::Get()->find("BYTEPS_PARTITION_BYTES")) : 4096000; + CHECK_GE(partition_bytes, len) + << "tensor partition is not supported in this benchmark" + << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; + + auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); + const int how_many_key_per_server = v ? atoi(v) : 20; + const int total_key_num = num_servers * how_many_key_per_server; + + std::vector > server_vals; + std::vector > server_keys; + std::vector > server_lens; + for (int i = 0; i < total_key_num; i++) { + auto key = EncodeKey(i); + void* ptr; + SArray vals; + auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); + vals.reset((char*) addr, len, [](void *){}); + server_vals.push_back(vals); + } + + // init push, do not count this into time cost + for (int i = 0; i < total_key_num; i++) { + int server = i % num_servers; + auto vals = server_vals[i]; + + auto key = EncodeKey(i); + PS_VLOG(1) << "key=" << key + << " (i=" << i << ")" + << " assigned to server " << server; + + // page aligned keys + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + SArray keys; + keys.reset((Key*) ptr_key, 1, [](void *){}); + ps::Key ps_key = krs[server].begin() + key; + memcpy(ptr_key, &ps_key, sizeof(Key)); + server_keys.push_back(keys); + + // page aligned vals + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + SArray lens; + lens.reset((int*) ptr_len, 1, [](void *){}); + memcpy(ptr_len, &len, sizeof(len)); + server_lens.push_back(lens); + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + + push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); +} + +int main(int argc, char *argv[]) { + // disable multi-threaded processing first + setenv("BYTEPS_LOCAL_SIZE", "1", 1); + setenv("BYTEPS_ENABLE_IPC", "1", 1); + // start system + Start(0); + // setup server nodes + StartServer(); + // run worker nodes + RunWorker(argc, argv); + // stop system + Finalize(0, true); + // release shm + for (auto &it : _key_shm_addr) { + munmap(it.second, _key_shm_size[it.first]); + shm_unlink(it.first.c_str()); + } + return 0; +} From 51a9ac1fc542b7c2b9262235942400f8a8e4ef92 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 21:11:22 +0800 Subject: [PATCH 51/79] can run ipc 2v2 --- src/rdma_transport.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 9480cb3e..8dd5d5d7 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -238,6 +238,9 @@ class RDMATransport : public Transport { } virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + WRContext *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + msg_buf->reserved_context = reserved; // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); @@ -411,15 +414,10 @@ class RDMATransport : public Transport { } void Send(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - WRContext *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - msg_buf->reserved_context = reserved; - auto key = msg.meta.key; auto recver = msg.meta.recver; auto raddr = std::get<0>(remote_addr[recver]); auto rkey = std::get<1>(remote_addr[recver]); auto idx = std::get<2>(remote_addr[recver]); - RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } @@ -592,6 +590,9 @@ class IPCTransport : public RDMATransport { } void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + WRContext *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + msg_buf->reserved_context = reserved; // prepare RDMA write sge list struct ibv_sge sge[1 + msg_buf->mrs.size()]; sge[0].addr = reinterpret_cast(msg_buf->inline_buf); @@ -680,6 +681,10 @@ class IPCTransport : public RDMATransport { auto raddr = std::get<0>(m.remote_addr[m.recver]); auto rkey = std::get<1>(m.remote_addr[m.recver]); auto idx = std::get<2>(m.remote_addr[m.recver]); + + WRContext *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + m.msg_buf->reserved_context = reserved; PostRDMAWriteWithImm(m.msg_buf, sge, 1, raddr, rkey, idx); } } From a5bdc8e647d48dbd8bac4779e7e18f638434ed67 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 19 Dec 2019 21:34:46 +0800 Subject: [PATCH 52/79] clean some compile warnings --- src/rdma_transport.h | 4 ---- tests/test_benchmark.cc | 2 -- tests/test_ipc_benchmark.cc | 4 +--- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 8dd5d5d7..37e7fc01 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -435,9 +435,6 @@ class RDMATransport : public Transport { virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { std::lock_guard lock(map_mu_); - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = msg.meta.val_len; auto raddr = msg.meta.addr; auto rkey = msg.meta.option; auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); @@ -478,7 +475,6 @@ class RDMATransport : public Transport { virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { std::lock_guard lock(map_mu_); - auto key = msg->meta.key; auto addr = msg->meta.addr; SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); diff --git a/tests/test_benchmark.cc b/tests/test_benchmark.cc index c08a5013..d8a9c4f3 100644 --- a/tests/test_benchmark.cc +++ b/tests/test_benchmark.cc @@ -206,7 +206,6 @@ void RunWorker(int argc, char *argv[]) { for (int i = 0; i < repeat; ++i) { auto start = std::chrono::high_resolution_clock::now(); for (int server = 0; server < num_servers; server++) { - int key = server; auto keys = server_keys[server]; auto lens = server_lens[server]; auto vals = server_vals[server]; @@ -226,7 +225,6 @@ void RunWorker(int argc, char *argv[]) { for (int i = 0; i < repeat; ++i) { auto start = std::chrono::high_resolution_clock::now(); for (int server = 0; server < num_servers; server++) { - int key = server; auto keys = server_keys[server]; auto lens = server_lens[server]; auto vals = server_vals[server]; diff --git a/tests/test_ipc_benchmark.cc b/tests/test_ipc_benchmark.cc index 3952f9c7..3f9330b6 100644 --- a/tests/test_ipc_benchmark.cc +++ b/tests/test_ipc_benchmark.cc @@ -119,7 +119,6 @@ void push_pull(KVWorker &kv, int cnt = 0; while (1) { for (int i = 0; i < total_key_num; i++) { - auto key = EncodeKey(i); auto keys = server_keys[i]; auto lens = server_lens[i]; auto vals = server_vals[i]; @@ -158,7 +157,7 @@ void RunWorker(int argc, char *argv[]) { size_t partition_bytes = Environment::Get()->find("BYTEPS_PARTITION_BYTES") ? atoi(Environment::Get()->find("BYTEPS_PARTITION_BYTES")) : 4096000; - CHECK_GE(partition_bytes, len) + CHECK_GE(partition_bytes, (size_t)len) << "tensor partition is not supported in this benchmark" << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; @@ -171,7 +170,6 @@ void RunWorker(int argc, char *argv[]) { std::vector > server_lens; for (int i = 0; i < total_key_num; i++) { auto key = EncodeKey(i); - void* ptr; SArray vals; auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); vals.reset((char*) addr, len, [](void *){}); From b6109e53547f799f119ce7c49ea92404c97d14d0 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 20 Dec 2019 10:43:58 +0800 Subject: [PATCH 53/79] keep the same partition size with worker --- src/rdma_transport.h | 5 ++--- src/rdma_utils.h | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 37e7fc01..bc3e8b11 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -570,10 +570,9 @@ class IPCTransport : public RDMATransport { val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); byteps_partition_bytes_ = val ? atoi(val) : 4096000; - val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); + val = CHECK_NOTNULL(Environment::Get()->find("BYTEPS_LOCAL_SIZE")); auto byteps_local_size = val ? atoi(val) : 1; - byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); - CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; + byteps_partition_bytes_ = RoundUp(byteps_partition_bytes_, byteps_local_size * sysconf(_SC_PAGESIZE)); }; ~IPCTransport() { diff --git a/src/rdma_utils.h b/src/rdma_utils.h index fcde3120..684ff05d 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -402,6 +402,8 @@ uint64_t DecodeWorkerKey(uint64_t key) { } int AlignTo(int input, int alignment) { return input / alignment * alignment; } +int DivUp(int x, int y) { return (x + y - 1) / y; } +int RoundUp(int x, int y) { return DivUp(x, y) * y; } }; // namespace ps From c01535f4875934c0a8cbd0c4d670865166d3e4ed Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 20 Dec 2019 12:08:10 +0800 Subject: [PATCH 54/79] default byteps_local_size=1 --- src/rdma_transport.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index bc3e8b11..0c9e4daa 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -570,7 +570,7 @@ class IPCTransport : public RDMATransport { val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); byteps_partition_bytes_ = val ? atoi(val) : 4096000; - val = CHECK_NOTNULL(Environment::Get()->find("BYTEPS_LOCAL_SIZE")); + val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); auto byteps_local_size = val ? atoi(val) : 1; byteps_partition_bytes_ = RoundUp(byteps_partition_bytes_, byteps_local_size * sysconf(_SC_PAGESIZE)); }; From 644f103d4d4d77d3cbfe0ee64c883f20ecb85f0a Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 20 Dec 2019 15:44:32 +0800 Subject: [PATCH 55/79] kMaxMetaBound: 4MB->4KB --- src/rdma_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 684ff05d..04d39c58 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -72,7 +72,7 @@ static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; // allocate 4KB more for meta with potentially variable length -static const int kMaxMetaBound = 4096000; +static const int kMaxMetaBound = 4096; // should have the same prefix with BytePS shared memory static const std::string kShmPrefix("BytePS_ShM_"); From 1efd8ecbf7113585314cd9ad53cc29d6977c3be2 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 20 Dec 2019 16:29:17 +0800 Subject: [PATCH 56/79] kMaxMetaBound: 4KB->2KB --- src/rdma_utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 04d39c58..be41b693 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -71,8 +71,8 @@ static const int kMaxDataFields = 4; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; -// allocate 4KB more for meta with potentially variable length -static const int kMaxMetaBound = 4096; +// allocate 2KB more for meta with potentially variable length +static const int kMaxMetaBound = 2048; // should have the same prefix with BytePS shared memory static const std::string kShmPrefix("BytePS_ShM_"); From 1752620af8e49524433e2002e7581f0a69386b38 Mon Sep 17 00:00:00 2001 From: Yibo Zhu Date: Fri, 20 Dec 2019 09:36:37 -0800 Subject: [PATCH 57/79] Allocate a whole page for meta --- src/rdma_utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index be41b693..d456d1bb 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -71,8 +71,8 @@ static const int kMaxDataFields = 4; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; -// allocate 2KB more for meta with potentially variable length -static const int kMaxMetaBound = 2048; +// allocate a whole page for meta with potentially variable length +static const int kMaxMetaBound = sysconf(_SC_PAGESIZE); // should have the same prefix with BytePS shared memory static const std::string kShmPrefix("BytePS_ShM_"); From e4296da4fffb21aacd02bfb6080d039d20a47401 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sat, 21 Dec 2019 13:48:10 +0800 Subject: [PATCH 58/79] pull response: do not add msg_buf->mrs --- src/rdma_transport.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 0c9e4daa..32a0024c 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -227,12 +227,17 @@ class RDMATransport : public Transport { } void PrepareData(Message &msg, MessageBuffer *msg_buf) { + if (!msg.meta.push && !msg.meta.request) return; // pull response does not need data + for (auto &sa : msg_buf->data) { if (sa.size() == 0) continue; + std::lock_guard lock(map_mu_); auto it = mem_mr_map_.find(sa.data()); MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } } @@ -460,8 +465,6 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; // after write keys/vals/lens (no imm), write the meta (with imm) - // TODO: consolidate this into one RDMA_WRITE_WITH_IMM - msg_buf->mrs.clear(); Send(msg, msg_buf, remote_addr); } From 397fee7501294cfa5c61ddae3f6da61f1650b0f8 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sat, 21 Dec 2019 23:55:51 +0800 Subject: [PATCH 59/79] [debug] ipc: sync copy --- src/rdma_transport.h | 53 ++++++++++++++++---------------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 32a0024c..3e5e120c 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -255,6 +255,7 @@ class RDMATransport : public Transport { size_t num_sge = 1; if (msg_buf->mrs.size() == 3) { + LOG(INFO) << "#########"; // push request, split the meta and data into two writes // further, it does not send keys and lens since these meta already carries these info struct ibv_sge my_sge; @@ -587,43 +588,27 @@ class IPCTransport : public RDMATransport { } } - void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - WRContext *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - msg_buf->reserved_context = reserved; - // prepare RDMA write sge list - struct ibv_sge sge[1 + msg_buf->mrs.size()]; - sge[0].addr = reinterpret_cast(msg_buf->inline_buf); - sge[0].length = msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - - size_t num_sge = 1; - if (msg_buf->mrs.size() != 3) { - // not push request - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; - ++num_sge; - } - } - PostRDMAWriteWithImm(msg_buf, sge, num_sge, remote_addr, rkey, idx); + void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + msg_buf->mrs.clear(); + CHECK_EQ(msg_buf->mrs.size(), 0); + Send(msg, msg_buf, remote_addr); } void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = msg.meta.val_len; - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - // async copy with a simple load-balancing strategy - AsyncCopy m = {msg_buf, shm_addr, addr, len, remote_addr, recver, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + CHECK_EQ(msg_buf->mrs.size(), 0); + std::lock_guard lock(map_mu_); + auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); + void* shm_addr = GetSharedMemory(kShmPrefix, msg.meta.key); + CHECK(shm_addr); + + /* async copy with a simple load-balancing strategy */ + // AsyncCopy m = {msg_buf, shm_addr, addr, len, remote_addr, recver, false}; + // auto cnt = cpy_counter_.fetch_add(1); + // async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + + // synchronous copy + memcpy(shm_addr, addr, msg.meta.val_len); + Send(msg, msg_buf, remote_addr); } int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { From ef0bc3bca17aafb6367d2c90efead1038d04270f Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 22 Dec 2019 23:31:43 +0800 Subject: [PATCH 60/79] fix transport type for two directions --- src/rdma_transport.h | 1 - src/rdma_van.h | 28 ++++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 3e5e120c..232a7c40 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -255,7 +255,6 @@ class RDMATransport : public Transport { size_t num_sge = 1; if (msg_buf->mrs.size() == 3) { - LOG(INFO) << "#########"; // push request, split the meta and data into two writes // further, it does not send keys and lens since these meta already carries these info struct ibv_sge my_sge; diff --git a/src/rdma_van.h b/src/rdma_van.h index 11c58373..4f009e8a 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -132,14 +132,6 @@ class RDMAVan : public Van { return; } - if (disable_ipc_) { - is_local_[node.id] = false; - } else { - std::lock_guard lock(local_mu_); - is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; - LOG(INFO) << "IPC connected to " << node.id; - } - if (node.id != Node::kEmpty) { auto it = endpoints_.find(node.id); @@ -213,6 +205,16 @@ class RDMAVan : public Van { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } + local_mu_.lock(); + if (disable_ipc_) { + is_local_[node.id] = false; + } else { + is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; + } + LOG(INFO) << "Connect to Node " << node.id + << " with Transport=" << (is_local_[node.id]?"IPC" : "RDMA"); + local_mu_.unlock(); + std::shared_ptr t = is_local_[node.id] ? std::make_shared(endpoint, send_mempool_.get()) : std::make_shared(endpoint, send_mempool_.get()); @@ -608,6 +610,16 @@ class RDMAVan : public Van { endpoint->Init(cq_, pd_); + local_mu_.lock(); + if (disable_ipc_) { + is_local_[remote_ctx->node] = false; + } else { + is_local_[remote_ctx->node] = (std::string(remote_ctx->hostname) == my_node_.hostname) ? true : false; + } + LOG(INFO) << "OnConnect to Node " << remote_ctx->node + << " with Transport=" << (is_local_[remote_ctx->node]?"IPC" : "RDMA"); + local_mu_.unlock(); + std::shared_ptr t = is_local_[remote_ctx->node] ? std::make_shared(endpoint, recv_mempool_.get()) : std::make_shared(endpoint, recv_mempool_.get()); From fe42f941c91db4f710b63138a905b302be7b1eef Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Dec 2019 11:31:27 +0800 Subject: [PATCH 61/79] fix async copy --- src/rdma_transport.h | 51 +++++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 232a7c40..bf791009 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -228,16 +228,12 @@ class RDMATransport : public Transport { void PrepareData(Message &msg, MessageBuffer *msg_buf) { if (!msg.meta.push && !msg.meta.request) return; // pull response does not need data - for (auto &sa : msg_buf->data) { if (sa.size() == 0) continue; - std::lock_guard lock(map_mu_); auto it = mem_mr_map_.find(sa.data()); MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); } } @@ -576,6 +572,10 @@ class IPCTransport : public RDMATransport { val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); auto byteps_local_size = val ? atoi(val) : 1; byteps_partition_bytes_ = RoundUp(byteps_partition_bytes_, byteps_local_size * sysconf(_SC_PAGESIZE)); + + val = Environment::Get()->find("BYTEPS_IPC_ENABLE_ASYNC_COPY"); + enable_async_copy_ = val ? atoi(val) : 1; // default enabled + if (!enable_async_copy_) LOG(INFO) << "Async copy has been disabled, this could affect the performance"; }; ~IPCTransport() { @@ -588,26 +588,26 @@ class IPCTransport : public RDMATransport { } void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - msg_buf->mrs.clear(); + msg_buf->mrs.clear(); // avoid rdma-write in RDMAWriteWithImm() CHECK_EQ(msg_buf->mrs.size(), 0); Send(msg, msg_buf, remote_addr); } void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { CHECK_EQ(msg_buf->mrs.size(), 0); - std::lock_guard lock(map_mu_); auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); - void* shm_addr = GetSharedMemory(kShmPrefix, msg.meta.key); - CHECK(shm_addr); + void* shm_addr = CHECK_NOTNULL(GetSharedMemory(kShmPrefix, msg.meta.key)); - /* async copy with a simple load-balancing strategy */ - // AsyncCopy m = {msg_buf, shm_addr, addr, len, remote_addr, recver, false}; - // auto cnt = cpy_counter_.fetch_add(1); - // async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); - - // synchronous copy - memcpy(shm_addr, addr, msg.meta.val_len); - Send(msg, msg_buf, remote_addr); + if (enable_async_copy_) { + // async copy with a simple load-balancing strategy + AsyncCopy m = {msg_buf, shm_addr, addr, msg.meta.val_len, remote_addr, msg.meta.recver, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + } else { + // synchronous copy + memcpy(shm_addr, addr, msg.meta.val_len); + Send(msg, msg_buf, remote_addr); + } } int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { @@ -655,19 +655,11 @@ class IPCTransport : public RDMATransport { CHECK(m.src); memcpy(m.dst, m.src, m.len); - struct ibv_sge sge[1]; - sge[0].addr = reinterpret_cast(m.msg_buf->inline_buf); - sge[0].length = m.msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(m.msg_buf->inline_buf); - auto raddr = std::get<0>(m.remote_addr[m.recver]); auto rkey = std::get<1>(m.remote_addr[m.recver]); auto idx = std::get<2>(m.remote_addr[m.recver]); - - WRContext *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - m.msg_buf->reserved_context = reserved; - PostRDMAWriteWithImm(m.msg_buf, sge, 1, raddr, rkey, idx); + + RDMAWriteWithImm(m.msg_buf, raddr, rkey, idx); } } @@ -693,9 +685,8 @@ class IPCTransport : public RDMATransport { CHECK_NE(base_ptr, (void*) -1) << strerror(errno); key_shm_addr_[base_key] = base_ptr; - LOG(INFO) << "open Shared Memory: " << shm_name - << ", offset=" << offset - << ", (in bytes) size=" << total_shm_size; + PS_VLOG(1) << "open Shared Memory: " << shm_name << ", offset=" + << offset << ", (in bytes) size=" << total_shm_size; return key_shm_addr_[base_key] + offset; } @@ -709,6 +700,8 @@ class IPCTransport : public RDMATransport { std::mutex shm_mu_; std::unordered_map key_shm_addr_; + bool enable_async_copy_; + }; // class IPCTransport From e30a44e36cb9e68cdfaf3a63faac73db8b6e4342 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Dec 2019 12:05:46 +0800 Subject: [PATCH 62/79] minimize map_mu_ lock coverage --- src/rdma_transport.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index bf791009..9770345b 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -435,11 +435,13 @@ class RDMATransport : public Transport { } virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - std::lock_guard lock(map_mu_); auto raddr = msg.meta.addr; auto rkey = msg.meta.option; + + map_mu_.lock(); auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); CHECK_NE(temp_mr, mem_mr_map_.end()); + map_mu_.unlock(); struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->data[1].data()); @@ -473,7 +475,6 @@ class RDMATransport : public Transport { } virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { - std::lock_guard lock(map_mu_); auto addr = msg->meta.addr; SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); From e1286fdd75da9fd687da5a8faec0c5599dd9bcf3 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 23 Dec 2019 18:30:45 +0800 Subject: [PATCH 63/79] non-functional improvement --- src/rdma_transport.h | 66 +++++++++++++++++++------------------------- src/rdma_utils.h | 6 ++-- src/rdma_van.h | 31 +++++++++++---------- 3 files changed, 49 insertions(+), 54 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 9770345b..fa19f292 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -170,8 +170,7 @@ struct Endpoint { class Transport { public: virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual void PostRDMAWriteWithImm(MessageBuffer *msg_buf, struct ibv_sge *sge, size_t num_sge, - uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; + virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; @@ -182,11 +181,11 @@ class Transport { virtual void RegisterMemory(Message &msg) = 0; virtual void PrepareData(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; - virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; - virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) = 0; + virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; @@ -285,12 +284,7 @@ class RDMATransport : public Transport { ++num_sge; } } - PostRDMAWriteWithImm(msg_buf, sge, num_sge, remote_addr, rkey, idx); - } - - void PostRDMAWriteWithImm(MessageBuffer *msg_buf, struct ibv_sge *sge, - size_t num_sge, uint64_t remote_addr, - uint32_t rkey, uint32_t idx) { + WRContext *write_ctx = msg_buf->reserved_context; CHECK(write_ctx); MessageBuffer **tmp = @@ -414,27 +408,26 @@ class RDMATransport : public Transport { } } - void Send(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - auto recver = msg.meta.recver; - auto raddr = std::get<0>(remote_addr[recver]); - auto rkey = std::get<1>(remote_addr[recver]); - auto idx = std::get<2>(remote_addr[recver]); + void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + auto raddr = std::get<0>(remote_tuple); + auto rkey = std::get<1>(remote_tuple); + auto idx = std::get<2>(remote_tuple); RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } - void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - Send(msg, msg_buf, remote_addr); + void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + Send(msg, msg_buf, remote_tuple); } - void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - Send(msg, msg_buf, remote_addr); + void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + Send(msg, msg_buf, remote_tuple); } - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { - Send(msg, msg_buf, remote_addr); + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + Send(msg, msg_buf, remote_tuple); } - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { auto raddr = msg.meta.addr; auto rkey = msg.meta.option; @@ -463,7 +456,7 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; // after write keys/vals/lens (no imm), write the meta (with imm) - Send(msg, msg_buf, remote_addr); + Send(msg, msg_buf, remote_tuple); } virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { @@ -571,7 +564,7 @@ class IPCTransport : public RDMATransport { byteps_partition_bytes_ = val ? atoi(val) : 4096000; val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); - auto byteps_local_size = val ? atoi(val) : 1; + auto byteps_local_size = val ? atoi(val) : 8; byteps_partition_bytes_ = RoundUp(byteps_partition_bytes_, byteps_local_size * sysconf(_SC_PAGESIZE)); val = Environment::Get()->find("BYTEPS_IPC_ENABLE_ASYNC_COPY"); @@ -588,26 +581,26 @@ class IPCTransport : public RDMATransport { } } - void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { msg_buf->mrs.clear(); // avoid rdma-write in RDMAWriteWithImm() CHECK_EQ(msg_buf->mrs.size(), 0); - Send(msg, msg_buf, remote_addr); + Send(msg, msg_buf, remote_tuple); } - void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteAddress remote_addr) { + void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { CHECK_EQ(msg_buf->mrs.size(), 0); auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); void* shm_addr = CHECK_NOTNULL(GetSharedMemory(kShmPrefix, msg.meta.key)); if (enable_async_copy_) { // async copy with a simple load-balancing strategy - AsyncCopy m = {msg_buf, shm_addr, addr, msg.meta.val_len, remote_addr, msg.meta.recver, false}; + AsyncCopy m = {msg_buf, remote_tuple, shm_addr, addr, msg.meta.val_len, false}; auto cnt = cpy_counter_.fetch_add(1); async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); } else { // synchronous copy memcpy(shm_addr, addr, msg.meta.val_len); - Send(msg, msg_buf, remote_addr); + Send(msg, msg_buf, remote_tuple); } } @@ -635,11 +628,10 @@ class IPCTransport : public RDMATransport { private: struct AsyncCopy { MessageBuffer* msg_buf; + RemoteTuple remote_tuple; void* dst; void* src; int len; - RemoteAddress remote_addr; - int recver; bool shutdown; }; @@ -656,9 +648,9 @@ class IPCTransport : public RDMATransport { CHECK(m.src); memcpy(m.dst, m.src, m.len); - auto raddr = std::get<0>(m.remote_addr[m.recver]); - auto rkey = std::get<1>(m.remote_addr[m.recver]); - auto idx = std::get<2>(m.remote_addr[m.recver]); + auto raddr = std::get<0>(m.remote_tuple); + auto rkey = std::get<1>(m.remote_tuple); + auto idx = std::get<2>(m.remote_tuple); RDMAWriteWithImm(m.msg_buf, raddr, rkey, idx); } diff --git a/src/rdma_utils.h b/src/rdma_utils.h index d456d1bb..afb27231 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -300,9 +300,11 @@ struct BufferContext { typedef std::unique_ptr> MRPtr; +// +typedef std::tuple RemoteTuple; + // recver, -typedef std::unordered_map > - RemoteAddress; +typedef std::unordered_map RemoteAddress; struct MessageBuffer { size_t inline_len; diff --git a/src/rdma_van.h b/src/rdma_van.h index 4f009e8a..2c6d56d9 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -270,29 +270,30 @@ class RDMAVan : public Van { return true; } // no remote info, store the msg_buf address and push/pull flag for RendezvousReply - msgbuf_cache_.emplace(reinterpret_cast(msg_buf), std::make_tuple(key, is_push, recver)); + auto buf_addr = reinterpret_cast(msg_buf); + CHECK_EQ(msgbuf_cache_.find(buf_addr), msgbuf_cache_.end()); + msgbuf_cache_.emplace(buf_addr, std::make_tuple(key, is_push, recver)); return false; } void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - auto buf = reinterpret_cast(msg_buf); - if (msgbuf_cache_.find(buf) == msgbuf_cache_.end()) return; // control message std::lock_guard lk(addr_mu_); - auto key = std::get<0>(msgbuf_cache_[buf]); - auto is_push = std::get<1>(msgbuf_cache_[buf]); - auto recver = std::get<2>(msgbuf_cache_[buf]); + auto buf_addr = reinterpret_cast(msg_buf); + if(msgbuf_cache_.find(buf_addr) == msgbuf_cache_.end()) { return; } // control message + auto key = std::get<0>(msgbuf_cache_[buf_addr]); + auto is_push = std::get<1>(msgbuf_cache_[buf_addr]); + auto recver = std::get<2>(msgbuf_cache_[buf_addr]); if (is_push) { push_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); } else { pull_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); } - CHECK_NE(msgbuf_cache_.find(buf), msgbuf_cache_.end()); - msgbuf_cache_.erase(buf); + msgbuf_cache_.erase(buf_addr); } - RemoteAddress GetRemoteInfo(uint64_t key, bool is_push) { + RemoteTuple GetRemoteInfo(uint64_t key, bool is_push, int recver) { std::lock_guard lk(addr_mu_); - return (is_push ? push_addr_[key] : pull_addr_[key]); + return (is_push ? push_addr_[key][recver] : pull_addr_[key][recver]); } int SendMsg(Message &msg) override { @@ -336,21 +337,21 @@ class RDMAVan : public Van { } } - auto remote_addr_tuple = GetRemoteInfo(msg.meta.key, msg.meta.push); + auto remote_tuple = GetRemoteInfo(msg.meta.key, msg.meta.push, remote_id); // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - trans->SendPushRequest(msg, msg_buf, remote_addr_tuple); + trans->SendPushRequest(msg, msg_buf, remote_tuple); } else if (msg.meta.push && !msg.meta.request) { // server, push response - trans->SendPushResponse(msg, msg_buf, remote_addr_tuple); + trans->SendPushResponse(msg, msg_buf, remote_tuple); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - trans->SendPullRequest(msg, msg_buf, remote_addr_tuple); + trans->SendPullRequest(msg, msg_buf, remote_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - trans->SendPullResponse(msg, msg_buf, remote_addr_tuple); + trans->SendPullResponse(msg, msg_buf, remote_tuple); } else { CHECK(0) << "unexpected message type"; } From f2dcad19ef69f69d3aad04d363c0a738f5897a63 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 24 Dec 2019 16:02:46 +0800 Subject: [PATCH 64/79] enforce all num_sge to 1 --- src/rdma_transport.h | 106 ++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 62 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index fa19f292..85fdb659 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -171,7 +171,6 @@ class Transport { public: virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; @@ -226,7 +225,7 @@ class RDMATransport : public Transport { } void PrepareData(Message &msg, MessageBuffer *msg_buf) { - if (!msg.meta.push && !msg.meta.request) return; // pull response does not need data + if (!(msg.meta.push && msg.meta.request)) return; // only push request for (auto &sa : msg_buf->data) { if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); @@ -247,8 +246,6 @@ class RDMATransport : public Transport { sge[0].length = msg_buf->inline_len; sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - size_t num_sge = 1; - if (msg_buf->mrs.size() == 3) { // push request, split the meta and data into two writes // further, it does not send keys and lens since these meta already carries these info @@ -274,15 +271,7 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } else { - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; - ++num_sge; - } + CHECK_EQ(msg_buf->mrs.size(),0); } WRContext *write_ctx = msg_buf->reserved_context; @@ -299,7 +288,7 @@ class RDMATransport : public Transport { wr.imm_data = idx; wr.send_flags = IBV_SEND_SIGNALED; wr.sg_list = sge; - wr.num_sge = num_sge; + wr.num_sge = 1; wr.wr.rdma.remote_addr = remote_addr; wr.wr.rdma.rkey = rkey; @@ -412,6 +401,7 @@ class RDMATransport : public Transport { auto raddr = std::get<0>(remote_tuple); auto rkey = std::get<1>(remote_tuple); auto idx = std::get<2>(remote_tuple); + RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } @@ -460,19 +450,55 @@ class RDMATransport : public Transport { } virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { - return Recv(msg, buffer_ctx, meta_len); + CHECK_EQ(buffer_ctx->data_num, 0); + return 0; } virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { - return Recv(msg, buffer_ctx, meta_len); + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; // add an empty sarray to pass kvapp check + + msg->data.push_back(keys); + msg->data.push_back(vals); + + return keys.size() + vals.size(); + } + + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + CHECK(msg->meta.push && msg->meta.request); + CHECK_EQ(buffer_ctx->data_num, 3); + uint32_t len = buffer_ctx->data_len[1]; + char* cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); + + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; + vals.reset(cur, len, [](void *) {}); // no need to delete + + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return keys.size() + vals.size() + lens.size(); } virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { - auto addr = msg->meta.addr; + LOG(INFO) << "RecvPullResponse: key=" << msg->meta.key + << ", " << (msg->meta.push ? "push" : "pull") + << " " << (msg->meta.request ? "request" : "response") + << ", sender=" << msg->meta.sender + << ", meta_len=" << meta_len + << ", buffer_ctx=" << reinterpret_cast(buffer_ctx) + << ", tensor_addr=" << msg->meta.addr + << ", val_len=" << msg->meta.val_len; SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); SArray vals; + auto addr = msg->meta.addr; vals.reset(reinterpret_cast(addr), msg->meta.val_len, [](void *){}); SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); @@ -480,11 +506,8 @@ class RDMATransport : public Transport { msg->data.push_back(keys); msg->data.push_back(vals); msg->data.push_back(lens); - return keys.size() + vals.size() + lens.size(); - } - virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { - return Recv(msg, buffer_ctx, meta_len); + return keys.size() + vals.size() + lens.size(); } SArray CreateFunctionalSarray(void *value, size_t size) { @@ -495,47 +518,6 @@ class RDMATransport : public Transport { return sarr; } - private: - virtual int Recv(Message *msg, BufferContext *buffer_ctx, int meta_len) { - uint64_t data_num = buffer_ctx->data_num; - if (data_num == 0) { - return 0; - } - - char *cur = buffer_ctx->buffer + meta_len; // offset - - if (msg->meta.push && msg->meta.request) { // push request - CHECK_EQ(data_num, 3); - uint32_t len = buffer_ctx->data_len[1]; - - cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); - - SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); - - SArray vals; - vals.reset(cur, len, [](void *) {}); // no need to delete - - SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); - - msg->data.push_back(keys); - msg->data.push_back(vals); - msg->data.push_back(lens); - - return sizeof(Key) + len + sizeof(int); - } - - int total_data_len = 0; - for (size_t i = 0; i < data_num; i++) { - uint32_t len = buffer_ctx->data_len[i]; - SArray data; - data.reset(cur, len, [](void *) {}); // no need for delete - msg->data.push_back(data); - cur += len; - total_data_len += len; - } - return total_data_len; - } - protected: size_t pagesize_ = 4096; Endpoint *endpoint_; From 73fd1cabadc969f1d0943929db111b18c65b18bc Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 24 Dec 2019 16:07:02 +0800 Subject: [PATCH 65/79] remove runtime log --- src/rdma_transport.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 85fdb659..e4ea2f43 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -486,15 +486,6 @@ class RDMATransport : public Transport { } virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { - LOG(INFO) << "RecvPullResponse: key=" << msg->meta.key - << ", " << (msg->meta.push ? "push" : "pull") - << " " << (msg->meta.request ? "request" : "response") - << ", sender=" << msg->meta.sender - << ", meta_len=" << meta_len - << ", buffer_ctx=" << reinterpret_cast(buffer_ctx) - << ", tensor_addr=" << msg->meta.addr - << ", val_len=" << msg->meta.val_len; - SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); SArray vals; From e7ad975ebed67fc3d76246e1b449c59b211071a1 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Thu, 26 Dec 2019 17:50:53 +0800 Subject: [PATCH 66/79] fix mr_list range --- src/rdma_utils.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index afb27231..896b9e8d 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -110,6 +110,8 @@ class SimpleMempool { // set init mempool size auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); size = byteps_rdma_mempool_size ? atoi(byteps_rdma_mempool_size) : size; + size = align_ceil(size, pagesize_); + auto byteps_rdma_mempool_num = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_NUM"); size_t mempool_num = byteps_rdma_mempool_num ? atoi(byteps_rdma_mempool_num) : 1; PS_VLOG(1) << "RDMA initial mempool size set to " << size @@ -122,7 +124,7 @@ class SimpleMempool { CHECK(p); CHECK(mr = ibv_reg_mr(pd, p, size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size] + mr_list.emplace(p+size-1, mr); // this mr is associated with memory address range [p, p+size-1] free_list.emplace(size, p); } } @@ -158,7 +160,7 @@ class SimpleMempool { CHECK(p); struct ibv_mr *mr; CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+new_mem_size, mr); + mr_list.emplace(p+new_mem_size-1, mr); free_list.emplace(new_mem_size, p); it = free_list.lower_bound(proper_size); PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; From 8c35b87406980fd82d8cdf27199b96b905015090 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 12:34:23 +0800 Subject: [PATCH 67/79] reuse local msg_buf address --- src/rdma_transport.h | 34 ++++++------ src/rdma_utils.h | 44 ++++----------- src/rdma_van.h | 124 ++++++++++++++++++++++++------------------- 3 files changed, 96 insertions(+), 106 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index e4ea2f43..efca2460 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -195,9 +195,9 @@ class Transport { class RDMATransport : public Transport { public: - explicit RDMATransport(Endpoint *endpoint, SimpleMempool *mempool) { + explicit RDMATransport(Endpoint *endpoint, MemoryAllocator *allocator) { endpoint_ = CHECK_NOTNULL(endpoint); - mempool_ = CHECK_NOTNULL(mempool); + allocator_ = CHECK_NOTNULL(allocator); pagesize_ = sysconf(_SC_PAGESIZE); auto val = Environment::Get()->find("DMLC_ROLE"); @@ -215,7 +215,7 @@ class RDMATransport : public Transport { std::lock_guard lock(map_mu_); if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(mempool_->GetPD(), sa.data(), sa.size(), + CHECK (temp_mr = ibv_reg_mr(allocator_->GetPD(), sa.data(), sa.size(), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) << "Failed to register the memory region: " << strerror(errno) << ", sa.size()=" << sa.size(); @@ -237,15 +237,6 @@ class RDMATransport : public Transport { } virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - WRContext *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - msg_buf->reserved_context = reserved; - // prepare RDMA write sge list - struct ibv_sge sge[1 + msg_buf->mrs.size()]; - sge[0].addr = reinterpret_cast(msg_buf->inline_buf); - sge[0].length = msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - if (msg_buf->mrs.size() == 3) { // push request, split the meta and data into two writes // further, it does not send keys and lens since these meta already carries these info @@ -273,6 +264,15 @@ class RDMATransport : public Transport { } else { CHECK_EQ(msg_buf->mrs.size(),0); } + + WRContext *reserved = nullptr; + endpoint_->free_write_ctx.WaitAndPop(&reserved); + msg_buf->reserved_context = reserved; + // prepare RDMA write sge list + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->inline_buf); + sge.length = msg_buf->inline_len; + sge.lkey = allocator_->LocalKey(msg_buf->inline_buf); WRContext *write_ctx = msg_buf->reserved_context; CHECK(write_ctx); @@ -287,7 +287,7 @@ class RDMATransport : public Transport { wr.next = nullptr; wr.imm_data = idx; wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = sge; + wr.sg_list = &sge; wr.num_sge = 1; wr.wr.rdma.remote_addr = remote_addr; wr.wr.rdma.rkey = rkey; @@ -343,7 +343,7 @@ class RDMATransport : public Transport { // worker only needs a buffer for receving meta char *buffer = - mempool_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); + allocator_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); CHECK(buffer); buf_ctx->buffer = buffer; WRContext *reply_ctx = nullptr; @@ -353,7 +353,7 @@ class RDMATransport : public Transport { reinterpret_cast(reply_ctx->buffer->addr); resp->addr = reinterpret_cast(buffer); - resp->rkey = mempool_->RemoteKey(buffer); + resp->rkey = allocator_->RemoteKey(buffer); resp->origin_addr = req->origin_addr; resp->idx = addrpool.StoreAddress(buf_ctx); @@ -512,7 +512,7 @@ class RDMATransport : public Transport { protected: size_t pagesize_ = 4096; Endpoint *endpoint_; - SimpleMempool *mempool_; + MemoryAllocator *allocator_; bool is_server_; std::mutex map_mu_; std::unordered_map mem_mr_map_; // (memory, ibv_mr) @@ -522,7 +522,7 @@ class RDMATransport : public Transport { class IPCTransport : public RDMATransport { public: - explicit IPCTransport(Endpoint *endpoint, SimpleMempool *mempool) : RDMATransport(endpoint, mempool) { + explicit IPCTransport(Endpoint *endpoint, MemoryAllocator *allocator) : RDMATransport(endpoint, allocator) { auto val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); ipc_copy_nthreads_ = val ? atoi(val) : 4; for (int i = 0; i < ipc_copy_nthreads_; ++i) { diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 896b9e8d..1de7d769 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -98,9 +98,9 @@ static inline void ib_malloc(void** ptr, size_t size) { *ptr = p; } -class SimpleMempool { +class MemoryAllocator { public: - explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) { + explicit MemoryAllocator(struct ibv_pd *pd, size_t size = 0x10000000) { std::lock_guard lk(mu_); pd_ = pd; struct ibv_mr *mr; @@ -129,7 +129,7 @@ class SimpleMempool { } } - ~SimpleMempool() { + ~MemoryAllocator() { std::lock_guard lk(mu_); for(auto it = mr_list.begin(); it != mr_list.end(); it++) { CHECK_EQ(ibv_dereg_mr(it->second), 0); @@ -206,6 +206,7 @@ class SimpleMempool { struct ibv_mr *mr = Addr2MR(addr); return mr->lkey; } + uint32_t RemoteKey(char *addr) { struct ibv_mr *mr = Addr2MR(addr); return mr->rkey; @@ -221,6 +222,7 @@ class SimpleMempool { std::unordered_map used_list; struct ibv_pd *pd_; size_t total_allocated_size = 0; + size_t pagesize_; // first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) std::map mr_list; @@ -233,33 +235,9 @@ class SimpleMempool { return it->second; } - size_t pagesize_; }; -class Block { - public: - explicit Block(SimpleMempool *pool, char *addr, int count) - : pool(pool), addr(addr), counter(count) {} - - ~Block() { - CHECK_EQ(counter, 0); - pool->Free(addr); - } - - void Release() { - int v = counter.fetch_sub(1); - if (v == 1) { - delete this; - } - } - - private: - SimpleMempool *pool; - char *addr; - std::atomic counter; -}; - enum MessageTypes : uint32_t { kRendezvousStart, kRendezvousReply, @@ -302,12 +280,6 @@ struct BufferContext { typedef std::unique_ptr> MRPtr; -// -typedef std::tuple RemoteTuple; - -// recver, -typedef std::unordered_map RemoteAddress; - struct MessageBuffer { size_t inline_len; char *inline_buf; @@ -322,6 +294,12 @@ struct RequestContext { char hostname[kMaxHostnameLength]; }; +// +typedef std::tuple RemoteTuple; + +// recver, +typedef std::unordered_map RemoteAndLocalAddress; + static_assert(std::is_pod::value, "RendezvousStart must be a POD type."); static_assert(std::is_pod::value, diff --git a/src/rdma_van.h b/src/rdma_van.h index 2c6d56d9..c96f951c 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -66,9 +66,8 @@ class RDMAVan : public Van { cm_event_polling_thread_->join(); cm_event_polling_thread_.reset(); - PS_VLOG(1) << "Clearing mempool."; - send_mempool_.reset(); - recv_mempool_.reset(); + PS_VLOG(1) << "Clearing memory allocator."; + mem_allocator_.reset(); PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); @@ -216,8 +215,8 @@ class RDMAVan : public Van { local_mu_.unlock(); std::shared_ptr t = is_local_[node.id] ? - std::make_shared(endpoint, send_mempool_.get()) : - std::make_shared(endpoint, send_mempool_.get()); + std::make_shared(endpoint, mem_allocator_.get()) : + std::make_shared(endpoint, mem_allocator_.get()); endpoint->SetTransport(t); freeaddrinfo(remote_addr); @@ -259,7 +258,7 @@ class RDMAVan : public Van { } } - bool HasRemoteInfo(MessageBuffer *msg_buf, uint64_t key, bool is_push, int recver) { + bool HasRemoteInfo(Message& msg, uint64_t key, bool is_push, int recver) { std::lock_guard lk(addr_mu_); if (is_push && (push_addr_.find(key) != push_addr_.end()) && (push_addr_[key].find(recver) != push_addr_[key].end())) { @@ -269,89 +268,109 @@ class RDMAVan : public Van { && (pull_addr_[key].find(recver) != pull_addr_[key].end())) { return true; } - // no remote info, store the msg_buf address and push/pull flag for RendezvousReply - auto buf_addr = reinterpret_cast(msg_buf); - CHECK_EQ(msgbuf_cache_.find(buf_addr), msgbuf_cache_.end()); - msgbuf_cache_.emplace(buf_addr, std::make_tuple(key, is_push, recver)); + return false; } - void StoreRemoteInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + void StoreMsgBuf(MessageBuffer *msg_buf, uint64_t key, bool is_push, int recver) { + std::lock_guard lk(addr_mu_); + CHECK_EQ(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + msgbuf_cache_[msg_buf] = std::make_tuple(key, is_push, recver); + } + + void StoreRemoteAndLocalInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { std::lock_guard lk(addr_mu_); - auto buf_addr = reinterpret_cast(msg_buf); - if(msgbuf_cache_.find(buf_addr) == msgbuf_cache_.end()) { return; } // control message - auto key = std::get<0>(msgbuf_cache_[buf_addr]); - auto is_push = std::get<1>(msgbuf_cache_[buf_addr]); - auto recver = std::get<2>(msgbuf_cache_[buf_addr]); + + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + + auto key = std::get<0>(msgbuf_cache_[msg_buf]); + auto is_push = std::get<1>(msgbuf_cache_[msg_buf]); + auto recver = std::get<2>(msgbuf_cache_[msg_buf]); + + auto t = std::make_tuple(remote_addr, rkey, idx, msg_buf); if (is_push) { - push_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); + push_addr_[key][recver] = t; } else { - pull_addr_[key][recver] = std::make_tuple(remote_addr, rkey, idx); + pull_addr_[key][recver] = t; } - msgbuf_cache_.erase(buf_addr); } - RemoteTuple GetRemoteInfo(uint64_t key, bool is_push, int recver) { + RemoteTuple GetRemoteAndLocalInfo(uint64_t key, bool is_push, int recver) { std::lock_guard lk(addr_mu_); return (is_push ? push_addr_[key][recver] : pull_addr_[key][recver]); } + MessageBuffer* PrepareNewMsgBuf(Message& msg) { + MessageBuffer *msg_buf = new MessageBuffer(); + auto meta_len = GetPackMetaLen(msg.meta); + msg_buf->inline_len = meta_len; + msg_buf->inline_buf = mem_allocator_->Alloc(meta_len); + msg_buf->data = msg.data; + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + return msg_buf; + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); - auto trans = CHECK_NOTNULL(endpoint->GetTransport()); - trans->RegisterMemory(msg); - - MessageBuffer *msg_buf = new MessageBuffer(); - int meta_len = GetPackMetaLen(msg.meta); - size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); - msg_buf->inline_len = meta_len; - msg_buf->inline_buf = send_mempool_->Alloc(meta_len); - msg_buf->data = msg.data; + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + trans->RegisterMemory(msg); + // pack meta info if (IsValidPushpull(msg)) { trans->AddMeta(msg); PackWorkerTensorAddress(msg); } - PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); - + // start rendezvous if no remote info if (!IsValidPushpull(msg)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, 0, 0, -1); trans->SendRendezvousBegin(msg, msg_buf); return total_len; + } else { - trans->PrepareData(msg, msg_buf); auto is_push = msg.meta.push; auto key = msg.meta.key; - if (!HasRemoteInfo(msg_buf, key, is_push, remote_id)) { + if (!HasRemoteInfo(msg, key, is_push, remote_id)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, key, is_push, remote_id); + trans->PrepareData(msg, msg_buf); trans->SendRendezvousBegin(msg, msg_buf); return total_len; - } + } } - auto remote_tuple = GetRemoteInfo(msg.meta.key, msg.meta.push, remote_id); + auto addr_tuple = GetRemoteAndLocalInfo(msg.meta.key, msg.meta.push, remote_id); + MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer + + // prepare new meta and data + CHECK_EQ(msg_buf->inline_len, meta_len); + CHECK(msg_buf->inline_buf); + msg_buf->data = msg.data; // may not need this + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request - trans->SendPushRequest(msg, msg_buf, remote_tuple); + trans->SendPushRequest(msg, msg_buf, addr_tuple); } else if (msg.meta.push && !msg.meta.request) { // server, push response - trans->SendPushResponse(msg, msg_buf, remote_tuple); + trans->SendPushResponse(msg, msg_buf, addr_tuple); } else if (!msg.meta.push && msg.meta.request) { // worker, pull request - trans->SendPullRequest(msg, msg_buf, remote_tuple); + trans->SendPullRequest(msg, msg_buf, addr_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - trans->SendPullResponse(msg, msg_buf, remote_tuple); + trans->SendPullResponse(msg, msg_buf, addr_tuple); } else { CHECK(0) << "unexpected message type"; } @@ -381,8 +400,6 @@ class RDMAVan : public Van { auto trans = CHECK_NOTNULL(endpoint->GetTransport()); if (!IsValidPushpull(*msg)) { - recv_mempool_->Free(buffer_ctx->buffer); - delete buffer_ctx; return total_len; } @@ -415,8 +432,7 @@ class RDMAVan : public Van { pd_ = ibv_alloc_pd(context_); CHECK(pd_) << "Failed to allocate protection domain"; - send_mempool_.reset(new SimpleMempool(pd_)); - recv_mempool_.reset(new SimpleMempool(pd_)); + mem_allocator_.reset(new MemoryAllocator(pd_)); comp_event_channel_ = ibv_create_comp_channel(context_); @@ -470,10 +486,6 @@ class RDMAVan : public Van { ReleaseWorkRequestContext(context, endpoint); break; case IBV_WC_RDMA_WRITE: { - MessageBuffer *msg_buf = - *reinterpret_cast(context->buffer->addr); - send_mempool_->Free(msg_buf->inline_buf); - delete msg_buf; ReleaseWorkRequestContext(context, endpoint); } break; case IBV_WC_RECV_RDMA_WITH_IMM: { @@ -507,7 +519,7 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous - StoreRemoteInfo(msg_buf, remote_addr, rkey, idx); + StoreRemoteAndLocalInfo(msg_buf, remote_addr, rkey, idx); trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); } else { CHECK(0); @@ -622,8 +634,8 @@ class RDMAVan : public Van { local_mu_.unlock(); std::shared_ptr t = is_local_[remote_ctx->node] ? - std::make_shared(endpoint, recv_mempool_.get()) : - std::make_shared(endpoint, recv_mempool_.get()); + std::make_shared(endpoint, mem_allocator_.get()) : + std::make_shared(endpoint, mem_allocator_.get()); endpoint->SetTransport(t); RequestContext ctx; @@ -709,8 +721,7 @@ class RDMAVan : public Van { } AddressPool addr_pool_; - std::unique_ptr recv_mempool_; - std::unique_ptr send_mempool_; + std::unique_ptr mem_allocator_; std::unique_ptr rdma_trans_; std::unique_ptr ipc_trans_; @@ -748,11 +759,12 @@ class RDMAVan : public Van { using RemoteTensorMeta = std::unordered_map; // sender as the key std::unordered_map tensor_info_map_; // (key, sender) --> TensorInfo - // store rendezvous address std::mutex addr_mu_; - std::unordered_map push_addr_; // , - std::unordered_map pull_addr_; // , - std::unordered_map > msgbuf_cache_; // msg_buf, + // , () + std::unordered_map push_addr_; + std::unordered_map pull_addr_; + std::unordered_map > msgbuf_cache_; // msg_buf, + }; // class RDMAVan }; // namespace ps From 6c7f80bf47f1ce295bd558e40dc9c186b4289f1a Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 13:38:09 +0800 Subject: [PATCH 68/79] add memory allocator --- src/rdma_utils.h | 125 ++++++++++------------------------------------- 1 file changed, 26 insertions(+), 99 deletions(-) diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 1de7d769..e7e49731 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -100,40 +100,16 @@ static inline void ib_malloc(void** ptr, size_t size) { class MemoryAllocator { public: - explicit MemoryAllocator(struct ibv_pd *pd, size_t size = 0x10000000) { + explicit MemoryAllocator(struct ibv_pd *pd) { std::lock_guard lk(mu_); pd_ = pd; - struct ibv_mr *mr; - - pagesize_ = sysconf(_SC_PAGESIZE); - - // set init mempool size - auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); - size = byteps_rdma_mempool_size ? atoi(byteps_rdma_mempool_size) : size; - size = align_ceil(size, pagesize_); - - auto byteps_rdma_mempool_num = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_NUM"); - size_t mempool_num = byteps_rdma_mempool_num ? atoi(byteps_rdma_mempool_num) : 1; - PS_VLOG(1) << "RDMA initial mempool size set to " << size - << ", mempool num set to " << mempool_num; - - for (size_t i = 0; i < mempool_num; ++i) { - char *p; - ib_malloc((void**) &p, size); - total_allocated_size += size; - CHECK(p); - CHECK(mr = ibv_reg_mr(pd, p, size, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+size-1, mr); // this mr is associated with memory address range [p, p+size-1] - free_list.emplace(size, p); - } } ~MemoryAllocator() { std::lock_guard lk(mu_); - for(auto it = mr_list.begin(); it != mr_list.end(); it++) { - CHECK_EQ(ibv_dereg_mr(it->second), 0); - free(it->second->addr); + for(auto &it : mr_) { + CHECK_EQ(ibv_dereg_mr(it.second), 0); + free(it.first); } } @@ -142,74 +118,29 @@ class MemoryAllocator { return nullptr; } - std::lock_guard lk(mu_); - - // use page aligned memory - size_t proper_size = align_ceil(size, pagesize_); - - auto it = free_list.lower_bound(proper_size); - - // if there is no space left, need to allocate and register new memory - if (it == free_list.end()) { - size_t new_mem_size = total_allocated_size; - while (proper_size > new_mem_size) { - new_mem_size *= 2; - } - char *p; - ib_malloc((void**) &p, new_mem_size); - CHECK(p); - struct ibv_mr *mr; - CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+new_mem_size-1, mr); - free_list.emplace(new_mem_size, p); - it = free_list.lower_bound(proper_size); - PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; - total_allocated_size += new_mem_size; - } - - CHECK_NE(free_list.end(), it) << "Not enough memory"; - CHECK_GE(it->first, proper_size); - - char *addr = it->second; - size_t space_left = it->first - proper_size; - - free_list.erase(it); - CHECK_EQ(used_list.find(addr), used_list.end()) - << "Address is already allocated"; - - used_list.emplace(addr, proper_size); - - if (space_left) { - free_list.emplace(space_left, addr + proper_size); - } - - return addr; - } + // align to page size (usually 4KB) + size = align_ceil(size, pagesize_); - void Free(char *addr) { - if (!addr) { - return; - } + char *p; + ib_malloc((void**) &p, size); + CHECK(p); + struct ibv_mr *mr; + CHECK(mr = ibv_reg_mr(pd_, p, size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); + std::lock_guard lk(mu_); + mr_[p] = mr; + used_list.emplace(p, size); - auto it = used_list.find(addr); - CHECK_NE(used_list.end(), it) - << "Cannot find info about address: " << (uintptr_t)addr; - - size_t size = it->second; - used_list.erase(it); - free_list.emplace(size, addr); + return p; } uint32_t LocalKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->lkey; + return Addr2MR(addr)->lkey; } uint32_t RemoteKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->rkey; + return Addr2MR(addr)->rkey; } struct ibv_pd* GetPD() { @@ -217,25 +148,21 @@ class MemoryAllocator { } private: - std::mutex mu_; - std::multimap free_list; - std::unordered_map used_list; - struct ibv_pd *pd_; - size_t total_allocated_size = 0; - size_t pagesize_; - - // first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) - std::map mr_list; - // convert the memory address to its associated RDMA memory region inline struct ibv_mr* Addr2MR(char *addr) { std::lock_guard lk(mu_); - auto it = mr_list.lower_bound(addr); - CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; + auto it = mr_.find(addr); + CHECK_NE(it, mr_.end()) + << "cannot find the associated memory region"; + return it->second; } - + std::mutex mu_; + struct ibv_pd *pd_; + size_t pagesize_ = sysconf(_SC_PAGESIZE); + std::unordered_map used_list; + std::unordered_map mr_; }; enum MessageTypes : uint32_t { From 565bd4544a70074673b17a0915817446fa9faa26 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 14:01:34 +0800 Subject: [PATCH 69/79] improve receiver memory alignment --- src/rdma_transport.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index efca2460..af17825b 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -335,17 +335,18 @@ class RDMATransport : public Transport { buf_ctx->meta_len = req->meta_len; buf_ctx->data_num = req->data_num; - uint64_t len = req->meta_len; + auto data_len = 0; for (size_t i = 0; i < req->data_num; ++i) { buf_ctx->data_len[i] = req->data_len[i]; - len += req->data_len[i]; + data_len += req->data_len[i]; } // worker only needs a buffer for receving meta - char *buffer = - allocator_->Alloc(is_server_ ? (kMaxMetaBound + len) : (kMaxMetaBound + req->meta_len)); + char *buffer = allocator_->Alloc( + is_server_ ? (data+len + align_ceil(req->meta_len, pagesize_)) : req->meta_len); CHECK(buffer); buf_ctx->buffer = buffer; + WRContext *reply_ctx = nullptr; endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx); From 3a7437c21af5595c276b8fe41e96f7b78845cee5 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 14:05:37 +0800 Subject: [PATCH 70/79] quick fix --- src/rdma_transport.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index af17825b..30e2fa2b 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -343,7 +343,7 @@ class RDMATransport : public Transport { // worker only needs a buffer for receving meta char *buffer = allocator_->Alloc( - is_server_ ? (data+len + align_ceil(req->meta_len, pagesize_)) : req->meta_len); + is_server_ ? (data_len + align_ceil(req->meta_len, pagesize_)) : req->meta_len); CHECK(buffer); buf_ctx->buffer = buffer; From 93d44954d18965076565e3697430e258e1c5cc27 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 18:15:22 +0800 Subject: [PATCH 71/79] fix 1v1 hang --- src/rdma_transport.h | 21 +++++---------------- src/rdma_utils.h | 2 +- src/rdma_van.h | 11 ++++++----- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 30e2fa2b..bb2a85d3 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -123,7 +123,7 @@ struct Endpoint { attr.sq_sig_all = 0; CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) - << "Create RDMA queue pair failed"; + << "Create RDMA queue pair failed: " << strerror(errno); InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, kRendezvousStartContext); @@ -265,24 +265,15 @@ class RDMATransport : public Transport { CHECK_EQ(msg_buf->mrs.size(),0); } - WRContext *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); - msg_buf->reserved_context = reserved; // prepare RDMA write sge list struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->inline_buf); sge.length = msg_buf->inline_len; sge.lkey = allocator_->LocalKey(msg_buf->inline_buf); - WRContext *write_ctx = msg_buf->reserved_context; - CHECK(write_ctx); - MessageBuffer **tmp = - reinterpret_cast(write_ctx->buffer->addr); - *tmp = msg_buf; // write the addr of msg_buf into the mr buffer - struct ibv_send_wr wr, *bad_wr = nullptr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = reinterpret_cast(write_ctx); + wr.wr_id = reinterpret_cast(msg_buf); wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wr.next = nullptr; wr.imm_data = idx; @@ -297,11 +288,9 @@ class RDMATransport : public Transport { } void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) { - WRContext *context = nullptr, *reserved = nullptr; - endpoint_->free_write_ctx.WaitAndPop(&reserved); + WRContext *context = nullptr; endpoint_->free_start_ctx.WaitAndPop(&context); - msg_buf->reserved_context = reserved; RendezvousStart *req = reinterpret_cast(context->buffer->addr); req->meta_len = msg_buf->inline_len; @@ -335,7 +324,7 @@ class RDMATransport : public Transport { buf_ctx->meta_len = req->meta_len; buf_ctx->data_num = req->data_num; - auto data_len = 0; + size_t data_len = 0; for (size_t i = 0; i < req->data_num; ++i) { buf_ctx->data_len[i] = req->data_len[i]; data_len += req->data_len[i]; @@ -343,7 +332,7 @@ class RDMATransport : public Transport { // worker only needs a buffer for receving meta char *buffer = allocator_->Alloc( - is_server_ ? (data_len + align_ceil(req->meta_len, pagesize_)) : req->meta_len); + is_server_ ? (align_ceil(req->meta_len, pagesize_) + data_len) : req->meta_len); CHECK(buffer); buf_ctx->buffer = buffer; diff --git a/src/rdma_utils.h b/src/rdma_utils.h index e7e49731..c21c733b 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -60,7 +60,7 @@ static const int kWriteDepth = kStartDepth * 2; static const int kRxDepth = kStartDepth + kWriteDepth; static const int kReplyDepth = kRxDepth; -static const int kSGEntry = 4; +static const int kSGEntry = 1; static const int kTimeoutms = 1000; static const int kRdmaListenBacklog = 128; static const int kMaxConcurrentWorkRequest = diff --git a/src/rdma_van.h b/src/rdma_van.h index c96f951c..c2308b40 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -474,19 +474,20 @@ class RDMAVan : public Van { << "Failed status \n" << ibv_wc_status_str(wc[i].status) << " " << wc[i].status << " " << static_cast(wc[i].wr_id) << " " << wc[i].vendor_err; - + WRContext *context = reinterpret_cast(wc[i].wr_id); Endpoint *endpoint = reinterpret_cast(context->private_data); - CHECK(endpoint); + // IBV_WC_RDMA_WRITE use msg_buf as the wr_id + // so there won't be context and endpoint for this op switch (wc[i].opcode) { - case IBV_WC_SEND: + case IBV_WC_SEND: { ReleaseWorkRequestContext(context, endpoint); - break; + } break; case IBV_WC_RDMA_WRITE: { - ReleaseWorkRequestContext(context, endpoint); + // do nothing } break; case IBV_WC_RECV_RDMA_WITH_IMM: { uint32_t addr_idx = wc[i].imm_data; From 84279d8c8ccc0da1e21ec89246066cfc1400e435 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Fri, 27 Dec 2019 21:00:27 +0800 Subject: [PATCH 72/79] add debug mode for test case --- tests/test_benchmark.cc | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_benchmark.cc b/tests/test_benchmark.cc index d8a9c4f3..a0324e92 100644 --- a/tests/test_benchmark.cc +++ b/tests/test_benchmark.cc @@ -6,6 +6,8 @@ #define DIVUP(x, y) (((x)+(y)-1)/(y)) #define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) +#define DEBUG_PRINT_TENSOR_VALUE(X) (*((float *)(X) + 0)) +#define DEBUG_PRINT_TENSOR_ADDRESS(X) (reinterpret_cast(X)) using namespace ps; @@ -16,6 +18,7 @@ enum MODE { PULL_ONLY = 3 }; std::unordered_map > mem_map; +bool debug_mode_ = false; void aligned_memory_alloc(void** ptr, size_t size) { size_t page_size = sysconf(_SC_PAGESIZE); @@ -24,10 +27,16 @@ void aligned_memory_alloc(void** ptr, size_t size) { int ret = posix_memalign(&p, page_size, size_aligned); CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); CHECK(p); - memset(p, 0, size); + memset(p, 1, size); *ptr = p; } +void float_sum(float *dst, float *src, size_t len) { + for (size_t i = 0; i < len / (size_t) sizeof(float); ++i) { + dst[i] = dst[i] + src[i]; + } +} + template void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { uint64_t key = req_data.keys[0]; @@ -36,6 +45,8 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; + size_t tensor_len = req_data.vals.size(); + if (mem_map.find(key) == mem_map.end()) { size_t len = (size_t) req_data.vals.size(); @@ -54,6 +65,18 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer memcpy(ptr_len, &len, sizeof(int)); } + auto recved = reinterpret_cast(req_data.vals.data()); + float_sum((float*) mem_map[key].vals.data(), (float*) recved, tensor_len); + + if (debug_mode_) { + LOG(INFO) << "recved tensor! key=" << key << "\t" + << "store: " << DEBUG_PRINT_TENSOR_VALUE(mem_map[key].vals.data()) << "\t" + << "recv: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t" + << "address: " << DEBUG_PRINT_TENSOR_ADDRESS(recved) << "\t" + << "len: " << req_data.vals.size() << "\t" + << "sender: " << req_meta.sender; + } + // send push response (empty) KVPairs res; server->Response(req_meta, res); @@ -67,6 +90,8 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer void StartServer() { if (!IsServer()) return; + debug_mode_ = Environment::Get()->find("DEBUG_MODE") ? true : false; + auto server = new KVServer(0); server->set_request_handle(EmptyHandler); RegisterExitCallback([server]() { delete server; }); From add1d10eb6b011ef5ae5960a37f8421a8835606e Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sat, 28 Dec 2019 00:14:01 +0800 Subject: [PATCH 73/79] tests: disable sum by default --- tests/test_benchmark.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_benchmark.cc b/tests/test_benchmark.cc index a0324e92..1fafc0aa 100644 --- a/tests/test_benchmark.cc +++ b/tests/test_benchmark.cc @@ -32,6 +32,7 @@ void aligned_memory_alloc(void** ptr, size_t size) { } void float_sum(float *dst, float *src, size_t len) { + if (len == 0) return; for (size_t i = 0; i < len / (size_t) sizeof(float); ++i) { dst[i] = dst[i] + src[i]; } @@ -45,7 +46,6 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; - size_t tensor_len = req_data.vals.size(); if (mem_map.find(key) == mem_map.end()) { size_t len = (size_t) req_data.vals.size(); @@ -66,7 +66,9 @@ void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer } auto recved = reinterpret_cast(req_data.vals.data()); - float_sum((float*) mem_map[key].vals.data(), (float*) recved, tensor_len); + // only sum the first 4 bytes + size_t sum_len = debug_mode_ ? req_data.vals.size() : 0; + float_sum((float*) mem_map[key].vals.data(), (float*) recved, sum_len); if (debug_mode_) { LOG(INFO) << "recved tensor! key=" << key << "\t" From d43f37da079baebfa1280f24482f52c2a4668ac6 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 29 Dec 2019 20:31:01 +0800 Subject: [PATCH 74/79] cleaning: remove useless write ctx --- src/rdma_transport.h | 13 +------------ src/rdma_utils.h | 2 +- src/rdma_van.h | 3 --- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index bb2a85d3..517259e1 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -38,11 +38,9 @@ struct Endpoint { WRContext start_ctx[kStartDepth]; WRContext reply_ctx[kReplyDepth]; - WRContext write_ctx[kWriteDepth]; ThreadsafeQueue free_start_ctx; ThreadsafeQueue free_reply_ctx; - ThreadsafeQueue free_write_ctx; Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} @@ -69,13 +67,6 @@ struct Endpoint { } } - for (int i = 0; i < kWriteDepth; ++i) { - if (write_ctx[i].buffer) { - free(write_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); - } - } - rdma_destroy_qp(cm_id); CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); } @@ -115,7 +106,7 @@ struct Endpoint { memset(&attr, 0, sizeof(ibv_qp_init_attr)); attr.send_cq = cq; attr.recv_cq = cq; - attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; + attr.cap.max_send_wr = kStartDepth + kReplyDepth; attr.cap.max_recv_wr = kRxDepth; attr.cap.max_send_sge = kSGEntry; attr.cap.max_recv_sge = kSGEntry; @@ -129,8 +120,6 @@ struct Endpoint { kRendezvousStartContext); InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, kRendezvousReplyContext); - InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, - kWriteContext); for (size_t i = 0; i < kRxDepth; ++i) { void *buf; diff --git a/src/rdma_utils.h b/src/rdma_utils.h index c21c733b..16de28d9 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -55,7 +55,7 @@ namespace ps { #define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) static const int kStartDepth = 1024; -static const int kWriteDepth = kStartDepth * 2; +static const int kWriteDepth = kStartDepth; static const int kRxDepth = kStartDepth + kWriteDepth; static const int kReplyDepth = kRxDepth; diff --git a/src/rdma_van.h b/src/rdma_van.h index c2308b40..1b28de5c 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -452,9 +452,6 @@ class RDMAVan : public Van { case kRendezvousReplyContext: endpoint->free_reply_ctx.Push(context); break; - case kWriteContext: - endpoint->free_write_ctx.Push(context); - break; case kReceiveContext: endpoint->PostRecv(context); break; From 962a05cdcf8d26045a913fcf6da842dc820c32c7 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Sun, 29 Dec 2019 23:11:50 +0800 Subject: [PATCH 75/79] simplify RDMAWriteWithImm --- src/rdma_transport.h | 101 ++++++++++++++++++++++--------------------- src/rdma_utils.h | 3 -- src/rdma_van.h | 2 +- 3 files changed, 53 insertions(+), 53 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 517259e1..ab836036 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -158,26 +158,26 @@ struct Endpoint { class Transport { public: - virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - - virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; - virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; - virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; - virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; - virtual void AddMeta(Message &msg) = 0; - virtual void RegisterMemory(Message &msg) = 0; - virtual void PrepareData(Message &msg, MessageBuffer *msg_buf) = 0; + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; - virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; + virtual void AddMeta(Message &msg) = 0; + virtual void RegisterMemory(Message &msg) = 0; + virtual void PrepareData(Message &msg, MessageBuffer *msg_buf) = 0; - virtual SArray CreateFunctionalSarray(void *value, size_t size) = 0; + virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; + + virtual SArray CreateFunctionalSarray(void *value, size_t size) = 0; }; // class Transport @@ -226,35 +226,6 @@ class RDMATransport : public Transport { } virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { - if (msg_buf->mrs.size() == 3) { - // push request, split the meta and data into two writes - // further, it does not send keys and lens since these meta already carries these info - struct ibv_sge my_sge; - my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); - my_sge.length = msg_buf->mrs[1].second; - my_sge.lkey = msg_buf->mrs[1].first->lkey; - - // this rdma-write will not trigger any signal both remotely and locally - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = 0; - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - wr.sg_list = &my_sge; - wr.num_sge = 1; - wr.wr.rdma.rkey = rkey; - - // write to the next page-aligned address (remote_addr should already be aligned) - wr.wr.rdma.remote_addr = remote_addr + align_ceil(msg_buf->inline_len, pagesize_); - - CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - - } else { - CHECK_EQ(msg_buf->mrs.size(),0); - } - - // prepare RDMA write sge list struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->inline_buf); sge.length = msg_buf->inline_len; @@ -384,19 +355,51 @@ class RDMATransport : public Transport { RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } - void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { - Send(msg, msg_buf, remote_tuple); + void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 3); + auto raddr = std::get<0>(remote_tuple); + auto rkey = std::get<1>(remote_tuple); + auto idx = std::get<2>(remote_tuple); + + // push request, split the meta and data into two writes + // further, it does not send keys and lens since these meta already carries these info + struct ibv_sge my_sge; + my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); + my_sge.length = msg_buf->mrs[1].second; + my_sge.lkey = msg_buf->mrs[1].first->lkey; + + // this rdma-write will not trigger any signal both remotely and locally + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + wr.sg_list = &my_sge; + wr.num_sge = 1; + wr.wr.rdma.rkey = rkey; + + // write to the next page-aligned address (remote_addr should already be aligned) + wr.wr.rdma.remote_addr = raddr + align_ceil(msg_buf->inline_len, pagesize_); + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + + RDMAWriteWithImm(msg_buf, raddr, rkey, idx); } void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 0); Send(msg, msg_buf, remote_tuple); } - virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 0); Send(msg, msg_buf, remote_tuple); } virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 0); + auto raddr = msg.meta.addr; auto rkey = msg.meta.option; diff --git a/src/rdma_utils.h b/src/rdma_utils.h index 16de28d9..5d40a8ec 100644 --- a/src/rdma_utils.h +++ b/src/rdma_utils.h @@ -71,9 +71,6 @@ static const int kMaxDataFields = 4; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; -// allocate a whole page for meta with potentially variable length -static const int kMaxMetaBound = sysconf(_SC_PAGESIZE); - // should have the same prefix with BytePS shared memory static const std::string kShmPrefix("BytePS_ShM_"); diff --git a/src/rdma_van.h b/src/rdma_van.h index 1b28de5c..7b919766 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -353,7 +353,7 @@ class RDMAVan : public Van { MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer // prepare new meta and data - CHECK_EQ(msg_buf->inline_len, meta_len); + CHECK_EQ(msg_buf->inline_len, (size_t) meta_len); CHECK(msg_buf->inline_buf); msg_buf->data = msg.data; // may not need this PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); From e551a42f2e509f007eb7beb6e8ddd5a0bccc0a13 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 30 Dec 2019 14:14:20 +0800 Subject: [PATCH 76/79] fix repeated ibv_reg_mr --- src/rdma_transport.h | 70 +++++--------------------------------------- src/rdma_van.h | 70 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 72 insertions(+), 68 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index ab836036..585d1a8c 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -165,15 +165,11 @@ class Transport { virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; - virtual void AddMeta(Message &msg) = 0; - virtual void RegisterMemory(Message &msg) = 0; - virtual void PrepareData(Message &msg, MessageBuffer *msg_buf) = 0; - virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) = 0; virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; @@ -194,36 +190,7 @@ class RDMATransport : public Transport { is_server_ = (role=="server"); }; - ~RDMATransport() { - for (auto& it : mem_mr_map_) ibv_dereg_mr(it.second); - }; - - void RegisterMemory(Message &msg) { - for (auto& sa : msg.data) { - if (sa.size() == 0) continue; - std::lock_guard lock(map_mu_); - if (mem_mr_map_.find(sa.data()) == mem_mr_map_.end()) { - struct ibv_mr *temp_mr; - CHECK (temp_mr = ibv_reg_mr(allocator_->GetPD(), sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " << strerror(errno) - << ", sa.size()=" << sa.size(); - mem_mr_map_[sa.data()] = temp_mr; - } - } - } - - void PrepareData(Message &msg, MessageBuffer *msg_buf) { - if (!(msg.meta.push && msg.meta.request)) return; // only push request - for (auto &sa : msg_buf->data) { - if (sa.size() == 0) continue; - std::lock_guard lock(map_mu_); - auto it = mem_mr_map_.find(sa.data()); - MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); - } - } + ~RDMATransport() {}; virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { struct ibv_sge sge; @@ -329,24 +296,6 @@ class RDMATransport : public Transport { << "ibv_post_send failed."; } - void AddMeta(Message &msg) { - if (msg.meta.request) { - msg.meta.key = DecodeKey(msg.data[0]); - } - if (msg.meta.push && msg.meta.request) { - // push request - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - - std::lock_guard lock(map_mu_); - CHECK_NE(mem_mr_map_.find(msg.data[1].data()), mem_mr_map_.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = mem_mr_map_[vals.data()]->rkey; - } - } - void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { auto raddr = std::get<0>(remote_tuple); auto rkey = std::get<1>(remote_tuple); @@ -397,21 +346,18 @@ class RDMATransport : public Transport { Send(msg, msg_buf, remote_tuple); } - virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) { CHECK_EQ(msg_buf->mrs.size(), 0); auto raddr = msg.meta.addr; auto rkey = msg.meta.option; - - map_mu_.lock(); - auto temp_mr = mem_mr_map_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, mem_mr_map_.end()); - map_mu_.unlock(); + auto len = msg.meta.val_len; + CHECK_EQ(msg.meta.val_len, msg_buf->data[1].size()); struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; + sge.length = len; + sge.lkey = lkey; // this rdma-write will not trigger any signal both remotely and locally struct ibv_send_wr wr, *bad_wr = nullptr; @@ -496,8 +442,6 @@ class RDMATransport : public Transport { Endpoint *endpoint_; MemoryAllocator *allocator_; bool is_server_; - std::mutex map_mu_; - std::unordered_map mem_mr_map_; // (memory, ibv_mr) }; // class Transport diff --git a/src/rdma_van.h b/src/rdma_van.h index 7b919766..e351d55a 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -78,6 +78,8 @@ class RDMAVan : public Van { CHECK(!ibv_destroy_comp_channel(comp_event_channel_)) << "Failed to destroy channel"; + for (auto& it : mem_mr_) ibv_dereg_mr(it.second); + // TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this // CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << // strerror(errno); @@ -310,6 +312,56 @@ class RDMAVan : public Van { return msg_buf; } + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (sa.size() == 0) continue; + std::lock_guard lock(map_mu_); + if (mem_mr_.find(sa.data()) == mem_mr_.end()) { + struct ibv_mr *temp_mr; + LOG(INFO) << "ibv_mr register size=" << sa.size() + << "\t key=" << DecodeKey(msg.data[0]) + << "\t " << (msg.meta.push ? "push" : "pull") + << " " << (msg.meta.request ? "request" : "response"); + + CHECK (temp_mr = ibv_reg_mr(mem_allocator_->GetPD(), sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + mem_mr_[sa.data()] = temp_mr; + } + } + } + + void PrepareData(Message &msg, MessageBuffer *msg_buf) { + if (!(msg.meta.push && msg.meta.request)) return; // only push request + for (auto &sa : msg_buf->data) { + if (sa.size() == 0) continue; + std::lock_guard lock(map_mu_); + auto it = mem_mr_.find(sa.data()); + MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); + } + } + + void AddMeta(Message &msg) { + if (msg.meta.request) { + msg.meta.key = DecodeKey(msg.data[0]); + } + if (msg.meta.push && msg.meta.request) { + // push request + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + + std::lock_guard lock(map_mu_); + CHECK_NE(mem_mr_.find(msg.data[1].data()), mem_mr_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = mem_mr_[vals.data()]->rkey; + } + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); @@ -321,15 +373,16 @@ class RDMAVan : public Van { size_t total_len = meta_len + data_len; CHECK(meta_len); - auto trans = CHECK_NOTNULL(endpoint->GetTransport()); - trans->RegisterMemory(msg); + RegisterMemory(msg); // pack meta info if (IsValidPushpull(msg)) { - trans->AddMeta(msg); + AddMeta(msg); PackWorkerTensorAddress(msg); } + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + // start rendezvous if no remote info if (!IsValidPushpull(msg)) { MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); @@ -343,7 +396,7 @@ class RDMAVan : public Van { if (!HasRemoteInfo(msg, key, is_push, remote_id)) { MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); StoreMsgBuf(msg_buf, key, is_push, remote_id); - trans->PrepareData(msg, msg_buf); + PrepareData(msg, msg_buf); trans->SendRendezvousBegin(msg, msg_buf); return total_len; } @@ -370,7 +423,11 @@ class RDMAVan : public Van { trans->SendPullRequest(msg, msg_buf, addr_tuple); } else if (!msg.meta.push && !msg.meta.request) { // server, pull response - trans->SendPullResponse(msg, msg_buf, addr_tuple); + map_mu_.lock(); + auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_.end()); + map_mu_.unlock(); + trans->SendPullResponse(msg, msg_buf, addr_tuple, temp_mr->second->lkey); } else { CHECK(0) << "unexpected message type"; } @@ -763,6 +820,9 @@ class RDMAVan : public Van { std::unordered_map pull_addr_; std::unordered_map > msgbuf_cache_; // msg_buf, + std::mutex map_mu_; + std::unordered_map mem_mr_; // (memory address, ibv_mr) + }; // class RDMAVan }; // namespace ps From a702ca562ebfdae6dc914e27c35859ba4ab67ab1 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 30 Dec 2019 16:48:12 +0800 Subject: [PATCH 77/79] fix broadcast: send missed first message in rendez --- src/rdma_transport.h | 2 +- src/rdma_van.h | 312 ++++++++++++++++++++++++------------------- 2 files changed, 175 insertions(+), 139 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 585d1a8c..6653dd7a 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -486,7 +486,7 @@ class IPCTransport : public RDMATransport { Send(msg, msg_buf, remote_tuple); } - void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) { CHECK_EQ(msg_buf->mrs.size(), 0); auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); void* shm_addr = CHECK_NOTNULL(GetSharedMemory(kShmPrefix, msg.meta.key)); diff --git a/src/rdma_van.h b/src/rdma_van.h index e351d55a..7e83cd2c 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -225,6 +225,127 @@ class RDMAVan : public Van { } } + int SendMsg(Message &msg) override { + int remote_id = msg.meta.recver; + CHECK_NE(remote_id, Meta::kEmpty); + CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); + Endpoint *endpoint = endpoints_[remote_id].get(); + + int meta_len = GetPackMetaLen(msg.meta); + size_t data_len = msg.meta.data_size; + size_t total_len = meta_len + data_len; + CHECK(meta_len); + + RegisterMemory(msg); + + // pack meta info + if (IsValidPushpull(msg)) { + AddMeta(msg); + PackWorkerTensorAddress(msg); + } + + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + + // start rendezvous if no remote info + if (!IsValidPushpull(msg)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, msg); + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; + + } else { + auto is_push = msg.meta.push; + auto key = msg.meta.key; + if (!HasRemoteInfo(msg, key, is_push, remote_id)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, msg); + PrepareData(msg, msg_buf); + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; + } + } + + auto addr_tuple = GetRemoteAndLocalInfo(msg.meta.key, msg.meta.push, remote_id); + MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer + + // prepare new meta and data + CHECK_EQ(msg_buf->inline_len, (size_t) meta_len); + CHECK(msg_buf->inline_buf); + msg_buf->data = msg.data; // may not need this + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + + // already know remote address, directly use RDMA-write + if (msg.meta.push && msg.meta.request) { + // worker, push request + trans->SendPushRequest(msg, msg_buf, addr_tuple); + } else if (msg.meta.push && !msg.meta.request) { + // server, push response + trans->SendPushResponse(msg, msg_buf, addr_tuple); + } else if (!msg.meta.push && msg.meta.request) { + // worker, pull request + trans->SendPullRequest(msg, msg_buf, addr_tuple); + } else if (!msg.meta.push && !msg.meta.request) { + // server, pull response + map_mu_.lock(); + auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_.end()); + map_mu_.unlock(); + trans->SendPullResponse(msg, msg_buf, addr_tuple, temp_mr->second->lkey); + } else { + CHECK(0) << "unexpected message type"; + } + + return total_len; + } + + int RecvMsg(Message *msg) override { + msg->data.clear(); + std::tuple notification; + recv_buffers_.WaitAndPop(¬ification); + + Endpoint *endpoint = std::get(notification); + BufferContext *buffer_ctx = std::get(notification); + + msg->meta.recver = my_node_.id; + msg->meta.sender = endpoint->node_id; + + // the second argument is actually deprecated, + // we keep it as is in order to be compatible + UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); + int meta_len = GetPackMetaLen(msg->meta); + + int total_len = 0; + total_len += meta_len; + + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + + if (!IsValidPushpull(*msg)) { + return total_len; + } + + // valid data message + if (msg->meta.push && msg->meta.request) { + // push request + total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); + StoreWorkerTensorAddress(msg); + } else if (!msg->meta.push && msg->meta.request) { + // pull request + total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); + } else if (msg->meta.push && !msg->meta.request) { + // push response + total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); + } else if (!msg->meta.push && !msg->meta.request) { + // pull response + total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); + } else { + CHECK(0) << "unknown msg type"; + } + + return total_len; + } + + private: + void PackWorkerTensorAddress(Message &msg) { // must be pull response if (msg.meta.push || msg.meta.request) return; @@ -233,10 +354,8 @@ class RDMAVan : public Van { auto recver = msg.meta.recver; std::lock_guard lock(info_mu_); - CHECK_NE(tensor_info_map_.find(key), tensor_info_map_.end()) - << "key=" << key << " not inited in tensor_info_map_"; - CHECK_NE(tensor_info_map_[key].find(recver), tensor_info_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in tensor_info_map_[key]"; + CHECK_NE(tensor_info_map_.find(key), tensor_info_map_.end()); + CHECK_NE(tensor_info_map_[key].find(recver), tensor_info_map_[key].end()); msg.meta.val_len = std::get<0>(tensor_info_map_[key][recver]); msg.meta.addr = std::get<1>(tensor_info_map_[key][recver]); msg.meta.option = std::get<2>(tensor_info_map_[key][recver]); @@ -274,20 +393,34 @@ class RDMAVan : public Van { return false; } - void StoreMsgBuf(MessageBuffer *msg_buf, uint64_t key, bool is_push, int recver) { + void StoreMsgBuf(MessageBuffer *msg_buf, Message& msg) { std::lock_guard lk(addr_mu_); CHECK_EQ(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); - msgbuf_cache_[msg_buf] = std::make_tuple(key, is_push, recver); - } + msgbuf_cache_[msg_buf] = msg; + } + + Message* GetFirstMsg(MessageBuffer *msg_buf) { + std::lock_guard lk(addr_mu_); + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + return &msgbuf_cache_[msg_buf]; + } + + void ReleaseFirstMsg(MessageBuffer *msg_buf) { + std::lock_guard lk(addr_mu_); + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + msgbuf_cache_.erase(msg_buf); + } void StoreRemoteAndLocalInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { std::lock_guard lk(addr_mu_); CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + + auto& msg = msgbuf_cache_[msg_buf]; - auto key = std::get<0>(msgbuf_cache_[msg_buf]); - auto is_push = std::get<1>(msgbuf_cache_[msg_buf]); - auto recver = std::get<2>(msgbuf_cache_[msg_buf]); + auto key = msg.meta.key; + auto is_push = msg.meta.push; + auto recver = msg.meta.recver; auto t = std::make_tuple(remote_addr, rkey, idx, msg_buf); if (is_push) { @@ -318,11 +451,6 @@ class RDMAVan : public Van { std::lock_guard lock(map_mu_); if (mem_mr_.find(sa.data()) == mem_mr_.end()) { struct ibv_mr *temp_mr; - LOG(INFO) << "ibv_mr register size=" << sa.size() - << "\t key=" << DecodeKey(msg.data[0]) - << "\t " << (msg.meta.push ? "push" : "pull") - << " " << (msg.meta.request ? "request" : "response"); - CHECK (temp_mr = ibv_reg_mr(mem_allocator_->GetPD(), sa.data(), sa.size(), IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) << "Failed to register the memory region: " << strerror(errno) @@ -362,126 +490,6 @@ class RDMAVan : public Van { } } - int SendMsg(Message &msg) override { - int remote_id = msg.meta.recver; - CHECK_NE(remote_id, Meta::kEmpty); - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); - Endpoint *endpoint = endpoints_[remote_id].get(); - - int meta_len = GetPackMetaLen(msg.meta); - size_t data_len = msg.meta.data_size; - size_t total_len = meta_len + data_len; - CHECK(meta_len); - - RegisterMemory(msg); - - // pack meta info - if (IsValidPushpull(msg)) { - AddMeta(msg); - PackWorkerTensorAddress(msg); - } - - auto trans = CHECK_NOTNULL(endpoint->GetTransport()); - - // start rendezvous if no remote info - if (!IsValidPushpull(msg)) { - MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); - StoreMsgBuf(msg_buf, 0, 0, -1); - trans->SendRendezvousBegin(msg, msg_buf); - return total_len; - - } else { - auto is_push = msg.meta.push; - auto key = msg.meta.key; - if (!HasRemoteInfo(msg, key, is_push, remote_id)) { - MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); - StoreMsgBuf(msg_buf, key, is_push, remote_id); - PrepareData(msg, msg_buf); - trans->SendRendezvousBegin(msg, msg_buf); - return total_len; - } - } - - auto addr_tuple = GetRemoteAndLocalInfo(msg.meta.key, msg.meta.push, remote_id); - MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer - - // prepare new meta and data - CHECK_EQ(msg_buf->inline_len, (size_t) meta_len); - CHECK(msg_buf->inline_buf); - msg_buf->data = msg.data; // may not need this - PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); - - // already know remote address, directly use RDMA-write - if (msg.meta.push && msg.meta.request) { - // worker, push request - trans->SendPushRequest(msg, msg_buf, addr_tuple); - } else if (msg.meta.push && !msg.meta.request) { - // server, push response - trans->SendPushResponse(msg, msg_buf, addr_tuple); - } else if (!msg.meta.push && msg.meta.request) { - // worker, pull request - trans->SendPullRequest(msg, msg_buf, addr_tuple); - } else if (!msg.meta.push && !msg.meta.request) { - // server, pull response - map_mu_.lock(); - auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, mem_mr_.end()); - map_mu_.unlock(); - trans->SendPullResponse(msg, msg_buf, addr_tuple, temp_mr->second->lkey); - } else { - CHECK(0) << "unexpected message type"; - } - - return total_len; - } - - int RecvMsg(Message *msg) override { - msg->data.clear(); - std::tuple notification; - recv_buffers_.WaitAndPop(¬ification); - - Endpoint *endpoint = std::get(notification); - BufferContext *buffer_ctx = std::get(notification); - - msg->meta.recver = my_node_.id; - msg->meta.sender = endpoint->node_id; - - // the second argument is actually deprecated, - // we keep it as is in order to be compatible - UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); - int meta_len = GetPackMetaLen(msg->meta); - - int total_len = 0; - total_len += meta_len; - - auto trans = CHECK_NOTNULL(endpoint->GetTransport()); - - if (!IsValidPushpull(*msg)) { - return total_len; - } - - // valid data message - if (msg->meta.push && msg->meta.request) { - // push request - total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); - StoreWorkerTensorAddress(msg); - } else if (!msg->meta.push && msg->meta.request) { - // pull request - total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); - } else if (msg->meta.push && !msg->meta.request) { - // push response - total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); - } else if (!msg->meta.push && !msg->meta.request) { - // pull response - total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); - } else { - CHECK(0) << "unknown msg type"; - } - - return total_len; - } - - private: void InitContext(struct ibv_context *context) { context_ = context; CHECK(context_) << "ibv_context* empty"; @@ -561,7 +569,6 @@ class RDMAVan : public Van { trans->SendRendezvousReply(req, addr_pool_); } else if (imm == kRendezvousReply) { - auto trans = CHECK_NOTNULL(endpoint->GetTransport()); RendezvousReply *resp = reinterpret_cast(mr->addr); uint64_t remote_addr = resp->addr; @@ -575,7 +582,36 @@ class RDMAVan : public Van { // Before RDMA write, store the remote info so that // subsequent write does not need repeated rendezvous StoreRemoteAndLocalInfo(msg_buf, remote_addr, rkey, idx); - trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + + Message *msg = GetFirstMsg(msg_buf); + + auto addr_tuple = GetRemoteAndLocalInfo(msg->meta.key, msg->meta.push, msg->meta.recver); + + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + if (!IsValidPushpull(*msg)) { + // control message + trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } else if (msg->meta.push && msg->meta.request) { + // worker, push request + trans->SendPushRequest(*msg, msg_buf, addr_tuple); + } else if (msg->meta.push && !msg->meta.request) { + // server, push response + trans->SendPushResponse(*msg, msg_buf, addr_tuple); + } else if (!msg->meta.push && msg->meta.request) { + // worker, pull request + trans->SendPullRequest(*msg, msg_buf, addr_tuple); + } else if (!msg->meta.push && !msg->meta.request) { + // server, pull response + map_mu_.lock(); + auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_.end()); + map_mu_.unlock(); + trans->SendPullResponse(*msg, msg_buf, addr_tuple, temp_mr->second->lkey); + } + + // release the msg_buf from msgbuf_cache_ + ReleaseFirstMsg(msg_buf); + } else { CHECK(0); } @@ -818,7 +854,7 @@ class RDMAVan : public Van { // , () std::unordered_map push_addr_; std::unordered_map pull_addr_; - std::unordered_map > msgbuf_cache_; // msg_buf, + std::unordered_map msgbuf_cache_; // msg_buf, msg std::mutex map_mu_; std::unordered_map mem_mr_; // (memory address, ibv_mr) From 7d1ec4932fe93e5cf591bcf81a86beadea969147 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Mon, 30 Dec 2019 19:20:55 +0800 Subject: [PATCH 78/79] log: add send/recv log --- src/rdma_transport.h | 2 +- src/rdma_van.h | 95 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 6653dd7a..44a802f9 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -352,7 +352,7 @@ class RDMATransport : public Transport { auto raddr = msg.meta.addr; auto rkey = msg.meta.option; auto len = msg.meta.val_len; - CHECK_EQ(msg.meta.val_len, msg_buf->data[1].size()); + CHECK_EQ((size_t) msg.meta.val_len, msg_buf->data[1].size()); struct ibv_sge sge; sge.addr = reinterpret_cast(msg_buf->data[1].data()); diff --git a/src/rdma_van.h b/src/rdma_van.h index 7e83cd2c..1854e78c 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -47,6 +47,11 @@ class RDMAVan : public Van { new std::thread(&RDMAVan::PollEvents, this)); } + // enable logging + val = Environment::Get()->find("BYTEPS_PRINT_RDMA_LOG"); + enable_log_ = val ? atoi(val) : false; + if (enable_log_) LOG(INFO) << "Enable RDMA logging."; + start_mu_.unlock(); Van::Start(customer_id); } @@ -123,7 +128,7 @@ class RDMAVan : public Van { } void Connect(const Node &node) override { - PS_VLOG(1) << "Connecting to Node " << node.id; + PS_VLOG(1) << "Connecting to Node " << node.id << ", My_Node=" << my_node_.id; CHECK_NE(node.id, node.kEmpty); CHECK_NE(node.port, node.kEmpty); CHECK(node.hostname.size()); @@ -274,6 +279,8 @@ class RDMAVan : public Van { msg_buf->data = msg.data; // may not need this PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + PrintSendLog(msg, msg_buf, addr_tuple); + // already know remote address, directly use RDMA-write if (msg.meta.push && msg.meta.request) { // worker, push request @@ -319,6 +326,8 @@ class RDMAVan : public Van { auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + PrintRecvLog(msg, buffer_ctx, meta_len); + if (!IsValidPushpull(*msg)) { return total_len; } @@ -346,6 +355,84 @@ class RDMAVan : public Van { private: + void PrintSendLog(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + if (!enable_log_) return; + std::lock_guard lock(log_mu_); + + if (!IsValidPushpull(msg)) { + LOG(INFO) << "Send Control Message" << std::flush; + } else if (msg.meta.push && msg.meta.request) { + // worker, push request + LOG(INFO) << "Send Push Request: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t tensor_len=" << msg_buf->mrs[1].first->lkey + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (msg.meta.push && !msg.meta.request) { + // server, push response + LOG(INFO) << "Send Push Response: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (!msg.meta.push && msg.meta.request) { + // worker, pull request + LOG(INFO) << "Send Pull Request: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (!msg.meta.push && !msg.meta.request) { + // server, pull response + LOG(INFO) << "Send Pull Response: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t tensor_len=" << msg.meta.val_len + << "\t idx=" << "none" + << "\t remote_addr=" << msg.meta.addr + << std::flush; + } + + } + + void PrintRecvLog(Message *msg, BufferContext *buffer_ctx, int meta_len) { + if (!enable_log_) return; + std::lock_guard lock(log_mu_); + + if (!IsValidPushpull(*msg)) { + LOG(INFO) << "Recv Control Message" << std::flush; + } else if (msg->meta.push && msg->meta.request) { + // push request + LOG(INFO) << "Recv Push Request: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << "\t tensor_len=" << buffer_ctx->data_len[1] + << std::flush; + } else if (!msg->meta.push && msg->meta.request) { + // pull request + LOG(INFO) << "Recv Pull Request: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << std::flush; + } else if (msg->meta.push && !msg->meta.request) { + // push response + LOG(INFO) << "Recv Push Response: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << std::flush; + } else if (!msg->meta.push && !msg->meta.request) { + // pull response + LOG(INFO) << "Recv Pull Response: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << "\t tensor_len=" << msg->meta.val_len; + } + } + void PackWorkerTensorAddress(Message &msg) { // must be pull response if (msg.meta.push || msg.meta.request) return; @@ -587,6 +674,8 @@ class RDMAVan : public Van { auto addr_tuple = GetRemoteAndLocalInfo(msg->meta.key, msg->meta.push, msg->meta.recver); + PrintSendLog(*msg, msg_buf, addr_tuple); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); if (!IsValidPushpull(*msg)) { // control message @@ -859,6 +948,10 @@ class RDMAVan : public Van { std::mutex map_mu_; std::unordered_map mem_mr_; // (memory address, ibv_mr) + // logging + bool enable_log_; + std::mutex log_mu_; + }; // class RDMAVan }; // namespace ps From f7a704ab5082e6be897b6a039e4b0d2c89f12435 Mon Sep 17 00:00:00 2001 From: jiangyimin Date: Tue, 31 Dec 2019 12:32:45 +0800 Subject: [PATCH 79/79] fix log --- src/rdma_transport.h | 7 ++----- src/rdma_van.h | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/rdma_transport.h b/src/rdma_transport.h index 44a802f9..fec1c70d 100644 --- a/src/rdma_transport.h +++ b/src/rdma_transport.h @@ -481,13 +481,10 @@ class IPCTransport : public RDMATransport { } void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { - msg_buf->mrs.clear(); // avoid rdma-write in RDMAWriteWithImm() - CHECK_EQ(msg_buf->mrs.size(), 0); Send(msg, msg_buf, remote_tuple); } void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) { - CHECK_EQ(msg_buf->mrs.size(), 0); auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); void* shm_addr = CHECK_NOTNULL(GetSharedMemory(kShmPrefix, msg.meta.key)); @@ -504,7 +501,6 @@ class IPCTransport : public RDMATransport { } int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { - CHECK(msg->meta.push && msg->meta.request); // get data message from local shared memory auto key = msg->meta.key; auto len = msg->meta.val_len; @@ -567,7 +563,8 @@ class IPCTransport : public RDMATransport { std::string shm_name(prefix); shm_name += std::to_string(base_key); int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name + << ", " << strerror(errno); struct stat sb; CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); diff --git a/src/rdma_van.h b/src/rdma_van.h index 1854e78c..945b205f 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -366,7 +366,7 @@ class RDMAVan : public Van { LOG(INFO) << "Send Push Request: key=" << msg.meta.key << "\t timestamp=" << msg.meta.timestamp << "\t recver=" << msg.meta.recver - << "\t tensor_len=" << msg_buf->mrs[1].first->lkey + << "\t tensor_len=" << msg_buf->mrs[1].second << "\t remote_idx=" << std::get<2>(remote_tuple) << "\t remote_addr=" << std::get<0>(remote_tuple) << std::flush;