Skip to content

Commit

Permalink
Fix: hparams.yaml saved twice when using TensorBoardLogger (#5953)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Feb 15, 2021
1 parent ba806c8 commit b5d29df
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
6 changes: 2 additions & 4 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,12 @@ def log_graph(self, model: LightningModule, input_array=None):
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not self._fs.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

# save the metatags file if it doesn't exist
if not self._fs.isfile(hparams_file):
# save the metatags file if it doesn't exist and the log directory exists
if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file):
save_hparams_to_yaml(hparams_file, self.hparams)

@rank_zero_only
Expand Down
12 changes: 12 additions & 0 deletions tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,15 @@ def test_tensorboard_finalize(summary_writer, tmpdir):
logger.finalize("any")
summary_writer().flush.assert_called()
summary_writer().close.assert_called()


def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
model = BoringModel()
logger = TensorBoardLogger(save_dir=tmpdir, default_hp_metric=False)
trainer = Trainer(max_steps=1, default_root_dir=tmpdir, logger=logger)
assert trainer.log_dir == trainer.logger.log_dir
trainer.fit(model)

hparams_file = "hparams.yaml"
assert os.path.isfile(os.path.join(trainer.log_dir, hparams_file))
assert not os.path.isfile(os.path.join(tmpdir, hparams_file))

0 comments on commit b5d29df

Please sign in to comment.