diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 09069da8eb805..6fee7bdd6cc6b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -289,7 +289,8 @@ def on_train_start(self, trainer, pl_module): self.dirpath = ckpt_path assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' - os.makedirs(self.dirpath, exist_ok=True) + if not gfile.exists(self.dirpath): + makedirs(self.dirpath) @rank_zero_only def on_validation_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index a7ec3474862f4..23d8920c061d9 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -308,10 +308,12 @@ def on_train_end(self, trainer, pl_module): default_root_dir ^^^^^^^^^^^^^^^^ -Default path for logs and weights when no logger -or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. -On certain clusters you might want to separate where logs and checkpoints -are stored. If you don't then use this argument for convenience. +Default path for logs and weights when no logger or +:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On +certain clusters you might want to separate where logs and checkpoints are +stored. If you don't then use this argument for convenience. Paths can be local +paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials +will need to be set up to use remote filepaths. Example:: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2593ab786bf04..edbba05813ad1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -220,6 +220,7 @@ def __init__( default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' gradient_clip_val: 0 means don't clip. @@ -305,8 +306,10 @@ def __init__( weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. + amp_level: The optimization level to use (O1, O2, etc...). .. warning:: .. deprecated:: v0.7.4 @@ -877,6 +880,9 @@ def default_root_dir(self) -> str: The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ + if "://" in str(self._default_root_dir): + # it is a remote uri, use as is + return self._default_root_dir return os.path.normpath(self._default_root_dir) @property @@ -885,6 +891,9 @@ def weights_save_path(self) -> str: The default root location to save weights (checkpoints), e.g., when the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. """ + if "://" in str(self._weights_save_path): + # it is a remote uri, use as is + return self._weights_save_path return os.path.normpath(self._weights_save_path) # ----------------------------- diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 7ab2656771c00..f6b0f5b42b831 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -1,4 +1,4 @@ -import sys +import platform import os from typing import Union from pathlib import Path @@ -35,14 +35,15 @@ def modern_gfile(): file operations """ tb_version = version.parse(tensorboard.version.VERSION) - modern_gfile = tb_version >= version.parse('2.0') + modern_gfile = tb_version >= version.parse("2.0") + return modern_gfile -def cloud_open(path: pathlike, mode: str, newline:str = None): - if sys.platform == "win32": +def cloud_open(path: pathlike, mode: str, newline: str = None): + if platform.system() == "Windows": log.debug( "gfile does not handle newlines correctly on windows so remote files are not" - "supported falling back to normal local file open." + " supported falling back to normal local file open." ) return open(path, mode, newline=newline) if not modern_gfile():