Skip to content

Commit

Permalink
rdma: fix WR flush error (#2)
Browse files Browse the repository at this point in the history
* rdma: fix server bug
  • Loading branch information
ymjiang authored Sep 22, 2019
1 parent 101ea1f commit a6ddd1f
Showing 1 changed file with 49 additions and 92 deletions.
141 changes: 49 additions & 92 deletions src/rdma_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Author: [email protected] (Chang Lan)
* [email protected] (Yimin Jiang)
* [email protected] (Jingrong Chen)
*/
*/
#ifndef PS_RDMA_VAN_H_
#define PS_RDMA_VAN_H_

Expand Down Expand Up @@ -34,39 +34,19 @@

namespace ps {

// Number of context buffers for sending START messages
static const int kStartDepth = 128;

// Number of context buffers for writing messages
static const int kWriteDepth = kStartDepth;

// Number of context buffers for receiving messages
static const int kRxDepth = kStartDepth * 2;

// Number of context buffers for sending REPLY messages
static const int kReplyDepth = kRxDepth;

// Maximum number of scatter/gather elements in any Work Request
static const int kSGEntry = 4;

// Time to wait for resolution to complete (in milliseconds)
static const int kTimeoutms = 1000;

// Number of backlog of incoming connection requests
static const int kRdmaListenBacklog = 128;

// Number of preallocated work request buffers
static const int kMaxConcurrentWorkRequest =
kRxDepth + kStartDepth + kReplyDepth + kWriteDepth;

// Length of buffers for storing hostname in the context of a connection request
static const int kMaxHostnameLength = 16;

// Maximum number of ``data'' in a Message
// TODO(changlan): What if there are more data in Message?
static const int kMaxDataFields = 4;

// Alignment in Mempool
static const size_t kAlignment = 8;

template <typename T>
Expand All @@ -79,11 +59,9 @@ static inline T align_ceil(T v, T align) {
return align_floor(v + align - 1, align);
}

// A simple thread-safe memory pool for RDMA memory regions
class SimpleMempool {
public:
// Allocated an initial ``size'' of registered memory regions
explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x1000000) {
explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) {
std::lock_guard<std::mutex> lk(mu_);
pd_ = pd;
struct ibv_mr *mr;
Expand All @@ -92,54 +70,42 @@ class SimpleMempool {
CHECK(p);
CHECK(mr = ibv_reg_mr(pd, p, size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE));
// this mr is associated with memory address range [p, p+size]
mr_list.emplace(p + size, mr);
mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size]
free_list.emplace(size, p);
}

// Deregister and release all memory regions
~SimpleMempool() {
std::lock_guard<std::mutex> lk(mu_);
for (auto it = mr_list.begin(); it != mr_list.end(); it++) {
for(auto it = mr_list.begin(); it != mr_list.end(); it++){
CHECK_EQ(ibv_dereg_mr(it->second), 0);
free(it->second->addr);
}
}

// Take a buffer of ``size'' from the pool. If there is not enough remaining
// space in existing memory regions, allocate and register a new memory
// region.
char *Alloc(size_t size) {
if (size == 0) {
return nullptr;
}

std::lock_guard<std::mutex> lk(mu_);

// Make sure the memory addresses are aligned by rounding the size up to
// next power of two
size_t proper_size = align_ceil(size, kAlignment);

// Find a buffer of size greater than or equal to proper_size
auto it = free_list.lower_bound(proper_size);

if (it == free_list.end()) { // if there is no space left, need to allocate
// and register new memory
if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory
size_t new_mem_size = total_allocated_size;
while (proper_size > new_mem_size) {
new_mem_size *= 2;
}
char *p =
reinterpret_cast<char *>(aligned_alloc(kAlignment, new_mem_size));
char *p = reinterpret_cast<char *>(aligned_alloc(kAlignment, new_mem_size));
CHECK(p);
struct ibv_mr *mr;
CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE));
mr_list.emplace(p + new_mem_size, mr);
CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE));
mr_list.emplace(p+new_mem_size, mr);
free_list.emplace(new_mem_size, p);
it = free_list.lower_bound(proper_size);
PS_VLOG(1) << "Not enough memory in the pool, requested size "
<< proper_size << ", new allocated size " << new_mem_size;
PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size;
total_allocated_size += new_mem_size;
}

