Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed basic_examples to use LightningCLI #6862

Merged
merged 12 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ jobs:
set -e
python -m pytest pl_examples -v --maxfail=2 --durations=0
pip install . --user --quiet
bash pl_examples/run_examples-args.sh --gpus 1 --max_epochs 1 --batch_size 64 --limit_train_batches 5 --limit_val_batches 3
bash pl_examples/run_ddp-examples.sh --max_epochs 1 --batch_size 32 --limit_train_batches 2 --limit_val_batches 2
bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --trainer.limit_val_batches 3
bash pl_examples/run_ddp-examples.sh --trainer.max_epochs 1 --data.batch_size 32 --trainer.limit_train_batches 2 --trainer.limit_val_batches 2
# cd pl_examples/basic_examples
# bash submit_ddp_job.sh
# bash submit_ddp2_job.sh
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ datamodule class. However, there are many cases in which the objective is to eas
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
specified by an import path and init arguments. For example, with a tool implemented as:

.. testcode::
.. code-block:: python

from pytorch_lightning.utilities.cli import LightningCLI

Expand Down
12 changes: 6 additions & 6 deletions pl_examples/basic_examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Trains MNIST where the model is defined inside the `LightningModule`.
python simple_image_classifier.py

# gpus (any number)
python simple_image_classifier.py --gpus 2
python simple_image_classifier.py --trainer.gpus 2

# dataparallel
python simple_image_classifier.py --gpus 2 --distributed_backend 'dp'
python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
```

---
Expand All @@ -30,10 +30,10 @@ Generic image classifier with an arbitrary backbone (ie: a simple system)
python backbone_image_classifier.py

# gpus (any number)
python backbone_image_classifier.py --gpus 2
python backbone_image_classifier.py --trainer.gpus 2

# dataparallel
python backbone_image_classifier.py --gpus 2 --distributed_backend 'dp'
python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
```

---
Expand All @@ -44,10 +44,10 @@ Showing the power of a system... arbitrarily complex training loops
python autoencoder.py

# gpus (any number)
python autoencoder.py --gpus 2
python autoencoder.py --trainer.gpus 2

# dataparallel
python autoencoder.py --gpus 2 --distributed_backend 'dp'
python autoencoder.py --trainer.gpus 2 --trainer.accelerator 'dp'
```
---
# Multi-node example
Expand Down
68 changes: 30 additions & 38 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST autoencoder example.

from argparse import ArgumentParser
To run:
python autoencoder.py --trainer.max_epochs=50
"""

import torch
import torch.nn.functional as F
Expand All @@ -21,6 +25,7 @@

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -87,44 +92,31 @@ def configure_optimizers(self):
return optimizer


class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=64)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitAutoEncoder(args.hidden_dim)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)


Expand Down
89 changes: 44 additions & 45 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST backbone image classifier example.

from argparse import ArgumentParser
To run:
python backbone_image_classifier.py --trainer.max_epochs=50
"""

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -59,7 +64,11 @@ class LitClassifier(pl.LightningModule):
)
"""

def __init__(self, backbone, learning_rate=1e-3):
def __init__(
self,
backbone,
learning_rate: float = 0.0001,
):
super().__init__()
self.save_hyperparameters()
self.backbone = backbone
Expand Down Expand Up @@ -92,52 +101,42 @@ def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitClassifier")
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parent_parser

class MyDataModule(pl.LightningDataModule):

def __init__(
self,
batch_size: int = 32,
):
super().__init__()
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)


class MyLightningCLI(LightningCLI):

def add_arguments_to_parser(self, parser):
parser.add_class_arguments(Backbone, 'model.backbone')

def instantiate_model(self):
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
super().instantiate_model()


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
print(result)


Expand Down
Loading