From c2bdd75c5142f1146003b5640ceaf297aeba2b32 Mon Sep 17 00:00:00 2001 From: Yimin Jiang Date: Wed, 11 Dec 2019 12:55:15 +0800 Subject: [PATCH] rdma: add IPC support (#13) * basic worker->server * basic server->worker * finish GetSharedMemory() * bugfix for recvmsg * fix reserved_context * fix repeated release * can run 1v1 with very large partition bytes * improve env check * fix GetSharedMemory and clean * fix seg fault * add async copy * fix compile * join threads when shutdown * quick fix * add ipc benchmark * add ipc benchmark again * fix ipc benchmark * fix 2v2 --- Makefile | 2 +- src/rdma_van.h | 445 +++++++++++++++++++++-------- tests/local_multi_workers.sh | 5 +- tests/test_kv_app_ipc_benchmark.cc | 246 ++++++++++++++++ 4 files changed, 573 insertions(+), 125 deletions(-) mode change 100644 => 100755 src/rdma_van.h create mode 100644 tests/test_kv_app_ipc_benchmark.cc diff --git a/Makefile b/Makefile index 7bcfc69a8..3dc9b5661 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ CFLAGS = -std=c++14 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $ LIBS = -pthread ifeq ($(USE_RDMA), 1) -LIBS += -lrdmacm -libverbs +LIBS += -lrdmacm -libverbs -lrt CFLAGS += -DDMLC_USE_RDMA endif diff --git a/src/rdma_van.h b/src/rdma_van.h old mode 100644 new mode 100755 index 1d3def530..9c63c195b --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -11,6 +11,11 @@ #include #include +#include +#include +#include +#include +#include #include #include #include @@ -40,7 +45,7 @@ namespace ps { static const int kStartDepth = 128; static const int kWriteDepth = kStartDepth; -static const int kRxDepth = kStartDepth * 2; +static const int kRxDepth = kStartDepth + kWriteDepth; static const int kReplyDepth = kRxDepth; static const int kSGEntry = 4; @@ -55,6 +60,8 @@ static const size_t kAlignment = 8; static const int kMaxResolveRetry = 50000; static const int kBasePort = 9010; +// should have the same prefix with BytePS shared memory +static const std::string kShmPrefix("BytePS_ShM_"); template static inline T align_floor(T v, T align) { @@ -291,7 +298,7 @@ static_assert(std::is_pod::value, "RequestContext must be a POD type."); static const size_t kMempoolChunkSize = - std::max(sizeof(RendezvousStart), sizeof(RendezvousReply)); + std::max({sizeof(RendezvousStart), sizeof(RendezvousReply)}); template class AddressPool { @@ -472,6 +479,16 @@ struct Endpoint { } }; +struct AsyncCopy { + Endpoint* endpoint; + MessageBuffer* msg_buf; + void* dst; + void* src; + int len; + uint64_t meta_len; + bool shutdown; +}; + class RDMAVan : public Van { public: RDMAVan() { @@ -479,22 +496,51 @@ class RDMAVan : public Van { } ~RDMAVan() {} - protected: + protected: void Start(int customer_id) override { start_mu_.lock(); should_stop_ = false; auto val = Environment::Get()->find("DMLC_ROLE"); std::string role(val); - is_server = role=="server"; - if (is_server) LOG(INFO) << "This is server"; + is_server_ = role=="server"; + if (is_server_) LOG(INFO) << "This is server"; else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); val = Environment::Get()->find("ENABLE_RDMA_LOG"); enable_rdma_log_ = val? atoi(val) : false; if (enable_rdma_log_) LOG(INFO) << "Enable RDMA logging"; - else LOG(INFO) << "RDMA logging is disabled, you can enable it with ENABLE_RDMA_LOG=1"; + else LOG(INFO) << "You can enable RDMA logging with ENABLE_RDMA_LOG=1"; + + val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); + disable_ipc_ = val ? !atoi(val) : true; + if (disable_ipc_) LOG(INFO) << "Shared memory IPC has been disabled"; + + val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); + byteps_partition_bytes_ = val ? atoi(val) : 4096000; + val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); + auto byteps_local_size = val ? atoi(val) : 1; + byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); + if (!disable_ipc_) { + CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; + LOG(INFO) << "partition bytes set to " << byteps_partition_bytes_ << ", should be identical with byteps core"; + } + + val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); + ipc_copy_nthreads_ = val ? atoi(val) : 4; + if (!disable_ipc_) { + LOG(INFO) << "IPC async copy nthreads set to " << ipc_copy_nthreads_; + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto q = new ThreadsafeQueue; + async_copy_queue_.push_back(q); + } + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto t = new std::thread(&RDMAVan::AsyncCopyThread, this, i); + ipc_copy_thread_list_.push_back(t); + } + } + if (event_channel_ == nullptr) { event_channel_ = rdma_create_event_channel(); CHECK(event_channel_) << "Create RDMA event channel failed"; @@ -539,6 +585,13 @@ class RDMAVan : public Van { CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; CHECK(!ibv_destroy_comp_channel(comp_event_channel_)) << "Failed to destroy channel"; + + for (size_t i = 0; i < ipc_copy_thread_list_.size(); ++i) { + AsyncCopy m; + m.shutdown = true; + async_copy_queue_[i]->Push(m); + ipc_copy_thread_list_[i]->join(); + } // TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this // CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << @@ -593,7 +646,14 @@ class RDMAVan : public Van { return; } - std::string node_host_ip = node.hostname + ":" + std::to_string(node.port); + if (disable_ipc_) { + is_local_[node.id] = false; + } else { + std::lock_guard lock(local_mu_); + is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; + LOG(INFO) << "IPC connected to " << node.id; + } + if (node.id != Node::kEmpty) { auto it = endpoints_.find(node.id); @@ -687,6 +747,91 @@ class RDMAVan : public Van { return key; } + uint64_t DecodeWorkerKey(uint64_t key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::Postoffice::Get()->my_rank()]; + return key - kr.begin(); + } + + void* GetSharedMemory(const std::string& prefix, uint64_t key) { + std::lock_guard lock(shm_mu_); + auto worker_key = DecodeWorkerKey(key); + auto seq_num = worker_key % (1 << 16); + auto base_key = worker_key - seq_num; + uint64_t offset = byteps_partition_bytes_ * seq_num; + if (key_shm_addr_.find(base_key) != key_shm_addr_.end()) { + return key_shm_addr_[base_key] + offset; + } + std::string shm_name(prefix); + shm_name += std::to_string(base_key); + int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + + struct stat sb; + CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); + auto total_shm_size = sb.st_size; + + void* base_ptr = mmap(0, total_shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(base_ptr, (void*) -1) << strerror(errno); + key_shm_addr_[base_key] = base_ptr; + + LOG(INFO) << "open Shared Memory: " << shm_name + << ", offset=" << offset + << ", (in bytes) size=" << total_shm_size; + return key_shm_addr_[base_key] + offset; + } + + void SendRendezvousBegin(Endpoint* endpoint, + uint64_t origin_addr, WRContext *context, MessageTypes msg_type) { + struct ibv_sge sge; + sge.addr = origin_addr; + sge.lkey = context->buffer->lkey; + sge.length = sizeof(RendezvousStart); + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(context); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + + wr.imm_data = msg_type; + + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << strerror(errno); + } + + void AsyncCopyThread(int i) { + auto& q = async_copy_queue_[i]; + while (true) { + AsyncCopy m; + q->WaitAndPop(&m); + if (m.shutdown) break; + if (m.len == 0) continue; + + // TODO: use parallel copy + CHECK(m.dst); + CHECK(m.src); + memcpy(m.dst, m.src, m.len); + + WRContext *context = nullptr, *reserved = nullptr; + m.endpoint->free_write_ctx.WaitAndPop(&reserved); + m.endpoint->free_start_ctx.WaitAndPop(&context); + + m.msg_buf->reserved_context = reserved; + RendezvousStart *req = + reinterpret_cast(context->buffer->addr); + req->meta_len = m.meta_len; + req->origin_addr = reinterpret_cast(m.msg_buf); + + auto addr = reinterpret_cast(req); + req->data_num = 0; + SendRendezvousBegin(m.endpoint, addr, context, kRendezvousStart); + } + } + int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); @@ -708,12 +853,12 @@ class RDMAVan : public Van { } } + // init for inplace push_pull if (IsValidPushpull(msg)) { - if (!is_server) { // worker + if (!is_server_) { // worker std::lock_guard lock(map_mu_); uint64_t key = DecodeKey(msg.data[0]); msg.meta.key = key; - //LOG(INFO) << "key=" << key << ", " << std::string(msg.meta.push?"push":"pull"); if (msg.meta.push && msg.meta.request) { // push request CHECK_EQ(msg.data.size(), 3) << msg.data.size(); @@ -727,13 +872,14 @@ class RDMAVan : public Van { if (enable_rdma_log_) { LOG(INFO) << "send push key=" << key << ", val_len=" << msg.meta.val_len + << ", recver=" << msg.meta.recver << ", val_addr=" << msg.meta.addr << ", rkey=" << msg.meta.option; } } } if (!msg.meta.push && !msg.meta.request) { // server, pull response - CHECK(is_server); + CHECK(is_server_); CHECK_EQ(msg.data.size(), 3) << msg.data.size(); std::lock_guard lock(map_mu_); @@ -760,7 +906,6 @@ class RDMAVan : public Van { PBMeta meta; PackMetaPB(msg.meta, &meta); - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); MessageBuffer *msg_buf = new MessageBuffer(); @@ -768,9 +913,9 @@ class RDMAVan : public Van { size_t meta_len = meta.ByteSize(); size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; - CHECK(meta_len); + // prepare memory if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message msg_buf->inline_len = total_len; msg_buf->inline_buf = mempool_->Alloc(total_len); @@ -786,7 +931,7 @@ class RDMAVan : public Van { msg_buf->inline_buf = mempool_->Alloc(meta_len); msg_buf->data = msg.data; meta.SerializeToArray(msg_buf->inline_buf, meta_len); - if (!is_server) { // worker remains the same + if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers for (auto &sa : msg_buf->data) { if (sa.size()) { auto search_map_iterator = memory_mr_map.find(sa.data()); @@ -799,11 +944,13 @@ class RDMAVan : public Van { } } - if (is_server && IsValidPushpull(msg) && - !msg.meta.push && !msg.meta.request) { // server send pull response (vals) with RDMA-write + // server send pull response (vals) with RDMA-write / IPC + if (is_server_ && IsValidPushpull(msg) && + !msg.meta.push && !msg.meta.request) { std::lock_guard lock(map_mu_); auto key = msg.meta.key; auto recver = msg.meta.recver; + auto len = std::get<0>(key_meta_map_[key][recver]); CHECK_EQ(msg_buf->data.size(), 3) << "Actual msg_buf size is " << msg_buf->data.size(); CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) @@ -812,75 +959,77 @@ class RDMAVan : public Van { << "key=" << key << ", recver=" << recver << " not initiated"; - - auto len = std::get<0>(key_meta_map_[key][recver]); - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); - CHECK_EQ(msg_buf->data[1].size(), (unsigned int) len) - << msg_buf->data[1].size() << ", " << len; - - auto temp_mr = memory_mr_map.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map.end()); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(raddr); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - wr.wr.rdma.remote_addr = raddr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; + << msg_buf->data[1].size() << ", " << len; + + if (is_local_[remote_id]) { + // IPC + auto addr = (void*) msg_buf->data[1].data(); + CHECK(addr); + void* shm_addr = GetSharedMemory(kShmPrefix, key); + + // async copy + AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + return total_len; + } else { + // RDMA write + auto raddr = std::get<1>(key_meta_map_[key][recver]); + auto rkey = std::get<2>(key_meta_map_[key][recver]); + + auto temp_mr = memory_mr_map.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, memory_mr_map.end()); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->data[1].data()); + sge.length = msg_buf->data[1].size(); + sge.lkey = temp_mr->second->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(raddr); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + // wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + wr.wr.rdma.remote_addr = raddr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } } WRContext *context = nullptr, *reserved = nullptr; endpoint->free_write_ctx.WaitAndPop(&reserved); endpoint->free_start_ctx.WaitAndPop(&context); - + msg_buf->reserved_context = reserved; - RendezvousStart *req = reinterpret_cast(context->buffer->addr); req->meta_len = meta_len; - - for (size_t i = 0; i < msg.data.size(); ++i) { - req->data_len[i] = msg.data[i].size(); - } - req->data_num = msg.data.size(); req->origin_addr = reinterpret_cast(msg_buf); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(req); - sge.length = sizeof(RendezvousStart); - sge.lkey = context->buffer->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(context); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = kRendezvousStart; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << strerror(errno); - + + auto addr = reinterpret_cast(req); + + // rendezvous message, not data message + if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { + // local IPC with shared memory + req->data_num = 0; + auto key = DecodeKey(msg.data[0]); + CHECK_EQ(key, msg.meta.key); + } else { + // normal RDMA + req->data_num = msg.data.size(); + for (size_t i = 0; i < req->data_num; ++i) { + req->data_len[i] = msg.data[i].size(); + } + } + SendRendezvousBegin(endpoint, addr, context, kRendezvousStart); return total_len; } @@ -904,10 +1053,50 @@ class RDMAVan : public Van { uint64_t data_num = buffer_ctx->data_num; cur += buffer_ctx->meta_len; - if (IsValidPushpull(*msg) && !msg->meta.push && !msg->meta.request) { // worker + bool is_released = false; + if (is_server_ && IsValidPushpull(*msg) && is_local_[msg->meta.sender]) { + // get data message from local shared memory + auto key = msg->meta.key; + + std::lock_guard lock(map_mu_); + if (key_addr_map_.find(key) == key_addr_map_.end()) { + key_addr_map_[key] = key; + } + + SArray keys; + SArray vals; + SArray lens; + keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); + msg->data.push_back(keys); + total_len += keys.size(); + + if (msg->meta.push && msg->meta.request) { // push request + auto len = msg->meta.val_len; + if (key_len_map_.find(key) == key_len_map_.end()) { + key_len_map_[key] = len; + } + CHECK_EQ(len, key_len_map_[key]) << "key=" << key + << ": " << len << ", " << key_len_map_[key]; + + void* addr = GetSharedMemory(kShmPrefix, key); + vals.reset(reinterpret_cast(addr), len, [](void *){}); + lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); + msg->data.push_back(vals); + msg->data.push_back(lens); + } else { // pull request + msg->data.push_back(vals); + } + total_len += vals.size() + lens.size(); + + mempool_->Free(buffer_ctx->buffer); + is_released = true; + } + + if (IsValidPushpull(*msg) && !msg->meta.push && !msg->meta.request) { + // worker, get message directly from inplace tensor std::lock_guard lock(map_mu_); auto key = msg->meta.key; - CHECK(!is_server); + CHECK(!is_server_); if (key_len_map_.find(key) == key_len_map_.end()) { key_addr_map_[key] = (ps::Key) key; key_len_map_[key] = (int) msg->meta.val_len; @@ -948,11 +1137,11 @@ class RDMAVan : public Van { total_len += len; } } else { - mempool_->Free(buffer_ctx->buffer); + if (!is_released) mempool_->Free(buffer_ctx->buffer); } if (msg->meta.push && msg->meta.request) { // server - CHECK(is_server); + CHECK(is_server_); auto key = msg->meta.key; auto len = msg->meta.val_len; auto addr = msg->meta.addr; @@ -1026,6 +1215,40 @@ class RDMAVan : public Van { CHECK(0); } } + + void SendRendezvousReply(Endpoint* endpoint, BufferContext *buf_ctx, uint64_t origin_addr) { + WRContext *reply_ctx = nullptr; + endpoint->free_reply_ctx.WaitAndPop(&reply_ctx); + RendezvousReply *resp = + reinterpret_cast(reply_ctx->buffer->addr); + + char* buffer = buf_ctx->buffer; + resp->addr = reinterpret_cast(buffer); + resp->rkey = mempool_->RemoteKey(buffer); + resp->origin_addr = origin_addr; + resp->idx = addr_pool_.StoreAddress(buf_ctx); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(resp); + sge.length = sizeof(RendezvousReply); + sge.lkey = reply_ctx->buffer->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(reply_ctx); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + + wr.imm_data = kRendezvousReply; + + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } void PollCQ() { // Pre-allocated work completions array used for polling @@ -1075,8 +1298,8 @@ class RDMAVan : public Van { // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousStart"; RendezvousStart *req = reinterpret_cast(mr->addr); + BufferContext *buf_ctx = new BufferContext(); - uint64_t len = req->meta_len; buf_ctx->meta_len = req->meta_len; buf_ctx->data_num = req->data_num; @@ -1084,46 +1307,11 @@ class RDMAVan : public Van { buf_ctx->data_len[i] = req->data_len[i]; len += req->data_len[i]; } - - char *buffer = mempool_->Alloc(is_server ? len : req->meta_len); - CHECK(buffer) << "Alloc for " << len - << " bytes, data_num: " << req->data_num; - + char *buffer = mempool_->Alloc(is_server_ ? len : req->meta_len); + CHECK(buffer) << len; buf_ctx->buffer = buffer; - uint64_t origin_addr = req->origin_addr; - - WRContext *reply_ctx = nullptr; - endpoint->free_reply_ctx.WaitAndPop(&reply_ctx); - RendezvousReply *resp = - reinterpret_cast(reply_ctx->buffer->addr); - - resp->addr = reinterpret_cast(buffer); - resp->rkey = mempool_->RemoteKey(buffer); - resp->origin_addr = origin_addr; - resp->idx = addr_pool_.StoreAddress(buf_ctx); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(resp); - sge.length = sizeof(RendezvousReply); - sge.lkey = reply_ctx->buffer->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(reply_ctx); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = kRendezvousReply; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - + SendRendezvousReply(endpoint, buf_ctx, req->origin_addr); } else if (imm == kRendezvousReply) { // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousReply"; RendezvousReply *resp = @@ -1152,7 +1340,7 @@ class RDMAVan : public Van { sge[num_sge].lkey = pair.first->lkey; ++num_sge; } - if (is_server) CHECK_EQ(num_sge, 1) << num_sge; + if (is_server_) CHECK_EQ(num_sge, 1) << num_sge; WRContext *write_ctx = msg_buf->reserved_context; @@ -1363,6 +1551,8 @@ class RDMAVan : public Van { LOG(INFO) << "OnDisconnected from Node " << endpoint->node_id; } + int AlignTo(int input, int alignment) { return input / alignment * alignment; } + AddressPool addr_pool_; std::unique_ptr mempool_; @@ -1390,13 +1580,12 @@ class RDMAVan : public Van { // Recv buffer queue ThreadsafeQueue> recv_buffers_; - // JYM: the following are for push/pull buffer reuse - - // whether my role is server or not - bool is_server; + // role is server or worker + bool is_server_; // RDMA logging info bool enable_rdma_log_; + std::mutex map_mu_; // macros for key_meta_map using MetaInfo = std::tuple; // len, addr, rkey using SenderMeta = std::unordered_map; // sender as the key @@ -1407,8 +1596,20 @@ class RDMAVan : public Van { // a static address for the length std::unordered_map key_len_map_; - std::mutex map_mu_; + // local IPC related + bool disable_ipc_ = false; + std::mutex local_mu_; + std::unordered_map is_local_; + + std::mutex shm_mu_; + std::unordered_map key_shm_addr_; + + int byteps_partition_bytes_ = 4096000; + int ipc_copy_nthreads_; + std::vector ipc_copy_thread_list_; + std::vector*> async_copy_queue_; + std::atomic cpy_counter_{0}; }; // namespace ps }; // namespace ps diff --git a/tests/local_multi_workers.sh b/tests/local_multi_workers.sh index 089b1747a..b0f536167 100755 --- a/tests/local_multi_workers.sh +++ b/tests/local_multi_workers.sh @@ -14,8 +14,9 @@ shift arg="$@" # start the scheduler -export DMLC_PS_ROOT_URI='127.0.0.1' -export DMLC_PS_ROOT_PORT=8000 +export DMLC_PS_ROOT_URI=${DMLC_PS_ROOT_URI:-'127.0.0.1'} +export DMLC_PS_ROOT_PORT=${DMLC_PS_ROOT_PORT:-8000} +export DMLC_INTERFACE=${DMLC_INTERFACE:-eth10} export DMLC_ROLE='scheduler' ${bin} ${arg} & diff --git a/tests/test_kv_app_ipc_benchmark.cc b/tests/test_kv_app_ipc_benchmark.cc new file mode 100644 index 000000000..2ecf9a95b --- /dev/null +++ b/tests/test_kv_app_ipc_benchmark.cc @@ -0,0 +1,246 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ps/ps.h" +#define DATA_TYPE char +using namespace ps; + +enum MODE { + PUSH_THEN_PULL = 0, + PUSH_PULL_MIX_ENDLESS = 1 +}; +std::unordered_map > mem_map; +std::unordered_map _key_shm_addr; +std::unordered_map _key_shm_size; +std::unordered_map store_; +std::mutex mu_; + +void* OpenSharedMemory(const std::string& prefix, + uint64_t key, size_t size) { + std::string shm_name(prefix); + shm_name += std::to_string(key); + int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); + + void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(ptr, (void*)-1) << strerror(errno); + + LOG(INFO) << "initialized share memory size=" << size + << " for key=" << key << ", name=" << shm_name; + _key_shm_addr[shm_name] = ptr; + _key_shm_size[shm_name] = size; + return ptr; +} + +uint64_t DecodeKey(ps::Key key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; + return key - kr.begin(); +} + +template +void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { + std::lock_guard lk(mu_); + uint64_t key = DecodeKey(req_data.keys[0]); + if (req_meta.push) { + CHECK(req_data.lens.size()); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); + + if (mem_map.find(key) == mem_map.end()) { + PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; + size_t len = (size_t) req_data.vals.size(); + mem_map[key].keys.push_back(key); + mem_map[key].lens.push_back(len); + + store_[key] = (char*) malloc(len); + mem_map[key].vals = ps::SArray(store_[key], len, false); + } + + // send push response (empty) + KVPairs res; + server->Response(req_meta, res); + } else { // pull request + auto iter = mem_map.find(key); + CHECK_NE(iter, mem_map.end()); + server->Response(req_meta, iter->second); + } +} + +void StartServer() { + if (!IsServer()) return; + auto server = new KVServer(0); + server->set_request_handle(EmptyHandler); + RegisterExitCallback([server]() { delete server; }); +} + +struct PSKV { + SArray keys; // n keys + SArray lens; // the length of the i-th value +}; +std::unordered_map ps_kv_; + +uint64_t EncodeKey(uint64_t seed) { + return seed << 16; +} + +void RunWorker(int argc, char *argv[]) { + if (!IsWorker()) return; + CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; + KVWorker kv(0, 0); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + + const int num_servers = krs.size(); + LOG(INFO) << num_servers << " servers in total"; + CHECK_GT(num_servers, 0); + + // init + auto val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); + unsigned int partition_bytes = val ? atoi(val) : 4096000; + int len = atoi(argv[1]); + CHECK_GE(partition_bytes, len) + << "tensor partition is not supported in this benchmark" + << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; + int repeat = atoi(argv[2]); + MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL_MIX_ENDLESS; + + std::vector > server_vals; + for (int server = 0; server < num_servers; server++) { + auto key = EncodeKey(server); + auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); + SArray vals(addr, len, false); + server_vals.push_back(vals); + } + + // init broadcast + for (int server = 0; server < num_servers; server++) { + auto key = EncodeKey(server); + auto vals = server_vals[server]; + PSKV& pskv = ps_kv_[key]; + SArray keys; + ps::Key ps_key = krs[server].begin() + key; + keys.push_back(ps_key); + SArray lens; + lens.push_back(len); + pskv.keys.push_back(ps_key); + pskv.lens.push_back(len); + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + + switch(mode) { + case PUSH_THEN_PULL: { + LOG(INFO) << "PUSH_THEN_PULL mode"; + // push + uint64_t accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { + auto start = std::chrono::high_resolution_clock::now(); + for (int server = 0; server < num_servers; server++) { + auto key = EncodeKey(server); + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + auto end = std::chrono::high_resolution_clock::now(); + accumulated_ms += (end - start).count(); // ns + } + LL << "push " << len * sizeof(DATA_TYPE) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + + // pull + accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { + auto start = std::chrono::high_resolution_clock::now(); + for (int server = 0; server < num_servers; server++) { + auto key = EncodeKey(server); + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPull(keys, &vals, &lens)); + } + auto end = std::chrono::high_resolution_clock::now(); + accumulated_ms += (end - start).count(); // ns + } + + LL << "pull " << len * sizeof(DATA_TYPE) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + } + break; + + case PUSH_PULL_MIX_ENDLESS: { + LOG(INFO) << "PUSH_PULL_MIX_ENDLESS mode, should exit by Ctrl+C"; + std::vector timestamp_list; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto val = Environment::Get()->find("THRESHOLD"); + unsigned int threshold = val ? atoi(val) : 10; + val = Environment::Get()->find("LOG_DURATION"); + unsigned int log_duration = val ? atoi(val) : 50; + int cnt = 0; + while (1) { + for (int server = 0; server < num_servers; server++) { + auto key = EncodeKey(server); + PSKV& pskv = ps_kv_[key]; + auto keys = pskv.keys; + auto lens = pskv.lens; + auto vals = server_vals[server]; + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } + if (timestamp_list.size()/2/num_servers >= threshold) { // flow control + for (auto& ts : timestamp_list) { + kv.Wait(ts); + } + timestamp_list.clear(); + cnt++; + if (cnt % log_duration == 0) { + end = std::chrono::high_resolution_clock::now(); + LL << "Application goodput: " + << 8.0 * len * sizeof(DATA_TYPE) * num_servers * cnt * threshold / (end - start).count() + << " Gbps"; + cnt = 0; + start = std::chrono::high_resolution_clock::now(); + } + } + } + } break; + default: + CHECK(0) << "unknown mode " << mode; + } +} + +int main(int argc, char *argv[]) { + // disable multi-threaded processing first + setenv("ENABLE_SERVER_MULTIPULL", "0", 1); + setenv("BYTEPS_LOCAL_SIZE", "1", 1); + setenv("BYTEPS_ENABLE_IPC", "1", 0); + // start system + Start(0); + // setup server nodes + StartServer(); + // run worker nodes + RunWorker(argc, argv); + // stop system + Finalize(0, true); + // release shm + for (auto &it : _key_shm_addr) { + munmap(it.second, _key_shm_size[it.first]); + shm_unlink(it.first.c_str()); + } + return 0; +}