-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rdma: fix server bug
- Loading branch information
Showing
1 changed file
with
49 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
* Author: [email protected] (Chang Lan) | ||
* [email protected] (Yimin Jiang) | ||
* [email protected] (Jingrong Chen) | ||
*/ | ||
*/ | ||
#ifndef PS_RDMA_VAN_H_ | ||
#define PS_RDMA_VAN_H_ | ||
|
||
|
@@ -34,39 +34,19 @@ | |
|
||
namespace ps { | ||
|
||
// Number of context buffers for sending START messages | ||
static const int kStartDepth = 128; | ||
|
||
// Number of context buffers for writing messages | ||
static const int kWriteDepth = kStartDepth; | ||
|
||
// Number of context buffers for receiving messages | ||
static const int kRxDepth = kStartDepth * 2; | ||
|
||
// Number of context buffers for sending REPLY messages | ||
static const int kReplyDepth = kRxDepth; | ||
|
||
// Maximum number of scatter/gather elements in any Work Request | ||
static const int kSGEntry = 4; | ||
|
||
// Time to wait for resolution to complete (in milliseconds) | ||
static const int kTimeoutms = 1000; | ||
|
||
// Number of backlog of incoming connection requests | ||
static const int kRdmaListenBacklog = 128; | ||
|
||
// Number of preallocated work request buffers | ||
static const int kMaxConcurrentWorkRequest = | ||
kRxDepth + kStartDepth + kReplyDepth + kWriteDepth; | ||
|
||
// Length of buffers for storing hostname in the context of a connection request | ||
static const int kMaxHostnameLength = 16; | ||
|
||
// Maximum number of ``data'' in a Message | ||
// TODO(changlan): What if there are more data in Message? | ||
static const int kMaxDataFields = 4; | ||
|
||
// Alignment in Mempool | ||
static const size_t kAlignment = 8; | ||
|
||
template <typename T> | ||
|
@@ -79,11 +59,9 @@ static inline T align_ceil(T v, T align) { | |
return align_floor(v + align - 1, align); | ||
} | ||
|
||
// A simple thread-safe memory pool for RDMA memory regions | ||
class SimpleMempool { | ||
public: | ||
// Allocated an initial ``size'' of registered memory regions | ||
explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x1000000) { | ||
explicit SimpleMempool(struct ibv_pd *pd, size_t size = 0x10000000) { | ||
std::lock_guard<std::mutex> lk(mu_); | ||
pd_ = pd; | ||
struct ibv_mr *mr; | ||
|
@@ -92,54 +70,42 @@ class SimpleMempool { | |
CHECK(p); | ||
CHECK(mr = ibv_reg_mr(pd, p, size, | ||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); | ||
// this mr is associated with memory address range [p, p+size] | ||
mr_list.emplace(p + size, mr); | ||
mr_list.emplace(p+size, mr); // this mr is associated with memory address range [p, p+size] | ||
free_list.emplace(size, p); | ||
} | ||
|
||
// Deregister and release all memory regions | ||
~SimpleMempool() { | ||
std::lock_guard<std::mutex> lk(mu_); | ||
for (auto it = mr_list.begin(); it != mr_list.end(); it++) { | ||
for(auto it = mr_list.begin(); it != mr_list.end(); it++){ | ||
CHECK_EQ(ibv_dereg_mr(it->second), 0); | ||
free(it->second->addr); | ||
} | ||
} | ||
|
||
// Take a buffer of ``size'' from the pool. If there is not enough remaining | ||
// space in existing memory regions, allocate and register a new memory | ||
// region. | ||
char *Alloc(size_t size) { | ||
if (size == 0) { | ||
return nullptr; | ||
} | ||
|
||
std::lock_guard<std::mutex> lk(mu_); | ||
|
||
// Make sure the memory addresses are aligned by rounding the size up to | ||
// next power of two | ||
size_t proper_size = align_ceil(size, kAlignment); | ||
|
||
// Find a buffer of size greater than or equal to proper_size | ||
auto it = free_list.lower_bound(proper_size); | ||
|
||
if (it == free_list.end()) { // if there is no space left, need to allocate | ||
// and register new memory | ||
if (it == free_list.end()) { // if there is no space left, need to allocate and register new memory | ||
size_t new_mem_size = total_allocated_size; | ||
while (proper_size > new_mem_size) { | ||
new_mem_size *= 2; | ||
} | ||
char *p = | ||
reinterpret_cast<char *>(aligned_alloc(kAlignment, new_mem_size)); | ||
char *p = reinterpret_cast<char *>(aligned_alloc(kAlignment, new_mem_size)); | ||
CHECK(p); | ||
struct ibv_mr *mr; | ||
CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, | ||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); | ||
mr_list.emplace(p + new_mem_size, mr); | ||
CHECK(mr = ibv_reg_mr(pd_, p, new_mem_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)); | ||
mr_list.emplace(p+new_mem_size, mr); | ||
free_list.emplace(new_mem_size, p); | ||
it = free_list.lower_bound(proper_size); | ||
PS_VLOG(1) << "Not enough memory in the pool, requested size " | ||
<< proper_size << ", new allocated size " << new_mem_size; | ||
PS_VLOG(1) << "Not enough memory in the pool, requested size " << proper_size << ", new allocated size " << new_mem_size; | ||
total_allocated_size += new_mem_size; | ||
} | ||
|
||
|
@@ -162,7 +128,6 @@ class SimpleMempool { | |
return addr; | ||
} | ||
|
||
// Return the buffer pointed by ``addr'' into the pool | ||
void Free(char *addr) { | ||
if (!addr) { | ||
return; | ||
|
@@ -183,33 +148,29 @@ class SimpleMempool { | |
struct ibv_mr *mr = Addr2MR(addr); | ||
return mr->lkey; | ||
} | ||
|
||
uint32_t RemoteKey(char *addr) { | ||
struct ibv_mr *mr = Addr2MR(addr); | ||
return mr->rkey; | ||
} | ||
|
||
private: | ||
std::mutex mu_; // for thread safety | ||
struct ibv_pd *pd_; | ||
|
||
// buffer size -> buffer pointer | ||
std::mutex mu_; | ||
std::multimap<size_t, char *> free_list; | ||
// buffer pointer -> buffer size | ||
std::unordered_map<char *, size_t> used_list; | ||
// first: `end` of this mr address (e.g., for mr with [addr, addr+size), point | ||
// to `addr+size`) | ||
std::map<char *, struct ibv_mr *> mr_list; | ||
|
||
struct ibv_pd *pd_; | ||
size_t total_allocated_size = 0; | ||
|
||
// Convert the memory address to its associated RDMA memory region | ||
inline struct ibv_mr *Addr2MR(char *addr) { | ||
// first: `end` of this mr address (e.g., for mr with [addr, addr+size], point to `addr+size`) | ||
std::map<char *, struct ibv_mr*> mr_list; | ||
|
||
// convert the memory address to its associated RDMA memory region | ||
inline struct ibv_mr* Addr2MR(char *addr) { | ||
std::lock_guard<std::mutex> lk(mu_); | ||
auto it = mr_list.lower_bound(addr); | ||
CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; | ||
return it->second; | ||
} | ||
|
||
}; | ||
|
||
class Block { | ||
|
@@ -274,12 +235,15 @@ struct BufferContext { | |
size_t data_len[kMaxDataFields]; | ||
}; | ||
|
||
typedef std::unique_ptr<struct ibv_mr, std::function<void(struct ibv_mr *)>> | ||
MRPtr; | ||
|
||
struct MessageBuffer { | ||
size_t inline_len; | ||
char *inline_buf; | ||
WRContext *reserved_context; | ||
std::vector<SArray<char>> data; | ||
std::vector<std::pair<struct ibv_mr *, size_t>> mrs; | ||
std::vector<std::pair<MRPtr, size_t>> mrs; | ||
}; | ||
|
||
struct RequestContext { | ||
|
@@ -522,8 +486,10 @@ class RDMAVan : public Van { | |
PS_VLOG(1) << "Clearing mempool."; | ||
mempool_.reset(); | ||
|
||
for (auto &it : allocated_mr_) { | ||
ibv_dereg_mr(it.second); | ||
auto map_iter = memory_mr_map.begin(); | ||
while (map_iter != memory_mr_map.end()) { | ||
ibv_dereg_mr(map_iter->second); | ||
map_iter++; | ||
} | ||
|
||
PS_VLOG(1) << "Clearing endpoints."; | ||
|
@@ -535,7 +501,9 @@ class RDMAVan : public Van { | |
CHECK(!ibv_destroy_comp_channel(comp_event_channel_)) | ||
<< "Failed to destroy channel"; | ||
|
||
// TODO: ibv_dealloc_pd sometimes complains about busy resources | ||
// TODO: ibv_dealloc_pd sometimes complains resource busy, need to fix this | ||
// CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD: " << | ||
// strerror(errno); | ||
|
||
PS_VLOG(1) << "Destroying listener."; | ||
rdma_destroy_id(listener_); | ||
|
@@ -653,19 +621,19 @@ class RDMAVan : public Van { | |
int remote_id = msg.meta.recver; | ||
CHECK_NE(remote_id, Meta::kEmpty); | ||
|
||
for (auto &sa : msg.data) { | ||
for (auto& sa : msg.data) { | ||
if (sa.size()) { | ||
std::lock_guard<std::mutex> lock(map_mu_); | ||
auto search_map_iterator = allocated_mr_.find(sa.data()); | ||
if (search_map_iterator == allocated_mr_.end()) { | ||
auto search_map_iterator = memory_mr_map.find(sa.data()); | ||
if (search_map_iterator == memory_mr_map.end()) { | ||
struct ibv_mr *temp_mr; | ||
CHECK(sa.data()) << "address empty"; | ||
CHECK (temp_mr = ibv_reg_mr(pd_, sa.data(), sa.size(), | ||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE)) | ||
<< "Failed to register the memory region: " | ||
<< strerror(errno) | ||
<< ", sa.size()=" << sa.size(); | ||
allocated_mr_[sa.data()] = temp_mr; | ||
memory_mr_map[sa.data()] = temp_mr; | ||
} | ||
} | ||
} | ||
|
@@ -679,12 +647,12 @@ class RDMAVan : public Van { | |
|
||
if (msg.meta.push && msg.meta.request) { // push request | ||
CHECK_EQ(msg.data.size(), 3) << msg.data.size(); | ||
CHECK_NE(allocated_mr_.find(msg.data[1].data()), allocated_mr_.end()); | ||
CHECK_NE(memory_mr_map.find(msg.data[1].data()), memory_mr_map.end()); | ||
|
||
auto& vals = msg.data[1]; | ||
msg.meta.addr = reinterpret_cast<uint64_t>(vals.data()); // vals address | ||
msg.meta.val_len = vals.size(); | ||
msg.meta.option = allocated_mr_[vals.data()]->rkey; | ||
msg.meta.option = memory_mr_map[vals.data()]->rkey; | ||
|
||
if (enable_rdma_log_) { | ||
LOG(INFO) << "send push key=" << key | ||
|
@@ -733,12 +701,7 @@ class RDMAVan : public Van { | |
|
||
CHECK(meta_len); | ||
|
||
// For control messages, inline the message content | ||
// into the START message. | ||
// Otherwise, register the data buffer as RDMA memory | ||
// region. | ||
if (msg.meta.simple_app || | ||
!msg.meta.control.empty()) { // simple_app or control message | ||
if (msg.meta.simple_app || !msg.meta.control.empty()){ // simple_app or control message | ||
msg_buf->inline_len = total_len; | ||
msg_buf->inline_buf = mempool_->Alloc(total_len); | ||
meta.SerializeToArray(msg_buf->inline_buf, meta_len); | ||
|
@@ -748,26 +711,20 @@ class RDMAVan : public Van { | |
memcpy(cur, sa.data(), seg_len); | ||
cur += seg_len; | ||
} | ||
} else { // data message | ||
} else { // data message | ||
msg_buf->inline_len = meta_len; | ||
msg_buf->inline_buf = mempool_->Alloc(meta_len); | ||
msg_buf->data = msg.data; | ||
meta.SerializeToArray(msg_buf->inline_buf, meta_len); | ||
if (!is_server) { // worker remains the same | ||
for (auto &sa : msg_buf->data) { | ||
if (sa.size() == 0) { | ||
continue; | ||
} | ||
// Optimization: If the memory region has been registered, | ||
// (assuming the previously registered address is not freed) | ||
// re-use the same memory region. | ||
char *p = sa.data(); | ||
auto it = allocated_mr_.find(p); | ||
if (it == allocated_mr_.end()) { | ||
allocated_mr_[p] = ibv_reg_mr(pd_, p, sa.size(), 0); | ||
if (sa.size()) { | ||
auto search_map_iterator = memory_mr_map.find(sa.data()); | ||
CHECK_NE(search_map_iterator, memory_mr_map.end()) << "not registered memory region"; | ||
MRPtr ptr(search_map_iterator->second, [](struct ibv_mr *mr) {}); | ||
CHECK(ptr.get()) << strerror(errno); | ||
msg_buf->mrs.push_back(std::make_pair(std::move(ptr), sa.size())); | ||
} | ||
CHECK(allocated_mr_[p]) << "Invalid memory region"; | ||
msg_buf->mrs.push_back({allocated_mr_[p], sa.size()}); | ||
} | ||
} | ||
} | ||
|
@@ -790,11 +747,11 @@ class RDMAVan : public Van { | |
auto raddr = std::get<1>(key_meta_map_[key][recver]); | ||
auto rkey = std::get<2>(key_meta_map_[key][recver]); | ||
|
||
CHECK_EQ(msg_buf->data[1].size(), (unsigned int)len) | ||
<< msg_buf->data[1].size() << ", " << len; | ||
CHECK_EQ(msg_buf->data[1].size(), (unsigned int) len) | ||
<< msg_buf->data[1].size() << ", " << len; | ||
|
||
auto temp_mr = allocated_mr_.find(msg_buf->data[1].data()); | ||
CHECK_NE(temp_mr, allocated_mr_.end()); | ||
auto temp_mr = memory_mr_map.find(msg_buf->data[1].data()); | ||
CHECK_NE(temp_mr, memory_mr_map.end()); | ||
|
||
struct ibv_sge sge; | ||
sge.addr = reinterpret_cast<uint64_t>(msg_buf->data[1].data()); | ||
|
@@ -815,7 +772,7 @@ class RDMAVan : public Van { | |
wr.wr.rdma.rkey = rkey; | ||
|
||
CHECK_EQ(ibv_post_send(endpoint->cm_id->qp, &wr, &bad_wr), 0) | ||
<< "ibv_post_send failed."; | ||
<< "ibv_post_send failed."; | ||
} | ||
|
||
WRContext *context = nullptr, *reserved = nullptr; | ||
|
@@ -1345,7 +1302,7 @@ class RDMAVan : public Van { | |
struct rdma_event_channel *event_channel_ = nullptr; | ||
struct ibv_context *context_ = nullptr; | ||
|
||
std::unordered_map<char *, struct ibv_mr *> allocated_mr_; | ||
std::unordered_map<char *, struct ibv_mr *> memory_mr_map; | ||
|
||
// ibverbs protection domain | ||
struct ibv_pd *pd_ = nullptr; | ||
|
@@ -1383,4 +1340,4 @@ class RDMAVan : public Van { | |
}; // namespace ps | ||
|
||
#endif // DMLC_USE_RDMA | ||
#endif // PS_RDMA_VAN_H_ | ||
#endif // PS_RDMA_VAN_H_ |