Skip to content

Commit

Permalink
Revert #16401 and user proper CSVLogger (#16405)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
carmocca and Borda authored Jan 17, 2023
1 parent 76b3cd5 commit fc195b9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
7 changes: 0 additions & 7 deletions .github/workflows/ci-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,3 @@ jobs:
flags: ${COVERAGE_SCOPE},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }}
name: CPU-coverage
fail_ci_if_error: false

# TODO
# - name: Testing legacy creation
# working-directory: tests/
# run: |
# export PYTHONPATH=$(dirname $LEGACY_PATH);$PYTHONPATH # for `import tests_pytorch`
# python legacy/simple_classif_training.py
11 changes: 7 additions & 4 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,14 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
return self.dirpath

if len(trainer.loggers) > 0:
logger_ = trainer.loggers[0]
save_dir = getattr(logger_, "save_dir", None) or trainer.default_root_dir
version = logger_.version
if trainer.loggers[0].save_dir is not None:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
name = trainer.loggers[0].name
version = trainer.loggers[0].version
version = version if isinstance(version, str) else f"version_{version}"
ckpt_path = os.path.join(save_dir, str(logger_.name), version, "checkpoints")
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from torch import Tensor

import pytorch_lightning as pl
from lightning_fabric.loggers import CSVLogger
from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
from lightning_fabric.plugins.environments import SLURMEnvironment
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from pytorch_lightning.loggers import Logger, TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, Logger, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT

warning_cache = WarningCache()
Expand Down Expand Up @@ -72,7 +71,7 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
" or `tensorboardX` packages are found."
" Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default"
)
logger_ = CSVLogger(root_dir=self.trainer.default_root_dir) # type: ignore[assignment]
logger_ = CSVLogger(save_dir=self.trainer.default_root_dir) # type: ignore[assignment]
self.trainer.loggers = [logger_]
elif isinstance(logger, Iterable):
self.trainer.loggers = list(logger)
Expand Down
7 changes: 5 additions & 2 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
Expand Down Expand Up @@ -1806,8 +1807,10 @@ def model(self, model: torch.nn.Module) -> None:
@property
def log_dir(self) -> Optional[str]:
if len(self.loggers) > 0:
logger_ = self.loggers[0]
dirpath = getattr(logger_, "log_dir", None) or getattr(logger_, "save_dir", None)
if not isinstance(self.loggers[0], TensorBoardLogger):
dirpath = self.loggers[0].save_dir
else:
dirpath = self.loggers[0].log_dir
else:
dirpath = self.default_root_dir

Expand Down

0 comments on commit fc195b9

Please sign in to comment.