diff --git a/byteps/mxnet/__init__.py b/byteps/mxnet/__init__.py index 9c14c9f92..379cc437b 100644 --- a/byteps/mxnet/__init__.py +++ b/byteps/mxnet/__init__.py @@ -222,8 +222,8 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0, compre for i, param in enumerate(self._params): byteps_declare_tensor("parameter_" + str(i)) if param.grad_req != 'null': - self._intra_compressors[i] = copy.copy( - self._intra_compressor) + self._intra_compressors[i] = type(self._intra_compressor)( + **self._intra_compressor.__dict__) byteps_params = dict( filter(lambda attr: attr[0].startswith( "byteps_",), param.__dict__.items()) diff --git a/byteps/mxnet/compression.py b/byteps/mxnet/compression.py index fab9a8ff6..4da5fc430 100644 --- a/byteps/mxnet/compression.py +++ b/byteps/mxnet/compression.py @@ -17,7 +17,8 @@ import mxnet import mxnet.ndarray as nd -import concurrent.futures +import threading +from queue import Queue, Empty class Compressor(object): @@ -66,21 +67,33 @@ def decompress(self, tensor, ctx, *args, **kwargs): class WeightDecayMomentum(Compressor): """For 1bit compression.""" - pool = concurrent.futures.ThreadPoolExecutor() - def __init__(self, compressor, mu, wd): + def __init__(self, compressor, mu, wd, *args, **kwargs): self.compressor = compressor - self.mom = None - self.cache = None self.mu = mu self.wd = wd + self.task_queue = Queue() + self.done_queue = Queue() + threading.Thread(target=self._worker, args=( + self.mu, self.wd, self.task_queue, self.done_queue), daemon=True).start() + + def __del__(self): + self.task_queue.put('STOP') @staticmethod - def _wd_mom(x, mom, cache, wd, mu): - nd._internal._mul_scalar(x, wd, out=cache) - mom += cache - nd._internal._mul_scalar(mom, mu, out=mom) - cache += mom + def _worker(mu, wd, input, output): + mom = None + cache = None + for x, _ in iter(input.get, 'STOP'): + if mom is None: + mom = nd.zeros_like(x) + cache = nd.zeros_like(x) + + nd._internal._mul_scalar(x, wd, out=cache) + mom += cache + nd._internal._mul_scalar(mom, mu, out=mom) + cache += mom + output.put(cache) def compress(self, tensor, *args, **kwargs): """Returns the tensor unmodified.""" @@ -88,13 +101,7 @@ def compress(self, tensor, *args, **kwargs): return self.compressor.compress(tensor) x = kwargs["x"] - - if self.mom is None: - self.mom = nd.zeros_like(x) - self.cache = nd.zeros_like(x) - - self.future = self.pool.submit( - self._wd_mom, x, self.mom, self.cache, self.wd, self.mu) + self.task_queue.put((x, None)) return self.compressor.compress(tensor) def decompress(self, tensor, ctx, *args, **kwargs): @@ -103,9 +110,10 @@ def decompress(self, tensor, ctx, *args, **kwargs): x_{t+1} = x_t - \eta_t (tensor + \mu m_t + wd * x_t) """ try: - self.future.result(timeout=0.1) - tensor += self.cache - except concurrent.futures.TimeoutError: + tensor += self.done_queue.get(timeout=0.1) + except Empty: + print("empty for wd-momentum") + except TimeoutError: print("timeout for wd-momentum") return self.compressor.decompress(tensor, ctx) @@ -121,3 +129,12 @@ class Compression(object): """Additional Momentum for weight decay. This is only for 1bit. This is a wrapper.""" wdmom = WeightDecayMomentum + + +# if __name__ == "__main__": +# x = WeightDecayMomentum(Compression.none, 0.9, 1e-4) +# import copy +# print(x.__dict__) +# y = type(x)(**x.__dict__) +# print(y.__dict__) +# print(x.task_queue is y.task_queue)