diff --git a/Makefile b/Makefile index 3dc9b566..9853008b 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,10 @@ endif INCPATH = -I./src -I./include -I$(DEPS_PATH)/include CFLAGS = -std=c++14 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -LIBS = -pthread +LIBS = -pthread -lrt ifeq ($(USE_RDMA), 1) -LIBS += -lrdmacm -libverbs -lrt +LIBS += -lrdmacm -libverbs CFLAGS += -DDMLC_USE_RDMA endif @@ -38,25 +38,21 @@ include make/deps.mk clean: rm -rf build $(TEST) tests/*.d tests/*.dSYM - find src -name "*.pb.[ch]*" -delete lint: python tests/lint.py ps all include/ps src ps: build/libps.a -OBJS = $(addprefix build/, customer.o postoffice.o van.o meta.pb.o) +OBJS = $(addprefix build/, customer.o postoffice.o van.o) build/libps.a: $(OBJS) ar crv $@ $(filter %.o, $?) -build/%.o: src/%.cc ${ZMQ} src/meta.pb.h +build/%.o: src/%.cc ${ZMQ} @mkdir -p $(@D) $(CXX) $(INCPATH) -std=c++0x -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@ -src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF} - $(PROTOC) --cpp_out=./src --proto_path=./src $< - -include build/*.d -include build/*/*.d diff --git a/include/ps/internal/van.h b/include/ps/internal/van.h index e33bf4f5..a943975e 100644 --- a/include/ps/internal/van.h +++ b/include/ps/internal/van.h @@ -16,7 +16,6 @@ #include "ps/internal/message.h" namespace ps { class Resender; -class PBMeta; /** * \brief Van sends messages to remote nodes * @@ -107,14 +106,14 @@ class Van { virtual int SendMsg(Message &msg) = 0; /** - * \brief pack meta into a string + * \brief get the length of pack meta */ - void PackMeta(const Meta &meta, char **meta_buf, int *buf_size); + int GetPackMetaLen(const Meta &meta); /** - * \brief pack meta into protobuf + * \brief pack meta into a string */ - void PackMetaPB(const Meta &meta, PBMeta *pb); + void PackMeta(const Meta &meta, char **meta_buf, int *buf_size); /** * \brief unpack meta from a string diff --git a/make/deps.mk b/make/deps.mk index 8463ecca..a8ec57be 100644 --- a/make/deps.mk +++ b/make/deps.mk @@ -1,21 +1,10 @@ # Install dependencies URL1=https://raw.githubusercontent.com/mli/deps/master/build -URL2=https://github.com/google/protobuf/releases/download/v3.5.1 ifndef WGET WGET = wget endif -# protobuf -PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h -${PROTOBUF}: - $(eval FILE=protobuf-cpp-3.5.1.tar.gz) - $(eval DIR=protobuf-3.5.1) - rm -rf $(FILE) $(DIR) - $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE) - cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install - rm -rf $(FILE) $(DIR) - # zmq ZMQ = ${DEPS_PATH}/include/zmq.h diff --git a/make/ps.mk b/make/ps.mk index 0b0f678d..866caf32 100644 --- a/make/ps.mk +++ b/make/ps.mk @@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1) ADD_CFLAGS += -DUSE_KEY32=1 endif -PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq -PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a) +PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lzmq +PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libzmq.a) diff --git a/src/meta.h b/src/meta.h new file mode 100644 index 00000000..98dff57b --- /dev/null +++ b/src/meta.h @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2018-2019 Bytedance Inc. + * Author: zhuyibo@bytedance.com (Yibo Zhu) +*/ +#ifndef PS_LITE_META_H_ +#define PS_LITE_META_H_ + +#include + +namespace ps { + +struct RawNode { + // the node role + int role; + // node id + int id; + // hostname or ip + char hostname[64]; + // the port this node is binding + int port; + // whether this node is created by failover + bool is_recovery; + // the locally unique id of an customer + int customer_id; +}; + +// system control info +struct RawControl { + int cmd; + int node_size; + int barrier_group; + uint64_t msg_sig; +}; + +// mete information about a message +struct RawMeta { + // message.head + int head; + // message.body + int body_size; + // if set, then it is system control task. otherwise, it is for app + RawControl control; + // true: a request task + // false: the response task to the request task with the same *time* + bool request; + // the unique id of an application + int app_id; + // the timestamp of this message + int timestamp; + // data type of message.data[i] + int data_type_size; + // the locally unique id of an customer + int customer_id; + // whether or not a push message + bool push; + // whether or not it's for SimpleApp + bool simple_app; + // message.data_size + int data_size; + // message.key + uint64_t key; + // message.addr + uint64_t addr; + // the length of the message's value + int val_len; + // the option field + int option; + + // body + // data_type + // node +}; + +} // namespace + +#endif diff --git a/src/meta.proto b/src/meta.proto deleted file mode 100644 index 9a146b8a..00000000 --- a/src/meta.proto +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright (c) 2015 by Contributors - */ -syntax = "proto2"; -package ps; -option optimize_for = LITE_RUNTIME; - -message PBNode { - // the node role - required int32 role = 1; - // node id - optional int32 id = 2; - // hostname or ip - optional string hostname = 3; - // the port this node is binding - optional int32 port = 4; - // whether this node is created by failover - optional bool is_recovery = 5; - // the locally unique id of an customer - optional int32 customer_id = 10; -} - -// system control info -message PBControl { - required int32 cmd = 1; - repeated PBNode node = 2; - optional int32 barrier_group = 3; - optional uint64 msg_sig = 4; -} - -// mete information about a message -message PBMeta { - // message.head - optional int32 head = 1; - // message.body - optional bytes body = 2; - // if set, then it is system control task. otherwise, it is for app - optional PBControl control = 3; - // true: a request task - // false: the response task to the request task with the same *time* - optional bool request = 4 [default = false]; - // the unique id of an application - optional int32 app_id = 7; - // the timestamp of this message - optional int32 timestamp = 8; - // data type of message.data[i] - repeated int32 data_type = 9 [packed=true]; - // the locally unique id of an customer - optional int32 customer_id = 10; - // whether or not a push message - optional bool push = 5; - // whether or not it's for SimpleApp - optional bool simple_app = 6 [default = false]; - // message.data_size - optional int32 data_size = 11; - // message.key - optional uint64 key = 12; - // message.addr - optional uint64 addr = 13 [default = 0]; - // the length of the message's value - optional int32 val_len = 14; - // the option field - optional int32 option = 15; -} diff --git a/src/rdma_transport.h b/src/rdma_transport.h new file mode 100644 index 00000000..fec1c70d --- /dev/null +++ b/src/rdma_transport.h @@ -0,0 +1,601 @@ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef PS_RDMA_TRANSPORT_H_ +#define PS_RDMA_TRANSPORT_H_ + +#ifdef DMLC_USE_RDMA + +#include "rdma_utils.h" + +namespace ps { + +class Transport; + +struct Endpoint { + enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; + + ConnectionStatus status; + int node_id; + std::condition_variable cv; + std::mutex connect_mu; + struct rdma_cm_id *cm_id; + std::shared_ptr trans; + + WRContext rx_ctx[kRxDepth]; + + WRContext start_ctx[kStartDepth]; + WRContext reply_ctx[kReplyDepth]; + + ThreadsafeQueue free_start_ctx; + ThreadsafeQueue free_reply_ctx; + + Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} + + ~Endpoint() { + for (int i = 0; i < kRxDepth; ++i) { + if (!(rx_ctx[i].buffer)) { + continue; + } + free(rx_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); + } + + for (int i = 0; i < kStartDepth; ++i) { + if (start_ctx[i].buffer) { + free(start_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); + } + } + + for (int i = 0; i < kReplyDepth; ++i) { + if (reply_ctx[i].buffer) { + free(reply_ctx[i].buffer->addr); + CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); + } + } + + rdma_destroy_qp(cm_id); + CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); + } + + void SetTransport(std::shared_ptr t) { trans = t; } + + std::shared_ptr GetTransport() { return trans; } + + void Disconnect() { + std::unique_lock lk(connect_mu); + CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); + cv.wait(lk, [this] { return status == IDLE; }); + trans.reset(); + } + + void SetNodeID(int id) { node_id = id; } + + void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, + ThreadsafeQueue *queue, size_t num, + WRContextType type) { + for (size_t i = 0; i < num; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); + CHECK(mr); + + ctx[i].type = type; + ctx[i].buffer = mr; + ctx[i].private_data = this; + queue->Push(&ctx[i]); + } + } + + void Init(struct ibv_cq *cq, struct ibv_pd *pd) { + struct ibv_qp_init_attr attr; + memset(&attr, 0, sizeof(ibv_qp_init_attr)); + attr.send_cq = cq; + attr.recv_cq = cq; + attr.cap.max_send_wr = kStartDepth + kReplyDepth; + attr.cap.max_recv_wr = kRxDepth; + attr.cap.max_send_sge = kSGEntry; + attr.cap.max_recv_sge = kSGEntry; + attr.qp_type = IBV_QPT_RC; + attr.sq_sig_all = 0; + + CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) + << "Create RDMA queue pair failed: " << strerror(errno); + + InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, + kRendezvousStartContext); + InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, + kRendezvousReplyContext); + + for (size_t i = 0; i < kRxDepth; ++i) { + void *buf; + ib_malloc((void**) &buf, kMempoolChunkSize); + CHECK(buf); + struct ibv_mr *mr = + ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); + CHECK(mr); + + rx_ctx[i].type = kReceiveContext; + rx_ctx[i].buffer = mr; + rx_ctx[i].private_data = this; + + PostRecv(&rx_ctx[i]); + } + } + + void PostRecv(WRContext *ctx) { + struct ibv_recv_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(ctx->buffer->addr); + sge.length = kMempoolChunkSize; + sge.lkey = ctx->buffer->lkey; + + wr.wr_id = reinterpret_cast(ctx); + wr.next = nullptr; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_recv failed."; + } +}; + +class Transport { + public: + virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) = 0; + + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) = 0; + + virtual void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) = 0; + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) = 0; + virtual void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) = 0; + virtual void SendRendezvousReply(RendezvousStart *req, AddressPool &pool) = 0; + + virtual SArray CreateFunctionalSarray(void *value, size_t size) = 0; + +}; // class Transport + + +class RDMATransport : public Transport { + public: + explicit RDMATransport(Endpoint *endpoint, MemoryAllocator *allocator) { + endpoint_ = CHECK_NOTNULL(endpoint); + allocator_ = CHECK_NOTNULL(allocator); + pagesize_ = sysconf(_SC_PAGESIZE); + + auto val = Environment::Get()->find("DMLC_ROLE"); + std::string role(val); + is_server_ = (role=="server"); + }; + + ~RDMATransport() {}; + + virtual void RDMAWriteWithImm(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->inline_buf); + sge.length = msg_buf->inline_len; + sge.lkey = allocator_->LocalKey(msg_buf->inline_buf); + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(msg_buf); + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.next = nullptr; + wr.imm_data = idx; + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + void SendRendezvousBegin(Message &msg, MessageBuffer *msg_buf) { + WRContext *context = nullptr; + endpoint_->free_start_ctx.WaitAndPop(&context); + + RendezvousStart *req = + reinterpret_cast(context->buffer->addr); + req->meta_len = msg_buf->inline_len; + req->origin_addr = reinterpret_cast(msg_buf); + req->data_num = msg_buf->data.size(); + for (size_t i = 0; i < req->data_num; ++i) { + req->data_len[i] = msg.data[i].size(); + } + + struct ibv_sge sge; + sge.addr = reinterpret_cast(req); + sge.lkey = context->buffer->lkey; + sge.length = sizeof(RendezvousStart); + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(context); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + wr.imm_data = kRendezvousStart; + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << strerror(errno); + } + + void SendRendezvousReply(RendezvousStart *req, AddressPool &addrpool) { + BufferContext *buf_ctx = new BufferContext(); + buf_ctx->meta_len = req->meta_len; + buf_ctx->data_num = req->data_num; + + size_t data_len = 0; + for (size_t i = 0; i < req->data_num; ++i) { + buf_ctx->data_len[i] = req->data_len[i]; + data_len += req->data_len[i]; + } + + // worker only needs a buffer for receving meta + char *buffer = allocator_->Alloc( + is_server_ ? (align_ceil(req->meta_len, pagesize_) + data_len) : req->meta_len); + CHECK(buffer); + buf_ctx->buffer = buffer; + + WRContext *reply_ctx = nullptr; + endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx); + + RendezvousReply *resp = + reinterpret_cast(reply_ctx->buffer->addr); + + resp->addr = reinterpret_cast(buffer); + resp->rkey = allocator_->RemoteKey(buffer); + resp->origin_addr = req->origin_addr; + resp->idx = addrpool.StoreAddress(buf_ctx); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(resp); + sge.length = sizeof(RendezvousReply); + sge.lkey = reply_ctx->buffer->lkey; + + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + + wr.wr_id = reinterpret_cast(reply_ctx); + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.next = nullptr; + + wr.imm_data = kRendezvousReply; + + wr.send_flags = IBV_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + } + + void Send(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + auto raddr = std::get<0>(remote_tuple); + auto rkey = std::get<1>(remote_tuple); + auto idx = std::get<2>(remote_tuple); + + RDMAWriteWithImm(msg_buf, raddr, rkey, idx); + } + + void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 3); + auto raddr = std::get<0>(remote_tuple); + auto rkey = std::get<1>(remote_tuple); + auto idx = std::get<2>(remote_tuple); + + // push request, split the meta and data into two writes + // further, it does not send keys and lens since these meta already carries these info + struct ibv_sge my_sge; + my_sge.addr = reinterpret_cast(msg_buf->mrs[1].first->addr); + my_sge.length = msg_buf->mrs[1].second; + my_sge.lkey = msg_buf->mrs[1].first->lkey; + + // this rdma-write will not trigger any signal both remotely and locally + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = 0; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + wr.sg_list = &my_sge; + wr.num_sge = 1; + wr.wr.rdma.rkey = rkey; + + // write to the next page-aligned address (remote_addr should already be aligned) + wr.wr.rdma.remote_addr = raddr + align_ceil(msg_buf->inline_len, pagesize_); + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + + RDMAWriteWithImm(msg_buf, raddr, rkey, idx); + } + + void SendPullRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 0); + Send(msg, msg_buf, remote_tuple); + } + + virtual void SendPushResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + CHECK_EQ(msg_buf->mrs.size(), 0); + Send(msg, msg_buf, remote_tuple); + } + + virtual void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) { + CHECK_EQ(msg_buf->mrs.size(), 0); + + auto raddr = msg.meta.addr; + auto rkey = msg.meta.option; + auto len = msg.meta.val_len; + CHECK_EQ((size_t) msg.meta.val_len, msg_buf->data[1].size()); + + struct ibv_sge sge; + sge.addr = reinterpret_cast(msg_buf->data[1].data()); + sge.length = len; + sge.lkey = lkey; + + // this rdma-write will not trigger any signal both remotely and locally + struct ibv_send_wr wr, *bad_wr = nullptr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = reinterpret_cast(raddr); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.next = nullptr; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.wr.rdma.remote_addr = raddr; + wr.wr.rdma.rkey = rkey; + + CHECK_EQ(ibv_post_send(endpoint_->cm_id->qp, &wr, &bad_wr), 0) + << "ibv_post_send failed."; + + // after write keys/vals/lens (no imm), write the meta (with imm) + Send(msg, msg_buf, remote_tuple); + } + + virtual int RecvPushResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { + CHECK_EQ(buffer_ctx->data_num, 0); + return 0; + } + + virtual int RecvPullRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; // add an empty sarray to pass kvapp check + + msg->data.push_back(keys); + msg->data.push_back(vals); + + return keys.size() + vals.size(); + } + + virtual int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + CHECK(msg->meta.push && msg->meta.request); + CHECK_EQ(buffer_ctx->data_num, 3); + uint32_t len = buffer_ctx->data_len[1]; + char* cur = buffer_ctx->buffer + align_ceil((size_t) meta_len, pagesize_); + + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; + vals.reset(cur, len, [](void *) {}); // no need to delete + + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return keys.size() + vals.size() + lens.size(); + } + + virtual int RecvPullResponse(Message *msg, BufferContext *buffer_ctx, int meta_len) { + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; + auto addr = msg->meta.addr; + vals.reset(reinterpret_cast(addr), msg->meta.val_len, [](void *){}); + + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return keys.size() + vals.size() + lens.size(); + } + + SArray CreateFunctionalSarray(void *value, size_t size) { + SArray sarr; + void *p = malloc(size); + memcpy(p, value, size); + sarr.reset((char *) p, size, [p](void *) { free(p); }); + return sarr; + } + + protected: + size_t pagesize_ = 4096; + Endpoint *endpoint_; + MemoryAllocator *allocator_; + bool is_server_; + +}; // class Transport + +class IPCTransport : public RDMATransport { + public: + + explicit IPCTransport(Endpoint *endpoint, MemoryAllocator *allocator) : RDMATransport(endpoint, allocator) { + auto val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); + ipc_copy_nthreads_ = val ? atoi(val) : 4; + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto q = new ThreadsafeQueue; + async_copy_queue_.push_back(q); + } + for (int i = 0; i < ipc_copy_nthreads_; ++i) { + auto t = new std::thread(&IPCTransport::AsyncCopyThread, this, i); + ipc_copy_thread_list_.push_back(t); + } + val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); + byteps_partition_bytes_ = val ? atoi(val) : 4096000; + + val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); + auto byteps_local_size = val ? atoi(val) : 8; + byteps_partition_bytes_ = RoundUp(byteps_partition_bytes_, byteps_local_size * sysconf(_SC_PAGESIZE)); + + val = Environment::Get()->find("BYTEPS_IPC_ENABLE_ASYNC_COPY"); + enable_async_copy_ = val ? atoi(val) : 1; // default enabled + if (!enable_async_copy_) LOG(INFO) << "Async copy has been disabled, this could affect the performance"; + }; + + ~IPCTransport() { + for (size_t i = 0; i < ipc_copy_thread_list_.size(); ++i) { + AsyncCopy m; + m.shutdown = true; + async_copy_queue_[i]->Push(m); + ipc_copy_thread_list_[i]->join(); + } + } + + void SendPushRequest(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + Send(msg, msg_buf, remote_tuple); + } + + void SendPullResponse(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple, size_t lkey) { + auto addr = (void*) CHECK_NOTNULL(msg.data[1].data()); + void* shm_addr = CHECK_NOTNULL(GetSharedMemory(kShmPrefix, msg.meta.key)); + + if (enable_async_copy_) { + // async copy with a simple load-balancing strategy + AsyncCopy m = {msg_buf, remote_tuple, shm_addr, addr, msg.meta.val_len, false}; + auto cnt = cpy_counter_.fetch_add(1); + async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + } else { + // synchronous copy + memcpy(shm_addr, addr, msg.meta.val_len); + Send(msg, msg_buf, remote_tuple); + } + } + + int RecvPushRequest(Message *msg, BufferContext *buffer_ctx, int meta_len) { + // get data message from local shared memory + auto key = msg->meta.key; + auto len = msg->meta.val_len; + + SArray keys = CreateFunctionalSarray(&msg->meta.key, sizeof(Key)); + + SArray vals; + void* addr = GetSharedMemory(kShmPrefix, key); + vals.reset(reinterpret_cast(addr), len, [](void *){}); + + SArray lens = CreateFunctionalSarray(&msg->meta.val_len, sizeof(int)); + + msg->data.push_back(keys); + msg->data.push_back(vals); + msg->data.push_back(lens); + + return keys.size() + vals.size() + lens.size(); + } + + private: + struct AsyncCopy { + MessageBuffer* msg_buf; + RemoteTuple remote_tuple; + void* dst; + void* src; + int len; + bool shutdown; + }; + + void AsyncCopyThread(int i) { + auto& q = async_copy_queue_[i]; + while (true) { + AsyncCopy m; + q->WaitAndPop(&m); + if (m.shutdown) break; + if (m.len == 0) continue; + + // TODO: use parallel copy + CHECK(m.dst); + CHECK(m.src); + memcpy(m.dst, m.src, m.len); + + auto raddr = std::get<0>(m.remote_tuple); + auto rkey = std::get<1>(m.remote_tuple); + auto idx = std::get<2>(m.remote_tuple); + + RDMAWriteWithImm(m.msg_buf, raddr, rkey, idx); + } + } + + void* GetSharedMemory(const std::string& prefix, uint64_t key) { + std::lock_guard lock(shm_mu_); + auto worker_key = DecodeWorkerKey(key); + auto seq_num = worker_key % (1 << 16); + auto base_key = worker_key - seq_num; + uint64_t offset = byteps_partition_bytes_ * seq_num; + if (key_shm_addr_.find(base_key) != key_shm_addr_.end()) { + return key_shm_addr_[base_key] + offset; + } + std::string shm_name(prefix); + shm_name += std::to_string(base_key); + int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name + << ", " << strerror(errno); + + struct stat sb; + CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); + auto total_shm_size = sb.st_size; + + void* base_ptr = mmap(0, total_shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(base_ptr, (void*) -1) << strerror(errno); + key_shm_addr_[base_key] = base_ptr; + + PS_VLOG(1) << "open Shared Memory: " << shm_name << ", offset=" + << offset << ", (in bytes) size=" << total_shm_size; + return key_shm_addr_[base_key] + offset; + } + + int ipc_copy_nthreads_; + std::vector ipc_copy_thread_list_; + std::vector*> async_copy_queue_; + std::atomic cpy_counter_{0}; + + int byteps_partition_bytes_ = 4096000; + + std::mutex shm_mu_; + std::unordered_map key_shm_addr_; + + bool enable_async_copy_; + +}; // class IPCTransport + + +}; // namespace ps + +#endif // DMLC_USE_RDMA +#endif // PS_RDMA_VAN_H_ + diff --git a/src/rdma_utils.h b/src/rdma_utils.h new file mode 100644 index 00000000..5d40a8ec --- /dev/null +++ b/src/rdma_utils.h @@ -0,0 +1,318 @@ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef PS_RDMA_UTILS_H_ +#define PS_RDMA_UTILS_H_ + +#ifdef DMLC_USE_RDMA + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ps/internal/threadsafe_queue.h" +#include "ps/internal/van.h" + + +namespace ps { + + +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) + +static const int kStartDepth = 1024; +static const int kWriteDepth = kStartDepth; + +static const int kRxDepth = kStartDepth + kWriteDepth; +static const int kReplyDepth = kRxDepth; + +static const int kSGEntry = 1; +static const int kTimeoutms = 1000; +static const int kRdmaListenBacklog = 128; +static const int kMaxConcurrentWorkRequest = + kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; +static const int kMaxHostnameLength = 16; +static const int kMaxDataFields = 4; + +static const int kMaxResolveRetry = 50000; +static const int kBasePort = 9010; + +// should have the same prefix with BytePS shared memory +static const std::string kShmPrefix("BytePS_ShM_"); + +template +static inline T align_floor(T v, T align) { + return v - (v % align); +} + +template +static inline T align_ceil(T v, T align) { + return align_floor(v + align - 1, align); +} + +static inline void ib_malloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 0, size); + *ptr = p; +} + +class MemoryAllocator { + public: + explicit MemoryAllocator(struct ibv_pd *pd) { + std::lock_guard lk(mu_); + pd_ = pd; + } + + ~MemoryAllocator() { + std::lock_guard lk(mu_); + for(auto &it : mr_) { + CHECK_EQ(ibv_dereg_mr(it.second), 0); + free(it.first); + } + } + + char *Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + // align to page size (usually 4KB) + size = align_ceil(size, pagesize_); + + char *p; + ib_malloc((void**) &p, size); + CHECK(p); + + struct ibv_mr *mr; + CHECK(mr = ibv_reg_mr(pd_, p, size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); + + std::lock_guard lk(mu_); + mr_[p] = mr; + used_list.emplace(p, size); + + return p; + } + + uint32_t LocalKey(char *addr) { + return Addr2MR(addr)->lkey; + } + + uint32_t RemoteKey(char *addr) { + return Addr2MR(addr)->rkey; + } + + struct ibv_pd* GetPD() { + return pd_; + } + + private: + // convert the memory address to its associated RDMA memory region + inline struct ibv_mr* Addr2MR(char *addr) { + std::lock_guard lk(mu_); + auto it = mr_.find(addr); + CHECK_NE(it, mr_.end()) + << "cannot find the associated memory region"; + + return it->second; + } + + std::mutex mu_; + struct ibv_pd *pd_; + size_t pagesize_ = sysconf(_SC_PAGESIZE); + std::unordered_map used_list; + std::unordered_map mr_; +}; + +enum MessageTypes : uint32_t { + kRendezvousStart, + kRendezvousReply, +}; + +struct RendezvousStart { + uint64_t meta_len; + uint64_t data_num; + uint64_t data_len[kMaxDataFields]; + uint64_t origin_addr; +}; + +struct RendezvousReply { + uint64_t addr; + uint64_t origin_addr; + uint32_t rkey; + uint32_t idx; +}; + +enum WRContextType { + kRendezvousStartContext, + kRendezvousReplyContext, + kWriteContext, + kReceiveContext +}; + +struct WRContext { + WRContextType type; + struct ibv_mr *buffer; + void *private_data; +}; + +struct BufferContext { + char *buffer; + size_t meta_len; + size_t data_num; + size_t data_len[kMaxDataFields]; +}; + +typedef std::unique_ptr> + MRPtr; + +struct MessageBuffer { + size_t inline_len; + char *inline_buf; + WRContext *reserved_context; + std::vector> data; + std::vector> mrs; +}; + +struct RequestContext { + uint32_t node; + uint16_t port; + char hostname[kMaxHostnameLength]; +}; + +// +typedef std::tuple RemoteTuple; + +// recver, +typedef std::unordered_map RemoteAndLocalAddress; + +static_assert(std::is_pod::value, + "RendezvousStart must be a POD type."); +static_assert(std::is_pod::value, + "RendezvousReply must be a POD type."); +static_assert(std::is_pod::value, + "RequestContext must be a POD type."); + +static const size_t kMempoolChunkSize = + std::max({sizeof(RendezvousStart), sizeof(RendezvousReply)}); + +template +class AddressPool { + public: + AddressPool() { + auto addrpool_size = Environment::Get()->find("BYTEPS_ADDRESS_POOL_SIZE"); + kMaxEntries = addrpool_size ? atoi(addrpool_size) : kMaxEntries; + std::lock_guard lk(mu_); + table_ = new T*[kMaxEntries]; + // init the queue + for (int i = 0; i < kMaxEntries; i++) { + indices_.push(i); + table_[i] = nullptr; + } + } + + T *GetAddressAndRelease(uint32_t index) { + std::lock_guard lk(mu_); + T *ptr = table_[index]; + CHECK(ptr); + indices_.push(index); + table_[index] = nullptr; + return ptr; + } + + // TODO: make the address pool size dynamic + T *GetAddress(uint32_t index) { + std::lock_guard lk(mu_); + return CHECK_NOTNULL(table_[index]); + } + + uint32_t StoreAddress(T *ptr) { + std::lock_guard lk(mu_); + CHECK(ptr); + CHECK(!indices_.empty()) + << "Address pool size is too small, " + << "current size is " << kMaxEntries + << ", consider increasing BYTEPS_ADDRESS_POOL_SIZE"; + uint32_t idx = indices_.front(); + indices_.pop(); + CHECK_EQ(table_[idx], nullptr) << idx; + table_[idx] = ptr; + return idx; + } + + private: + int kMaxEntries = 10240; + + std::mutex mu_; + std::queue indices_; + T **table_; +}; + +bool IsValidPushpull(const Message &msg) { + if (!msg.meta.control.empty()) return false; + if (msg.meta.simple_app) return false; + return true; +} + +uint64_t DecodeKey(SArray keys) { // just a translation, the decoded key might not be readable when we have multiple servers + ps::Key key = 0; + uint64_t coef = 1; + for (unsigned int i = 0; i < keys.size(); ++i) { + key += coef * (uint8_t) keys.data()[i]; + coef *= 256; // 256=2^8 (uint8_t) + } + return key; +} + +uint64_t DecodeWorkerKey(uint64_t key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::Postoffice::Get()->my_rank()]; + return key - kr.begin(); +} + +int AlignTo(int input, int alignment) { return input / alignment * alignment; } +int DivUp(int x, int y) { return (x + y - 1) / y; } +int RoundUp(int x, int y) { return DivUp(x, y) * y; } + +}; // namespace ps + +#endif // DMLC_USE_RDMA +#endif // PS_RDMA_VAN_H_ + diff --git a/src/rdma_van.h b/src/rdma_van.h index 9c63c195..945b205f 100755 --- a/src/rdma_van.h +++ b/src/rdma_van.h @@ -1,494 +1,28 @@ -/** - * Copyright (c) 2018-2019 Bytedance Inc. - * Author: lanchang@bytedance.com (Chang Lan) - * jiangyimin@bytedance.com (Yimin Jiang) - * chenjingrong@bytedance.com (Jingrong Chen) -*/ +// Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + #ifndef PS_RDMA_VAN_H_ #define PS_RDMA_VAN_H_ #ifdef DMLC_USE_RDMA -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ps/internal/threadsafe_queue.h" -#include "ps/internal/van.h" +#include "rdma_utils.h" +#include "rdma_transport.h" namespace ps { -#define DIVUP(x, y) (((x)+(y)-1)/(y)) -#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) - -static const int kStartDepth = 128; -static const int kWriteDepth = kStartDepth; - -static const int kRxDepth = kStartDepth + kWriteDepth; -static const int kReplyDepth = kRxDepth; - -static const int kSGEntry = 4; -static const int kTimeoutms = 1000; -static const int kRdmaListenBacklog = 128; -static const int kMaxConcurrentWorkRequest = - kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; -static const int kMaxHostnameLength = 16; -static const int kMaxDataFields = 4; -static const size_t kAlignment = 8; - -static const int kMaxResolveRetry = 50000; -static const int kBasePort = 9010; - -// should have the same prefix with BytePS shared memory -static const std::string kShmPrefix("BytePS_ShM_"); - -template -static inline T align_floor(T v, T align) { - return v - (v % align); -} - -template -static inline T align_ceil(T v, T align) { - return align_floor(v + align - 1, align); -} - -static inline void ib_malloc(void** ptr, size_t size) { - size_t page_size = sysconf(_SC_PAGESIZE); - void* p; - int size_aligned = ROUNDUP(size, page_size); - int ret = posix_memalign(&p, page_size, size_aligned); - CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); - CHECK(p); - memset(p, 0, size); - *ptr = p; -} - -class SimpleMempool { - public: - explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) { - std::lock_guard lk(mu_); - pd_ = pd; - struct ibv_mr *mr; - - // set init mempool size - auto byteps_rdma_mempool_size = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_SIZE"); - size = byteps_rdma_mempool_size ? atoi(byteps_rdma_mempool_size) : size; - auto byteps_rdma_mempool_num = Environment::Get()->find("BYTEPS_RDMA_MEMPOOL_NUM"); - size_t mempool_num = byteps_rdma_mempool_num ? atoi(byteps_rdma_mempool_num) : 1; - PS_VLOG(1) << "RDMA initial mempool size set to " << size - << ", mempool num set to " << mempool_num; - - for (size_t i = 0; i < mempool_num; ++i) { - char *p; - ib_malloc((void**) &p, size); - total_allocated_size += size; - CHECK(p); - CHECK(mr = ibv_reg_mr(pd, p, size, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size] - free_list.emplace(size, p); - } - } - - ~SimpleMempool() { - std::lock_guard lk(mu_); - for(auto it = mr_list.begin(); it != mr_list.end(); it++){ - CHECK_EQ(ibv_dereg_mr(it->second), 0); - free(it->second->addr); - } - } - - char *Alloc(size_t size) { - if (size == 0) { - return nullptr; - } - - std::lock_guard lk(mu_); - - size_t proper_size = align_ceil(size, kAlignment); - - auto it = free_list.lower_bound(proper_size); - - if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory - size_t new_mem_size = total_allocated_size; - while (proper_size > new_mem_size) { - new_mem_size *= 2; - } - char *p; - ib_malloc((void**) &p, new_mem_size); - CHECK(p); - struct ibv_mr *mr; - CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); - mr_list.emplace(p+new_mem_size, mr); - free_list.emplace(new_mem_size, p); - it = free_list.lower_bound(proper_size); - PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; - total_allocated_size += new_mem_size; - } - - CHECK_NE(free_list.end(), it) << "Not enough memory"; - CHECK_GE(it->first, proper_size); - - char *addr = it->second; - size_t space_left = it->first - proper_size; - - free_list.erase(it); - CHECK_EQ(used_list.find(addr), used_list.end()) - << "Address is already allocated"; - - used_list.emplace(addr, proper_size); - - if (space_left) { - free_list.emplace(space_left, addr + proper_size); - } - - return addr; - } - - void Free(char *addr) { - if (!addr) { - return; - } - - std::lock_guard lk(mu_); - - auto it = used_list.find(addr); - CHECK_NE(used_list.end(), it) - << "Cannot find info about address: " << (uintptr_t)addr; - - size_t size = it->second; - used_list.erase(it); - free_list.emplace(size, addr); - } - - uint32_t LocalKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->lkey; - } - uint32_t RemoteKey(char *addr) { - struct ibv_mr *mr = Addr2MR(addr); - return mr->rkey; - } - - private: - std::mutex mu_; - std::multimap free_list; - std::unordered_map used_list; - struct ibv_pd *pd_; - size_t total_allocated_size = 0; - - // first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) - std::map mr_list; - - // convert the memory address to its associated RDMA memory region - inline struct ibv_mr* Addr2MR(char *addr) { - std::lock_guard lk(mu_); - auto it = mr_list.lower_bound(addr); - CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; - return it->second; - } - -}; - -class Block { - public: - explicit Block(SimpleMempool *pool, char *addr, int count) - : pool(pool), addr(addr), counter(count) {} - - ~Block() { - CHECK_EQ(counter, 0); - pool->Free(addr); - } - - void Release() { - int v = counter.fetch_sub(1); - if (v == 1) { - delete this; - } - } - - private: - SimpleMempool *pool; - char *addr; - std::atomic counter; -}; - -enum MessageTypes : uint32_t { - kRendezvousStart, - kRendezvousReply, -}; - -struct RendezvousStart { - uint64_t meta_len; - uint64_t data_num; - uint64_t data_len[kMaxDataFields]; - uint64_t origin_addr; -}; - -struct RendezvousReply { - uint64_t addr; - uint64_t origin_addr; - uint32_t rkey; - uint32_t idx; -}; - -enum WRContextType { - kRendezvousStartContext, - kRendezvousReplyContext, - kWriteContext, - kReceiveContext -}; - -struct WRContext { - WRContextType type; - struct ibv_mr *buffer; - void *private_data; -}; - -struct BufferContext { - char *buffer; - size_t meta_len; - size_t data_num; - size_t data_len[kMaxDataFields]; -}; - -typedef std::unique_ptr> - MRPtr; - -struct MessageBuffer { - size_t inline_len; - char *inline_buf; - WRContext *reserved_context; - std::vector> data; - std::vector> mrs; -}; - -struct RequestContext { - uint32_t node; - uint16_t port; - char hostname[kMaxHostnameLength]; -}; - -static_assert(std::is_pod::value, - "RendezvousStart must be a POD type."); -static_assert(std::is_pod::value, - "RendezvousReply must be a POD type."); -static_assert(std::is_pod::value, - "RequestContext must be a POD type."); - -static const size_t kMempoolChunkSize = - std::max({sizeof(RendezvousStart), sizeof(RendezvousReply)}); - -template -class AddressPool { - public: - AddressPool() { - std::lock_guard lk(mu_); - // init the queue - for (int i = 0; i < kMaxEntries; i++) { - indices_.push(i); - table_[i] = nullptr; - } - } - - T *GetAddressAndRelease(uint32_t index) { - std::lock_guard lk(mu_); - T *ptr = table_[index]; - CHECK(ptr); - indices_.push(index); - table_[index] = nullptr; - return ptr; - } - - uint32_t StoreAddress(T *ptr) { - std::lock_guard lk(mu_); - CHECK(ptr); - CHECK(!indices_.empty()) - << "Address pool size is too small, " - << "consider increasing kMaxEntries"; - uint32_t idx = indices_.front(); - indices_.pop(); - CHECK_EQ(table_[idx], nullptr) << idx; - table_[idx] = ptr; - return idx; - } - - private: - static const int kMaxEntries = 5120; - - std::mutex mu_; - std::queue indices_; - T *table_[kMaxEntries]; -}; - -struct Endpoint { - enum ConnectionStatus { IDLE, CONNECTING, CONNECTED, REJECTED }; - - ConnectionStatus status; - int node_id; - std::condition_variable cv; - std::mutex connect_mu; - struct rdma_cm_id *cm_id; - - WRContext rx_ctx[kRxDepth]; - - WRContext start_ctx[kStartDepth]; - WRContext reply_ctx[kReplyDepth]; - WRContext write_ctx[kWriteDepth]; - - ThreadsafeQueue free_start_ctx; - ThreadsafeQueue free_reply_ctx; - ThreadsafeQueue free_write_ctx; - - Endpoint() : status(IDLE), node_id(Node::kEmpty), cm_id(nullptr), rx_ctx() {} - - ~Endpoint() { - for (int i = 0; i < kRxDepth; ++i) { - if (!(rx_ctx[i].buffer)) { - continue; - } - free(rx_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0); - } - - for (int i = 0; i < kStartDepth; ++i) { - if (start_ctx[i].buffer) { - free(start_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kReplyDepth; ++i) { - if (reply_ctx[i].buffer) { - free(reply_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0); - } - } - - for (int i = 0; i < kWriteDepth; ++i) { - if (write_ctx[i].buffer) { - free(write_ctx[i].buffer->addr); - CHECK_EQ(ibv_dereg_mr(write_ctx[i].buffer), 0); - } - } - - rdma_destroy_qp(cm_id); - CHECK_EQ(rdma_destroy_id(cm_id), 0) << strerror(errno); - } - - void Disconnect() { - std::unique_lock lk(connect_mu); - CHECK_EQ(rdma_disconnect(cm_id), 0) << strerror(errno); - cv.wait(lk, [this] { return status == IDLE; }); - } - - void SetNodeID(int id) { node_id = id; } - - void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx, - ThreadsafeQueue *queue, size_t num, - WRContextType type) { - for (size_t i = 0; i < num; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0); - CHECK(mr); - - ctx[i].type = type; - ctx[i].buffer = mr; - ctx[i].private_data = this; - queue->Push(&ctx[i]); - } - } - - void Init(struct ibv_cq *cq, struct ibv_pd *pd) { - struct ibv_qp_init_attr attr; - memset(&attr, 0, sizeof(ibv_qp_init_attr)); - attr.send_cq = cq; - attr.recv_cq = cq; - attr.cap.max_send_wr = kStartDepth + kReplyDepth + kWriteDepth; - attr.cap.max_recv_wr = kRxDepth; - attr.cap.max_send_sge = kSGEntry; - attr.cap.max_recv_sge = kSGEntry; - attr.qp_type = IBV_QPT_RC; - attr.sq_sig_all = 0; - - CHECK_EQ(rdma_create_qp(cm_id, pd, &attr), 0) - << "Create RDMA queue pair failed"; - - InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth, - kRendezvousStartContext); - InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth, - kRendezvousReplyContext); - InitSendContextHelper(pd, write_ctx, &free_write_ctx, kWriteDepth, - kWriteContext); - - for (size_t i = 0; i < kRxDepth; ++i) { - void *buf; - ib_malloc((void**) &buf, kMempoolChunkSize); - CHECK(buf); - struct ibv_mr *mr = - ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE); - CHECK(mr); - - rx_ctx[i].type = kReceiveContext; - rx_ctx[i].buffer = mr; - rx_ctx[i].private_data = this; - - PostRecv(&rx_ctx[i]); - } - } - - void PostRecv(WRContext *ctx) { - struct ibv_recv_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(ctx->buffer->addr); - sge.length = kMempoolChunkSize; - sge.lkey = ctx->buffer->lkey; - - wr.wr_id = reinterpret_cast(ctx); - wr.next = nullptr; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_recv(cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_recv failed."; - } -}; - -struct AsyncCopy { - Endpoint* endpoint; - MessageBuffer* msg_buf; - void* dst; - void* src; - int len; - uint64_t meta_len; - bool shutdown; -}; - class RDMAVan : public Van { public: RDMAVan() { @@ -501,46 +35,10 @@ class RDMAVan : public Van { start_mu_.lock(); should_stop_ = false; - auto val = Environment::Get()->find("DMLC_ROLE"); - std::string role(val); - is_server_ = role=="server"; - if (is_server_) LOG(INFO) << "This is server"; - else LOG(INFO) << "This is " << ((role=="worker") ? "worker" : "scheduler"); - - val = Environment::Get()->find("ENABLE_RDMA_LOG"); - enable_rdma_log_ = val? atoi(val) : false; - if (enable_rdma_log_) LOG(INFO) << "Enable RDMA logging"; - else LOG(INFO) << "You can enable RDMA logging with ENABLE_RDMA_LOG=1"; - - val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); + auto val = Environment::Get()->find("BYTEPS_ENABLE_IPC"); disable_ipc_ = val ? !atoi(val) : true; if (disable_ipc_) LOG(INFO) << "Shared memory IPC has been disabled"; - val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); - byteps_partition_bytes_ = val ? atoi(val) : 4096000; - - val = Environment::Get()->find("BYTEPS_LOCAL_SIZE"); - auto byteps_local_size = val ? atoi(val) : 1; - byteps_partition_bytes_ = AlignTo(byteps_partition_bytes_, (8 * byteps_local_size)); - if (!disable_ipc_) { - CHECK(val) << "BYTEPS_LOCAL_SIZE not set"; - LOG(INFO) << "partition bytes set to " << byteps_partition_bytes_ << ", should be identical with byteps core"; - } - - val = Environment::Get()->find("BYTEPS_IPC_COPY_NUM_THREADS"); - ipc_copy_nthreads_ = val ? atoi(val) : 4; - if (!disable_ipc_) { - LOG(INFO) << "IPC async copy nthreads set to " << ipc_copy_nthreads_; - for (int i = 0; i < ipc_copy_nthreads_; ++i) { - auto q = new ThreadsafeQueue; - async_copy_queue_.push_back(q); - } - for (int i = 0; i < ipc_copy_nthreads_; ++i) { - auto t = new std::thread(&RDMAVan::AsyncCopyThread, this, i); - ipc_copy_thread_list_.push_back(t); - } - } - if (event_channel_ == nullptr) { event_channel_ = rdma_create_event_channel(); CHECK(event_channel_) << "Create RDMA event channel failed"; @@ -549,6 +47,11 @@ class RDMAVan : public Van { new std::thread(&RDMAVan::PollEvents, this)); } + // enable logging + val = Environment::Get()->find("BYTEPS_PRINT_RDMA_LOG"); + enable_log_ = val ? atoi(val) : false; + if (enable_log_) LOG(INFO) << "Enable RDMA logging."; + start_mu_.unlock(); Van::Start(customer_id); } @@ -568,14 +71,8 @@ class RDMAVan : public Van { cm_event_polling_thread_->join(); cm_event_polling_thread_.reset(); - PS_VLOG(1) << "Clearing mempool."; - mempool_.reset(); - - auto map_iter = memory_mr_map.begin(); - while (map_iter != memory_mr_map.end()) { - ibv_dereg_mr(map_iter->second); - map_iter++; - } + PS_VLOG(1) << "Clearing memory allocator."; + mem_allocator_.reset(); PS_VLOG(1) << "Clearing endpoints."; incoming_.clear(); @@ -585,13 +82,8 @@ class RDMAVan : public Van { CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; CHECK(!ibv_destroy_comp_channel(comp_event_channel_)) << "Failed to destroy channel"; - - for (size_t i = 0; i < ipc_copy_thread_list_.size(); ++i) { - AsyncCopy m; - m.shutdown = true; - async_copy_queue_[i]->Push(m); - ipc_copy_thread_list_[i]->join(); - } + + for (auto& it : mem_mr_) ibv_dereg_mr(it.second); // TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this // CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << @@ -636,7 +128,7 @@ class RDMAVan : public Van { } void Connect(const Node &node) override { - PS_VLOG(1) << "Connecting to Node " << node.id; + PS_VLOG(1) << "Connecting to Node " << node.id << ", My_Node=" << my_node_.id; CHECK_NE(node.id, node.kEmpty); CHECK_NE(node.port, node.kEmpty); CHECK(node.hostname.size()); @@ -646,14 +138,6 @@ class RDMAVan : public Van { return; } - if (disable_ipc_) { - is_local_[node.id] = false; - } else { - std::lock_guard lock(local_mu_); - is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; - LOG(INFO) << "IPC connected to " << node.id; - } - if (node.id != Node::kEmpty) { auto it = endpoints_.find(node.id); @@ -727,309 +211,97 @@ class RDMAVan : public Van { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } - freeaddrinfo(remote_addr); - } - } - - bool IsValidPushpull(const Message &msg) { - if (!msg.meta.control.empty()) return false; - if (msg.meta.simple_app) return false; - return true; - } - - uint64_t DecodeKey(SArray keys) { // just a translation, the decoded key might not be readable when we have multiple servers - ps::Key key = 0; - uint64_t coef = 1; - for (unsigned int i = 0; i < keys.size(); ++i) { - key += coef * (uint8_t) keys.data()[i]; - coef *= 256; // 256=2^8 (uint8_t) - } - return key; - } - - uint64_t DecodeWorkerKey(uint64_t key) { - auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::Postoffice::Get()->my_rank()]; - return key - kr.begin(); - } - - void* GetSharedMemory(const std::string& prefix, uint64_t key) { - std::lock_guard lock(shm_mu_); - auto worker_key = DecodeWorkerKey(key); - auto seq_num = worker_key % (1 << 16); - auto base_key = worker_key - seq_num; - uint64_t offset = byteps_partition_bytes_ * seq_num; - if (key_shm_addr_.find(base_key) != key_shm_addr_.end()) { - return key_shm_addr_[base_key] + offset; + local_mu_.lock(); + if (disable_ipc_) { + is_local_[node.id] = false; + } else { + is_local_[node.id] = (node.hostname == my_node_.hostname) ? true : false; } - std::string shm_name(prefix); - shm_name += std::to_string(base_key); - int shm_fd = shm_open(shm_name.c_str(), O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; - - struct stat sb; - CHECK_EQ(0, fstat(shm_fd, &sb)) << strerror(errno); - auto total_shm_size = sb.st_size; - - void* base_ptr = mmap(0, total_shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); - CHECK_NE(base_ptr, (void*) -1) << strerror(errno); - key_shm_addr_[base_key] = base_ptr; - - LOG(INFO) << "open Shared Memory: " << shm_name - << ", offset=" << offset - << ", (in bytes) size=" << total_shm_size; - return key_shm_addr_[base_key] + offset; - } - - void SendRendezvousBegin(Endpoint* endpoint, - uint64_t origin_addr, WRContext *context, MessageTypes msg_type) { - struct ibv_sge sge; - sge.addr = origin_addr; - sge.lkey = context->buffer->lkey; - sge.length = sizeof(RendezvousStart); - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); + LOG(INFO) << "Connect to Node " << node.id + << " with Transport=" << (is_local_[node.id]?"IPC" : "RDMA"); + local_mu_.unlock(); - wr.wr_id = reinterpret_cast(context); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; + std::shared_ptr t = is_local_[node.id] ? + std::make_shared(endpoint, mem_allocator_.get()) : + std::make_shared(endpoint, mem_allocator_.get()); + endpoint->SetTransport(t); - wr.imm_data = msg_type; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << strerror(errno); - } - - void AsyncCopyThread(int i) { - auto& q = async_copy_queue_[i]; - while (true) { - AsyncCopy m; - q->WaitAndPop(&m); - if (m.shutdown) break; - if (m.len == 0) continue; - - // TODO: use parallel copy - CHECK(m.dst); - CHECK(m.src); - memcpy(m.dst, m.src, m.len); - - WRContext *context = nullptr, *reserved = nullptr; - m.endpoint->free_write_ctx.WaitAndPop(&reserved); - m.endpoint->free_start_ctx.WaitAndPop(&context); - - m.msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = m.meta_len; - req->origin_addr = reinterpret_cast(m.msg_buf); - - auto addr = reinterpret_cast(req); - req->data_num = 0; - SendRendezvousBegin(m.endpoint, addr, context, kRendezvousStart); + freeaddrinfo(remote_addr); } } int SendMsg(Message &msg) override { int remote_id = msg.meta.recver; CHECK_NE(remote_id, Meta::kEmpty); - - for (auto& sa : msg.data) { - if (sa.size()) { - std::lock_guard lock(map_mu_); - auto search_map_iterator = memory_mr_map.find(sa.data()); - if (search_map_iterator == memory_mr_map.end()) { - struct ibv_mr *temp_mr; - CHECK(sa.data()) << "address empty"; - CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) - << "Failed to register the memory region: " - << strerror(errno) - << ", sa.size()=" << sa.size(); - memory_mr_map[sa.data()] = temp_mr; - } - } - } - - // init for inplace push_pull - if (IsValidPushpull(msg)) { - if (!is_server_) { // worker - std::lock_guard lock(map_mu_); - uint64_t key = DecodeKey(msg.data[0]); - msg.meta.key = key; - - if (msg.meta.push && msg.meta.request) { // push request - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - CHECK_NE(memory_mr_map.find(msg.data[1].data()), memory_mr_map.end()); - - auto& vals = msg.data[1]; - msg.meta.addr = reinterpret_cast(vals.data()); // vals address - msg.meta.val_len = vals.size(); - msg.meta.option = memory_mr_map[vals.data()]->rkey; - - if (enable_rdma_log_) { - LOG(INFO) << "send push key=" << key - << ", val_len=" << msg.meta.val_len - << ", recver=" << msg.meta.recver - << ", val_addr=" << msg.meta.addr - << ", rkey=" << msg.meta.option; - } - } - } - if (!msg.meta.push && !msg.meta.request) { // server, pull response - CHECK(is_server_); - CHECK_EQ(msg.data.size(), 3) << msg.data.size(); - - std::lock_guard lock(map_mu_); - uint64_t key = msg.meta.key; - auto recver = msg.meta.recver; - - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not inited in key_meta_map"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key << ", recver=" << recver << " not inited in key_meta_map[key]"; - - msg.meta.val_len = std::get<0>(key_meta_map_[key][recver]); - msg.meta.addr = std::get<1>(key_meta_map_[key][recver]); - msg.meta.option = std::get<2>(key_meta_map_[key][recver]); - - if (enable_rdma_log_) { - LOG(INFO) << "send pull response key=" << key - << ", val_len=" << msg.meta.val_len - << ", val_addr=" << msg.meta.addr - << ", rkey=" << msg.meta.option; - } - } - } - - PBMeta meta; - PackMetaPB(msg.meta, &meta); CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); Endpoint *endpoint = endpoints_[remote_id].get(); - MessageBuffer *msg_buf = new MessageBuffer(); - size_t meta_len = meta.ByteSize(); + int meta_len = GetPackMetaLen(msg.meta); size_t data_len = msg.meta.data_size; size_t total_len = meta_len + data_len; CHECK(meta_len); - // prepare memory - if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message - msg_buf->inline_len = total_len; - msg_buf->inline_buf = mempool_->Alloc(total_len); - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - char *cur = msg_buf->inline_buf + meta_len; - for (auto &sa : msg.data) { - size_t seg_len = sa.size(); - memcpy(cur, sa.data(), seg_len); - cur += seg_len; - } - } else { // data message - msg_buf->inline_len = meta_len; - msg_buf->inline_buf = mempool_->Alloc(meta_len); - msg_buf->data = msg.data; - meta.SerializeToArray(msg_buf->inline_buf, meta_len); - if (!is_server_ && !is_local_[remote_id]) { // worker, send to non-local servers - for (auto &sa : msg_buf->data) { - if (sa.size()) { - auto search_map_iterator = memory_mr_map.find(sa.data()); - CHECK_NE(search_map_iterator, memory_mr_map.end()) << "not registered memory region"; - MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {}); - CHECK(ptr.get()) << strerror(errno); - msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); - } - } - } + RegisterMemory(msg); + + // pack meta info + if (IsValidPushpull(msg)) { + AddMeta(msg); + PackWorkerTensorAddress(msg); } - // server send pull response (vals) with RDMA-write / IPC - if (is_server_ && IsValidPushpull(msg) && - !msg.meta.push && !msg.meta.request) { - std::lock_guard lock(map_mu_); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + + // start rendezvous if no remote info + if (!IsValidPushpull(msg)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, msg); + trans->SendRendezvousBegin(msg, msg_buf); + return total_len; + + } else { + auto is_push = msg.meta.push; auto key = msg.meta.key; - auto recver = msg.meta.recver; - auto len = std::get<0>(key_meta_map_[key][recver]); - - CHECK_EQ(msg_buf->data.size(), 3) << "Actual msg_buf size is " << msg_buf->data.size(); - CHECK_NE(key_meta_map_.find(key), key_meta_map_.end()) - << "key=" << key << " not initiated"; - CHECK_NE(key_meta_map_[key].find(recver), key_meta_map_[key].end()) - << "key=" << key - << ", recver=" << recver - << " not initiated"; - CHECK_EQ(msg_buf->data[1].size(), (unsigned int) len) - << msg_buf->data[1].size() << ", " << len; - - if (is_local_[remote_id]) { - // IPC - auto addr = (void*) msg_buf->data[1].data(); - CHECK(addr); - void* shm_addr = GetSharedMemory(kShmPrefix, key); - - // async copy - AsyncCopy m = {endpoint, msg_buf, shm_addr, addr, len, meta_len, false}; - auto cnt = cpy_counter_.fetch_add(1); - async_copy_queue_[cnt % ipc_copy_nthreads_]->Push(m); + if (!HasRemoteInfo(msg, key, is_push, remote_id)) { + MessageBuffer *msg_buf = PrepareNewMsgBuf(msg); + StoreMsgBuf(msg_buf, msg); + PrepareData(msg, msg_buf); + trans->SendRendezvousBegin(msg, msg_buf); return total_len; - } else { - // RDMA write - auto raddr = std::get<1>(key_meta_map_[key][recver]); - auto rkey = std::get<2>(key_meta_map_[key][recver]); - - auto temp_mr = memory_mr_map.find(msg_buf->data[1].data()); - CHECK_NE(temp_mr, memory_mr_map.end()); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(msg_buf->data[1].data()); - sge.length = msg_buf->data[1].size(); - sge.lkey = temp_mr->second->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(raddr); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.next = nullptr; - // wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - wr.wr.rdma.remote_addr = raddr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } + } } - WRContext *context = nullptr, *reserved = nullptr; - endpoint->free_write_ctx.WaitAndPop(&reserved); - endpoint->free_start_ctx.WaitAndPop(&context); - - msg_buf->reserved_context = reserved; - RendezvousStart *req = - reinterpret_cast(context->buffer->addr); - req->meta_len = meta_len; - req->origin_addr = reinterpret_cast(msg_buf); + auto addr_tuple = GetRemoteAndLocalInfo(msg.meta.key, msg.meta.push, remote_id); + MessageBuffer *msg_buf = std::get<3>(addr_tuple); // local message buffer - auto addr = reinterpret_cast(req); - - // rendezvous message, not data message - if (!is_server_ && is_local_[remote_id] && IsValidPushpull(msg)) { - // local IPC with shared memory - req->data_num = 0; - auto key = DecodeKey(msg.data[0]); - CHECK_EQ(key, msg.meta.key); - } else { - // normal RDMA - req->data_num = msg.data.size(); - for (size_t i = 0; i < req->data_num; ++i) { - req->data_len[i] = msg.data[i].size(); - } + // prepare new meta and data + CHECK_EQ(msg_buf->inline_len, (size_t) meta_len); + CHECK(msg_buf->inline_buf); + msg_buf->data = msg.data; // may not need this + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + + PrintSendLog(msg, msg_buf, addr_tuple); + + // already know remote address, directly use RDMA-write + if (msg.meta.push && msg.meta.request) { + // worker, push request + trans->SendPushRequest(msg, msg_buf, addr_tuple); + } else if (msg.meta.push && !msg.meta.request) { + // server, push response + trans->SendPushResponse(msg, msg_buf, addr_tuple); + } else if (!msg.meta.push && msg.meta.request) { + // worker, pull request + trans->SendPullRequest(msg, msg_buf, addr_tuple); + } else if (!msg.meta.push && !msg.meta.request) { + // server, pull response + map_mu_.lock(); + auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_.end()); + map_mu_.unlock(); + trans->SendPullResponse(msg, msg_buf, addr_tuple, temp_mr->second->lkey); + } else { + CHECK(0) << "unexpected message type"; } - SendRendezvousBegin(endpoint, addr, context, kRendezvousStart); + return total_len; } @@ -1041,143 +313,270 @@ class RDMAVan : public Van { Endpoint *endpoint = std::get(notification); BufferContext *buffer_ctx = std::get(notification); - int total_len = 0; - msg->meta.recver = my_node_.id; msg->meta.sender = endpoint->node_id; - char *cur = buffer_ctx->buffer; + // the second argument is actually deprecated, + // we keep it as is in order to be compatible + UnpackMeta(buffer_ctx->buffer, buffer_ctx->meta_len, &msg->meta); + int meta_len = GetPackMetaLen(msg->meta); - UnpackMeta(cur, buffer_ctx->meta_len, &msg->meta); - total_len += buffer_ctx->meta_len; - uint64_t data_num = buffer_ctx->data_num; - cur += buffer_ctx->meta_len; + int total_len = 0; + total_len += meta_len; + + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + + PrintRecvLog(msg, buffer_ctx, meta_len); + + if (!IsValidPushpull(*msg)) { + return total_len; + } + + // valid data message + if (msg->meta.push && msg->meta.request) { + // push request + total_len += trans->RecvPushRequest(msg, buffer_ctx, meta_len); + StoreWorkerTensorAddress(msg); + } else if (!msg->meta.push && msg->meta.request) { + // pull request + total_len += trans->RecvPullRequest(msg, buffer_ctx, meta_len); + } else if (msg->meta.push && !msg->meta.request) { + // push response + total_len += trans->RecvPushResponse(msg, buffer_ctx, meta_len); + } else if (!msg->meta.push && !msg->meta.request) { + // pull response + total_len += trans->RecvPullResponse(msg, buffer_ctx, meta_len); + } else { + CHECK(0) << "unknown msg type"; + } - bool is_released = false; - if (is_server_ && IsValidPushpull(*msg) && is_local_[msg->meta.sender]) { - // get data message from local shared memory - auto key = msg->meta.key; + return total_len; + } - std::lock_guard lock(map_mu_); - if (key_addr_map_.find(key) == key_addr_map_.end()) { - key_addr_map_[key] = key; - } + private: - SArray keys; - SArray vals; - SArray lens; - keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); - msg->data.push_back(keys); - total_len += keys.size(); - - if (msg->meta.push && msg->meta.request) { // push request - auto len = msg->meta.val_len; - if (key_len_map_.find(key) == key_len_map_.end()) { - key_len_map_[key] = len; - } - CHECK_EQ(len, key_len_map_[key]) << "key=" << key - << ": " << len << ", " << key_len_map_[key]; - - void* addr = GetSharedMemory(kShmPrefix, key); - vals.reset(reinterpret_cast(addr), len, [](void *){}); - lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); - msg->data.push_back(vals); - msg->data.push_back(lens); - } else { // pull request - msg->data.push_back(vals); - } - total_len += vals.size() + lens.size(); + void PrintSendLog(Message &msg, MessageBuffer *msg_buf, RemoteTuple remote_tuple) { + if (!enable_log_) return; + std::lock_guard lock(log_mu_); + + if (!IsValidPushpull(msg)) { + LOG(INFO) << "Send Control Message" << std::flush; + } else if (msg.meta.push && msg.meta.request) { + // worker, push request + LOG(INFO) << "Send Push Request: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t tensor_len=" << msg_buf->mrs[1].second + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (msg.meta.push && !msg.meta.request) { + // server, push response + LOG(INFO) << "Send Push Response: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (!msg.meta.push && msg.meta.request) { + // worker, pull request + LOG(INFO) << "Send Pull Request: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t remote_idx=" << std::get<2>(remote_tuple) + << "\t remote_addr=" << std::get<0>(remote_tuple) + << std::flush; + } else if (!msg.meta.push && !msg.meta.request) { + // server, pull response + LOG(INFO) << "Send Pull Response: key=" << msg.meta.key + << "\t timestamp=" << msg.meta.timestamp + << "\t recver=" << msg.meta.recver + << "\t tensor_len=" << msg.meta.val_len + << "\t idx=" << "none" + << "\t remote_addr=" << msg.meta.addr + << std::flush; + } + + } + + void PrintRecvLog(Message *msg, BufferContext *buffer_ctx, int meta_len) { + if (!enable_log_) return; + std::lock_guard lock(log_mu_); + + if (!IsValidPushpull(*msg)) { + LOG(INFO) << "Recv Control Message" << std::flush; + } else if (msg->meta.push && msg->meta.request) { + // push request + LOG(INFO) << "Recv Push Request: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << "\t tensor_len=" << buffer_ctx->data_len[1] + << std::flush; + } else if (!msg->meta.push && msg->meta.request) { + // pull request + LOG(INFO) << "Recv Pull Request: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << std::flush; + } else if (msg->meta.push && !msg->meta.request) { + // push response + LOG(INFO) << "Recv Push Response: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << std::flush; + } else if (!msg->meta.push && !msg->meta.request) { + // pull response + LOG(INFO) << "Recv Pull Response: key=" << msg->meta.key + << "\t timestamp=" << msg->meta.timestamp + << "\t sender=" << msg->meta.sender + << "\t tensor_len=" << msg->meta.val_len; + } + } + + void PackWorkerTensorAddress(Message &msg) { + // must be pull response + if (msg.meta.push || msg.meta.request) return; + + uint64_t key = msg.meta.key; + auto recver = msg.meta.recver; + + std::lock_guard lock(info_mu_); + CHECK_NE(tensor_info_map_.find(key), tensor_info_map_.end()); + CHECK_NE(tensor_info_map_[key].find(recver), tensor_info_map_[key].end()); + msg.meta.val_len = std::get<0>(tensor_info_map_[key][recver]); + msg.meta.addr = std::get<1>(tensor_info_map_[key][recver]); + msg.meta.option = std::get<2>(tensor_info_map_[key][recver]); + } + + void StoreWorkerTensorAddress(Message *msg) { + auto key = msg->meta.key; + auto len = msg->meta.val_len; + auto addr = msg->meta.addr; + auto rkey = msg->meta.option; + auto sender = msg->meta.sender; + + std::lock_guard lock(info_mu_); + if (tensor_info_map_.find(key) == tensor_info_map_.end() + || tensor_info_map_[key].find(sender) == tensor_info_map_[key].end()) { + tensor_info_map_[key][sender] = std::make_tuple(len, addr, rkey); + } else { + CHECK_EQ(len, std::get<0>(tensor_info_map_[key][sender])); + CHECK_EQ(addr, std::get<1>(tensor_info_map_[key][sender])); + CHECK_EQ(rkey, std::get<2>(tensor_info_map_[key][sender])); + } + } - mempool_->Free(buffer_ctx->buffer); - is_released = true; + bool HasRemoteInfo(Message& msg, uint64_t key, bool is_push, int recver) { + std::lock_guard lk(addr_mu_); + if (is_push && (push_addr_.find(key) != push_addr_.end()) + && (push_addr_[key].find(recver) != push_addr_[key].end())) { + return true; + } + if (!is_push && (pull_addr_.find(key) != pull_addr_.end()) + && (pull_addr_[key].find(recver) != pull_addr_[key].end())) { + return true; } - if (IsValidPushpull(*msg) && !msg->meta.push && !msg->meta.request) { - // worker, get message directly from inplace tensor - std::lock_guard lock(map_mu_); - auto key = msg->meta.key; - CHECK(!is_server_); - if (key_len_map_.find(key) == key_len_map_.end()) { - key_addr_map_[key] = (ps::Key) key; - key_len_map_[key] = (int) msg->meta.val_len; - } - CHECK_NE(key_len_map_.find(key), key_len_map_.end()) << key; - CHECK_NE(key_addr_map_.find(key), key_addr_map_.end()) << key; - - auto addr = msg->meta.addr; - - CHECK_NE(key_len_map_[key], 0) << msg->DebugString(); - - SArray keys; - SArray vals; - SArray lens; - - keys.reset(reinterpret_cast(&key_addr_map_[key]), sizeof(ps::Key), [](void *){}); - vals.reset(reinterpret_cast(addr), key_len_map_[key], [](void *){}); - lens.reset(reinterpret_cast(&key_len_map_[key]), sizeof(int), [](void *){}); - - msg->data.push_back(keys); - msg->data.push_back(vals); - msg->data.push_back(lens); - total_len += keys.size() + vals.size() + lens.size(); - - mempool_->Free(buffer_ctx->buffer); - } else if (data_num > 0) { - Block *mem_block = - new Block(mempool_.get(), buffer_ctx->buffer, data_num); - - for (size_t i = 0; i < data_num; i++) { - uint32_t len = buffer_ctx->data_len[i]; - SArray data; - data.reset(cur, len, [mem_block](void *) { - mem_block->Release(); - }); // Defer the deletion of block_ref - msg->data.push_back(data); - cur += len; - total_len += len; - } + return false; + } + + void StoreMsgBuf(MessageBuffer *msg_buf, Message& msg) { + std::lock_guard lk(addr_mu_); + CHECK_EQ(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + msgbuf_cache_[msg_buf] = msg; + } + + Message* GetFirstMsg(MessageBuffer *msg_buf) { + std::lock_guard lk(addr_mu_); + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + return &msgbuf_cache_[msg_buf]; + } + + void ReleaseFirstMsg(MessageBuffer *msg_buf) { + std::lock_guard lk(addr_mu_); + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + msgbuf_cache_.erase(msg_buf); + } + + void StoreRemoteAndLocalInfo(MessageBuffer *msg_buf, uint64_t remote_addr, uint32_t rkey, uint32_t idx) { + std::lock_guard lk(addr_mu_); + + CHECK_NE(msgbuf_cache_.find(msg_buf), msgbuf_cache_.end()); + + auto& msg = msgbuf_cache_[msg_buf]; + + auto key = msg.meta.key; + auto is_push = msg.meta.push; + auto recver = msg.meta.recver; + + auto t = std::make_tuple(remote_addr, rkey, idx, msg_buf); + if (is_push) { + push_addr_[key][recver] = t; } else { - if (!is_released) mempool_->Free(buffer_ctx->buffer); + pull_addr_[key][recver] = t; } + } + + RemoteTuple GetRemoteAndLocalInfo(uint64_t key, bool is_push, int recver) { + std::lock_guard lk(addr_mu_); + return (is_push ? push_addr_[key][recver] : pull_addr_[key][recver]); + } - if (msg->meta.push && msg->meta.request) { // server - CHECK(is_server_); - auto key = msg->meta.key; - auto len = msg->meta.val_len; - auto addr = msg->meta.addr; - auto rkey = msg->meta.option; - auto sender = msg->meta.sender; + MessageBuffer* PrepareNewMsgBuf(Message& msg) { + MessageBuffer *msg_buf = new MessageBuffer(); + auto meta_len = GetPackMetaLen(msg.meta); + msg_buf->inline_len = meta_len; + msg_buf->inline_buf = mem_allocator_->Alloc(meta_len); + msg_buf->data = msg.data; + PackMeta(msg.meta, &(msg_buf->inline_buf), &meta_len); + return msg_buf; + } + void RegisterMemory(Message &msg) { + for (auto& sa : msg.data) { + if (sa.size() == 0) continue; std::lock_guard lock(map_mu_); - if (key_meta_map_.find(key) == key_meta_map_.end() - || key_meta_map_[key].find(sender) == key_meta_map_[key].end()) { - if (enable_rdma_log_) { - LOG(INFO) << "(init) recv key=" << key - << ", len=" << len - << ", sender=" << msg->meta.sender - << ", val_addr=" << addr - << ", rkey=" << rkey; - } - key_meta_map_[key][sender] = std::make_tuple(len, addr, rkey); - } else { - CHECK_EQ(len, std::get<0>(key_meta_map_[key][sender])); - CHECK_EQ(addr, std::get<1>(key_meta_map_[key][sender])); - CHECK_EQ(rkey, std::get<2>(key_meta_map_[key][sender])); - - if (enable_rdma_log_) { - LOG(INFO) << "recv push key=" << key - << ", len=" << len - << ", val_addr=" << addr - << ", rkey=" << rkey; - } + if (mem_mr_.find(sa.data()) == mem_mr_.end()) { + struct ibv_mr *temp_mr; + CHECK (temp_mr = ibv_reg_mr(mem_allocator_->GetPD(), sa.data(), sa.size(), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) + << "Failed to register the memory region: " << strerror(errno) + << ", sa.size()=" << sa.size(); + mem_mr_[sa.data()] = temp_mr; } } + } - delete buffer_ctx; - return total_len; + void PrepareData(Message &msg, MessageBuffer *msg_buf) { + if (!(msg.meta.push && msg.meta.request)) return; // only push request + for (auto &sa : msg_buf->data) { + if (sa.size() == 0) continue; + std::lock_guard lock(map_mu_); + auto it = mem_mr_.find(sa.data()); + MRPtr ptr(it->second, [](struct ibv_mr *mr) {}); + CHECK(ptr.get()) << strerror(errno); + msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); + } + } + + void AddMeta(Message &msg) { + if (msg.meta.request) { + msg.meta.key = DecodeKey(msg.data[0]); + } + if (msg.meta.push && msg.meta.request) { + // push request + CHECK_EQ(msg.data.size(), 3) << msg.data.size(); + + std::lock_guard lock(map_mu_); + CHECK_NE(mem_mr_.find(msg.data[1].data()), mem_mr_.end()); + + auto& vals = msg.data[1]; + msg.meta.addr = reinterpret_cast(vals.data()); // vals address + msg.meta.val_len = vals.size(); + msg.meta.option = mem_mr_[vals.data()]->rkey; + } } - private: void InitContext(struct ibv_context *context) { context_ = context; CHECK(context_) << "ibv_context* empty"; @@ -1185,7 +584,7 @@ class RDMAVan : public Van { pd_ = ibv_alloc_pd(context_); CHECK(pd_) << "Failed to allocate protection domain"; - mempool_.reset(new SimpleMempool(pd_)); + mem_allocator_.reset(new MemoryAllocator(pd_)); comp_event_channel_ = ibv_create_comp_channel(context_); @@ -1205,9 +604,6 @@ class RDMAVan : public Van { case kRendezvousReplyContext: endpoint->free_reply_ctx.Push(context); break; - case kWriteContext: - endpoint->free_write_ctx.Push(context); - break; case kReceiveContext: endpoint->PostRecv(context); break; @@ -1215,40 +611,6 @@ class RDMAVan : public Van { CHECK(0); } } - - void SendRendezvousReply(Endpoint* endpoint, BufferContext *buf_ctx, uint64_t origin_addr) { - WRContext *reply_ctx = nullptr; - endpoint->free_reply_ctx.WaitAndPop(&reply_ctx); - RendezvousReply *resp = - reinterpret_cast(reply_ctx->buffer->addr); - - char* buffer = buf_ctx->buffer; - resp->addr = reinterpret_cast(buffer); - resp->rkey = mempool_->RemoteKey(buffer); - resp->origin_addr = origin_addr; - resp->idx = addr_pool_.StoreAddress(buf_ctx); - - struct ibv_sge sge; - sge.addr = reinterpret_cast(resp); - sge.length = sizeof(RendezvousReply); - sge.lkey = reply_ctx->buffer->lkey; - - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(reply_ctx); - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = kRendezvousReply; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = &sge; - wr.num_sge = 1; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; - } void PollCQ() { // Pre-allocated work completions array used for polling @@ -1261,31 +623,24 @@ class RDMAVan : public Van { << "Failed status \n" << ibv_wc_status_str(wc[i].status) << " " << wc[i].status << " " << static_cast(wc[i].wr_id) << " " << wc[i].vendor_err; - + WRContext *context = reinterpret_cast(wc[i].wr_id); Endpoint *endpoint = reinterpret_cast(context->private_data); - CHECK(endpoint); + // IBV_WC_RDMA_WRITE use msg_buf as the wr_id + // so there won't be context and endpoint for this op switch (wc[i].opcode) { - case IBV_WC_SEND: - // LOG(INFO) << "opcode: IBV_WC_SEND"; + case IBV_WC_SEND: { ReleaseWorkRequestContext(context, endpoint); - break; + } break; case IBV_WC_RDMA_WRITE: { - // LOG(INFO) << "opcode: IBV_WC_RDMA_WRITE"; - // Note: This is not a struct ibv_mr* - MessageBuffer *msg_buf = - *reinterpret_cast(context->buffer->addr); - mempool_->Free(msg_buf->inline_buf); - delete msg_buf; - ReleaseWorkRequestContext(context, endpoint); + // do nothing } break; case IBV_WC_RECV_RDMA_WITH_IMM: { - // LOG(INFO) << "opcode: IBV_WC_RECV_RDMA_WITH_IMM"; uint32_t addr_idx = wc[i].imm_data; - BufferContext *buf_ctx = addr_pool_.GetAddressAndRelease(addr_idx); + BufferContext *buf_ctx = addr_pool_.GetAddress(addr_idx); recv_buffers_.Push(std::make_tuple(endpoint, buf_ctx)); ReleaseWorkRequestContext(context, endpoint); } break; @@ -1295,25 +650,12 @@ class RDMAVan : public Van { struct ibv_mr *mr = context->buffer; if (imm == kRendezvousStart) { - // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousStart"; RendezvousStart *req = reinterpret_cast(mr->addr); - - BufferContext *buf_ctx = new BufferContext(); - uint64_t len = req->meta_len; - buf_ctx->meta_len = req->meta_len; - buf_ctx->data_num = req->data_num; - for (size_t i = 0; i < req->data_num; ++i) { - buf_ctx->data_len[i] = req->data_len[i]; - len += req->data_len[i]; - } - char *buffer = mempool_->Alloc(is_server_ ? len : req->meta_len); - CHECK(buffer) << len; - buf_ctx->buffer = buffer; - - SendRendezvousReply(endpoint, buf_ctx, req->origin_addr); + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + trans->SendRendezvousReply(req, addr_pool_); + } else if (imm == kRendezvousReply) { - // LOG(INFO) << "opcode: IBV_WC_RECV kRendezvousReply"; RendezvousReply *resp = reinterpret_cast(mr->addr); uint64_t remote_addr = resp->addr; @@ -1324,48 +666,40 @@ class RDMAVan : public Van { MessageBuffer *msg_buf = reinterpret_cast(origin_addr); - struct ibv_sge sge[1 + msg_buf->mrs.size()]; - - sge[0].addr = reinterpret_cast(msg_buf->inline_buf); - sge[0].length = msg_buf->inline_len; - sge[0].lkey = mempool_->LocalKey(msg_buf->inline_buf); - - size_t num_sge = 1; - for (auto &pair : msg_buf->mrs) { - size_t length = pair.second; - CHECK(length); - sge[num_sge].addr = - reinterpret_cast(pair.first->addr); - sge[num_sge].length = length; - sge[num_sge].lkey = pair.first->lkey; - ++num_sge; + // Before RDMA write, store the remote info so that + // subsequent write does not need repeated rendezvous + StoreRemoteAndLocalInfo(msg_buf, remote_addr, rkey, idx); + + Message *msg = GetFirstMsg(msg_buf); + + auto addr_tuple = GetRemoteAndLocalInfo(msg->meta.key, msg->meta.push, msg->meta.recver); + + PrintSendLog(*msg, msg_buf, addr_tuple); + + auto trans = CHECK_NOTNULL(endpoint->GetTransport()); + if (!IsValidPushpull(*msg)) { + // control message + trans->RDMAWriteWithImm(msg_buf, remote_addr, rkey, idx); + } else if (msg->meta.push && msg->meta.request) { + // worker, push request + trans->SendPushRequest(*msg, msg_buf, addr_tuple); + } else if (msg->meta.push && !msg->meta.request) { + // server, push response + trans->SendPushResponse(*msg, msg_buf, addr_tuple); + } else if (!msg->meta.push && msg->meta.request) { + // worker, pull request + trans->SendPullRequest(*msg, msg_buf, addr_tuple); + } else if (!msg->meta.push && !msg->meta.request) { + // server, pull response + map_mu_.lock(); + auto temp_mr = mem_mr_.find(msg_buf->data[1].data()); + CHECK_NE(temp_mr, mem_mr_.end()); + map_mu_.unlock(); + trans->SendPullResponse(*msg, msg_buf, addr_tuple, temp_mr->second->lkey); } - if (is_server_) CHECK_EQ(num_sge, 1) << num_sge; - - WRContext *write_ctx = msg_buf->reserved_context; - - MessageBuffer **tmp = - reinterpret_cast(write_ctx->buffer->addr); - *tmp = msg_buf; // write the addr of msg_buf into the mr buffer - struct ibv_send_wr wr, *bad_wr = nullptr; - memset(&wr, 0, sizeof(wr)); - - wr.wr_id = reinterpret_cast(write_ctx); - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.next = nullptr; - - wr.imm_data = idx; - - wr.send_flags = IBV_SEND_SIGNALED; - wr.sg_list = sge; - wr.num_sge = num_sge; - - wr.wr.rdma.remote_addr = remote_addr; - wr.wr.rdma.rkey = rkey; - - CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) - << "ibv_post_send failed."; + // release the msg_buf from msgbuf_cache_ + ReleaseFirstMsg(msg_buf); } else { CHECK(0); @@ -1469,6 +803,21 @@ class RDMAVan : public Van { endpoint->Init(cq_, pd_); + local_mu_.lock(); + if (disable_ipc_) { + is_local_[remote_ctx->node] = false; + } else { + is_local_[remote_ctx->node] = (std::string(remote_ctx->hostname) == my_node_.hostname) ? true : false; + } + LOG(INFO) << "OnConnect to Node " << remote_ctx->node + << " with Transport=" << (is_local_[remote_ctx->node]?"IPC" : "RDMA"); + local_mu_.unlock(); + + std::shared_ptr t = is_local_[remote_ctx->node] ? + std::make_shared(endpoint, mem_allocator_.get()) : + std::make_shared(endpoint, mem_allocator_.get()); + endpoint->SetTransport(t); + RequestContext ctx; ctx.node = static_cast(my_node_.id); ctx.port = static_cast(my_node_.port); @@ -1496,7 +845,7 @@ class RDMAVan : public Van { void OnRouteResolved(struct rdma_cm_event *event) { struct rdma_cm_id *id = event->id; Endpoint *endpoint = reinterpret_cast(id->context); - + if (context_ == nullptr) { InitContext(id->verbs); } @@ -1551,10 +900,11 @@ class RDMAVan : public Van { LOG(INFO) << "OnDisconnected from Node " << endpoint->node_id; } - int AlignTo(int input, int alignment) { return input / alignment * alignment; } - AddressPool addr_pool_; - std::unique_ptr mempool_; + std::unique_ptr mem_allocator_; + + std::unique_ptr rdma_trans_; + std::unique_ptr ipc_trans_; struct rdma_cm_id *listener_ = nullptr; std::atomic should_stop_; @@ -1565,8 +915,6 @@ class RDMAVan : public Van { struct rdma_event_channel *event_channel_ = nullptr; struct ibv_context *context_ = nullptr; - std::unordered_map memory_mr_map; - // ibverbs protection domain struct ibv_pd *pd_ = nullptr; // Completion event channel, to wait for work completions @@ -1580,37 +928,32 @@ class RDMAVan : public Van { // Recv buffer queue ThreadsafeQueue> recv_buffers_; - // role is server or worker - bool is_server_; - // RDMA logging info - bool enable_rdma_log_; - - std::mutex map_mu_; - // macros for key_meta_map - using MetaInfo = std::tuple; // len, addr, rkey - using SenderMeta = std::unordered_map; // sender as the key - // (key, sender) --> MetaInfo - std::unordered_map key_meta_map_; - // a static address for the key - std::unordered_map key_addr_map_; - // a static address for the length - std::unordered_map key_len_map_; - // local IPC related bool disable_ipc_ = false; std::mutex local_mu_; std::unordered_map is_local_; - std::mutex shm_mu_; - std::unordered_map key_shm_addr_; + // worker's tensor address + std::mutex info_mu_; + using TensorInfo = std::tuple; // len, addr, rkey + using RemoteTensorMeta = std::unordered_map; // sender as the key + std::unordered_map tensor_info_map_; // (key, sender) --> TensorInfo - int byteps_partition_bytes_ = 4096000; + std::mutex addr_mu_; + // , () + std::unordered_map push_addr_; + std::unordered_map pull_addr_; + std::unordered_map msgbuf_cache_; // msg_buf, msg + + std::mutex map_mu_; + std::unordered_map mem_mr_; // (memory address, ibv_mr) + + // logging + bool enable_log_; + std::mutex log_mu_; + +}; // class RDMAVan - int ipc_copy_nthreads_; - std::vector ipc_copy_thread_list_; - std::vector*> async_copy_queue_; - std::atomic cpy_counter_{0}; -}; // namespace ps }; // namespace ps #endif // DMLC_USE_RDMA diff --git a/src/van.cc b/src/van.cc index 64f004ed..aa40880c 100644 --- a/src/van.cc +++ b/src/van.cc @@ -14,7 +14,7 @@ #include "ps/internal/van.h" #include "ps/sarray.h" -#include "./meta.pb.h" +#include "./meta.h" #include "./network_utils.h" #include "./rdma_van.h" #include "./resender.h" @@ -494,129 +494,116 @@ void Van::Receiving() { } } -void Van::PackMetaPB(const Meta &meta, PBMeta *pb) { - pb->set_head(meta.head); - if (meta.app_id != Meta::kEmpty) pb->set_app_id(meta.app_id); - if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); - if (meta.body.size()) pb->set_body(meta.body); - pb->set_push(meta.push); - pb->set_request(meta.request); - pb->set_simple_app(meta.simple_app); - pb->set_customer_id(meta.customer_id); - for (auto d : meta.data_type) pb->add_data_type(d); - if (!meta.control.empty()) { - auto ctrl = pb->mutable_control(); - ctrl->set_cmd(meta.control.cmd); - if (meta.control.cmd == Control::BARRIER) { - ctrl->set_barrier_group(meta.control.barrier_group); - } else if (meta.control.cmd == Control::ACK) { - ctrl->set_msg_sig(meta.control.msg_sig); - } - for (const auto &n : meta.control.node) { - auto p = ctrl->add_node(); - p->set_id(n.id); - p->set_role(n.role); - p->set_port(n.port); - p->set_hostname(n.hostname); - p->set_is_recovery(n.is_recovery); - p->set_customer_id(n.customer_id); - } - } - pb->set_data_size(meta.data_size); - pb->set_key(meta.key); - pb->set_addr(meta.addr); - pb->set_val_len(meta.val_len); - pb->set_option(meta.option); +int Van::GetPackMetaLen(const Meta &meta) { + return sizeof(RawMeta) + meta.body.size() + + meta.data_type.size() * sizeof(int) + + meta.control.node.size() * sizeof(RawNode); } void Van::PackMeta(const Meta &meta, char **meta_buf, int *buf_size) { - // convert into protobuf - PBMeta pb; - pb.set_head(meta.head); - if (meta.app_id != Meta::kEmpty) pb.set_app_id(meta.app_id); - if (meta.timestamp != Meta::kEmpty) pb.set_timestamp(meta.timestamp); - if (meta.body.size()) pb.set_body(meta.body); - pb.set_push(meta.push); - pb.set_request(meta.request); - pb.set_simple_app(meta.simple_app); - pb.set_customer_id(meta.customer_id); - for (auto d : meta.data_type) pb.add_data_type(d); + *buf_size = GetPackMetaLen(meta); + // allocate buffer only when needed + if (*meta_buf == nullptr) { + *meta_buf = new char[*buf_size + 1]; + } + + RawMeta *raw = (RawMeta*)*meta_buf; + bzero(raw, sizeof(RawMeta)); + char *raw_body = *meta_buf + sizeof(RawMeta); + int *raw_data_type = (int*)(raw_body + meta.body.size()); + RawNode *raw_node = (RawNode*)(raw_data_type + meta.data_type.size()); + + // convert into raw buffer + raw->head = meta.head; + raw->app_id = meta.app_id; + raw->timestamp = meta.timestamp; + if (meta.body.size()) { + memcpy(raw_body, meta.body.c_str(), meta.body.size()); + raw->body_size = meta.body.size(); + } + raw->push = meta.push; + raw->request = meta.request; + raw->simple_app = meta.simple_app; + raw->customer_id = meta.customer_id; + int data_type_count = 0; + for (auto d : meta.data_type) { + raw_data_type[data_type_count] = d; + data_type_count++; + } + raw->data_type_size = meta.data_type.size(); + auto ctrl = &(raw->control); if (!meta.control.empty()) { - auto ctrl = pb.mutable_control(); - ctrl->set_cmd(meta.control.cmd); + ctrl->cmd = meta.control.cmd; if (meta.control.cmd == Control::BARRIER) { - ctrl->set_barrier_group(meta.control.barrier_group); + ctrl->barrier_group = meta.control.barrier_group; } else if (meta.control.cmd == Control::ACK) { - ctrl->set_msg_sig(meta.control.msg_sig); + ctrl->msg_sig = meta.control.msg_sig; } + ctrl->node_size = meta.control.node.size(); + int node_count = 0; for (const auto &n : meta.control.node) { - auto p = ctrl->add_node(); - p->set_id(n.id); - p->set_role(n.role); - p->set_port(n.port); - p->set_hostname(n.hostname); - p->set_is_recovery(n.is_recovery); - p->set_customer_id(n.customer_id); + raw_node[node_count].id = n.id; + raw_node[node_count].role = n.role; + raw_node[node_count].port = n.port; + bzero(raw_node[node_count].hostname, sizeof(raw_node[node_count].hostname)); + memcpy(raw_node[node_count].hostname, n.hostname.c_str(), n.hostname.size()); + raw_node[node_count].is_recovery = n.is_recovery; + raw_node[node_count].customer_id = n.customer_id; + node_count++; } } - pb.set_data_size(meta.data_size); - pb.set_key(meta.key); - pb.set_addr(meta.addr); - pb.set_val_len(meta.val_len); - pb.set_option(meta.option); - - // to string - *buf_size = pb.ByteSize(); - // allocate buffer only when needed - if (*meta_buf == nullptr) { - *meta_buf = new char[*buf_size + 1]; + else { + ctrl->cmd = Control::EMPTY; } - CHECK(pb.SerializeToArray(*meta_buf, *buf_size)) << "failed to serialize protbuf"; + raw->data_size = meta.data_size; + raw->key = meta.key; + raw->addr = meta.addr; + raw->val_len = meta.val_len; + raw->option = meta.option; } void Van::UnpackMeta(const char *meta_buf, int buf_size, Meta *meta) { - // to protobuf - PBMeta pb; - CHECK(pb.ParseFromArray(meta_buf, buf_size)) - << "failed to parse string into protobuf, buf_size=" << buf_size; + + RawMeta *raw = (RawMeta*)meta_buf; + const char *raw_body = meta_buf + sizeof(RawMeta); + const int *raw_data_type = (const int*)(raw_body + raw->body_size); + const RawNode *raw_node = (RawNode*)(raw_data_type + raw->data_type_size); // to meta - meta->head = pb.head(); - meta->app_id = pb.has_app_id() ? pb.app_id() : Meta::kEmpty; - meta->timestamp = pb.has_timestamp() ? pb.timestamp() : Meta::kEmpty; - meta->request = pb.request(); - meta->push = pb.push(); - meta->simple_app = pb.simple_app(); - meta->body = pb.body(); - meta->customer_id = pb.customer_id(); - meta->data_type.resize(pb.data_type_size()); - for (int i = 0; i < pb.data_type_size(); ++i) { - meta->data_type[i] = static_cast(pb.data_type(i)); + meta->head = raw->head; + meta->app_id = raw->app_id; + meta->timestamp = raw->timestamp; + meta->request = raw->request; + meta->push = raw->push; + meta->simple_app = raw->simple_app; + meta->body = std::string(raw_body, raw->body_size); + meta->customer_id = raw->customer_id; + meta->data_type.resize(raw->data_type_size); + for (int i = 0; i < raw->data_type_size; ++i) { + meta->data_type[i] = static_cast(raw_data_type[i]); } - if (pb.has_control()) { - const auto &ctrl = pb.control(); - meta->control.cmd = static_cast(ctrl.cmd()); - meta->control.barrier_group = ctrl.barrier_group(); - meta->control.msg_sig = ctrl.msg_sig(); - for (int i = 0; i < ctrl.node_size(); ++i) { - const auto &p = ctrl.node(i); - Node n; - n.role = static_cast(p.role()); - n.port = p.port(); - n.hostname = p.hostname(); - n.id = p.has_id() ? p.id() : Node::kEmpty; - n.is_recovery = p.is_recovery(); - n.customer_id = p.customer_id(); - meta->control.node.push_back(n); - } - } else { - meta->control.cmd = Control::EMPTY; + + auto ctrl = &(raw->control); + meta->control.cmd = static_cast(ctrl->cmd); + meta->control.barrier_group = ctrl->barrier_group; + meta->control.msg_sig = ctrl->msg_sig; + for (int i = 0; i < ctrl->node_size; ++i) { + const auto &p = raw_node[i]; + Node n; + n.role = static_cast(p.role); + n.port = p.port; + n.hostname = p.hostname; + n.id = p.id; + n.is_recovery = p.is_recovery; + n.customer_id = p.customer_id; + meta->control.node.push_back(n); } - meta->data_size = pb.data_size(); - meta->key = pb.key(); - meta->addr = pb.addr(); - meta->val_len = pb.val_len(); - meta->option = pb.option(); + + meta->data_size = raw->data_size; + meta->key = raw->key; + meta->addr = raw->addr; + meta->val_len = raw->val_len; + meta->option = raw->option; } void Van::Heartbeat() { diff --git a/tests/test_benchmark.cc b/tests/test_benchmark.cc new file mode 100644 index 00000000..1fafc0aa --- /dev/null +++ b/tests/test_benchmark.cc @@ -0,0 +1,294 @@ +#include +#include +#include +#include +#include "ps/ps.h" + +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) +#define DEBUG_PRINT_TENSOR_VALUE(X) (*((float *)(X) + 0)) +#define DEBUG_PRINT_TENSOR_ADDRESS(X) (reinterpret_cast(X)) + +using namespace ps; + +enum MODE { + PUSH_THEN_PULL = 0, + PUSH_PULL = 1, + PUSH_ONLY = 2, + PULL_ONLY = 3 +}; +std::unordered_map > mem_map; +bool debug_mode_ = false; + +void aligned_memory_alloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 1, size); + *ptr = p; +} + +void float_sum(float *dst, float *src, size_t len) { + if (len == 0) return; + for (size_t i = 0; i < len / (size_t) sizeof(float); ++i) { + dst[i] = dst[i] + src[i]; + } +} + +template +void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { + uint64_t key = req_data.keys[0]; + if (req_meta.push) { + CHECK(req_data.lens.size()); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) + << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; + + + if (mem_map.find(key) == mem_map.end()) { + size_t len = (size_t) req_data.vals.size(); + + void* ptr_val; + aligned_memory_alloc(&ptr_val, len); + mem_map[key].vals.reset((char*)ptr_val, len, [](void *){ }); + + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + mem_map[key].keys.reset((Key*)ptr_key, 1, [](void *){ }); + memcpy(ptr_key, &key, sizeof(Key)); + + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + mem_map[key].lens.reset((int*)ptr_len, 1, [](void *){ }); + memcpy(ptr_len, &len, sizeof(int)); + } + + auto recved = reinterpret_cast(req_data.vals.data()); + // only sum the first 4 bytes + size_t sum_len = debug_mode_ ? req_data.vals.size() : 0; + float_sum((float*) mem_map[key].vals.data(), (float*) recved, sum_len); + + if (debug_mode_) { + LOG(INFO) << "recved tensor! key=" << key << "\t" + << "store: " << DEBUG_PRINT_TENSOR_VALUE(mem_map[key].vals.data()) << "\t" + << "recv: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t" + << "address: " << DEBUG_PRINT_TENSOR_ADDRESS(recved) << "\t" + << "len: " << req_data.vals.size() << "\t" + << "sender: " << req_meta.sender; + } + + // send push response (empty) + KVPairs res; + server->Response(req_meta, res); + } + else { + auto iter = mem_map.find(key); + CHECK_NE(iter, mem_map.end()); + server->Response(req_meta, iter->second); + } +} + +void StartServer() { + if (!IsServer()) return; + debug_mode_ = Environment::Get()->find("DEBUG_MODE") ? true : false; + + auto server = new KVServer(0); + server->set_request_handle(EmptyHandler); + RegisterExitCallback([server]() { delete server; }); +} + +void push_pull(KVWorker &kv, + std::vector > &server_keys, + std::vector > &server_vals, + std::vector > &server_lens, + int len, int num_servers, int total_key_num, + int how_many_key_per_server, MODE mode) { + CHECK_GT(mode, 0); + switch (mode) { + case PUSH_PULL: + LOG(INFO) << "========= PUSH_PULL mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; + break; + case PUSH_ONLY: + LOG(INFO) << "========= PUSH_ONLY mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; + break; + case PULL_ONLY: + LOG(INFO) << "========= PULL_ONLY mode ========="; + LOG(INFO) << "========= msg_size=" << len*sizeof(char) << " bytes ========="; + break; + default: CHECK(0); + } + + std::vector timestamp_list; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + + auto val = Environment::Get()->find("LOG_DURATION"); + unsigned int log_duration = val ? atoi(val) : 10; + + int cnt = 0; + while (1) { + for (int key = 0; key < total_key_num; key++) { + auto keys = server_keys[key]; + auto lens = server_lens[key]; + auto vals = server_vals[key]; + + switch (mode) { + case PUSH_PULL: { + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } break; + case PUSH_ONLY: { + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + } break; + case PULL_ONLY: { + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } break; + default: { + CHECK(0); + break; + } + } + } + + for (auto& ts : timestamp_list) { kv.Wait(ts); } + timestamp_list.clear(); + + cnt++; + if (cnt % log_duration != 0) continue; + + end = std::chrono::high_resolution_clock::now(); + LL << "Application goodput: " + << 8.0 * len * sizeof(char) * total_key_num * cnt / (end - start).count() + << " Gbps"; + cnt = 0; + start = std::chrono::high_resolution_clock::now(); + } +} + +void RunWorker(int argc, char *argv[]) { + if (!IsWorker()) return; + KVWorker kv(0, 0); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + + const int num_servers = krs.size(); + LOG(INFO) << num_servers << " servers in total"; + CHECK_GT(num_servers, 0); + + // init + int len = (argc > 1) ? atoi(argv[1]) : 1024000; + int repeat = (argc > 2) ? atoi(argv[2]) : 10; + MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL; + + auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); + const int how_many_key_per_server = v ? atoi(v) : 40; + const int total_key_num = num_servers * how_many_key_per_server; + + std::vector > server_vals; + std::vector > server_keys; + std::vector > server_lens; + for (int key = 0; key < total_key_num; key++) { + void* ptr; + aligned_memory_alloc(&ptr, len); + SArray vals; + vals.reset((char*) ptr, len * sizeof(char), [](void *){}); + server_vals.push_back(vals); + } + + // init push, do not count this into time cost + for (int key = 0; key < total_key_num; key++) { + int server = key % num_servers; + PS_VLOG(1) << "key=" << key << " assigned to server " << server; + + auto vals = server_vals[key]; + + // page aligned keys + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + SArray keys; + keys.reset((Key*) ptr_key, 1, [](void *){}); + ps::Key ps_key = krs[server].begin() + key; + memcpy(ptr_key, &ps_key, sizeof(Key)); + server_keys.push_back(keys); + + // page aligned vals + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + SArray lens; + lens.reset((int*) ptr_len, 1, [](void *){}); + memcpy(ptr_len, &len, sizeof(len)); + server_lens.push_back(lens); + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + + switch(mode) { + case PUSH_THEN_PULL: { + LOG(INFO) << "PUSH_THEN_PULL mode"; + // push + uint64_t accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { + auto start = std::chrono::high_resolution_clock::now(); + for (int server = 0; server < num_servers; server++) { + auto keys = server_keys[server]; + auto lens = server_lens[server]; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + auto end = std::chrono::high_resolution_clock::now(); + accumulated_ms += (end - start).count(); // ns + } + LL << "push " << len * sizeof(char) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + + // pull + accumulated_ms = 0; + for (int i = 0; i < repeat; ++i) { + auto start = std::chrono::high_resolution_clock::now(); + for (int server = 0; server < num_servers; server++) { + auto keys = server_keys[server]; + auto lens = server_lens[server]; + auto vals = server_vals[server]; + + kv.Wait(kv.ZPull(keys, &vals, &lens)); + } + auto end = std::chrono::high_resolution_clock::now(); + accumulated_ms += (end - start).count(); // ns + } + + LL << "pull " << len * sizeof(char) + << " bytes to each server, repeat=" << repeat + << ", total_time=" + << accumulated_ms / 1e6 << "ms"; + } break; + case PUSH_PULL: + case PUSH_ONLY: + case PULL_ONLY: + push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); + break; + default: + CHECK(0) << "unknown mode " << mode; + } + + +} + +int main(int argc, char *argv[]) { + // disable multi-threaded processing first + setenv("ENABLE_SERVER_MULTIPULL", "0", 1); + // start system + Start(0); + // setup server nodes + StartServer(); + // run worker nodes + RunWorker(argc, argv); + // stop system + Finalize(0, true); + return 0; +} \ No newline at end of file diff --git a/tests/test_connection.cc b/tests/test_connection.cc deleted file mode 100644 index eeea7618..00000000 --- a/tests/test_connection.cc +++ /dev/null @@ -1,8 +0,0 @@ -#include "ps/ps.h" - -int main(int argc, char *argv[]) { - ps::Start(0); - // do nothing - ps::Finalize(0, true); - return 0; -} diff --git a/tests/test_ipc_benchmark.cc b/tests/test_ipc_benchmark.cc new file mode 100644 index 00000000..3f9330b6 --- /dev/null +++ b/tests/test_ipc_benchmark.cc @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ps/ps.h" + +#define DIVUP(x, y) (((x)+(y)-1)/(y)) +#define ROUNDUP(x, y) (DIVUP((x), (y))*(y)) + +using namespace ps; + +enum MODE { IPC }; + +std::unordered_map _key_shm_addr; +std::unordered_map _key_shm_size; +std::unordered_map store_; +std::mutex mu_; + +void* OpenSharedMemory(const std::string& prefix, uint64_t key, size_t size) { + std::string shm_name(prefix); + shm_name += std::to_string(key); + int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); + CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; + CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); + + void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + CHECK_NE(ptr, (void*)-1) << strerror(errno); + + LOG(INFO) << "initialized share memory size=" << size + << " for key=" << key << ", name=" << shm_name; + _key_shm_addr[shm_name] = ptr; + _key_shm_size[shm_name] = size; + return ptr; +} + +uint64_t EncodeKey(uint64_t seed) { + return seed << 16; +} + +uint64_t DecodeKey(ps::Key key) { + auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; + return key - kr.begin(); +} + +void aligned_memory_alloc(void** ptr, size_t size) { + size_t page_size = sysconf(_SC_PAGESIZE); + void* p; + int size_aligned = ROUNDUP(size, page_size); + int ret = posix_memalign(&p, page_size, size_aligned); + CHECK_EQ(ret, 0) << "posix_memalign error: " << strerror(ret); + CHECK(p); + memset(p, 0, size); + *ptr = p; +} + +std::unordered_map > mem_map; +template +void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { + uint64_t key = DecodeKey(req_data.keys[0]); + if (req_meta.push) { + CHECK(req_data.lens.size()); + CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]) + << "key=" << key << ", " << req_data.vals.size() << ", " << req_data.lens[0]; + + if (mem_map.find(key) == mem_map.end()) { + size_t len = (size_t) req_data.vals.size(); + + void* ptr_val; + aligned_memory_alloc(&ptr_val, len); + mem_map[key].vals.reset((char*)ptr_val, len, [](void *){ }); + + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + mem_map[key].keys.reset((Key*)ptr_key, 1, [](void *){ }); + memcpy(ptr_key, &key, sizeof(Key)); + + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + mem_map[key].lens.reset((int*)ptr_len, 1, [](void *){ }); + memcpy(ptr_len, &len, sizeof(int)); + } + + // send push response (empty) + KVPairs res; + server->Response(req_meta, res); + } + else { + auto iter = mem_map.find(key); + CHECK_NE(iter, mem_map.end()); + server->Response(req_meta, iter->second); + } +} + +void StartServer() { + if (!IsServer()) return; + auto server = new KVServer(0); + server->set_request_handle(EmptyHandler); + RegisterExitCallback([server]() { delete server; }); +} + +void push_pull(KVWorker &kv, + std::vector > &server_keys, + std::vector > &server_vals, + std::vector > &server_lens, + int len, int num_servers, int total_key_num, + int how_many_key_per_server, MODE mode) { + std::vector timestamp_list; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto val = Environment::Get()->find("LOG_DURATION"); + unsigned int log_duration = val ? atoi(val) : 10; + + int cnt = 0; + while (1) { + for (int i = 0; i < total_key_num; i++) { + auto keys = server_keys[i]; + auto lens = server_lens[i]; + auto vals = server_vals[i]; + + timestamp_list.push_back(kv.ZPush(keys, vals, lens)); + timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); + } + + for (auto& ts : timestamp_list) { kv.Wait(ts); } + timestamp_list.clear(); + + cnt++; + if (cnt % log_duration != 0) continue; + + end = std::chrono::high_resolution_clock::now(); + LL << "Application goodput: " + << 8.0 * len * sizeof(char) * total_key_num * cnt / (end - start).count() + << " Gbps"; + cnt = 0; + start = std::chrono::high_resolution_clock::now(); + } +} + +void RunWorker(int argc, char *argv[]) { + if (!IsWorker()) return; + KVWorker kv(0, 0); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + + const int num_servers = krs.size(); + LOG(INFO) << num_servers << " servers in total"; + CHECK_GT(num_servers, 0); + + // init + int len = (argc > 1) ? atoi(argv[1]) : 1024000; + MODE mode = IPC; + + size_t partition_bytes = Environment::Get()->find("BYTEPS_PARTITION_BYTES") ? + atoi(Environment::Get()->find("BYTEPS_PARTITION_BYTES")) : 4096000; + CHECK_GE(partition_bytes, (size_t)len) + << "tensor partition is not supported in this benchmark" + << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; + + auto v = Environment::Get()->find("NUM_KEY_PER_SERVER"); + const int how_many_key_per_server = v ? atoi(v) : 20; + const int total_key_num = num_servers * how_many_key_per_server; + + std::vector > server_vals; + std::vector > server_keys; + std::vector > server_lens; + for (int i = 0; i < total_key_num; i++) { + auto key = EncodeKey(i); + SArray vals; + auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); + vals.reset((char*) addr, len, [](void *){}); + server_vals.push_back(vals); + } + + // init push, do not count this into time cost + for (int i = 0; i < total_key_num; i++) { + int server = i % num_servers; + auto vals = server_vals[i]; + + auto key = EncodeKey(i); + PS_VLOG(1) << "key=" << key + << " (i=" << i << ")" + << " assigned to server " << server; + + // page aligned keys + void* ptr_key; + aligned_memory_alloc(&ptr_key, sizeof(Key)); + SArray keys; + keys.reset((Key*) ptr_key, 1, [](void *){}); + ps::Key ps_key = krs[server].begin() + key; + memcpy(ptr_key, &ps_key, sizeof(Key)); + server_keys.push_back(keys); + + // page aligned vals + void* ptr_len; + aligned_memory_alloc(&ptr_len, sizeof(int)); + SArray lens; + lens.reset((int*) ptr_len, 1, [](void *){}); + memcpy(ptr_len, &len, sizeof(len)); + server_lens.push_back(lens); + + kv.Wait(kv.ZPush(keys, vals, lens)); + } + + push_pull(kv, server_keys, server_vals, server_lens, len, num_servers, total_key_num, how_many_key_per_server, mode); +} + +int main(int argc, char *argv[]) { + // disable multi-threaded processing first + setenv("BYTEPS_LOCAL_SIZE", "1", 1); + setenv("BYTEPS_ENABLE_IPC", "1", 1); + // start system + Start(0); + // setup server nodes + StartServer(); + // run worker nodes + RunWorker(argc, argv); + // stop system + Finalize(0, true); + // release shm + for (auto &it : _key_shm_addr) { + munmap(it.second, _key_shm_size[it.first]); + shm_unlink(it.first.c_str()); + } + return 0; +} diff --git a/tests/test_kv_app.cc b/tests/test_kv_app.cc deleted file mode 100644 index aa143304..00000000 --- a/tests/test_kv_app.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include "ps/ps.h" -#include - -using namespace ps; - -void StartServer() { - if (!IsServer()) { - return; - } - auto server = new KVServer(0); - server->set_request_handle(KVServerDefaultHandle()); - RegisterExitCallback([server](){ delete server; }); -} - -void RunWorker() { - if (!IsWorker()) return; - KVWorker kv(0, 0); - - // init - int num = 10000; - std::vector keys(num); - std::vector vals(num); - - int rank = MyRank(); - srand(rank + 7); - for (int i = 0; i < num; ++i) { - keys[i] = kMaxKey / num * i + rank; - vals[i] = (rand() % 1000); - } - - // push - int repeat = 50; - std::vector ts; - for (int i = 0; i < repeat; ++i) { - ts.push_back(kv.Push(keys, vals)); - - // to avoid too frequency push, which leads huge memory usage - if (i > 10) kv.Wait(ts[ts.size()-10]); - } - for (int t : ts) kv.Wait(t); - - // pull - std::vector rets; - kv.Wait(kv.Pull(keys, &rets)); - - float res = 0; - for (int i = 0; i < num; ++i) { - res += std::fabs(rets[i] - vals[i] * repeat); - } - CHECK_LT(res / repeat, 1e-5); - LL << "error: " << res / repeat; -} - -int main(int argc, char *argv[]) { - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(); - // stop system - Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app_benchmark.cc b/tests/test_kv_app_benchmark.cc deleted file mode 100644 index 3bb6383e..00000000 --- a/tests/test_kv_app_benchmark.cc +++ /dev/null @@ -1,235 +0,0 @@ -#include -#include -#include -#include "ps/ps.h" - -using namespace ps; - -enum MODE { - PUSH_THEN_PULL = 0, - PUSH_PULL_MIX_ENDLESS = 1, - PUSH_ONLY = 2 -}; -std::unordered_map > mem_map; - -template -void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { - uint64_t key = req_data.keys[0]; - if (req_meta.push) { - CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); - - if (mem_map.find(key) == mem_map.end()) { - PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; - size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].vals.CopyFrom(req_data.vals); - mem_map[key].lens.push_back(len); - } - - // send push response (empty) - KVPairs res; - server->Response(req_meta, res); - } - else { - auto iter = mem_map.find(key); - CHECK_NE(iter, mem_map.end()); - server->Response(req_meta, iter->second); - } -} - -void StartServer() { - if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); - RegisterExitCallback([server]() { delete server; }); -} - -struct PSKV { - SArray keys; // n keys - SArray lens; // the length of the i-th value -}; -std::unordered_map ps_kv_; - -void RunWorker(int argc, char *argv[]) { - if (!IsWorker()) return; - CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; - KVWorker kv(0, 0); - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); - - const int num_servers = krs.size(); - LOG(INFO) << num_servers << " servers in total"; - CHECK_GT(num_servers, 0); - - // init - int len = atoi(argv[1]); - int repeat = atoi(argv[2]); - MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL_MIX_ENDLESS; - - std::vector > server_vals; - for (int server = 0; server < num_servers; server++) { - std::vector vec(len); - SArray vals(vec); - server_vals.push_back(vals); - } - - // init push, do not count this into time cos - for (int server = 0; server < num_servers; server++) { - int key = server; // could be other value - auto vals = server_vals[server]; - PSKV& pskv = ps_kv_[key]; - SArray keys; - ps::Key ps_key = krs[server].begin() + key; - keys.push_back(ps_key); - SArray lens; - lens.push_back(len); - pskv.keys.push_back(ps_key); - pskv.lens.push_back(len); - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - - switch(mode) { - case PUSH_THEN_PULL: { - LOG(INFO) << "PUSH_THEN_PULL mode"; - // push - uint64_t accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - LL << "push " << len * sizeof(float) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - - // pull - accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPull(keys, &vals, &lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - - LL << "pull " << len * sizeof(float) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - } - break; - - case PUSH_PULL_MIX_ENDLESS: { - LOG(INFO) << "PUSH_PULL_MIX_ENDLESS mode, should exit by Ctrl+C"; - std::vector timestamp_list; - auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); - } - if (timestamp_list.size()/2/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(float) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } - } - } break; - case PUSH_ONLY: { - LOG(INFO) << "PUSH_ONLY mode, should exit by Ctrl+C"; - std::vector timestamp_list; - auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - int key = server; - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - } - if (timestamp_list.size()/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(float) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } - } - } break; - - default: - CHECK(0) << "unknown mode " << mode; - } - - -} - -int main(int argc, char *argv[]) { - // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(argc, argv); - // stop system - Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app_ipc_benchmark.cc b/tests/test_kv_app_ipc_benchmark.cc deleted file mode 100644 index 2ecf9a95..00000000 --- a/tests/test_kv_app_ipc_benchmark.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ps/ps.h" -#define DATA_TYPE char -using namespace ps; - -enum MODE { - PUSH_THEN_PULL = 0, - PUSH_PULL_MIX_ENDLESS = 1 -}; -std::unordered_map > mem_map; -std::unordered_map _key_shm_addr; -std::unordered_map _key_shm_size; -std::unordered_map store_; -std::mutex mu_; - -void* OpenSharedMemory(const std::string& prefix, - uint64_t key, size_t size) { - std::string shm_name(prefix); - shm_name += std::to_string(key); - int shm_fd = shm_open(shm_name.c_str(), O_CREAT | O_RDWR, 0666); - CHECK_GE(shm_fd, 0) << "shm_open failed for " << shm_name; - CHECK_GE(ftruncate(shm_fd, size), 0) << strerror(errno); - - void* ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); - CHECK_NE(ptr, (void*)-1) << strerror(errno); - - LOG(INFO) << "initialized share memory size=" << size - << " for key=" << key << ", name=" << shm_name; - _key_shm_addr[shm_name] = ptr; - _key_shm_size[shm_name] = size; - return ptr; -} - -uint64_t DecodeKey(ps::Key key) { - auto kr = ps::Postoffice::Get()->GetServerKeyRanges()[ps::MyRank()]; - return key - kr.begin(); -} - -template -void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { - std::lock_guard lk(mu_); - uint64_t key = DecodeKey(req_data.keys[0]); - if (req_meta.push) { - CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); - - if (mem_map.find(key) == mem_map.end()) { - PS_VLOG(1) << "key " << key << " from worker-" << req_meta.sender; - size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].lens.push_back(len); - - store_[key] = (char*) malloc(len); - mem_map[key].vals = ps::SArray(store_[key], len, false); - } - - // send push response (empty) - KVPairs res; - server->Response(req_meta, res); - } else { // pull request - auto iter = mem_map.find(key); - CHECK_NE(iter, mem_map.end()); - server->Response(req_meta, iter->second); - } -} - -void StartServer() { - if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); - RegisterExitCallback([server]() { delete server; }); -} - -struct PSKV { - SArray keys; // n keys - SArray lens; // the length of the i-th value -}; -std::unordered_map ps_kv_; - -uint64_t EncodeKey(uint64_t seed) { - return seed << 16; -} - -void RunWorker(int argc, char *argv[]) { - if (!IsWorker()) return; - CHECK_GE(argc, 3) << "input argument should be at least 3: SCRIPT, LEN, REPEAT, (OPTIONAL) MODE"; - KVWorker kv(0, 0); - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); - - const int num_servers = krs.size(); - LOG(INFO) << num_servers << " servers in total"; - CHECK_GT(num_servers, 0); - - // init - auto val = Environment::Get()->find("BYTEPS_PARTITION_BYTES"); - unsigned int partition_bytes = val ? atoi(val) : 4096000; - int len = atoi(argv[1]); - CHECK_GE(partition_bytes, len) - << "tensor partition is not supported in this benchmark" - << ", try reduce tensor size or increase BYTEPS_PARTITION_BYTES"; - int repeat = atoi(argv[2]); - MODE mode = (argc > 3) ? static_cast(atoi(argv[3])) : PUSH_PULL_MIX_ENDLESS; - - std::vector > server_vals; - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - auto addr = (char*) OpenSharedMemory(std::string("BytePS_ShM_"), key, len); - SArray vals(addr, len, false); - server_vals.push_back(vals); - } - - // init broadcast - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - auto vals = server_vals[server]; - PSKV& pskv = ps_kv_[key]; - SArray keys; - ps::Key ps_key = krs[server].begin() + key; - keys.push_back(ps_key); - SArray lens; - lens.push_back(len); - pskv.keys.push_back(ps_key); - pskv.lens.push_back(len); - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - - switch(mode) { - case PUSH_THEN_PULL: { - LOG(INFO) << "PUSH_THEN_PULL mode"; - // push - uint64_t accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPush(keys, vals, lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - LL << "push " << len * sizeof(DATA_TYPE) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - - // pull - accumulated_ms = 0; - for (int i = 0; i < repeat; ++i) { - auto start = std::chrono::high_resolution_clock::now(); - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - - kv.Wait(kv.ZPull(keys, &vals, &lens)); - } - auto end = std::chrono::high_resolution_clock::now(); - accumulated_ms += (end - start).count(); // ns - } - - LL << "pull " << len * sizeof(DATA_TYPE) - << " bytes to each server, repeat=" << repeat - << ", total_time=" - << accumulated_ms / 1e6 << "ms"; - } - break; - - case PUSH_PULL_MIX_ENDLESS: { - LOG(INFO) << "PUSH_PULL_MIX_ENDLESS mode, should exit by Ctrl+C"; - std::vector timestamp_list; - auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - auto val = Environment::Get()->find("THRESHOLD"); - unsigned int threshold = val ? atoi(val) : 10; - val = Environment::Get()->find("LOG_DURATION"); - unsigned int log_duration = val ? atoi(val) : 50; - int cnt = 0; - while (1) { - for (int server = 0; server < num_servers; server++) { - auto key = EncodeKey(server); - PSKV& pskv = ps_kv_[key]; - auto keys = pskv.keys; - auto lens = pskv.lens; - auto vals = server_vals[server]; - timestamp_list.push_back(kv.ZPush(keys, vals, lens)); - timestamp_list.push_back(kv.ZPull(keys, &vals, &lens)); - } - if (timestamp_list.size()/2/num_servers >= threshold) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - cnt++; - if (cnt % log_duration == 0) { - end = std::chrono::high_resolution_clock::now(); - LL << "Application goodput: " - << 8.0 * len * sizeof(DATA_TYPE) * num_servers * cnt * threshold / (end - start).count() - << " Gbps"; - cnt = 0; - start = std::chrono::high_resolution_clock::now(); - } - } - } - } break; - default: - CHECK(0) << "unknown mode " << mode; - } -} - -int main(int argc, char *argv[]) { - // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); - setenv("BYTEPS_LOCAL_SIZE", "1", 1); - setenv("BYTEPS_ENABLE_IPC", "1", 0); - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(argc, argv); - // stop system - Finalize(0, true); - // release shm - for (auto &it : _key_shm_addr) { - munmap(it.second, _key_shm_size[it.first]); - shm_unlink(it.first.c_str()); - } - return 0; -} diff --git a/tests/test_kv_app_multi_servers.cc b/tests/test_kv_app_multi_servers.cc deleted file mode 100644 index d4cc44c9..00000000 --- a/tests/test_kv_app_multi_servers.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* -This code only works for 1 worker VS N server -*/ -#include -#include -#include -#include "ps/ps.h" - -using namespace ps; - -std::unordered_map > mem_map; - -template -void EmptyHandler(const KVMeta &req_meta, const KVPairs &req_data, KVServer *server) { - uint64_t key = req_data.keys[0]; - if (req_meta.push) { - CHECK(req_data.lens.size()); - CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); - - if (mem_map.find(key) == mem_map.end()) { - size_t len = (size_t) req_data.vals.size(); - mem_map[key].keys.push_back(key); - mem_map[key].vals.CopyFrom(req_data.vals); - mem_map[key].lens.push_back(len); - } - - // send push response (empty) - KVPairs res; - server->Response(req_meta, res); - } - else { - auto iter = mem_map.find(key); - CHECK_NE(iter, mem_map.end()); - server->Response(req_meta, iter->second); - } -} - -void StartServer() { - if (!IsServer()) return; - int myrank = Postoffice::Get()->my_rank(); - LOG(INFO) << "This is server " << myrank; - auto server = new KVServer(0); - server->set_request_handle(EmptyHandler); - RegisterExitCallback([server]() { delete server; }); -} - -void RunWorker(int argc, char *argv[]) { - if (!IsWorker()) return; - CHECK_EQ(argc, 2) << "input argument should be: [SCRIPT, LEN]"; - KVWorker kv(0, 0); - - auto krs = Postoffice::Get()->GetServerKeyRanges(); - const int num_servers = krs.size(); - CHECK_GT(num_servers, 0); - - // init - int len = atoi(argv[1]); - - std::vector vec(len); - - std::vector > keys(num_servers); - std::vector > lens(num_servers); - std::vector > vals; - - for (int i = 0; i < num_servers ; ++i) { - int key = i; - int server = (key * 9973) % num_servers; - ps::Key ps_key = krs[server].begin() + key; - CHECK_LT(ps_key, krs[server].end()); - - SArray tmp_vals(vec); - - keys[i].push_back(ps_key); - lens[i].push_back(len); - vals.push_back(tmp_vals); - - // init push, to register memory, better not count this into time cost - kv.Wait(kv.ZPush(keys[i], vals[i], lens[i])); - } - - std::vector timestamp_list; - while (1) { - for (int j = 0; j < num_servers; ++j) { - timestamp_list.push_back(kv.ZPush(keys[j], vals[j], lens[j])); - timestamp_list.push_back(kv.ZPull(keys[j], &vals[j], &lens[j])); - } - if (timestamp_list.size() >= 30) { // flow control - for (auto& ts : timestamp_list) { - kv.Wait(ts); - } - timestamp_list.clear(); - } - } -} - -int main(int argc, char *argv[]) { - // disable multi-threaded processing first - setenv("ENABLE_SERVER_MULTIPULL", "0", 1); - // start system - Start(0); - // setup server nodes - StartServer(); - // run worker nodes - RunWorker(argc, argv); - // stop system - Finalize(0, true); - return 0; -} diff --git a/tests/test_kv_app_multi_workers.cc b/tests/test_kv_app_multi_workers.cc deleted file mode 100644 index 636d494b..00000000 --- a/tests/test_kv_app_multi_workers.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include -#include "ps/ps.h" -using namespace ps; - -void StartServer() { - if (!IsServer()) return; - auto server = new KVServer(0); - server->set_request_handle(KVServerDefaultHandle()); - RegisterExitCallback([server](){ delete server; }); -} - -void RunWorker(int customer_id) { - Start(customer_id); - if (!IsWorker()) { - return; - } - KVWorker kv(0, customer_id); - // init - int num = 10000; - std::vector keys(num); - std::vector vals(num); - - int rank = MyRank(); - srand(rank + 7); - for (int i = 0; i < num; ++i) { - keys[i] = kMaxKey / num * i + customer_id; - vals[i] = (rand() % 1000); - } - // push - int repeat = 50; - std::vector ts; - for (int i = 0; i < repeat; ++i) { - ts.push_back(kv.Push(keys, vals)); - - // to avoid too frequency push, which leads huge memory usage - if (i > 10) kv.Wait(ts[ts.size()-10]); - } - for (int t : ts) kv.Wait(t); - - // pull - std::vector rets; - kv.Wait(kv.Pull(keys, &rets)); - - float res = 0; - for (int i = 0; i < num; ++i) { - res += fabs(rets[i] - vals[i] * repeat); - } - CHECK_LT(res / repeat, 1e-5); - LL << "error: " << res / repeat; - // stop system - Finalize(customer_id, true); -} - -int main(int argc, char *argv[]) { - // start system - bool isWorker = (strcmp(argv[1], "worker") == 0); - if (!isWorker) { - Start(0); - // setup server nodes - StartServer(); - Finalize(0, true); - return 0; - } - // run worker nodes - std::thread t0(RunWorker, 0); - std::thread t1(RunWorker, 1); - - t0.join(); - t1.join(); - return 0; -} diff --git a/tests/test_simple_app.cc b/tests/test_simple_app.cc deleted file mode 100644 index cecf5de7..00000000 --- a/tests/test_simple_app.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "ps/ps.h" -using namespace ps; - -int num = 0; - -void ReqHandle(const SimpleData& req, SimpleApp* app) { - CHECK_EQ(req.head, 1); - CHECK_EQ(req.body, "test"); - app->Response(req); - ++ num; -} - -int main(int argc, char *argv[]) { - int n = 100; - Start(0); - SimpleApp app(0, 0); - app.set_request_handle(ReqHandle); - - if (IsScheduler()) { - std::vector ts; - for (int i = 0; i < n; ++i) { - int recver = kScheduler + kServerGroup + kWorkerGroup; - ts.push_back(app.Request(1, "test", recver)); - } - - for (int t : ts) { - app.Wait(t); - } - } - - Finalize(0, true); - - CHECK_EQ(num, n); - return 0; -}