Skip to content

Commit

Permalink
common: enforce page aligned memory (#184)
Browse files Browse the repository at this point in the history
* common: enforce page aligned memory

* common: fix compile

* add missing header file

* fix compile

* fix compile

* final fix, can compile now

* fix setup

* do not add key into push response

* enforce thread safety for getting stored

* server: fix blocking mode

* server: correctly handle fast-worker adjacent pull

* update pslite

* remove hard code value for page size

* disable scheduling by default

* declare page size in global.cc

* mxnet: fix async incompatibility

* server: add async support

* update ps-lite submodule
  • Loading branch information
ymjiang authored and bobzhuyb committed Jan 5, 2020
1 parent 3fba75d commit bb91039
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 114 deletions.
16 changes: 9 additions & 7 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "global.h"
#include <malloc.h>
#include <numa.h>
#include <unistd.h>
#include <sstream>

namespace byteps {
Expand Down Expand Up @@ -46,6 +45,8 @@ std::string BytePSGlobal::_trace_dir;
std::unordered_map<std::string, int> BytePSGlobal::_name2end;
int BytePSGlobal::_output_counter = 0;

int BytePSGlobal::_pagesize = 0;

std::shared_ptr<BytePSComm> BytePSGlobal::_basic_comm;
std::shared_ptr<BytePSSharedMemory> BytePSGlobal::_shm_obj;
std::unordered_map<uint64_t, PSKV> BytePSGlobal::ps_kv_;
Expand Down Expand Up @@ -112,17 +113,17 @@ void BytePSGlobal::Init() {
&_my_role);

_is_root_device = (_my_role == LOCAL_ROOT) ? true : false;

// should round up partition bytes in order to be page aligned
if (getenv("BYTEPS_PARTITION_BYTES")) {
_partition_bytes = atoi(getenv("BYTEPS_PARTITION_BYTES"));
}
BPS_LOG(DEBUG) << "Partition bound set to " << _partition_bytes << " bytes"
<< ", aligned to "
<< AlignTo(_partition_bytes, (8 * _local_size)) << " bytes";
// alignment for Reduce-Scatter/All-Gather
_partition_bytes = AlignTo(_partition_bytes, (8 * _local_size));
_pagesize = sysconf(_SC_PAGESIZE);
BPS_CHECK_GT(_pagesize, 0);
_partition_bytes = RoundUp(_partition_bytes, _local_size * _pagesize);
BPS_LOG(DEBUG) << "Partition size round up to " << _partition_bytes << " (bytes)";

BPS_CHECK(getenv("DMLC_NUM_WORKER")) << "error: env DMLC_NUM_WORKER not set";

_num_worker = atoi(getenv("DMLC_NUM_WORKER"));

if (getenv("BYTEPS_FORCE_DISTRIBUTED")) {
Expand Down Expand Up @@ -522,6 +523,7 @@ PSKV& BytePSGlobal::EncodeDefaultKey(uint64_t key, size_t len) {
BPS_LOG(DEBUG) << "key " << key << " assigned to server " << server
<< ", accumulated workload for this server is "
<< _server_accumulated_len[server];

ps::Key ps_key = krs[server].begin() + key;
BPS_CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
Expand Down
10 changes: 7 additions & 3 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <thread>
#include <unordered_map>
#include <vector>
#include <unistd.h>
#include "common.h"
#include "communicator.h"
#include "cpu_reducer.h"
Expand Down Expand Up @@ -183,10 +184,12 @@ class BytePSGlobal {
// for debug sampling
static uint64_t _sample_key;

static int AlignTo(int input, int alignment) {
return input / alignment * alignment;
}
static int AlignTo(int input, int alignment) { return input / alignment * alignment; }

static int _pagesize;
static int DivUp(int x, int y) { return (x + y - 1) / y; }
static int RoundUp(int x, int y) { return DivUp(x, y) * y; }

// hash functions
static std::string _hash_knob;
static std::hash<std::string> _built_in_hash_fn;
Expand All @@ -195,6 +198,7 @@ class BytePSGlobal {
static uint64_t Hash_BuiltIn(uint64_t key);
static uint64_t Hash_DJB2(uint64_t key);
static uint64_t Hash_SDBM(uint64_t key);

};

} // namespace common
Expand Down
8 changes: 4 additions & 4 deletions byteps/common/scheduled_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ BytePSScheduledQueue::BytePSScheduledQueue(QueueType type) {
}

size_t credit_in_partition = BytePSGlobal::GetNccl()->GetGroupSize() + 1;
if (getenv("BYTEPS_SCHEDULING_CREDIT")) {
credit_in_partition = atoi(getenv("BYTEPS_SCHEDULING_CREDIT"));
}
if (!credit_in_partition) {

auto byteps_scheduling_credit = getenv("BYTEPS_SCHEDULING_CREDIT");
credit_in_partition = byteps_scheduling_credit ? atoi(byteps_scheduling_credit) : 0;
if (!credit_in_partition) { // disable scheduling by default
_is_scheduled = false;
}

Expand Down
26 changes: 22 additions & 4 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,40 @@ def _do_push_pull_param(self, index, delta_weight):

def update(self, index, weight, grad, state):
if self._enable_async:
temp_weight = weight.copy()
# create a tmp list for storing the original weight
temp_weight_list = [w.copy() for w in weight]
assert len(temp_weight_list) == len(weight)

# update parameter locally
self._optimizer.update(index, weight, grad, state)

# get delta weight
for i, temp_weight in enumerate(temp_weight_list):
weight[i].__isub__(temp_weight)

# push delta weight, and pull weight back to the same tensor
weight.__isub__(temp_weight)
self._do_push_pull_param(index, weight)

else:
self._do_push_pull(index, grad)
self._optimizer.update(index, weight, grad, state)

def update_multi_precision(self, index, weight, grad, state):
if self._enable_async:
temp_weight = weight.copy()
# create a tmp list for storing the original weight
temp_weight_list = [w.copy() for w in weight]
assert len(temp_weight_list) == len(weight)

# update parameter locally
self._optimizer.update_multi_precision(index, weight, grad, state)

# get delta weight
for i, temp_weight in enumerate(temp_weight_list):
weight[i].__isub__(temp_weight)

# push delta weight, and pull weight back to the same tensor
weight.__isub__(temp_weight)
self._do_push_pull_param(index, weight)

else:
self._do_push_pull(index, grad)
self._optimizer.update_multi_precision(index, weight, grad, state)
Expand Down
Loading

0 comments on commit bb91039

Please sign in to comment.