diff --git a/byteps/mxnet/compression.py b/byteps/mxnet/compression.py index 43380e3d9..bb8310c2e 100644 --- a/byteps/mxnet/compression.py +++ b/byteps/mxnet/compression.py @@ -96,7 +96,8 @@ def decompress(self, tensor, ctx, *args, **kwargs): nd._internal._mul_scalar(x, self.wd, out=self.cache) self.mom += self.cache nd._internal._mul_scalar(self.mom, self.mu, out=self.mom) - tensor += self.mom + self.cache + tensor += self.mom + tensor += self.cache return self.compressor.decompress(tensor, ctx)