diff --git a/notebooks/06-cifar10-baseline.ipynb b/notebooks/06-cifar10-baseline.ipynb
new file mode 100644
index 0000000000000..d4b2209cc91b6
--- /dev/null
+++ b/notebooks/06-cifar10-baseline.ipynb
@@ -0,0 +1,394 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "name": "06_cifar10_baseline.ipynb",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qMDj0BYNECU8"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ECu0zDh8UXU8"
+ },
+ "source": [
+ "# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n",
+ "\n",
+ "Train a Resnet to 94% accuracy on Cifar10!\n",
+ "\n",
+ "Main takeaways:\n",
+ "1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n",
+ "2. Use an existing Resnet architecture with modifications directly with Lightning\n",
+ "\n",
+ "---\n",
+ "\n",
+ " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
+ " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
+ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HYpMlx7apuHq"
+ },
+ "source": [
+ "### Setup\n",
+ "Lightning is easy to install. Simply `pip install pytorch-lightning`.\n",
+ "Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ziAQCrE-TYWG"
+ },
+ "source": [
+ "! pip install pytorch-lightning pytorch-lightning-bolts -qU"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "L-W_Gq2FORoU"
+ },
+ "source": [
+ "# Run this if you intend to use TPUs\n",
+ "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n",
+ "# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "wjov-2N_TgeS"
+ },
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from torch.optim.swa_utils import AveragedModel, update_bn\n",
+ "import torchvision\n",
+ "\n",
+ "import pytorch_lightning as pl\n",
+ "from pytorch_lightning.callbacks import LearningRateMonitor\n",
+ "from pytorch_lightning.metrics.functional import accuracy\n",
+ "from pl_bolts.datamodules import CIFAR10DataModule\n",
+ "from pl_bolts.transforms.dataset_normalizations import cifar10_normalization"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "54JMU1N-0y0g"
+ },
+ "source": [
+ "pl.seed_everything(7);"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FA90qwFcqIXR"
+ },
+ "source": [
+ "### CIFAR10 Data Module\n",
+ "\n",
+ "Import the existing data module from `bolts` and modify the train and test transforms."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "S9e-W8CSa8nH"
+ },
+ "source": [
+ "batch_size = 32\n",
+ "\n",
+ "train_transforms = torchvision.transforms.Compose([\n",
+ " torchvision.transforms.RandomCrop(32, padding=4),\n",
+ " torchvision.transforms.RandomHorizontalFlip(),\n",
+ " torchvision.transforms.ToTensor(),\n",
+ " cifar10_normalization(),\n",
+ "])\n",
+ "\n",
+ "test_transforms = torchvision.transforms.Compose([\n",
+ " torchvision.transforms.ToTensor(),\n",
+ " cifar10_normalization(),\n",
+ "])\n",
+ "\n",
+ "cifar10_dm = CIFAR10DataModule(\n",
+ " batch_size=batch_size,\n",
+ " train_transforms=train_transforms,\n",
+ " test_transforms=test_transforms,\n",
+ " val_transforms=test_transforms,\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SfCsutp3qUMc"
+ },
+ "source": [
+ "### Resnet\n",
+ "Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "GNSeJgwvhHp-"
+ },
+ "source": [
+ "def create_model():\n",
+ " model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n",
+ " model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
+ " model.maxpool = nn.Identity()\n",
+ " return model"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HUCj5TKsqty1"
+ },
+ "source": [
+ "### Lightning Module\n",
+ "Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "03OMrBa5iGtT"
+ },
+ "source": [
+ "class LitResnet(pl.LightningModule):\n",
+ " def __init__(self, lr=0.05):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.save_hyperparameters()\n",
+ " self.model = create_model()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " out = self.model(x)\n",
+ " return F.log_softmax(out, dim=1)\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " x, y = batch\n",
+ " logits = F.log_softmax(self.model(x), dim=1)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " self.log('train_loss', loss)\n",
+ " return loss\n",
+ "\n",
+ " def evaluate(self, batch, stage=None):\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = accuracy(preds, y)\n",
+ "\n",
+ " if stage:\n",
+ " self.log(f'{stage}_loss', loss, prog_bar=True)\n",
+ " self.log(f'{stage}_acc', acc, prog_bar=True)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " self.evaluate(batch, 'val')\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " self.evaluate(batch, 'test')\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n",
+ " steps_per_epoch = 45000 // batch_size\n",
+ " scheduler_dict = {\n",
+ " 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n",
+ " 'interval': 'step',\n",
+ " }\n",
+ " return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3FFPgpAFi9KU"
+ },
+ "source": [
+ "model = LitResnet(lr=0.05)\n",
+ "model.datamodule = cifar10_dm\n",
+ "\n",
+ "trainer = pl.Trainer(\n",
+ " progress_bar_refresh_rate=20,\n",
+ " max_epochs=40,\n",
+ " gpus=1,\n",
+ " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n",
+ " callbacks=[LearningRateMonitor(logging_interval='step')],\n",
+ ")\n",
+ "\n",
+ "trainer.fit(model, cifar10_dm)\n",
+ "trainer.test(model, datamodule=cifar10_dm);"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lWL_WpeVIXWQ"
+ },
+ "source": [
+ "### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n",
+ "\n",
+ "Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n",
+ "- Use `training_epoch_end` to run code after the end of every epoch\n",
+ "- Use a pretrained model directly with this wrapper for SWA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "bsSwqKv0t9uY"
+ },
+ "source": [
+ "class SWAResnet(LitResnet):\n",
+ " def __init__(self, trained_model, lr=0.01):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.save_hyperparameters('lr')\n",
+ " self.model = trained_model\n",
+ " self.swa_model = AveragedModel(self.model)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " out = self.swa_model(x)\n",
+ " return F.log_softmax(out, dim=1)\n",
+ "\n",
+ " def training_epoch_end(self, training_step_outputs):\n",
+ " self.swa_model.update_parameters(self.model)\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx, stage=None):\n",
+ " x, y = batch\n",
+ " logits = F.log_softmax(self.model(x), dim=1)\n",
+ " loss = F.nll_loss(logits, y)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = accuracy(preds, y)\n",
+ "\n",
+ " self.log(f'val_loss', loss, prog_bar=True)\n",
+ " self.log(f'val_acc', acc, prog_bar=True)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n",
+ " return optimizer\n",
+ "\n",
+ " def on_train_end(self):\n",
+ " update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "cA6ZG7C74rjL"
+ },
+ "source": [
+ "swa_model = SWAResnet(model.model, lr=0.01)\n",
+ "swa_model.datamodule = cifar10_dm\n",
+ "\n",
+ "swa_trainer = pl.Trainer(\n",
+ " progress_bar_refresh_rate=20,\n",
+ " max_epochs=20,\n",
+ " gpus=1,\n",
+ " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n",
+ ")\n",
+ "\n",
+ "swa_trainer.fit(swa_model, cifar10_dm)\n",
+ "swa_trainer.test(swa_model, datamodule=cifar10_dm);"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "RRHMfGiDpZ2M"
+ },
+ "source": [
+ "# Start tensorboard.\n",
+ "%reload_ext tensorboard\n",
+ "%tensorboard --logdir lightning_logs/"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RltpFGS-s0M1"
+ },
+ "source": [
+ "\n",
+ " Congratulations - Time to Join the Community!
\n",
+ "
\n",
+ "\n",
+ "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
+ "\n",
+ "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
+ "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
+ "\n",
+ "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
+ "\n",
+ "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
+ "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
+ "\n",
+ "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
+ "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
+ "\n",
+ "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
+ "\n",
+ "### Contributions !\n",
+ "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
+ "\n",
+ "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
+ "* You can also contribute your own notebooks with useful examples !\n",
+ "\n",
+ "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
+ "\n",
+ ""
+ ]
+ }
+ ]
+}
diff --git a/notebooks/README.md b/notebooks/README.md
index 695e1a038c18f..5d0f3564e9387 100644
--- a/notebooks/README.md
+++ b/notebooks/README.md
@@ -4,10 +4,11 @@
You can easily run any of the official notebooks by clicking the 'Open in Colab' links in the table below :smile:
-| Notebook | Description | Colab Link |
-| :--- | :--- | :---: |
-| __MNIST Hello World__ | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) |
-| __Datamodules__ | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10.| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb)|
-| __GAN__ | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) |
-| __BERT__ | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) |
-| __Trainer Flags__ | Overview of the available Lightning `Trainer` flags | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb) |
+| Notebook | Description | Colab Link |
+| :----------------------- | :----------------------------------------------------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| **MNIST Hello World** | Train your first Lightning Module on the classic MNIST Handwritten Digits Dataset. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb) |
+| **Datamodules** | Learn about DataModules and train a dataset-agnostic model on MNIST and CIFAR10. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/02-datamodules.ipynb) |
+| **GAN** | Train a GAN on the MNIST Dataset. Learn how to use multiple optimizers in Lightning. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb) |
+| **BERT** | Fine-tune HuggingFace Transformers models on the GLUE Benchmark | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb) |
+| **Trainer Flags** | Overview of the available Lightning `Trainer` flags | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb) |
+| **94% Baseline CIFAR10** | Establish a quick baseline of ~94% accuracy on CIFAR10 using Resnet in Lightning | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-cifar10-baseline.ipynb) |
diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py
index f02b6dc0952d7..b7116547d389b 100644
--- a/pl_examples/domain_templates/imagenet.py
+++ b/pl_examples/domain_templates/imagenet.py
@@ -237,7 +237,7 @@ def run_cli():
help='seed for initializing training.')
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
parser.set_defaults(
- profiler=True,
+ profiler="simple",
deterministic=True,
max_epochs=90,
)
diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py
index f66055bccab3e..f29e7f75bfbff 100644
--- a/pytorch_lightning/core/lightning.py
+++ b/pytorch_lightning/core/lightning.py
@@ -1412,8 +1412,10 @@ def get_progress_bar_dict(self):
def _verify_is_manual_optimization(self, fn_name):
if self.trainer.train_loop.automatic_optimization:
- m = f'to use {fn_name}, please disable automatic optimization: Trainer(automatic_optimization=False)'
- raise MisconfigurationException(m)
+ raise MisconfigurationException(
+ f'to use {fn_name}, please disable automatic optimization:'
+ ' set model property `automatic_optimization` as False'
+ )
@classmethod
def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py
index 974bd69229a73..21d6af043df02 100644
--- a/pytorch_lightning/trainer/configuration_validator.py
+++ b/pytorch_lightning/trainer/configuration_validator.py
@@ -79,7 +79,7 @@ def __verify_train_loop_configuration(self, model):
if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
rank_zero_warn(
"When overriding `LightningModule` optimizer_step with"
- " `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
+ " `Trainer(..., enable_pl_optimizer=False, ...)`,"
" we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
" For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
)
@@ -89,15 +89,16 @@ def __verify_train_loop_configuration(self, model):
has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
raise MisconfigurationException(
- 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
- '`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
+ 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad'
+ ' , `accumulate_grad_batches` in `Trainer` should to be 1.'
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)
if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
raise MisconfigurationException(
- 'When overriding `LightningModule` optimizer_zero_grad with '
- '`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
+ 'When overriding `LightningModule` optimizer_zero_grad'
+ ' and preserving model property `automatic_optimization` as True with'
+ ' `Trainer(enable_pl_optimizer=True, ...) is not supported'
)
def __verify_eval_loop_configuration(self, model, eval_loop_name):
diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py
index 0f6686f1f83c7..3ecc168231b38 100644
--- a/pytorch_lightning/trainer/connectors/profiler_connector.py
+++ b/pytorch_lightning/trainer/connectors/profiler_connector.py
@@ -18,6 +18,11 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+PROFILERS = {
+ "simple": SimpleProfiler,
+ "advanced": AdvancedProfiler,
+}
+
class ProfilerConnector:
@@ -28,9 +33,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
- raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
- "are valid values for `Trainer`'s `profiler` parameter. "
- f"Received {profiler} which is of type {type(profiler)}.")
+ raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
+ " are valid values for `Trainer`'s `profiler` parameter."
+ f" Received {profiler} which is of type {type(profiler)}.")
if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
@@ -39,11 +44,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
- profiler = profiler.lower()
- if profiler == "simple":
- profiler = SimpleProfiler()
- elif profiler == "advanced":
- profiler = AdvancedProfiler()
+ if profiler.lower() in PROFILERS:
+ profiler_class = PROFILERS[profiler.lower()]
+ profiler = profiler_class()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 31a64d00ccb60..35da90625adef 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -358,7 +358,7 @@ def __init__(
)
# init train loop related flags
- # TODO: deprecate in 1.2.0
+ # TODO: remove in 1.3.0
if automatic_optimization is None:
automatic_optimization = True
else:
diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py
index 0c71259373d1b..3e2e6d040f44c 100644
--- a/tests/core/test_lightning_module.py
+++ b/tests/core/test_lightning_module.py
@@ -38,7 +38,6 @@ def optimizer_step(self, *_, **__):
default_root_dir=tmpdir,
limit_train_batches=2,
accumulate_grad_batches=2,
- automatic_optimization=True
)
trainer.fit(model)
@@ -90,7 +89,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
- automatic_optimization=True,
enable_pl_optimizer=enable_pl_optimizer
)
diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py
index bd19c26784bc2..e6ec59ec4f5aa 100644
--- a/tests/core/test_lightning_optimizer.py
+++ b/tests/core/test_lightning_optimizer.py
@@ -112,6 +112,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
@@ -121,8 +125,8 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
- automatic_optimization=False,
- enable_pl_optimizer=True)
+ enable_pl_optimizer=True,
+ )
trainer.fit(model)
assert len(mock_sgd_step.mock_calls) == 2
@@ -161,6 +165,10 @@ def configure_optimizers(self):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
return [optimizer_1, optimizer_2], [lr_scheduler]
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.training_step_end = None
model.training_epoch_end = None
@@ -170,7 +178,6 @@ def configure_optimizers(self):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
- automatic_optimization=False,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
)
@@ -237,7 +244,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
- automatic_optimization=True
)
trainer.fit(model)
@@ -291,7 +297,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
- automatic_optimization=True
)
trainer.fit(model)
@@ -352,7 +357,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
- automatic_optimization=True
)
trainer.fit(model)
@@ -406,7 +410,6 @@ def configure_optimizers(self):
max_epochs=1,
weights_summary=None,
enable_pl_optimizer=True,
- automatic_optimization=True,
)
trainer.fit(model)
diff --git a/tests/trainer/dynamic_args/test_multiple_optimizers.py b/tests/trainer/dynamic_args/test_multiple_optimizers.py
index a01ef8e0b2a7e..48b1bf6ab7ac9 100644
--- a/tests/trainer/dynamic_args/test_multiple_optimizers.py
+++ b/tests/trainer/dynamic_args/test_multiple_optimizers.py
@@ -97,11 +97,14 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py
index 2b2881cb928b4..5e341e9c66f63 100644
--- a/tests/trainer/optimization/test_manual_optimization.py
+++ b/tests/trainer/optimization/test_manual_optimization.py
@@ -69,12 +69,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -133,12 +136,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -198,12 +204,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -265,12 +274,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -278,7 +290,7 @@ def configure_optimizers(self):
log_every_n_steps=1,
weights_summary=None,
precision=16,
- gpus=1
+ gpus=1,
)
trainer.fit(model)
@@ -335,12 +347,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -412,6 +427,10 @@ def on_train_end(self):
assert self.called["on_train_batch_start"] == 10
assert self.called["on_train_batch_end"] == 10
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -431,7 +450,6 @@ def test_manual_optimization_and_return_tensor(tmpdir):
limit_train_batches=10,
limit_test_batches=0,
limit_val_batches=0,
- automatic_optimization=False,
precision=16,
amp_backend='native',
accelerator="ddp_spawn",
@@ -461,7 +479,6 @@ def test_manual_optimization_and_return_detached_tensor(tmpdir):
limit_train_batches=10,
limit_test_batches=0,
limit_val_batches=0,
- automatic_optimization=False,
precision=16,
amp_backend='native',
accelerator="ddp_spawn",
@@ -538,6 +555,10 @@ def on_train_end(self):
assert self.called["on_train_batch_start"] == 20
assert self.called["on_train_batch_end"] == 20
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = ExtendedModel()
model.training_step_end = None
model.training_epoch_end = None
@@ -548,7 +569,6 @@ def on_train_end(self):
limit_train_batches=20,
limit_test_batches=0,
limit_val_batches=0,
- automatic_optimization=False,
precision=16,
amp_backend='native',
accumulate_grad_batches=4,
@@ -610,12 +630,15 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -692,13 +715,16 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 2
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -753,13 +779,16 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 4
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -806,13 +835,16 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 4
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -881,13 +913,16 @@ def configure_optimizers(self):
optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
return [optimizer_gen, optimizer_dis]
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 8
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -985,6 +1020,10 @@ def configure_optimizers(self):
optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
return [optimizer_gen, optimizer_dis]
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
seed_everything(42)
model = TestModel()
@@ -993,7 +1032,6 @@ def configure_optimizers(self):
limit_train_batches = 8
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -1023,13 +1061,16 @@ class TestModel(BoringModel):
def optimizer_zero_grad(self, *_):
pass
+ @property
+ def automatic_optimization(self) -> bool:
+ return False
+
model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None
limit_train_batches = 8
trainer = Trainer(
- automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
@@ -1039,4 +1080,4 @@ def optimizer_zero_grad(self, *_):
enable_pl_optimizer=True,
)
except MisconfigurationException as e:
- assert "`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported" in str(e)
+ assert "`Trainer(enable_pl_optimizer=True, ...) is not supported" in str(e)
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index a25067b136fef..9b29d6ec2b1dd 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg():
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(MisconfigurationException,
- match=r"Only None, bool, str and subclasses of `BaseProfiler` "
- r"are valid values for `Trainer`'s `profiler` parameter. *"):
+ match=r"Only None, bool, str and subclasses of `BaseProfiler`"
+ r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)