From 3712c3310b28f4d256dee1bec893c89c8d890546 Mon Sep 17 00:00:00 2001 From: Yuchen Zhong Date: Sat, 16 May 2020 22:39:14 +0800 Subject: [PATCH] hotfix: update non-compression case register error (#14) --- byteps/mxnet/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/byteps/mxnet/__init__.py b/byteps/mxnet/__init__.py index f4ae1ed3b..5c8b10192 100644 --- a/byteps/mxnet/__init__.py +++ b/byteps/mxnet/__init__.py @@ -253,7 +253,7 @@ def _register_compressor(self, params, optimizer_params, compression_params): for _, param in params.items(): # generic for item in check_list: - if item in compression_params and compression_params[item]: + if compression_params.get(item): if isinstance(compression_params[item], str): setattr(param, "byteps_%s_type" % item, compression_params[item]) @@ -270,12 +270,12 @@ def _register_compressor(self, params, optimizer_params, compression_params): setattr(param, "byteps_compressor_k", compression_params["k"]) - if "momentum" in compression_params: + if compression_params.get("momentum"): setattr(param, "byteps_momentum_mu", optimizer_params["momentum"]) # change - if "momentum" in compression_params: + if compression_params.get("momentum"): # 1bit compressor use an additional momentum for weight decay if compressor == "onebit" and "wd" in optimizer_params: intra_compressor = Compression.wdmom(