Skip to content

Commit

Permalink
compression: async wd momentum (#6)
Browse files Browse the repository at this point in the history
* compression: trivial updates

* async: use thread pool

* async: update init

* async: fix typo

* async: add self

* async: fix param bug

* async: fix typos

* async: fix bug

* async: update

* async: enlarge pool

* async: debug

* async: debug

* async: use coroutine

* async: use run_coroutine_threadsafe

* async: use shallow copy

* async: test

* async: add log

* async: add result

* async: wait

* async: use run_in_executor

* async: use concurrent.futures

* async: test overhead

* async: test thread

* async: test
  • Loading branch information
jasperzhong committed Jun 23, 2020
1 parent b841f7b commit 7a0bea6
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 36 deletions.
1 change: 0 additions & 1 deletion byteps/common/compressor/strategy/onebit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ void OnebitCompressor::Decompress(ByteBuf compressed, int dtype,

void OnebitCompressor::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
float scale;
if (decompressed.data == nullptr) decompressed.data = _buf.get();
Unpacking(decompressed.data, compressed.data, compressed.size, dtype);
}
Expand Down
25 changes: 12 additions & 13 deletions byteps/common/cpu_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include "global.h"
#endif

#include <omp.h>

#include <cmath>

#include "cpu_reducer.h"
Expand All @@ -44,7 +42,7 @@ CpuReducer::CpuReducer(std::shared_ptr<BytePSComm> comm) {
if (getenv("BYTEPS_OMP_THREAD_PER_GPU")) {
_num_threads = atoi(getenv("BYTEPS_OMP_THREAD_PER_GPU"));
} else {
_num_threads = 1;
_num_threads = 4;
}
return;
}
Expand Down Expand Up @@ -249,16 +247,17 @@ int CpuReducer::_sum_float16(void* dst, const void* src1, const void* src2,
}

int CpuReducer::copy(void* dst, const void* src, size_t len) {
auto in = reinterpret_cast<const float*>(src);
auto out = reinterpret_cast<float*>(dst);
#pragma omp parallel for simd num_threads(_num_threads)
for (size_t i = 0; i < len / 4; ++i) {
out[i] = in[i];
}
if (len % 4) {
std::memcpy(out + len / 4, in + len / 4, len % 4);
}
return 0;
// auto in = reinterpret_cast<const float*>(src);
// auto out = reinterpret_cast<float*>(dst);
// #pragma omp parallel for simd num_threads(_num_threads)
// for (size_t i = 0; i < len / 4; ++i) {
// out[i] = in[i];
// }
// if (len % 4) {
// std::memcpy(out + len / 4, in + len / 4, len % 4);
// }
// return 0;
std::memcpy(dst, src, len);
}

int CpuReducer::sign(void* dst, const void* src, size_t len, DataType dtype) {
Expand Down
8 changes: 3 additions & 5 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
for i, param in enumerate(self._params):
byteps_declare_tensor("parameter_" + str(i))
if param.grad_req != 'null':
self._intra_compressors[i] = copy.deepcopy(
self._intra_compressors[i] = copy.copy(
self._intra_compressor)
byteps_params = dict(
filter(lambda attr: attr[0].startswith(
Expand Down Expand Up @@ -307,11 +307,11 @@ def _allreduce_grads(self):
nd._internal._mul_scalar(
param._grad[0], 1.0 / self._scale / self._bps_size, out=param._grad[0])
compressed, ctx = self._intra_compressors[i].compress(
param._grad[0])
param._grad[0], x=param._data[0])
byteps_push_pull(compressed, is_average=False,
name="gradient_" + str(i), priority=-i)
param._grad[0] = self._intra_compressors[i].decompress(
compressed, ctx, x=param._data[0])
compressed, ctx)

def _init_params(self):
tensors = []
Expand All @@ -324,9 +324,7 @@ def _init_params(self):

if rank() != self.root_rank:
param_arrays[0].__imul__(0)
# compressed, ctx = self._compression.compress(param_arrays[0])
byteps_push_pull(param_arrays[0], version=0, priority=0,
name="parameter_" + str(idx), is_average=False)
# param.set_data(self._compression.decompress(compressed, ctx))

self._params_to_init = tensors
40 changes: 26 additions & 14 deletions byteps/mxnet/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.
# ==============================================================================
"""Gradient compression algorithms."""

import mxnet
import mxnet.ndarray as nd

import concurrent.futures


class Compressor(object):
"""Interface for compressing and decompressing a given tensor."""
Expand Down Expand Up @@ -65,6 +66,7 @@ def decompress(self, tensor, ctx, *args, **kwargs):

class WeightDecayMomentum(Compressor):
"""For 1bit compression."""
pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)

def __init__(self, compressor, mu, wd):
self.compressor = compressor
Expand All @@ -73,28 +75,38 @@ def __init__(self, compressor, mu, wd):
self.mu = mu
self.wd = wd

@staticmethod
def _wd_mom(x, mom, cache, wd, mu):
nd._internal._mul_scalar(x, wd, out=cache)
mom += cache
nd._internal._mul_scalar(mom, mu, out=mom)
cache += mom

def compress(self, tensor, *args, **kwargs):
"""Returns the tensor unmodified."""
if "x" not in kwargs:
return self.compressor.compress(tensor)

x = kwargs["x"]

if self.mom is None:
self.mom = nd.zeros_like(x)
self.cache = nd.zeros_like(x)

self.future = self.pool.submit(
self._wd_mom, x, self.mom, self.cache, self.wd, self.mu)
return self.compressor.compress(tensor)

def decompress(self, tensor, ctx, *args, **kwargs):
"""Returns the tensor added with additional momentum for wd
m_t = \mu * m_{t-1} + wd * x_t
x_{t+1} = x_t - \eta_t (tensor + \mu m_t + wd * x_t)
"""
if "x" not in kwargs:
return self.compressor.decompress(tensor, ctx)

x = kwargs["x"]

if self.mom is None:
self.mom = nd.zeros_like(tensor)
self.cache = nd.zeros_like(tensor)

nd._internal._mul_scalar(x, self.wd, out=self.cache)
self.mom += self.cache
nd._internal._mul_scalar(self.mom, self.mu, out=self.mom)
tensor += self.mom + self.cache
try:
self.future.result(timeout=0.1)
tensor += self.cache
except TimeoutError:
print("timeout")
return self.compressor.decompress(tensor, ctx)


Expand Down
5 changes: 2 additions & 3 deletions byteps/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ void BytePSServerEngineThread(int i) {
if (msg.ops == ALL_RECV) {
// 2. no compress
auto& updates = update_buf_[msg.key];
auto stored = GetStore(msg.key);
updates.merged.tensor = stored->tensor;
updates.merged.len = stored->len;
updates.merged.tensor = reinterpret_cast<char*>(msg.src);
updates.merged.len = msg.len;
}
}

Expand Down

0 comments on commit 7a0bea6

Please sign in to comment.