Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common: enforce page aligned memory #184

Merged
merged 18 commits into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "global.h"
#include <malloc.h>
#include <numa.h>
#include <unistd.h>
#include <sstream>

namespace byteps {
Expand Down Expand Up @@ -46,6 +45,8 @@ std::string BytePSGlobal::_trace_dir;
std::unordered_map<std::string, int> BytePSGlobal::_name2end;
int BytePSGlobal::_output_counter = 0;

int BytePSGlobal::_pagesize = 4096;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you get the page size from syscall instead of a hard-coded number?


std::shared_ptr<BytePSComm> BytePSGlobal::_basic_comm;
std::shared_ptr<BytePSSharedMemory> BytePSGlobal::_shm_obj;
std::unordered_map<uint64_t, PSKV> BytePSGlobal::ps_kv_;
Expand Down Expand Up @@ -112,17 +113,15 @@ void BytePSGlobal::Init() {
&_my_role);

_is_root_device = (_my_role == LOCAL_ROOT) ? true : false;

if (getenv("BYTEPS_PARTITION_BYTES")) {
_partition_bytes = atoi(getenv("BYTEPS_PARTITION_BYTES"));
}
BPS_LOG(DEBUG) << "Partition bound set to " << _partition_bytes << " bytes"
<< ", aligned to "
<< AlignTo(_partition_bytes, (8 * _local_size)) << " bytes";
// alignment for Reduce-Scatter/All-Gather
_partition_bytes = AlignTo(_partition_bytes, (8 * _local_size));
_pagesize = sysconf(_SC_PAGESIZE);
_partition_bytes = RoundUp(_partition_bytes, _local_size * _pagesize);
BPS_LOG(DEBUG) << "Partition size round up to " << _partition_bytes << " (bytes)";

BPS_CHECK(getenv("DMLC_NUM_WORKER")) << "error: env DMLC_NUM_WORKER not set";

_num_worker = atoi(getenv("DMLC_NUM_WORKER"));

if (getenv("BYTEPS_FORCE_DISTRIBUTED")) {
Expand Down Expand Up @@ -522,6 +521,7 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) {
BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server
<< ", accumulated workload for this server is "
<< _server_accumulated_len[server];

ps::Key ps_key = krs[server].begin() + key;
BPS_CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
Expand Down
10 changes: 7 additions & 3 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <thread>
#include <unordered_map>
#include <vector>
#include <unistd.h>
#include "common.h"
#include "communicator.h"
#include "cpu_reducer.h"
Expand Down Expand Up @@ -183,10 +184,12 @@ class BytePSGlobal {
// for debug sampling
static uint64_t _sample_key;

static int AlignTo(int input, int alignment) {
return input / alignment * alignment;
}
static int AlignTo(int input, int alignment) { return input / alignment * alignment; }

static int _pagesize;
static int DivUp(int x, int y) { return (x + y - 1) / y; }
static int RoundUp(int x, int y) { return DivUp(x, y) * y; }

// hash functions
static std::string _hash_knob;
static std::hash<std::string> _built_in_hash_fn;
Expand All @@ -195,6 +198,7 @@ class BytePSGlobal {
static uint64_t Hash_BuiltIn(uint64_t key);
static uint64_t Hash_DJB2(uint64_t key);
static uint64_t Hash_SDBM(uint64_t key);

};

} // namespace common
Expand Down
153 changes: 85 additions & 68 deletions byteps/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,19 @@ using namespace ps;
std::vector<PriorityQueue*> engine_queues_;
std::vector<std::thread *> engine_threads_;

BytePSArray* GetStore(uint64_t key) {
std::lock_guard<std::mutex> lock(store_mu_);
return &store_[key];
}

void SendPushResponse(uint64_t key, const ps::KVMeta& req, ps::KVServer<char>* server){
auto iterator = push_response_map_.find(key);
if (iterator == push_response_map_.end()) { // new key
ps::KVPairs<char> response;
response.keys.push_back(key);
push_response_map_[key] = response; // add to the map
server->Response(req, response);
} else { // not new key, then reuse the memory address to avoid ibv_reg_mr on RDMA data path
ps::KVPairs<char> *response = &iterator->second;
response->keys[0] = key;
server->Response(req, *response);
}
}
Expand All @@ -44,23 +47,23 @@ void SendPullResponse(const DataHandleType type,
const ps::KVMeta& req_meta,
ps::KVServer<char>* server) {
std::lock_guard<std::mutex> lock(pullresp_mu_);
auto& stored = store_[key];
CHECK(stored.tensor) << "init " << key << " first";
// as server returns when store_realt is ready in this case
auto len = stored.len;
auto stored = GetStore(key);
CHECK(stored->tensor) << "init " << key << " first";
auto len = stored->len;

// send pull response
auto iterator = pull_response_map_.find(key);
if (iterator == pull_response_map_.end()) { // new key
ps::KVPairs<char> response;
response.keys = {EncodeKey(key)};
response.lens = {len};
response.vals = ps::SArray<char>(stored.tensor, len, false); // zero copy
response.vals = ps::SArray<char>(stored->tensor, len, false); // zero copy
pull_response_map_[key] = response; // add to the map
server->Response(req_meta, response);
} else { // not new key, then reuse the memory address to avoid ibv_reg_mr on RDMA data path
ps::KVPairs<char> *response = &iterator->second;
// keys and lens remain unchanged, just update vals
auto p = static_cast<char*>(stored.tensor);
auto p = static_cast<char*>(stored->tensor);
CHECK(p);
response->vals = ps::SArray<char>(p, len, false);
server->Response(req_meta, *response);
Expand Down Expand Up @@ -101,19 +104,29 @@ void BytePSServerEngineThread(int i) {
if (is_push_finished_[i].find(msg.key) == is_push_finished_[i].end()) {
is_push_finished_[i][msg.key] = false;
pull_cnt_[i][msg.key] = 0;
seen_sender_[i][msg.key].clear();
}
is_push_finished_[i][msg.key] = true;
for (auto& req_meta : q_pull_reqmeta_[i][msg.key]) {
SendPullResponse(msg.type, msg.key, req_meta, byteps_server_);
pull_cnt_[i][msg.key] += 1;

auto it = q_pull_reqmeta_[i][msg.key].begin();
while (it != q_pull_reqmeta_[i][msg.key].end()) {
if (seen_sender_[i][msg.key].find(it->sender) == seen_sender_[i][msg.key].end()) {
SendPullResponse(msg.type, msg.key, *it, byteps_server_);
pull_cnt_[i][msg.key] += 1;
seen_sender_[i][msg.key].insert(it->sender);
it = q_pull_reqmeta_[i][msg.key].erase(it);
} else {
++it;
}
if (pull_cnt_[i][msg.key] == (size_t) ps::NumWorkers()) {
is_push_finished_[i][msg.key] = false;
pull_cnt_[i][msg.key] = 0;
seen_sender_[i][msg.key].clear();
break;
}
}
q_pull_reqmeta_[i][msg.key].clear();
break;
}
} break;

case SUM_RECV: {
auto bps_type = bps_reducer_->GetDataType(msg.type.dtype);
if (is_debug) {
Expand All @@ -136,9 +149,9 @@ void BytePSServerEngineThread(int i) {
<< "dst_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.dst) << "\t"
<< "src_addr: " << DEBUG_PRINT_TENSOR_ADDRESS(msg.src) << "\t";
}
break;
}
default:
} break;

default:
CHECK(0);
}
}
Expand Down Expand Up @@ -169,10 +182,10 @@ void BytePSHandler(const ps::KVMeta& req_meta,
if (req_meta.push) { // push request
CHECK_EQ(req_data.lens.size(), (size_t)1);
CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]);
auto& stored = store_[key];
auto stored = GetStore(key);
auto len = (size_t) req_data.lens[0];
auto recved = reinterpret_cast<char*>(req_data.vals.data());
if (!stored.tensor) {
if (!stored->tensor) {
if (sync_mode_ && (update_buf_.find(key) == update_buf_.end())) {
update_buf_[key].merged.len = len;
update_buf_[key].merged.dtype = type.dtype;
Expand All @@ -188,11 +201,12 @@ void BytePSHandler(const ps::KVMeta& req_meta,
<< ", init the store buffer size=" << (size_t) req_data.lens[0];
}
// initialization
stored.tensor = (char*) malloc(len);
stored.len = len;
stored.dtype = type.dtype;
CHECK(stored.tensor);
bps_reducer_->copy(stored.tensor, recved, len); // we may not need this copy
PageAlignedMalloc((void**) &stored->tensor, len);
stored->len = len;
stored->dtype = type.dtype;
CHECK(stored->tensor);

bps_reducer_->copy(stored->tensor, recved, len); // we may not need this copy
for (const auto& req : updates.request) {
SendPushResponse(key, req, server);
}
Expand All @@ -202,50 +216,46 @@ void BytePSHandler(const ps::KVMeta& req_meta,
auto tid = GetThreadID(key, len);
if (updates.request.empty()) { // from the first incoming worker
if (sync_mode_) {
if (is_engine_blocking_) {
bps_reducer_->copy(updates.merged.tensor, recved, len);
} else { // non-blocking
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: FIRST_WORKER_RECV \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored.tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t"
<< "len: " << len << "\t"
<< "addr: " << DEBUG_PRINT_TENSOR_ADDRESS(recved);
}
// zero copy
updates.merged.tensor = recved;
updates.merged.tmp_sarray = req_data;
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: FIRST_WORKER_RECV \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored->tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t"
<< "len: " << len << "\t"
<< "addr: " << DEBUG_PRINT_TENSOR_ADDRESS(recved);
}
// zero copy
updates.merged.tensor = recved;
updates.merged.tmp_sarray = req_data;
} else { // async mode, directly add to the buffer
if (is_engine_blocking_) {
CHECK_GE(bps_reducer_->sum((void *) stored.tensor,
CHECK_GE(bps_reducer_->sum((void *) stored->tensor,
(void *) recved,
len,
bps_reducer_->GetDataType(stored.dtype)), 0);
bps_reducer_->GetDataType(stored->dtype)), 0);
} else {
BytePSEngineMessage msg = {timestamp_++, type, key, stored.tensor, recved, len, SUM_RECV, req_data};
BytePSEngineMessage msg = {timestamp_++, type, key, stored->tensor, recved, len, SUM_RECV, req_data};
engine_queues_[tid]->Push(msg);
}
}
} else { // from other workers
CHECK(sync_mode_);
CHECK(updates.merged.tensor);
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: OTHER_WORKER_SUM \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored->tensor) << "\t"
<< "merged: " << DEBUG_PRINT_TENSOR_VALUE(updates.merged.tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t"
<< "len: " << len << "\t"
<< "addr: " << DEBUG_PRINT_TENSOR_ADDRESS(recved);
}
if (is_engine_blocking_) {
CHECK_GE(bps_reducer_->sum((void *) updates.merged.tensor,
(void *) recved,
len,
bps_reducer_->GetDataType(updates.merged.dtype)), 0);
} else { // non-blocking
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: OTHER_WORKER_SUM \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored.tensor) << "\t"
<< "merged: " << DEBUG_PRINT_TENSOR_VALUE(updates.merged.tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved) << "\t"
<< "len: " << len << "\t"
<< "addr: " << DEBUG_PRINT_TENSOR_ADDRESS(recved);
}
BytePSEngineMessage msg = {timestamp_++, type, key, updates.merged.tensor, recved, len, SUM_RECV, req_data, req_meta};
engine_queues_[tid]->Push(msg);
}
Expand All @@ -254,19 +264,19 @@ void BytePSHandler(const ps::KVMeta& req_meta,
updates.request.push_back(req_meta);
SendPushResponse(key, req_meta, server);
if (sync_mode_ && updates.request.size() == (size_t) ps::NumWorkers()) {
auto& stored = store_[key];
auto stored = GetStore(key);
auto& update = updates.merged;
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: COPY_MERGED_TO_STORE \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored->tensor) << "\t"
<< "merged: " << DEBUG_PRINT_TENSOR_VALUE(updates.merged.tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved);
}
if (is_engine_blocking_) {
bps_reducer_->copy(stored.tensor, updates.merged.tensor, len);
bps_reducer_->copy(stored->tensor, updates.merged.tensor, len);
} else {
if (debug_mode_ && (debug_key_ == key)) {
std::lock_guard<std::mutex> lock(debug_mu_);
LOG(INFO) << "stage: COPY_MERGED_TO_STORE \t"
<< "stored: " << DEBUG_PRINT_TENSOR_VALUE(stored.tensor) << "\t"
<< "merged: " << DEBUG_PRINT_TENSOR_VALUE(updates.merged.tensor) << "\t"
<< "recved: " << DEBUG_PRINT_TENSOR_VALUE(recved);
}
BytePSEngineMessage msg = {timestamp_++, type, key, stored.tensor, update.tensor, len, COPY_MERGED};
BytePSEngineMessage msg = {timestamp_++, type, key, stored->tensor, update.tensor, len, COPY_MERGED};
engine_queues_[tid]->Push(msg);
engine_queues_[tid]->ClearCounter(key);
}
Expand All @@ -278,9 +288,8 @@ void BytePSHandler(const ps::KVMeta& req_meta,
}
}
} else { // pull request
auto& stored = store_[key];
CHECK(stored.tensor) << "Processing pull request when the NDArray of key "
<< key << " has not been inited yet, which is not expected.";
auto stored = GetStore(key);
CHECK(stored->tensor) << "Should init the buffer for key=" << key << " first";
if (is_engine_blocking_) {
SendPullResponse(type, key, req_meta, server);
} else {
Expand All @@ -289,20 +298,26 @@ void BytePSHandler(const ps::KVMeta& req_meta,
if (is_push_finished_[tid].find(key) == is_push_finished_[tid].end()) {
is_push_finished_[tid][key] = false;
pull_cnt_[tid][key] = 0;
seen_sender_[tid][key].clear();
}
if (is_push_finished_[tid][key]) { // push already finished

auto it = seen_sender_[tid][key].find(req_meta.sender);
if (is_push_finished_[tid][key] && (it == seen_sender_[tid][key].end())) {
// push already finished && not send the associated pull response yet
SendPullResponse(type, key, req_meta, server);
pull_cnt_[tid][key] += 1;
seen_sender_[tid][key].insert(req_meta.sender);

if (pull_cnt_[tid][key] == (size_t) ps::NumWorkers()) {
is_push_finished_[tid][key] = false;
pull_cnt_[tid][key] = 0;
// check: remain should be 0
auto remain = q_pull_reqmeta_[tid][key].size();
CHECK_EQ(remain, 0) << remain;
seen_sender_[tid][key].clear();
}
} else { // push not finished, put into the queue, and wait for the engine
} else {
// push not finished, put into the queue, and wait for the engine
q_pull_reqmeta_[tid][key].push_back(req_meta);
}

}
}
}
Expand Down Expand Up @@ -346,10 +361,12 @@ extern "C" void byteps_server() {
std::vector<std::mutex> tmp_flagmu(engine_thread_num_);
std::vector<std::unordered_map<uint64_t, bool> > tmp_ispushfinished(engine_thread_num_);
std::vector<std::unordered_map<uint64_t, std::vector<ps::KVMeta> > > tmp_qpullreqmeta(engine_thread_num_);
std::vector<std::unordered_map<uint64_t, std::set<int> > > tmp_seensender(engine_thread_num_);
std::vector<std::unordered_map<uint64_t, size_t> > tmp_pullcnt(engine_thread_num_);
flag_mu_.swap(tmp_flagmu);
is_push_finished_.swap(tmp_ispushfinished);
q_pull_reqmeta_.swap(tmp_qpullreqmeta);
seen_sender_.swap(tmp_seensender);
pull_cnt_.swap(tmp_pullcnt);
CHECK_EQ(flag_mu_.size(), engine_thread_num_);
CHECK_EQ(is_push_finished_.size(), engine_thread_num_);
Expand Down
Loading