From 1eb1d17e250d11184afa59b67eb4641a23fb0523 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 30 Sep 2020 12:11:19 -0700 Subject: [PATCH 1/6] Add trainer attribute to datamodule (#3749) * Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter * Store a reference to the trainer on the datamodule Fixes #3682 * Update data_connector.py * Update data_connector.py * Update test_datamodules.py * Add attribute to datamodule for trainer --- pytorch_lightning/core/datamodule.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 99a95f6598def5..51928c757b8d4e 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -147,6 +147,9 @@ def __init__( self._test_transforms = test_transforms self._dims = dims if dims is not None else () + # Pointer to the trainer object + self.trainer = None + # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False From cf182e80fca0244eadbbe0038a675c3381ca40a7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 30 Sep 2020 16:15:29 -0400 Subject: [PATCH 2/6] Finish Allow on_save_checkpoint... (#3688) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Finish #3562 * Apply suggestions from code review * Apply suggestions from code review * fix tests * Finish #3562 * Apply suggestions from code review * Apply suggestions from code review * fix tests * fix structure * fix structure * make save_last test pass * unnecessary global rank check * fix test * update test * update test * test * test * run save on all * remove assert * tracking saves * check if fails * test * clean up * adjust horovod test * clean up * remove unnecessary makdirs * change * undo * debug * debug * debug * debug * mock * undo debug code * add extra assertions * test Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Adrian Wälchli --- .../callbacks/model_checkpoint.py | 49 +++++----- tests/callbacks/test_model_checkpoint.py | 94 +++++++++++-------- .../data/horovod/train_default_model.py | 3 - 3 files changed, 83 insertions(+), 63 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2145940c0c03d0..d29b7e4ce8777a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -31,7 +31,7 @@ import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -176,16 +176,16 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - @rank_zero_only def save_checkpoint(self, trainer, pl_module): """ - Performs the main logic around saving a checkpoint + Performs the main logic around saving a checkpoint. + This method runs on all ranks, it is the responsibility of `self.save_function` + to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ epoch = trainer.current_epoch if ( - trainer.global_rank != 0 # only run on main process - or self.save_top_k == 0 # no models are saved + self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check @@ -207,13 +207,13 @@ def save_checkpoint(self, trainer, pl_module): # callback supports multiple simultaneous modes # here we call each mode sequentially - # Mode 1: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) - - # Mode 2: save all checkpoints OR only the top k + # Mode 1: save all checkpoints OR only the top k if self.save_top_k: self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath) + # Mode 2: save the last checkpoint + self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException( @@ -255,7 +255,6 @@ def __init_ckpt_dir(self, filepath, save_top_k): if self._fs.protocol == "file": # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - self._fs.makedirs(self.dirpath, exist_ok=True) def __init_monitor_mode(self, monitor, mode): torch_inf = torch.tensor(np.Inf) @@ -276,18 +275,21 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): self._fs.rm(filepath) + log.debug(f"Removed checkpoint: {filepath}") def _save_model(self, filepath: str, trainer, pl_module): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) - # delegate the saving to the model + # delegate the saving to the trainer if self.save_function is not None: self.save_function(filepath, self.save_weights_only) else: @@ -325,7 +327,7 @@ def _format_checkpoint_name( filename = "{epoch}" # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) - if groups: + if len(groups) >= 0: metrics["epoch"] = epoch for group in groups: name = group[1:] @@ -404,10 +406,8 @@ def __resolve_ckpt_dir(self, trainer, pl_module): self.dirpath = ckpt_path - assert ( - trainer.global_rank == 0 - ), "tried to make a checkpoint from non global_rank=0" - self._fs.makedirs(self.dirpath, exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(self.dirpath, exist_ok=True) def _add_backward_monitor_support(self, trainer): metrics = trainer.logger_connector.callback_metrics @@ -460,13 +460,18 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: - filename = self._format_checkpoint_name( + last_filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix ) - last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt") + last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") self._save_model(last_filepath, trainer, pl_module) - if self.last_model_path and self.last_model_path != last_filepath and (self.save_top_k != -1 or self.save_last): + if ( + self.last_model_path + and self.last_model_path != last_filepath + and (self.save_top_k != -1 or self.save_last) + and trainer.is_global_zero + ): self._del_model(self.last_model_path) self.last_model_path = last_filepath @@ -490,7 +495,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): elif self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, trainer, pl_module) elif self.verbose: - log.info( + rank_zero_info( f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" ) @@ -528,7 +533,7 @@ def _update_best_and_save( self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: - log.info( + rank_zero_info( f"Epoch {epoch:d}: {self.monitor} reached" f" {current:0.5f} (best {self.best_model_score:0.5f})," f" saving model to {filepath} as top {k}" diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 64de5242394dfd..4fdd03e1e6ed8b 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -1,4 +1,6 @@ import os +from unittest.mock import MagicMock, Mock + import yaml import pickle import platform @@ -92,21 +94,26 @@ class ModelCheckpointTestInvocations(ModelCheckpoint): def __init__(self, expected_count, *args, **kwargs): super().__init__(*args, **kwargs) - self.count = 0 self.expected_count = expected_count + self.on_save_checkpoint_count = 0 + + def on_train_start(self, trainer, pl_module): + torch.save = Mock(wraps=torch.save) - def _save_model(self, filepath, trainer, pl_module): - # make sure we don't save twice - assert not os.path.isfile(filepath) - self.count += 1 - super()._save_model(filepath, trainer, pl_module) + def on_save_checkpoint(self, trainer, pl_module): + # expect all ranks to run but only rank 0 will actually write the checkpoint file + super().on_save_checkpoint(trainer, pl_module) + self.on_save_checkpoint_count += 1 def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) - # on rank 0 we expect the saved files and on all others no saves - assert (trainer.global_rank == 0 and self.count == self.expected_count) or ( - trainer.global_rank > 0 and self.count == 0 - ) + assert self.best_model_path + assert self.best_model_score + assert self.on_save_checkpoint_count == self.expected_count + if trainer.is_global_zero: + assert torch.save.call_count == self.expected_count + else: + assert torch.save.call_count == 0 @pytest.mark.skipif( @@ -220,34 +227,6 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(filepath=tmpdir, save_last=False) -def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): - """Tests that the save_last checkpoint contains the latest information.""" - seed_everything(100) - model = EvalModelTemplate() - num_epochs = 3 - model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True) - trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=False, - checkpoint_callback=model_checkpoint, - max_epochs=num_epochs, - ) - trainer.fit(model) - - path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) - assert path_last_epoch != model_checkpoint.last_model_path - - ckpt_last_epoch = torch.load(path_last_epoch) - ckpt_last = torch.load(model_checkpoint.last_model_path) - assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) - - # it is easier to load the model objects than to iterate over the raw dict of tensors - model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) - model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path) - for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): - assert w0.eq(w1).all() - - def test_model_checkpoint_none_monitor(tmpdir): """ Test that it is possible to save all checkpoints when monitor=None. """ model = EvalModelTemplate() @@ -439,3 +418,42 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v ) trainer.fit(model) assert caplog.messages.count('Saving latest checkpoint...') == save_last + + +def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): + """ Tests that the save_last checkpoint contains the latest information. """ + seed_everything(100) + model = EvalModelTemplate() + num_epochs = 3 + model_checkpoint = ModelCheckpoint( + monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True + ) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=model_checkpoint, + max_epochs=num_epochs, + ) + trainer.fit(model) + + path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") + path_last = str(tmpdir / f"last.ckpt") + assert path_last == model_checkpoint.last_model_path + + ckpt_last_epoch = torch.load(path_last_epoch) + ckpt_last = torch.load(path_last) + assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) + + ch_type = type(model_checkpoint) + assert all(list( + ckpt_last["callbacks"][ch_type][k] == ckpt_last_epoch["callbacks"][ch_type][k] + for k in ("best_model_score", "best_model_path") + )) + + # it is easier to load the model objects than to iterate over the raw dict of tensors + model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) + model_last = EvalModelTemplate.load_from_checkpoint( + model_checkpoint.last_model_path + ) + for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): + assert w0.eq(w1).all() diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 1f9e6fe8cefd18..4e9ae581dfa572 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -66,9 +66,6 @@ def run_test_from_config(trainer_options): assert hvd.size() == 2 if trainer.global_rank > 0: - # on higher ranks the checkpoint location is unknown - # we want to test checkpointing on rank 0 only - assert not trainer.checkpoint_callback.best_model_path return # test model loading From a38d108a6831c248e08f078d74a091cff89fbb2a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 01:21:38 -0400 Subject: [PATCH 3/6] add dist lib to enable syncing anything across devices (#3762) * add dist lib to enable syncing anything across devices --- docs/source/index.rst | 1 + .../accelerators/base_backend.py | 5 +++ .../accelerators/ddp_base_backend.py | 5 +++ .../accelerators/ddp_cpu_spawn_backend.py | 5 +++ pytorch_lightning/accelerators/dp_backend.py | 3 +- pytorch_lightning/accelerators/gpu_backend.py | 2 ++ .../accelerators/horovod_backend.py | 4 +++ .../callbacks/model_checkpoint.py | 6 ++-- pytorch_lightning/distributed/__init__.py | 1 + pytorch_lightning/distributed/dist.py | 36 +++++++++++++++++++ pytorch_lightning/trainer/logging.py | 2 +- pytorch_lightning/trainer/training_loop.py | 3 ++ tests/callbacks/test_model_checkpoint.py | 2 +- tests/core/test_datamodules.py | 1 + 14 files changed, 71 insertions(+), 5 deletions(-) create mode 100644 pytorch_lightning/distributed/__init__.py create mode 100644 pytorch_lightning/distributed/dist.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 32f4ab3dbbdc89..ec85f7e043df35 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -141,3 +141,4 @@ Indices and tables api/pytorch_lightning.utilities api/pytorch_lightning.tuner api/pytorch_lightning.plugins + api/pytorch_lightning.distributed diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index 0afedf14ab74e9..60ea76aaa72845 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -7,6 +7,7 @@ from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.parsing import AttributeDict try: from apex import amp @@ -21,6 +22,7 @@ class Accelerator(object): def __init__(self, trainer): self.trainer = trainer + self.dist = AttributeDict(rank=0, device=None) def setup(self, model): pass @@ -31,6 +33,9 @@ def teardown(self): def barrier(self, name: str = None): pass + def broadcast(self, obj, src=0): + return obj + def train_or_test(self): if self.trainer.testing: results = self.trainer.run_test() diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py index 526b06d59d58a0..35dc89abe620a6 100644 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.distributed.dist import LightningDistributed try: from hydra.core.hydra_config import HydraConfig @@ -38,6 +39,7 @@ class DDPBase(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.dist = LightningDistributed() def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: @@ -177,6 +179,9 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs if self.trainer.global_rank == 0: return results + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + def set_world_ranks(self, process_idx): raise NotImplementedError('to create a ddp backend, please implement set_world_ranks') diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index 58577be61835f4..2f0c6d29c7ebd8 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.distributed.dist import LightningDistributed try: from hydra.core.hydra_config import HydraConfig @@ -41,6 +42,7 @@ def __init__(self, trainer, nprocs): super().__init__(trainer) self.mp_queue = None self.nprocs = nprocs + self.dist = LightningDistributed() def setup(self, model): os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) @@ -174,6 +176,9 @@ def test_step(self, args): def barrier(self, name: str = None): torch_distrib.barrier() + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 87d00bcd8bc8c8..0bf7e18bc24785 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -16,7 +16,7 @@ from torch import optim from pytorch_lightning.accelerators.base_backend import Accelerator -from pytorch_lightning.core import LightningModule +from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningDataParallel from pytorch_lightning.utilities import AMPType @@ -28,6 +28,7 @@ class DataParallelBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.model_autocast_original_forward = None + self.dist = LightningDistributed() def setup(self, model): # call setup after the ddp process has connected diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index d3c6c59160ac4d..ea1d57cceaf3e5 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -16,6 +16,7 @@ from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.utilities import AMPType +from pytorch_lightning.distributed.dist import LightningDistributed class GPUBackend(Accelerator): @@ -23,6 +24,7 @@ class GPUBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) + self.dist = LightningDistributed() def setup(self, model): diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index cfdf80fa2b2648..2fcf75c215cf79 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -158,3 +158,7 @@ def on_train_epoch_end(self): def barrier(self, name: str = None): hvd.join() + + def broadcast(self, obj, src=0): + obj = hvd.broadcast_object(obj, src) + return obj diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d29b7e4ce8777a..4357163f46c021 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -366,7 +366,6 @@ def format_checkpoint_name( ckpt_name = f"{filename}.ckpt" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - @rank_zero_only def __resolve_ckpt_dir(self, trainer, pl_module): """ Determines model checkpoint save directory at runtime. References attributes from the @@ -398,8 +397,11 @@ def __resolve_ckpt_dir(self, trainer, pl_module): if isinstance(trainer.logger.version, str) else f"version_{trainer.logger.version}" ) + + version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name)) + ckpt_path = os.path.join( - save_dir, trainer.logger.name, version, "checkpoints" + save_dir, name, version, "checkpoints" ) else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") diff --git a/pytorch_lightning/distributed/__init__.py b/pytorch_lightning/distributed/__init__.py new file mode 100644 index 00000000000000..15540f7a7d63e6 --- /dev/null +++ b/pytorch_lightning/distributed/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.distributed.dist import LightningDistributed diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py new file mode 100644 index 00000000000000..a3f0378f0a0627 --- /dev/null +++ b/pytorch_lightning/distributed/dist.py @@ -0,0 +1,36 @@ +import io +import torch +from typing import Any +from torch import distributed as torch_distrib + + +class LightningDistributed: + + def __init__(self, rank=None, device=None): + self.rank = rank + self.device = device + + def broadcast(self, obj: Any): + if self.rank == 0: + self._emit(obj) + else: + obj = self._receive() + return obj + + def _emit(self, obj): + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.tensor([len(data)]).long().to(self.device) + length_tensor = torch_distrib.broadcast(length_tensor, src=0) + data_tensor = torch.ByteTensor(data).to(self.device) + data_tensor = torch_distrib.broadcast(data_tensor, src=0) + + def _receive(self): + length_tensor = torch.tensor([0]).long().to(self.device) + torch_distrib.broadcast(length_tensor, src=0) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) + torch_distrib.broadcast(data_tensor, src=0) + buffer = io.BytesIO(data_tensor.cpu().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index ff8aab3743759d..26572feb5e35ea 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -69,7 +69,7 @@ def process_dict_result(self, output, train=False): m = inspect.cleandoc( f"""The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0 Please use self.log(...) inside the lightningModule instead. - + # log on a step or aggregate epoch metric to the logger and/or progress bar # (inside LightningModule) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 428e081a5cab1e..e670c01f041564 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -108,6 +108,9 @@ def setup_training(self, model: LightningModule): if self.trainer.data_parallel: ref_model = model.module + self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank + self.trainer.accelerator_backend.dist.device = ref_model.device + # give model convenience properties ref_model.trainer = self.trainer diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 4fdd03e1e6ed8b..fd394903183d37 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -437,7 +437,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") - path_last = str(tmpdir / f"last.ckpt") + path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path ckpt_last_epoch = torch.load(path_last_epoch) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 4e4476955adbae..5325ca828e47b7 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -235,6 +235,7 @@ def test_dm_checkpoint_save(tmpdir): assert dm.__class__.__name__ in checkpoint assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ + def test_test_loop_only(tmpdir): reset_seed() From 5ec00ccd28392d7200e8ec5e31322ad75b2567a3 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 1 Oct 2020 01:36:34 -0400 Subject: [PATCH 4/6] Added gradient clip test for native AMP (#3754) * added gradient clip test for fp16 * pep8 --- tests/trainer/test_trainer.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b2860cbe32ac13..d27a701cfae473 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -23,6 +23,7 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from tests.base import EvalModelTemplate @@ -867,6 +868,36 @@ def _optimizer_step(*args, **kwargs): trainer.fit(model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.") +def test_gradient_clipping_fp16(tmpdir): + """ + Test gradient clipping with fp16 + """ + + model = EvalModelTemplate() + + # test that gradient is clipped correctly + def _optimizer_step(*args, **kwargs): + parameters = model.parameters() + grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) + + trainer = Trainer( + max_steps=1, + max_epochs=1, + precision=16, + gpus=1, + gradient_clip_val=1.0, + default_root_dir=tmpdir, + ) + + # for the test + model.optimizer_step = _optimizer_step + model.prev_called_batch_idx = 0 + + trainer.fit(model) + def test_gpu_choice(tmpdir): trainer_options = dict( default_root_dir=tmpdir, From 7c61fc7c27ef81354af399c04e939e57c65ce046 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 02:31:11 -0400 Subject: [PATCH 5/6] ref: fixes logging for eval steps (#3763) * fixes logging for eval steps --- pl_examples/basic_examples/image_classifier.py | 3 ++- pytorch_lightning/callbacks/model_checkpoint.py | 15 +++++---------- .../trainer/connectors/logger_connector.py | 3 +++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pl_examples/basic_examples/image_classifier.py b/pl_examples/basic_examples/image_classifier.py index c453822d02af26..04e965c0491f6f 100644 --- a/pl_examples/basic_examples/image_classifier.py +++ b/pl_examples/basic_examples/image_classifier.py @@ -54,13 +54,14 @@ def training_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) loss = F.cross_entropy(y_hat, y) - self.log('valid_loss', loss) + self.log('valid_loss', loss, on_step=True) def test_step(self, batch, batch_idx): x, y = batch diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4357163f46c021..f0a8e159a2fd77 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -296,6 +296,9 @@ def _save_model(self, filepath: str, trainer, pl_module): raise ValueError(".save_function() not set") def check_monitor_top_k(self, current) -> bool: + if current is None: + return False + if self.save_top_k == -1: return True @@ -421,7 +424,7 @@ def _add_backward_monitor_support(self, trainer): if self.monitor is None and 'checkpoint_on' in metrics: self.monitor = 'checkpoint_on' - if self.save_top_k is None: + if self.save_top_k is None and self.monitor is not None: self.save_top_k = 1 def _validate_monitor_key(self, trainer): @@ -486,15 +489,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) - if current is None: - m = f"Can save best model only with {self.monitor} available, skipping." - if self.monitor == 'checkpoint_on': - m = ( - 'No checkpoint_on found. HINT: Did you set it in ' - 'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?' - ) - rank_zero_warn(m, RuntimeWarning) - elif self.check_monitor_top_k(current): + if self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, trainer, pl_module) elif self.verbose: rank_zero_info( diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 8030d57fe75be7..7a56639c63f616 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -157,6 +157,9 @@ def _log_on_evaluation_epoch_end_metrics(self): # track the final results for the dataloader self.eval_loop_results.append(deepcopy(self.callback_metrics)) + # actually log + self.log_metrics(logger_metrics, {}, step=self.trainer.global_step) + def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): if num_loaders == 1: return metrics From e4e60e9b82adc48482db4721ce3e1fdc3ab6d6fe Mon Sep 17 00:00:00 2001 From: GimmickNG Date: Thu, 1 Oct 2020 02:33:12 -0600 Subject: [PATCH 6/6] Add datamodule parameter to lr_find() (#3425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add datamodule parameter to lr_find() * Fixed missing import * Move datamodule parameter to end * Add datamodule parameter test with auto_lr_find * Change test for datamodule parameter * Apply suggestions from code review Co-authored-by: Nicki Skafte * Fix lr_find documentation Co-authored-by: Carlos Mocholí * formatting * Add description to datamodule param in lr_find * pep8: remove trailing whitespace on line 105 * added changelog Co-authored-by: Nicki Skafte Co-authored-by: Nicki Skafte Co-authored-by: Carlos Mocholí Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ pytorch_lightning/tuner/lr_finder.py | 28 +++++++++++++++++++--------- pytorch_lightning/tuner/tuning.py | 5 ++++- tests/trainer/test_lr_finder.py | 25 +++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 223cabc4f4eb45..4e6505890004f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for datamodules to save and load checkpoints when training ([#3563]https://github.com/PyTorchLightning/pytorch-lightning/pull/3563) +- Added support for datamodule in learning rate finder ([#3425](https://github.com/PyTorchLightning/pytorch-lightning/pull/3425)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 71756678af9c54..a3ba2550186a72 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -11,21 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os +from typing import List, Optional, Sequence, Union + +import numpy as np import torch -from typing import Optional, Sequence, List, Union +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader + +from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.optim.lr_scheduler import _LRScheduler -import importlib -from pytorch_lightning import _logger as log -import numpy as np -from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr - # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed if importlib.util.find_spec('ipywidgets') is not None: @@ -71,6 +73,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None, ): r""" lr_find enables the user to do a range test of good initial learning rates, @@ -81,7 +84,7 @@ def lr_find( train_dataloader: A PyTorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate @@ -98,6 +101,12 @@ def lr_find( loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. + datamodule: An optional `LightningDataModule` which holds the training + and validation dataloader(s). Note that the `train_dataloader` and + `val_dataloaders` parameters cannot be used at the same time as + this parameter, or a `MisconfigurationException` will be raised. + + Example:: # Setup model and trainer @@ -167,7 +176,8 @@ def lr_find( # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, - val_dataloaders=val_dataloaders) + val_dataloaders=val_dataloaders, + datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 1f1423a38db568..8c55ffac92c6a1 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -15,6 +15,7 @@ from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.tuner.lr_finder import _run_lr_finder_internally, lr_find from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule from typing import Optional, List, Union from torch.utils.data import DataLoader @@ -50,6 +51,7 @@ def lr_find( num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, + datamodule: Optional[LightningDataModule] = None ): return lr_find( self.trainer, @@ -60,7 +62,8 @@ def lr_find( max_lr, num_training, mode, - early_stop_threshold + early_stop_threshold, + datamodule, ) def internal_find_lr(self, trainer, model: LightningModule): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index cafa79e3f575b9..67c673df1318d1 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -5,6 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule def test_error_on_more_than_1_optimizer(tmpdir): @@ -152,6 +153,30 @@ def test_call_to_trainer_method(tmpdir): 'Learning rate was not altered after running learning rate finder' +def test_datamodule_parameter(tmpdir): + """ Test that the datamodule parameter works """ + + # trial datamodule + dm = TrialMNISTDataModule(tmpdir) + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + before_lr = hparams.get('learning_rate') + # logger file to get meta + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + ) + + lrfinder = trainer.tuner.lr_find(model, datamodule=dm) + after_lr = lrfinder.suggestion() + model.learning_rate = after_lr + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + + def test_accumulation_and_early_stopping(tmpdir): """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """