From f1112bdf1594b1f9eed69cd99442219b0c31b3ca Mon Sep 17 00:00:00 2001 From: Brendan Fahy Date: Sat, 8 Aug 2020 19:37:16 +0000 Subject: [PATCH] Squashed commit of the following: commit 29fb0506cd38a15c359e369cc8bc4435916b0c78 Author: Brendan Fahy Date: Sat Aug 8 19:35:30 2020 +0000 fix checking for version for docs to build commit 467fd640db02275972c7111af031c86bb59333e9 Author: Brendan Fahy Date: Sat Aug 8 18:56:05 2020 +0000 remove no local test commit a7cc9f88de00feec1a5406874d05313c42bd004c Author: Brendan Fahy Date: Sat Aug 8 18:46:44 2020 +0000 fix commit 3fdbb729da79ae9348c83410a138666bad467951 Author: Brendan Fahy Date: Sat Aug 8 18:23:30 2020 +0000 revert requirements commit 9b8686bd83e2bc243cf329e26f1c667c6949cf67 Author: Brendan Fahy Date: Sat Aug 8 18:16:42 2020 +0000 make it a fixture commit eec74953d24c8b25268d3b6dde3cc4affdd5cb8f Author: Brendan Fahy Date: Sat Aug 8 18:01:32 2020 +0000 fix up the testing commit 896d94a0e60083d52c81db2a036b7f1e015cad11 Author: Brendan Fahy Date: Sat Aug 8 17:47:28 2020 +0000 fix some tests commit 6d22bde19767bf2b71dfd44839b01efdf6888f83 Merge: 6175d4e2 6ebe0d72 Author: Brendan Fahy Date: Sat Aug 8 10:20:47 2020 +0000 Merge remote-tracking branch 'origin/master' into tb_use_gfile commit 6175d4e26b15a43c412c26d501762cd0b570616a Author: Brendan Fahy Date: Fri Aug 7 10:16:36 2020 +0000 Use tensorboard.compat.gfile to support remote writing --- .../callbacks/model_checkpoint.py | 25 +++++--- pytorch_lightning/core/saving.py | 34 ++++++----- pytorch_lightning/loggers/tensorboard.py | 12 ++-- pytorch_lightning/trainer/training_io.py | 17 +++--- pytorch_lightning/utilities/cloud_io.py | 60 ++++++++++++++++++- requirements/base.txt | 1 + 6 files changed, 111 insertions(+), 38 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2458007db218d..09069da8eb805 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities.cloud_io import gfile, makedirs class ModelCheckpoint(Callback): @@ -104,7 +105,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + if(filepath): + filepath = str(filepath) # the tests pass in a py.path.local but we want a str + if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -116,12 +119,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: - if os.path.isdir(filepath): + if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - os.makedirs(self.dirpath, exist_ok=True) + if not gfile.exists(self.dirpath): + makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -163,8 +167,14 @@ def kth_best_model(self): return self.kth_best_model_path def _del_model(self, filepath): - if os.path.isfile(filepath): - os.remove(filepath) + if gfile.exists(filepath): + try: + # in compat mode, remove is not implemented so if running this + # against an actual remove file system and the correct remote + # dependencies exist then this will work fine. + gfile.remove(filepath) + except AttributeError: + os.remove(filepath) def _save_model(self, filepath, trainer, pl_module): @@ -172,7 +182,8 @@ def _save_model(self, filepath, trainer, pl_module): trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + if not gfile.exists(os.path.dirname(filepath)): + makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: @@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module): filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 - while os.path.isfile(filepath): + while gfile.exists(filepath): filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) # this epoch called before version_cnt += 1 diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 68749f3f6cc11..28501fdcda06a 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -11,6 +11,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.cloud_io import gfile, cloud_open PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) @@ -273,30 +274,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: True >>> os.remove(path_csv) """ - if not os.path.isfile(tags_csv): - rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) + if not gfile.exists(tags_csv): + rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) return {} - with open(tags_csv) as fp: - csv_reader = csv.reader(fp, delimiter=',') + with cloud_open(tags_csv, "r", newline="") as fp: + csv_reader = csv.reader(fp, delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} return tags def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: - if not os.path.isdir(os.path.dirname(tags_csv)): - raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.') + if not gfile.isdir(os.path.dirname(tags_csv)): + raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") if isinstance(hparams, Namespace): hparams = vars(hparams) - with open(tags_csv, 'w', newline='') as fp: - fieldnames = ['key', 'value'] + with cloud_open(tags_csv, "w", newline="") as fp: + fieldnames = ["key", "value"] writer = csv.DictWriter(fp, fieldnames=fieldnames) - writer.writerow({'key': 'key', 'value': 'value'}) + writer.writerow({"key": "key", "value": "value"}) for k, v in hparams.items(): - writer.writerow({'key': k, 'value': v}) + writer.writerow({"key": k, "value": v}) def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: @@ -310,11 +311,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: True >>> os.remove(path_yaml) """ - if not os.path.isfile(config_yaml): - rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) + if not gfile.exists(config_yaml): + rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} - with open(config_yaml) as fp: + with cloud_open(config_yaml, "r") as fp: tags = yaml.load(fp) return tags @@ -326,11 +327,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: config_yaml: path to new YAML file hparams: parameters to be saved """ - if not os.path.isdir(os.path.dirname(config_yaml)): - raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') + if not gfile.isdir(os.path.dirname(config_yaml)): + raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") if OMEGACONF_AVAILABLE and isinstance(hparams, Container): from omegaconf import OmegaConf + OmegaConf.save(hparams, config_yaml, resolve=True) return @@ -341,7 +343,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: hparams = dict(hparams) assert isinstance(hparams, dict) - with open(config_yaml, 'w', newline='') as fp: + with cloud_open(config_yaml, "w", newline="") as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index f88b5f97cff8b..ea97d1e69d773 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -16,6 +16,7 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.cloud_io import gfile, makedirs try: from omegaconf import Container, OmegaConf @@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter: return self._experiment assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0' - os.makedirs(self.root_dir, exist_ok=True) + if self.root_dir and not gfile.exists(str(self.root_dir)): + makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> def save(self) -> None: super().save() dir_path = self.log_dir - if not os.path.isdir(dir_path): + if not gfile.isdir(dir_path): dir_path = self.save_dir # prepare the file path @@ -188,13 +190,13 @@ def version(self) -> int: def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not os.path.isdir(root_dir): + if not gfile.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in os.listdir(root_dir): - if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): + for d in gfile.listdir(root_dir): + if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): existing_versions.append(int(d.split("_")[1])) if len(existing_versions) == 0: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index a4137dc3879c4..7a1613b919a26 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -104,6 +104,7 @@ ) from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.cloud_io import gfile, makedirs try: import torch_xla @@ -407,9 +408,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule): did_restore = False # look for hpc weights - folderpath = self.weights_save_path - if os.path.exists(folderpath): - files = os.listdir(folderpath) + folderpath = str(self.weights_save_path) + if gfile.exists(folderpath): + files = gfile.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model @@ -488,15 +489,17 @@ def restore_training_state(self, checkpoint): # ---------------------------------- def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists - os.makedirs(folderpath, exist_ok=True) + folderpath = str(folderpath) # because the tests pass a path object + if not gfile.exists(folderpath): + makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - if not os.path.exists(folderpath): - os.makedirs(folderpath, exist_ok=True) + if not gfile.exists(folderpath): + makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save @@ -549,7 +552,7 @@ def hpc_load(self, folderpath, on_gpu): log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = os.listdir(path) + files = gfile.listdir(str(path)) files = [x for x in files if name_key in x] if len(files) == 0: return 0 diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index a2374ecf5d66d..7ab2656771c00 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -1,11 +1,65 @@ +import sys +import os +from typing import Union from pathlib import Path from urllib.parse import urlparse - import torch +import tensorboard +from packaging import version +from pytorch_lightning import _logger as log + +# we want this for tf.io.gfile, which if tf is installed gives full tf, +# otherwise gives a pruned down version which works for some file backends but +# not all +from tensorboard.compat import tf + +gfile = tf.io.gfile + +pathlike = Union[Path, str] + +# older version of tensorboard had buggy gfile compatibility layers +# only support remote cloud paths if newer + def load(path_or_url: str, map_location=None): if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter return torch.load(path_or_url, map_location=map_location) - else: - return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + + +def modern_gfile(): + """Check the version number of tensorboard. + + Cheking to see if it has the gfile compatibility layers needed for remote + file operations + """ + tb_version = version.parse(tensorboard.version.VERSION) + modern_gfile = tb_version >= version.parse('2.0') + + +def cloud_open(path: pathlike, mode: str, newline:str = None): + if sys.platform == "win32": + log.debug( + "gfile does not handle newlines correctly on windows so remote files are not" + "supported falling back to normal local file open." + ) + return open(path, mode, newline=newline) + if not modern_gfile(): + log.debug( + "tenosrboard.compat gfile does not work on older versions " + "of tensorboard for remote files, using normal local file open." + ) + return open(path, mode, newline=newline) + try: + return gfile.GFile(path, mode) + except NotImplementedError as e: + # minimal dependencies are installed and only local files will work + return open(path, mode, newline=newline) + + +def makedirs(path: pathlike): + if hasattr(gfile, "makedirs") and modern_gfile(): + return gfile.makedirs(str(path)) + # otherwise minimal dependencies are installed and only local files will work + return os.makedirs(path, exist_ok=True) diff --git a/requirements/base.txt b/requirements/base.txt index 4282f6a12d2eb..c40ef16d85be9 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 tqdm>=4.41.0 +packaging