From eff63d10ce996bab1442efcae084909b04a2dbd2 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 15:52:38 +0100 Subject: [PATCH 01/10] Raise if scheduler interval not 'step' or 'epoch' --- pytorch_lightning/trainer/optimizers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6772dcc645e3b..c94b51c2bfd95 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -117,6 +117,13 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' ) + allowed_interval_keys = ['step', 'epoch'] + if 'interval' in scheduler and scheduler['interval'] not in allowed_interval_keys: + raise MisconfigurationException( + f'The "interval" key in lr scheduler dict must be one of {allowed_interval_keys},' + f' but is "{scheduler["interval"]}"' + ) + scheduler['reduce_on_plateau'] = isinstance( scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau ) From 706274432b7d444bd3595a94351df9bbc3c24114 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 16:23:22 +0100 Subject: [PATCH 02/10] Add test for unknown 'interval' value in scheduler --- tests/trainer/optimization/test_optimizers.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c9a9250995dd0..4844e79383847 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir): trainer.fit(model) +def test_lr_scheduler_with_unknown_interval_raises(tmpdir): + """ + Test exception when lr_scheduler dict has unknown interval param value + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'interval': "incorrect_unknown_value" + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be one of.*'): + trainer.fit(model) + + def test_lr_scheduler_with_extra_keys_warns(tmpdir): """ Test warning when lr_scheduler dict has extra keys From c42f008dfbe173de9598d941c445f301555bd00b Mon Sep 17 00:00:00 2001 From: Dusan Drevicky <55678224+ddrevicky@users.noreply.github.com> Date: Thu, 11 Feb 2021 17:53:37 +0100 Subject: [PATCH 03/10] Use BoringModel instead of EvalModelTemplate Co-authored-by: Jirka Borovec --- tests/trainer/optimization/test_optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 4844e79383847..e23d0d781b6f0 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -463,7 +463,7 @@ def test_lr_scheduler_with_unknown_interval_raises(tmpdir): """ Test exception when lr_scheduler dict has unknown interval param value """ - model = EvalModelTemplate() + model = BoringModel() optimizer = torch.optim.Adam(model.parameters()) model.configure_optimizers = lambda: { 'optimizer': optimizer, From 7eb4016566b4289b85d1e9383772b86183c764b7 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 17:58:52 +0100 Subject: [PATCH 04/10] Fix import order --- pytorch_lightning/core/datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index ecf5a99e703c9..c0fd2ff5c9dfb 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,10 +15,9 @@ import functools import inspect -import os from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader, Dataset @@ -386,6 +385,7 @@ def from_datasets( number of CPUs available. """ + def dataloader(ds, shuffle=False): return DataLoader( ds, @@ -399,7 +399,7 @@ def train_dataloader(): if isinstance(train_dataset, Mapping): return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} if isinstance(train_dataset, Sequence): - return [dataloader(ds, shuffle=True) for ds in train_dataset] + return [dataloader(ds, shuffle=True) for ds in train_dataset] return dataloader(train_dataset, shuffle=True) def val_dataloader(): From 5fd026e7442dec7df846e9c368749d793ae502b6 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 18:10:58 +0100 Subject: [PATCH 05/10] Apply yapf in test_datamodules --- tests/core/test_datamodules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index a5c7c1cab3ee7..33ea0f085ead8 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -419,6 +419,7 @@ def test_step_end(self, outputs): def test_dm_transfer_batch_to_device(tmpdir): class CustomBatch: + def __init__(self, data): self.samples = data[0] self.targets = data[1] @@ -452,6 +453,7 @@ def transfer_batch_to_device(self, data, device): class CustomMNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): super().__init__() self.data_dir = data_dir @@ -508,6 +510,7 @@ def train_dataloader(self): class DummyDS(torch.utils.data.Dataset): + def __getitem__(self, index): return 1 From 4ab281dc58831ad0c40e92cb226f8a879bfbb457 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 18:15:18 +0100 Subject: [PATCH 06/10] Add missing imports to test_datamodules --- tests/core/test_datamodules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 33ea0f085ead8..799f7fe41a59a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -19,6 +19,7 @@ import pytest import torch import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator @@ -26,6 +27,7 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.datasets import TrialMNIST from tests.helpers.simple_models import ClassificationModel from tests.helpers.utils import reset_seed, set_random_master_port From 4d95d35bb87daeed18f98370bcb111b9f7b7d642 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Thu, 11 Feb 2021 18:23:29 +0100 Subject: [PATCH 07/10] Fix too long comment --- pytorch_lightning/core/datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index c0fd2ff5c9dfb..943e206c7be2b 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -381,8 +381,8 @@ def from_datasets( val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() batch_size: Batch size to use for each dataloader. Default is 1. - num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. - number of CPUs available. + num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded + in the main process. """ From 69d1968b8551d714bea383e763683f19f6db192a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 11 Feb 2021 20:38:29 +0100 Subject: [PATCH 08/10] Update pytorch_lightning/trainer/optimizers.py --- pytorch_lightning/trainer/optimizers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index c94b51c2bfd95..53c05ae8ced07 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -117,8 +117,7 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' ) - allowed_interval_keys = ['step', 'epoch'] - if 'interval' in scheduler and scheduler['interval'] not in allowed_interval_keys: + if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'): raise MisconfigurationException( f'The "interval" key in lr scheduler dict must be one of {allowed_interval_keys},' f' but is "{scheduler["interval"]}"' From e1e00de49bdcf272792a6062c37c8397f4055a8b Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Fri, 12 Feb 2021 11:44:42 +0100 Subject: [PATCH 09/10] Fix unused imports and exception message --- pytorch_lightning/trainer/optimizers.py | 2 +- tests/core/test_datamodules.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 53c05ae8ced07..6793a370fdc35 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -119,7 +119,7 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): ) if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'): raise MisconfigurationException( - f'The "interval" key in lr scheduler dict must be one of {allowed_interval_keys},' + f'The "interval" key in lr scheduler dict must be "step" or "epoch"' f' but is "{scheduler["interval"]}"' ) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index a4cc35c458a2b..a83a6a41c9287 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -19,7 +19,6 @@ import pytest import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, random_split from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator @@ -27,7 +26,6 @@ from pytorch_lightning.trainer.states import TrainerState from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule -from tests.helpers.datasets import TrialMNIST from tests.helpers.simple_models import ClassificationModel from tests.helpers.utils import reset_seed, set_random_master_port From c0820f6ad373c38fd99ee985a18fe5c212bbd4ca Mon Sep 17 00:00:00 2001 From: Dusan Drevicky Date: Fri, 12 Feb 2021 12:24:10 +0100 Subject: [PATCH 10/10] Fix failing test --- tests/trainer/optimization/test_optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index e23d0d781b6f0..7172b2dca76da 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -473,7 +473,7 @@ def test_lr_scheduler_with_unknown_interval_raises(tmpdir): }, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be one of.*'): + with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'): trainer.fit(model)