Skip to content

Commit

Permalink
fix logging wrt primary host
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Feb 9, 2021
1 parent c928fa9 commit 6d9a587
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def total_ensemble_xent_loss(params, x, y_true):
sample_data=opts.sample_data)

for imgs, labels in train_ds:

logging.info("labels %s" % labels)

# replicate batch across M devices
imgs = u.replicate(imgs) # (M, B, H, W, 3)
labels = u.replicate(labels) # (M, B)
Expand Down Expand Up @@ -206,13 +209,14 @@ def total_ensemble_xent_loss(params, x, y_true):
wandb.log({'validation_loss': mean_validation_loss}, step=epoch)

# close out wandb run
if wandb_enabled and u.primary_host():
wandb.log({'final_validation_loss': mean_validation_loss},
step=opts.epochs)
wandb.join()
else:
logging.info("finished %s final validation_loss %f" %
(run, mean_validation_loss))
if u.primary_host():
if wandb_enabled:
wandb.log({'final_validation_loss': mean_validation_loss},
step=opts.epochs)
wandb.join()
else:
logging.info("finished %s final validation_loss %f" %
(run, mean_validation_loss))

# return validation loss to ax
return mean_validation_loss
Expand Down

0 comments on commit 6d9a587

Please sign in to comment.