From 77fb425dd4588534504b99f0b562b3cde8b44dbd Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 10 Dec 2020 08:38:14 +0100 Subject: [PATCH 1/3] update usage of deprecated profiler (#5010) * drop deprecated profiler * lut Co-authored-by: Roger Shieh --- pl_examples/domain_templates/imagenet.py | 2 +- .../trainer/connectors/profiler_connector.py | 19 +++++++++++-------- tests/trainer/test_trainer.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index f02b6dc0952d7..b7116547d389b 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -237,7 +237,7 @@ def run_cli(): help='seed for initializing training.') parser = ImageNetLightningModel.add_model_specific_args(parent_parser) parser.set_defaults( - profiler=True, + profiler="simple", deterministic=True, max_epochs=90, ) diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 0f6686f1f83c7..3ecc168231b38 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -18,6 +18,11 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +PROFILERS = { + "simple": SimpleProfiler, + "advanced": AdvancedProfiler, +} + class ProfilerConnector: @@ -28,9 +33,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]): if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): # TODO: Update exception on removal of bool - raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` " - "are valid values for `Trainer`'s `profiler` parameter. " - f"Received {profiler} which is of type {type(profiler)}.") + raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`" + " are valid values for `Trainer`'s `profiler` parameter." + f" Received {profiler} which is of type {type(profiler)}.") if isinstance(profiler, bool): rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated" @@ -39,11 +44,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]): if profiler: profiler = SimpleProfiler() elif isinstance(profiler, str): - profiler = profiler.lower() - if profiler == "simple": - profiler = SimpleProfiler() - elif profiler == "advanced": - profiler = AdvancedProfiler() + if profiler.lower() in PROFILERS: + profiler_class = PROFILERS[profiler.lower()] + profiler = profiler_class() else: raise ValueError("When passing string value for the `profiler` parameter of" " `Trainer`, it can only be 'simple' or 'advanced'") diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a25067b136fef..9b29d6ec2b1dd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg(): )) def test_trainer_profiler_incorrect_arg_type(profiler): with pytest.raises(MisconfigurationException, - match=r"Only None, bool, str and subclasses of `BaseProfiler` " - r"are valid values for `Trainer`'s `profiler` parameter. *"): + match=r"Only None, bool, str and subclasses of `BaseProfiler`" + r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler) From 4ebce38478f28c70edc2c7236c514665df105217 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 10 Dec 2020 11:01:33 +0100 Subject: [PATCH 2/3] update usage of deprecated automatic_optimization (#5011) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * drop deprecated usage automatic_optimization * Apply suggestions from code review Co-authored-by: Adrian Wälchli * Apply suggestions from code review Co-authored-by: Rohit Gupta Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 6 +- .../trainer/configuration_validator.py | 11 +-- pytorch_lightning/trainer/trainer.py | 2 +- tests/core/test_lightning_module.py | 2 - tests/core/test_lightning_optimizer.py | 17 +++-- .../dynamic_args/test_multiple_optimizers.py | 5 +- .../optimization/test_manual_optimization.py | 75 ++++++++++++++----- 7 files changed, 83 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f66055bccab3e..f29e7f75bfbff 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1412,8 +1412,10 @@ def get_progress_bar_dict(self): def _verify_is_manual_optimization(self, fn_name): if self.trainer.train_loop.automatic_optimization: - m = f'to use {fn_name}, please disable automatic optimization: Trainer(automatic_optimization=False)' - raise MisconfigurationException(m) + raise MisconfigurationException( + f'to use {fn_name}, please disable automatic optimization:' + ' set model property `automatic_optimization` as False' + ) @classmethod def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 974bd69229a73..21d6af043df02 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -79,7 +79,7 @@ def __verify_train_loop_configuration(self, model): if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization: rank_zero_warn( "When overriding `LightningModule` optimizer_step with" - " `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`," + " `Trainer(..., enable_pl_optimizer=False, ...)`," " we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`." " For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`." ) @@ -89,15 +89,16 @@ def __verify_train_loop_configuration(self, model): has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( - 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with ' - '`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.' + 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad' + ' , `accumulate_grad_batches` in `Trainer` should to be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization: raise MisconfigurationException( - 'When overriding `LightningModule` optimizer_zero_grad with ' - '`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported' + 'When overriding `LightningModule` optimizer_zero_grad' + ' and preserving model property `automatic_optimization` as True with' + ' `Trainer(enable_pl_optimizer=True, ...) is not supported' ) def __verify_eval_loop_configuration(self, model, eval_loop_name): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 31a64d00ccb60..35da90625adef 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -358,7 +358,7 @@ def __init__( ) # init train loop related flags - # TODO: deprecate in 1.2.0 + # TODO: remove in 1.3.0 if automatic_optimization is None: automatic_optimization = True else: diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 0c71259373d1b..3e2e6d040f44c 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -38,7 +38,6 @@ def optimizer_step(self, *_, **__): default_root_dir=tmpdir, limit_train_batches=2, accumulate_grad_batches=2, - automatic_optimization=True ) trainer.fit(model) @@ -90,7 +89,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, default_root_dir=tmpdir, limit_train_batches=8, accumulate_grad_batches=1, - automatic_optimization=True, enable_pl_optimizer=enable_pl_optimizer ) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index bd19c26784bc2..e6ec59ec4f5aa 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -112,6 +112,10 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.training_step_end = None model.training_epoch_end = None @@ -121,8 +125,8 @@ def configure_optimizers(self): limit_val_batches=1, max_epochs=1, weights_summary=None, - automatic_optimization=False, - enable_pl_optimizer=True) + enable_pl_optimizer=True, + ) trainer.fit(model) assert len(mock_sgd_step.mock_calls) == 2 @@ -161,6 +165,10 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) return [optimizer_1, optimizer_2], [lr_scheduler] + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.training_step_end = None model.training_epoch_end = None @@ -170,7 +178,6 @@ def configure_optimizers(self): limit_val_batches=1, max_epochs=1, weights_summary=None, - automatic_optimization=False, accumulate_grad_batches=2, enable_pl_optimizer=True, ) @@ -237,7 +244,6 @@ def configure_optimizers(self): max_epochs=1, weights_summary=None, enable_pl_optimizer=True, - automatic_optimization=True ) trainer.fit(model) @@ -291,7 +297,6 @@ def configure_optimizers(self): max_epochs=1, weights_summary=None, enable_pl_optimizer=True, - automatic_optimization=True ) trainer.fit(model) @@ -352,7 +357,6 @@ def configure_optimizers(self): max_epochs=1, weights_summary=None, enable_pl_optimizer=True, - automatic_optimization=True ) trainer.fit(model) @@ -406,7 +410,6 @@ def configure_optimizers(self): max_epochs=1, weights_summary=None, enable_pl_optimizer=True, - automatic_optimization=True, ) trainer.fit(model) diff --git a/tests/trainer/dynamic_args/test_multiple_optimizers.py b/tests/trainer/dynamic_args/test_multiple_optimizers.py index a01ef8e0b2a7e..48b1bf6ab7ac9 100644 --- a/tests/trainer/dynamic_args/test_multiple_optimizers.py +++ b/tests/trainer/dynamic_args/test_multiple_optimizers.py @@ -97,11 +97,14 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 2b2881cb928b4..5e341e9c66f63 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -69,12 +69,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -133,12 +136,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -198,12 +204,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -265,12 +274,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -278,7 +290,7 @@ def configure_optimizers(self): log_every_n_steps=1, weights_summary=None, precision=16, - gpus=1 + gpus=1, ) trainer.fit(model) @@ -335,12 +347,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -412,6 +427,10 @@ def on_train_end(self): assert self.called["on_train_batch_start"] == 10 assert self.called["on_train_batch_end"] == 10 + @property + def automatic_optimization(self) -> bool: + return False + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -431,7 +450,6 @@ def test_manual_optimization_and_return_tensor(tmpdir): limit_train_batches=10, limit_test_batches=0, limit_val_batches=0, - automatic_optimization=False, precision=16, amp_backend='native', accelerator="ddp_spawn", @@ -461,7 +479,6 @@ def test_manual_optimization_and_return_detached_tensor(tmpdir): limit_train_batches=10, limit_test_batches=0, limit_val_batches=0, - automatic_optimization=False, precision=16, amp_backend='native', accelerator="ddp_spawn", @@ -538,6 +555,10 @@ def on_train_end(self): assert self.called["on_train_batch_start"] == 20 assert self.called["on_train_batch_end"] == 20 + @property + def automatic_optimization(self) -> bool: + return False + model = ExtendedModel() model.training_step_end = None model.training_epoch_end = None @@ -548,7 +569,6 @@ def on_train_end(self): limit_train_batches=20, limit_test_batches=0, limit_val_batches=0, - automatic_optimization=False, precision=16, amp_backend='native', accumulate_grad_batches=4, @@ -610,12 +630,15 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -692,13 +715,16 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 2 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -753,13 +779,16 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 4 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -806,13 +835,16 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 4 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -881,13 +913,16 @@ def configure_optimizers(self): optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) return [optimizer_gen, optimizer_dis] + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 8 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -985,6 +1020,10 @@ def configure_optimizers(self): optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) return [optimizer_gen, optimizer_dis] + @property + def automatic_optimization(self) -> bool: + return False + seed_everything(42) model = TestModel() @@ -993,7 +1032,6 @@ def configure_optimizers(self): limit_train_batches = 8 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -1023,13 +1061,16 @@ class TestModel(BoringModel): def optimizer_zero_grad(self, *_): pass + @property + def automatic_optimization(self) -> bool: + return False + model = TestModel() model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 8 trainer = Trainer( - automatic_optimization=False, default_root_dir=tmpdir, limit_train_batches=limit_train_batches, limit_val_batches=2, @@ -1039,4 +1080,4 @@ def optimizer_zero_grad(self, *_): enable_pl_optimizer=True, ) except MisconfigurationException as e: - assert "`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported" in str(e) + assert "`Trainer(enable_pl_optimizer=True, ...) is not supported" in str(e) From 820d5c734876feecbf05c5a73503d7600fcc6e41 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Thu, 10 Dec 2020 16:26:18 +0530 Subject: [PATCH 3/3] Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR (#4818) * Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning * Remove outputs * PR Feedback * some changes * some more changes Co-authored-by: chaton Co-authored-by: rohitgr7 --- notebooks/06-cifar10-baseline.ipynb | 394 ++++++++++++++++++++++++++++ notebooks/README.md | 15 +- 2 files changed, 402 insertions(+), 7 deletions(-) create mode 100644 notebooks/06-cifar10-baseline.ipynb diff --git a/notebooks/06-cifar10-baseline.ipynb b/notebooks/06-cifar10-baseline.ipynb new file mode 100644 index 0000000000000..d4b2209cc91b6 --- /dev/null +++ b/notebooks/06-cifar10-baseline.ipynb @@ -0,0 +1,394 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "06_cifar10_baseline.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "qMDj0BYNECU8" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ECu0zDh8UXU8" + }, + "source": [ + "# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n", + "\n", + "Train a Resnet to 94% accuracy on Cifar10!\n", + "\n", + "Main takeaways:\n", + "1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n", + "2. Use an existing Resnet architecture with modifications directly with Lightning\n", + "\n", + "---\n", + "\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HYpMlx7apuHq" + }, + "source": [ + "### Setup\n", + "Lightning is easy to install. Simply `pip install pytorch-lightning`.\n", + "Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ziAQCrE-TYWG" + }, + "source": [ + "! pip install pytorch-lightning pytorch-lightning-bolts -qU" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "L-W_Gq2FORoU" + }, + "source": [ + "# Run this if you intend to use TPUs\n", + "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", + "# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wjov-2N_TgeS" + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.optim.lr_scheduler import OneCycleLR\n", + "from torch.optim.swa_utils import AveragedModel, update_bn\n", + "import torchvision\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import LearningRateMonitor\n", + "from pytorch_lightning.metrics.functional import accuracy\n", + "from pl_bolts.datamodules import CIFAR10DataModule\n", + "from pl_bolts.transforms.dataset_normalizations import cifar10_normalization" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "54JMU1N-0y0g" + }, + "source": [ + "pl.seed_everything(7);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FA90qwFcqIXR" + }, + "source": [ + "### CIFAR10 Data Module\n", + "\n", + "Import the existing data module from `bolts` and modify the train and test transforms." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "S9e-W8CSa8nH" + }, + "source": [ + "batch_size = 32\n", + "\n", + "train_transforms = torchvision.transforms.Compose([\n", + " torchvision.transforms.RandomCrop(32, padding=4),\n", + " torchvision.transforms.RandomHorizontalFlip(),\n", + " torchvision.transforms.ToTensor(),\n", + " cifar10_normalization(),\n", + "])\n", + "\n", + "test_transforms = torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " cifar10_normalization(),\n", + "])\n", + "\n", + "cifar10_dm = CIFAR10DataModule(\n", + " batch_size=batch_size,\n", + " train_transforms=train_transforms,\n", + " test_transforms=test_transforms,\n", + " val_transforms=test_transforms,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SfCsutp3qUMc" + }, + "source": [ + "### Resnet\n", + "Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GNSeJgwvhHp-" + }, + "source": [ + "def create_model():\n", + " model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n", + " model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " model.maxpool = nn.Identity()\n", + " return model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HUCj5TKsqty1" + }, + "source": [ + "### Lightning Module\n", + "Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "03OMrBa5iGtT" + }, + "source": [ + "class LitResnet(pl.LightningModule):\n", + " def __init__(self, lr=0.05):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters()\n", + " self.model = create_model()\n", + "\n", + " def forward(self, x):\n", + " out = self.model(x)\n", + " return F.log_softmax(out, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = F.log_softmax(self.model(x), dim=1)\n", + " loss = F.nll_loss(logits, y)\n", + " self.log('train_loss', loss)\n", + " return loss\n", + "\n", + " def evaluate(self, batch, stage=None):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " if stage:\n", + " self.log(f'{stage}_loss', loss, prog_bar=True)\n", + " self.log(f'{stage}_acc', acc, prog_bar=True)\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " self.evaluate(batch, 'val')\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " self.evaluate(batch, 'test')\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", + " steps_per_epoch = 45000 // batch_size\n", + " scheduler_dict = {\n", + " 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n", + " 'interval': 'step',\n", + " }\n", + " return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3FFPgpAFi9KU" + }, + "source": [ + "model = LitResnet(lr=0.05)\n", + "model.datamodule = cifar10_dm\n", + "\n", + "trainer = pl.Trainer(\n", + " progress_bar_refresh_rate=20,\n", + " max_epochs=40,\n", + " gpus=1,\n", + " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n", + " callbacks=[LearningRateMonitor(logging_interval='step')],\n", + ")\n", + "\n", + "trainer.fit(model, cifar10_dm)\n", + "trainer.test(model, datamodule=cifar10_dm);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lWL_WpeVIXWQ" + }, + "source": [ + "### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n", + "\n", + "Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n", + "- Use `training_epoch_end` to run code after the end of every epoch\n", + "- Use a pretrained model directly with this wrapper for SWA" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bsSwqKv0t9uY" + }, + "source": [ + "class SWAResnet(LitResnet):\n", + " def __init__(self, trained_model, lr=0.01):\n", + " super().__init__()\n", + "\n", + " self.save_hyperparameters('lr')\n", + " self.model = trained_model\n", + " self.swa_model = AveragedModel(self.model)\n", + "\n", + " def forward(self, x):\n", + " out = self.swa_model(x)\n", + " return F.log_softmax(out, dim=1)\n", + "\n", + " def training_epoch_end(self, training_step_outputs):\n", + " self.swa_model.update_parameters(self.model)\n", + "\n", + " def validation_step(self, batch, batch_idx, stage=None):\n", + " x, y = batch\n", + " logits = F.log_softmax(self.model(x), dim=1)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = accuracy(preds, y)\n", + "\n", + " self.log(f'val_loss', loss, prog_bar=True)\n", + " self.log(f'val_acc', acc, prog_bar=True)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", + " return optimizer\n", + "\n", + " def on_train_end(self):\n", + " update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "cA6ZG7C74rjL" + }, + "source": [ + "swa_model = SWAResnet(model.model, lr=0.01)\n", + "swa_model.datamodule = cifar10_dm\n", + "\n", + "swa_trainer = pl.Trainer(\n", + " progress_bar_refresh_rate=20,\n", + " max_epochs=20,\n", + " gpus=1,\n", + " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n", + ")\n", + "\n", + "swa_trainer.fit(swa_model, cifar10_dm)\n", + "swa_trainer.test(swa_model, datamodule=cifar10_dm);" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "RRHMfGiDpZ2M" + }, + "source": [ + "# Start tensorboard.\n", + "%reload_ext tensorboard\n", + "%tensorboard --logdir lightning_logs/" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RltpFGS-s0M1" + }, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ] +} diff --git a/notebooks/README.md b/notebooks/README.md index 695e1a038c18f..5d0f3564e9387 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -4,10 +4,11 @@ You can easily run any of the official notebooks by clicking the 'Open in Colab' links in the table below :smile: -| Notebook | Description | Colab Link | -| :--- | :--- | :---: | -| __MNIST Hello World__ | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) | -| __Datamodules__ | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10.| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)| -| __GAN__ | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) | -| __BERT__ | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) | -| __Trainer Flags__ | Overview of the available Lightning `Trainer` flags | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb) | +| Notebook | Description | Colab Link | +| :----------------------- | :----------------------------------------------------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| **MNIST Hello World** | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) | +| **Datamodules** | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb) | +| **GAN** | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) | +| **BERT** | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) | +| **Trainer Flags** | Overview of the available Lightning `Trainer` flags | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb) | +| **94% Baseline CIFAR10** | Establish a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-cifar10-baseline.ipynb) |