Skip to content

Commit

Permalink
compression: speed up wd momentum with threading.Thread (#9)
Browse files Browse the repository at this point in the history
* wd: use Process

* wd: add log

* wd: fix bug

* wd: check context

* wd: add log

* wd: use threading.Thread

* wd: use deepcopy

* wd: fix

* wd: fix typo
  • Loading branch information
jasperzhong committed Jun 23, 2020
1 parent f6fd608 commit 1e40fd9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
4 changes: 2 additions & 2 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre
for i, param in enumerate(self._params):
byteps_declare_tensor("parameter_" + str(i))
if param.grad_req != 'null':
self._intra_compressors[i] = copy.copy(
self._intra_compressor)
self._intra_compressors[i] = type(self._intra_compressor)(
**self._intra_compressor.__dict__)
byteps_params = dict(
filter(lambda attr: attr[0].startswith(
"byteps_",), param.__dict__.items())
Expand Down
57 changes: 37 additions & 20 deletions byteps/mxnet/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import mxnet
import mxnet.ndarray as nd

import concurrent.futures
import threading
from queue import Queue, Empty


class Compressor(object):
Expand Down Expand Up @@ -66,35 +67,41 @@ def decompress(self, tensor, ctx, *args, **kwargs):

class WeightDecayMomentum(Compressor):
"""For 1bit compression."""
pool = concurrent.futures.ThreadPoolExecutor()

def __init__(self, compressor, mu, wd):
def __init__(self, compressor, mu, wd, *args, **kwargs):
self.compressor = compressor
self.mom = None
self.cache = None
self.mu = mu
self.wd = wd
self.task_queue = Queue()
self.done_queue = Queue()
threading.Thread(target=self._worker, args=(
self.mu, self.wd, self.task_queue, self.done_queue), daemon=True).start()

def __del__(self):
self.task_queue.put('STOP')

@staticmethod
def _wd_mom(x, mom, cache, wd, mu):
nd._internal._mul_scalar(x, wd, out=cache)
mom += cache
nd._internal._mul_scalar(mom, mu, out=mom)
cache += mom
def _worker(mu, wd, input, output):
mom = None
cache = None
for x, _ in iter(input.get, 'STOP'):
if mom is None:
mom = nd.zeros_like(x)
cache = nd.zeros_like(x)

nd._internal._mul_scalar(x, wd, out=cache)
mom += cache
nd._internal._mul_scalar(mom, mu, out=mom)
cache += mom
output.put(cache)

def compress(self, tensor, *args, **kwargs):
"""Returns the tensor unmodified."""
if "x" not in kwargs:
return self.compressor.compress(tensor)

x = kwargs["x"]

if self.mom is None:
self.mom = nd.zeros_like(x)
self.cache = nd.zeros_like(x)

self.future = self.pool.submit(
self._wd_mom, x, self.mom, self.cache, self.wd, self.mu)
self.task_queue.put((x, None))
return self.compressor.compress(tensor)

def decompress(self, tensor, ctx, *args, **kwargs):
Expand All @@ -103,9 +110,10 @@ def decompress(self, tensor, ctx, *args, **kwargs):
x_{t+1} = x_t - \eta_t (tensor + \mu m_t + wd * x_t)
"""
try:
self.future.result(timeout=0.1)
tensor += self.cache
except concurrent.futures.TimeoutError:
tensor += self.done_queue.get(timeout=0.1)
except Empty:
print("empty for wd-momentum")
except TimeoutError:
print("timeout for wd-momentum")
return self.compressor.decompress(tensor, ctx)

Expand All @@ -121,3 +129,12 @@ class Compression(object):

"""Additional Momentum for weight decay. This is only for 1bit. This is a wrapper."""
wdmom = WeightDecayMomentum


# if __name__ == "__main__":
# x = WeightDecayMomentum(Compression.none, 0.9, 1e-4)
# import copy
# print(x.__dict__)
# y = type(x)(**x.__dict__)
# print(y.__dict__)
# print(x.task_queue is y.task_queue)

0 comments on commit 1e40fd9

Please sign in to comment.