Skip to content

Commit

Permalink
hotfix: fix bugs (#20)
Browse files Browse the repository at this point in the history
* hotfix: update script

* hotfix: comment out....
  • Loading branch information
jasperzhong committed Jun 23, 2020
1 parent 7f12e15 commit 6c45c79
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def _register_compressor(self, params, optimizer_params, compression_params):
# change
if compression_params.get("momentum"):
# 1bit compressor use an additional momentum for weight decay
# if compressor == "onebit" and "wd" in optimizer_params:
# intra_compressor = Compression.wdmom(
# intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
# del optimizer_params["wd"]
if compressor == "onebit" and "wd" in optimizer_params:
intra_compressor = Compression.wdmom(
intra_compressor, optimizer_params["momentum"], optimizer_params["wd"])
del optimizer_params["wd"]

del optimizer_params['momentum']

Expand Down
26 changes: 13 additions & 13 deletions example/mxnet/train_gluon_imagenet_byteps_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def train(ctx):

best_val_score = 1

bps.byteps_declare_tensor("acc")
# bps.byteps_declare_tensor("acc")
for epoch in range(opt.resume_epoch, opt.num_epochs):
tic = time.time()
if opt.use_rec:
Expand Down Expand Up @@ -505,18 +505,18 @@ def train(ctx):

err_top1_val, err_top5_val = test(ctx, val_data)

acc = mx.nd.array([train_metric_score, err_top1_val, err_top5_val],
ctx=ctx[0])
bps.byteps_push_pull(acc, name="acc", is_average=False)
acc /= bps.size()
train_metric_score, err_top1_val, err_top5_val = acc[0].asscalar(
), acc[1].asscalar(), acc[2].asscalar()

if bps.rank() == 0:
logger.info('[Epoch %d] training: %s=%f' %
(epoch, train_metric_name, train_metric_score))
logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
(epoch, err_top1_val, err_top5_val))
# acc = mx.nd.array([train_metric_score, err_top1_val, err_top5_val],
# ctx=ctx[0])
# bps.byteps_push_pull(acc, name="acc", is_average=False)
# acc /= bps.size()
# train_metric_score, err_top1_val, err_top5_val = acc[0].asscalar(
# ), acc[1].asscalar(), acc[2].asscalar()

# if bps.rank() == 0:
logger.info('[Epoch %d] training: %s=%f' %
(epoch, train_metric_name, train_metric_score))
logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f' %
(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score:
best_val_score = err_top1_val
Expand Down

0 comments on commit 6c45c79

Please sign in to comment.