Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready for review] use gfile to support remote directories #2164

Merged
merged 1 commit into from
Aug 9, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Squashed commit of the following:
commit 29fb0506cd38a15c359e369cc8bc4435916b0c78
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 19:35:30 2020 +0000

    fix checking for version for docs to build

commit 467fd64
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 18:56:05 2020 +0000

    remove no local test

commit a7cc9f8
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 18:46:44 2020 +0000

    fix

commit 3fdbb72
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 18:23:30 2020 +0000

    revert requirements

commit 9b8686b
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 18:16:42 2020 +0000

    make it a fixture

commit eec7495
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 18:01:32 2020 +0000

    fix up the testing

commit 896d94a
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 17:47:28 2020 +0000

    fix some tests

commit 6d22bde
Merge: 6175d4e 6ebe0d7
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Sat Aug 8 10:20:47 2020 +0000

    Merge remote-tracking branch 'origin/master' into tb_use_gfile

commit 6175d4e
Author: Brendan Fahy <bmfahy@gmail.com>
Date:   Fri Aug 7 10:16:36 2020 +0000

    Use tensorboard.compat.gfile to support remote writing
f4hy committed Aug 8, 2020
commit f1112bdf1594b1f9eed69cd99442219b0c31b3ca
25 changes: 18 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs


class ModelCheckpoint(Callback):
@@ -104,7 +105,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -116,12 +119,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if os.path.isdir(filepath):
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
@@ -163,16 +167,23 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)

def _save_model(self, filepath, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
@@ -308,7 +319,7 @@ def on_validation_end(self, trainer, pl_module):

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
34 changes: 18 additions & 16 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
@@ -273,30 +274,30 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
if not gfile.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}

with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
with cloud_open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
if not gfile.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w', newline='') as fp:
fieldnames = ['key', 'value']
with cloud_open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
writer.writerow({"key": "key", "value": "value"})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})
writer.writerow({"key": k, "value": v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
@@ -310,11 +311,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
if not gfile.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}

with open(config_yaml) as fp:
with cloud_open(config_yaml, "r") as fp:
tags = yaml.load(fp)

return tags
@@ -326,11 +327,12 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if not gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf

OmegaConf.save(hparams, config_yaml, resolve=True)
return

@@ -341,7 +343,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
assert isinstance(hparams, dict)

with open(config_yaml, 'w', newline='') as fp:
with cloud_open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)


12 changes: 7 additions & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
from omegaconf import Container, OmegaConf
@@ -109,7 +110,8 @@ def experiment(self) -> SummaryWriter:
return self._experiment

assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
os.makedirs(self.root_dir, exist_ok=True)
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

@@ -162,7 +164,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not os.path.isdir(dir_path):
if not gfile.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
@@ -188,13 +190,13 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not os.path.isdir(root_dir):
if not gfile.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
@@ -104,6 +104,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, makedirs

try:
import torch_xla
@@ -407,9 +408,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
did_restore = False

# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
@@ -488,15 +489,17 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
@@ -549,7 +552,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = gfile.listdir(str(path))
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
60 changes: 57 additions & 3 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,65 @@
import sys
import os
from typing import Union
from pathlib import Path
from urllib.parse import urlparse

import torch

import tensorboard
from packaging import version
from pytorch_lightning import _logger as log

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

gfile = tf.io.gfile

pathlike = Union[Path, str]

# older version of tensorboard had buggy gfile compatibility layers
# only support remote cloud paths if newer


def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
return torch.load(path_or_url, map_location=map_location)
else:
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def modern_gfile():
"""Check the version number of tensorboard.
Cheking to see if it has the gfile compatibility layers needed for remote
file operations
"""
tb_version = version.parse(tensorboard.version.VERSION)
modern_gfile = tb_version >= version.parse('2.0')


def cloud_open(path: pathlike, mode: str, newline:str = None):
if sys.platform == "win32":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
"supported falling back to normal local file open."
)
return open(path, mode, newline=newline)
if not modern_gfile():
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard for remote files, using normal local file open."
)
return open(path, mode, newline=newline)
try:
return gfile.GFile(path, mode)
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
return open(path, mode, newline=newline)


def makedirs(path: pathlike):
if hasattr(gfile, "makedirs") and modern_gfile():
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
packaging