Expand All @@ -162,7 +128,6 @@ class SimpleMempool {
return addr;
}

// Return the buffer pointed by ``addr'' into the pool
void Free(char *addr) {
if (!addr) {
return;
Expand All @@ -183,33 +148,29 @@ class SimpleMempool {
struct ibv_mr *mr = Addr2MR(addr);
return mr->lkey;
}

uint32_t RemoteKey(char *addr) {
struct ibv_mr *mr = Addr2MR(addr);
return mr->rkey;
}

private:
std::mutex mu_; // for thread safety
struct ibv_pd *pd_;

// buffer size -> buffer pointer
std::mutex mu_;
std::multimap<size_t, char *> free_list;
// buffer pointer -> buffer size
std::unordered_map<char *, size_t> used_list;
// first: `end` of this mr address (e.g., for mr with [addr, addr+size), point
// to `addr+size`)
std::map<char *, struct ibv_mr *> mr_list;

struct ibv_pd *pd_;
size_t total_allocated_size = 0;

// Convert the memory address to its associated RDMA memory region
inline struct ibv_mr *Addr2MR(char *addr) {
// first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`)
std::map<char *, struct ibv_mr*> mr_list;

// convert the memory address to its associated RDMA memory region
inline struct ibv_mr* Addr2MR(char *addr) {
std::lock_guard<std::mutex> lk(mu_);
auto it = mr_list.lower_bound(addr);
CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region";
return it->second;
}

};

class Block {
Expand Down Expand Up @@ -274,12 +235,15 @@ struct BufferContext {
size_t data_len[kMaxDataFields];
};

typedef std::unique_ptr<struct ibv_mr, std::function<void(struct ibv_mr *)>>
MRPtr;

struct MessageBuffer {
size_t inline_len;
char *inline_buf;
WRContext *reserved_context;
std::vector<SArray<char>> data;
std::vector<std::pair<struct ibv_mr *, size_t>> mrs;
std::vector<std::pair<MRPtr, size_t>> mrs;
};

struct RequestContext {
Expand Down Expand Up @@ -522,8 +486,10 @@ class RDMAVan : public Van {
PS_VLOG(1) << "Clearing mempool.";
mempool_.reset();

for (auto &it : allocated_mr_) {
ibv_dereg_mr(it.second);
auto map_iter = memory_mr_map.begin();
while (map_iter != memory_mr_map.end()) {
ibv_dereg_mr(map_iter->second);
map_iter++;
}

PS_VLOG(1) << "Clearing endpoints.";
Expand All @@ -535,7 +501,9 @@ class RDMAVan : public Van {
CHECK(!ibv_destroy_comp_channel(comp_event_channel_))
<< "Failed to destroy channel";

// TODO: ibv_dealloc_pd sometimes complains about busy resources
// TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this
// CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " <<
// strerror(errno);

PS_VLOG(1) << "Destroying listener.";
rdma_destroy_id(listener_);
Expand Down Expand Up @@ -653,19 +621,19 @@ class RDMAVan : public Van {
int remote_id = msg.meta.recver;
CHECK_NE(remote_id, Meta::kEmpty);

for (auto &sa : msg.data) {
for (auto& sa : msg.data) {
if (sa.size()) {
std::lock_guard<std::mutex> lock(map_mu_);
auto search_map_iterator = allocated_mr_.find(sa.data());
if (search_map_iterator == allocated_mr_.end()) {
auto search_map_iterator = memory_mr_map.find(sa.data());
if (search_map_iterator == memory_mr_map.end()) {
struct ibv_mr *temp_mr;
CHECK(sa.data()) << "address empty";
CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(),
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE))
<< "Failed to register the memory region: "
<< strerror(errno)
<< ", sa.size()=" << sa.size();
allocated_mr_[sa.data()] = temp_mr;
memory_mr_map[sa.data()] = temp_mr;
}
}
}
Expand All @@ -679,12 +647,12 @@ class RDMAVan : public Van {

if (msg.meta.push && msg.meta.request) { // push request
CHECK_EQ(msg.data.size(), 3) << msg.data.size();
CHECK_NE(allocated_mr_.find(msg.data[1].data()), allocated_mr_.end());
CHECK_NE(memory_mr_map.find(msg.data[1].data()), memory_mr_map.end());

auto& vals = msg.data[1];
msg.meta.addr = reinterpret_cast<uint64_t>(vals.data()); // vals address
msg.meta.val_len = vals.size();
msg.meta.option = allocated_mr_[vals.data()]->rkey;
msg.meta.option = memory_mr_map[vals.data()]->rkey;

if (enable_rdma_log_) {
LOG(INFO) << "send push key=" << key
Expand Down Expand Up @@ -733,12 +701,7 @@ class RDMAVan : public Van {

CHECK(meta_len);

// For control messages, inline the message content
// into the START message.
// Otherwise, register the data buffer as RDMA memory
// region.
if (msg.meta.simple_app ||
!msg.meta.control.empty()) { // simple_app or control message
if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message
msg_buf->inline_len = total_len;
msg_buf->inline_buf = mempool_->Alloc(total_len);
meta.SerializeToArray(msg_buf->inline_buf, meta_len);
Expand All @@ -748,26 +711,20 @@ class RDMAVan : public Van {
memcpy(cur, sa.data(), seg_len);
cur += seg_len;
}
} else { // data message
} else { // data message
msg_buf->inline_len = meta_len;
msg_buf->inline_buf = mempool_->Alloc(meta_len);
msg_buf->data = msg.data;
meta.SerializeToArray(msg_buf->inline_buf, meta_len);
if (!is_server) { // worker remains the same
for (auto &sa : msg_buf->data) {
if (sa.size() == 0) {
continue;
}
// Optimization: If the memory region has been registered,
// (assuming the previously registered address is not freed)
// re-use the same memory region.
char *p = sa.data();
auto it = allocated_mr_.find(p);
if (it == allocated_mr_.end()) {
allocated_mr_[p] = ibv_reg_mr(pd_, p, sa.size(), 0);
if (sa.size()) {
auto search_map_iterator = memory_mr_map.find(sa.data());
CHECK_NE(search_map_iterator, memory_mr_map.end()) << "not registered memory region";
MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {});
CHECK(ptr.get()) << strerror(errno);
msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size()));
}
CHECK(allocated_mr_[p]) << "Invalid memory region";
msg_buf->mrs.push_back({allocated_mr_[p], sa.size()});
}
}
}
Expand All @@ -790,11 +747,11 @@ class RDMAVan : public Van {
auto raddr = std::get<1>(key_meta_map_[key][recver]);
auto rkey = std::get<2>(key_meta_map_[key][recver]);

CHECK_EQ(msg_buf->data[1].size(), (unsigned int)len)
<< msg_buf->data[1].size() << ", " << len;
CHECK_EQ(msg_buf->data[1].size(), (unsigned int) len)
<< msg_buf->data[1].size() << ", " << len;

auto temp_mr = allocated_mr_.find(msg_buf->data[1].data());
CHECK_NE(temp_mr, allocated_mr_.end());
auto temp_mr = memory_mr_map.find(msg_buf->data[1].data());
CHECK_NE(temp_mr, memory_mr_map.end());

struct ibv_sge sge;
sge.addr = reinterpret_cast<uint64_t>(msg_buf->data[1].data());
Expand All @@ -815,7 +772,7 @@ class RDMAVan : public Van {
wr.wr.rdma.rkey = rkey;

CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0)
<< "ibv_post_send failed.";
<< "ibv_post_send failed.";
}

WRContext *context = nullptr, *reserved = nullptr;
Expand Down Expand Up @@ -1345,7 +1302,7 @@ class RDMAVan : public Van {
struct rdma_event_channel *event_channel_ = nullptr;
struct ibv_context *context_ = nullptr;

std::unordered_map<char *, struct ibv_mr *> allocated_mr_;
std::unordered_map<char *, struct ibv_mr *> memory_mr_map;

// ibverbs protection domain
struct ibv_pd *pd_ = nullptr;
Expand Down Expand Up @@ -1383,4 +1340,4 @@ class RDMAVan : public Van {
}; // namespace ps

#endif // DMLC_USE_RDMA
#endif // PS_RDMA_VAN_H_
#endif // PS_RDMA_VAN_H_

0 comments on commit a6ddd1f

Please sign in to comment.