From c7e7ca370d691b4524290abbabb0bc6c5118f65b Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Fri, 19 Jun 2020 22:17:48 -0400 Subject: [PATCH 1/6] configuration fix job name template change to model create hydra examples folder fix error with none values optimizers and lr schedules clean up model structure model has data included dont configure outputs document hydra example update readme rename trainer conf scheduler example schedulers update change out structure for opt and sched flatten config dirs reduce number of classes scheduler and opt configs spelling change group config store location change import and store structured conf remaining classes fix for date change location of trainer config fix package name trainer instantiation clean up init trainer type fixes clean up imports update readme add in seed Update pl_examples/hydra_examples/README.md Co-authored-by: Omry Yadan Update pl_examples/hydra_examples/README.md Co-authored-by: Omry Yadan change to model clean up hydra example data to absolute path update file name fix path isort run name change hydra logging change config dir use name as logging group load configs in init py callout callbacks fix callbacks empty list example param data params example with two other data classes fix saving params dataset path correction comments in trainer conf logic in user app better config clean up arguments multiprocessing handled by PL settings cleaner callback list callback clean up top level config wip user config add in callbacks fix callbacks in user config fix group names name config fix user config instantiation without + change type split for readability user config move master config yaml hydra from master changes remove init py clean up model configuration add comments add to readme function doc need hydra for instantiate defaults defined in config yaml remove to do lines issue note remove imports unused cfg init removal double define instantiate changes change back to full config Update pl_examples/hydra_examples/pl_template.py Co-authored-by: Omry Yadan Revert "double define" This reverts commit 4a9a962c5af50d394520491c7398c9b45dfd4d98. fix data configuration remove bug comment, fixed already fix callbacks instantiate --- pl_examples/hydra_examples/README.md | 21 ++ .../hydra_examples/conf/callbacks/basic.yaml | 13 + pl_examples/hydra_examples/conf/config.yaml | 11 + .../conf/data/fashionmnist.yaml | 27 ++ .../hydra_examples/conf/data/kmnist.yaml | 27 ++ .../hydra_examples/conf/data/mnist.yaml | 27 ++ .../hydra_examples/conf/model/basic.yaml | 7 + pl_examples/hydra_examples/conf/optimizer.py | 108 ++++++++ pl_examples/hydra_examples/conf/scheduler.py | 139 ++++++++++ pl_examples/hydra_examples/pl_template.py | 48 ++++ pl_examples/hydra_examples/user_config.py | 56 +++++ pl_examples/models/hydra_config_model.py | 122 +++++++++ pytorch_lightning/trainer/trainer_conf.py | 237 ++++++++++++++++++ 13 files changed, 843 insertions(+) create mode 100644 pl_examples/hydra_examples/README.md create mode 100644 pl_examples/hydra_examples/conf/callbacks/basic.yaml create mode 100644 pl_examples/hydra_examples/conf/config.yaml create mode 100644 pl_examples/hydra_examples/conf/data/fashionmnist.yaml create mode 100644 pl_examples/hydra_examples/conf/data/kmnist.yaml create mode 100644 pl_examples/hydra_examples/conf/data/mnist.yaml create mode 100644 pl_examples/hydra_examples/conf/model/basic.yaml create mode 100644 pl_examples/hydra_examples/conf/optimizer.py create mode 100644 pl_examples/hydra_examples/conf/scheduler.py create mode 100644 pl_examples/hydra_examples/pl_template.py create mode 100644 pl_examples/hydra_examples/user_config.py create mode 100644 pl_examples/models/hydra_config_model.py create mode 100644 pytorch_lightning/trainer/trainer_conf.py diff --git a/pl_examples/hydra_examples/README.md b/pl_examples/hydra_examples/README.md new file mode 100644 index 0000000000000..9509a55b25dbb --- /dev/null +++ b/pl_examples/hydra_examples/README.md @@ -0,0 +1,21 @@ +## Hydra Pytorch Lightning Example + +This directory consists of an example of configuring Pytorch Lightning with [Hydra](https://hydra.cc/). Hydra is a tool that allows for the easy configuration of complex applications. +The core of this directory consists of a set of structured configs used for pytorch lightining, which are stored under the `from pytorch_lightning.trainer.trainer_conf import PLConfig`. Within the PL config there are 5 cofigurations: 1) Trainer Configuration, 2) Profiler Configuration, 3) Early Stopping Configuration, 4) Logger Configuration and 5) Checkpoint Configuration. All of these are basically mirrors of the arguments that make up these objects. These configuration are used to instantiate the objects using Hydras instantiation utility. + +Aside from the PyTorch Lightning configuration we have included a few other important configurations. Optimizer and Scheduler are easy off-the-shelf configurations for configuring your optimizer and learning rate scheduler. You can add them to your config defaults list as needed and use them to configure these objects. Additionally, we provide the arch and data configurations for changing model and data hyperparameters. + +All of the above hyperparameters are configured in the config.yaml file which contains the top level configuration for all these configurations. Under this file is a defaults list which highlights for each of these Hydra groups what is the default configuration. Beyond this configuration file, all of the parameters defined can be overriden via the command line. + +Additionally, for type safety we highlight in our file `user_config.py` an example of extending the `PLConfig` data class with a user configuration. Hence, we can get the benefits of type safety for our entire config.yaml. For further examples of this, [checkout](https://hydra.cc/docs/next/tutorials/structured_config/intro). + +### Tensorboard Visualization + +Hydra by default changes the running directory of your program when running into outputs/[DATE]/[TIME]. Hence, all data with a relative path is submitted into this directory. Therefore to visualize all your tensorboard runs one should run the command: `tensorboard --logdir outputs`. This will then allow you to compare your results across runs. +You can also [customize](https://hydra.cc/docs/configure_hydra/workdir) your Hydra working directory. + +### Multi Run + +One nice feature about Hydra in [multi-run](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run/). This can enable you to run your application multiple times with different configurations. A new directory will be created called multirun with the results of these various parameters. You can visualize from tensorboard these results by running: `tensorboard --logdir multirun`. + +Other interesting information about Hydra can be found in the [docs](https://hydra.cc/docs/intro/). diff --git a/pl_examples/hydra_examples/conf/callbacks/basic.yaml b/pl_examples/hydra_examples/conf/callbacks/basic.yaml new file mode 100644 index 0000000000000..35f8057b836cc --- /dev/null +++ b/pl_examples/hydra_examples/conf/callbacks/basic.yaml @@ -0,0 +1,13 @@ +# @package _group_ + +functions: + print: + target: pl_examples.hydra_examples.user_config.MyPrintingCallback + message: + target: pl_examples.hydra_examples.user_config.MessageCallback + params: + iter_num: 12 + +callbacks_list: + - ${callbacks.functions.print} + - ${callbacks.functions.message} diff --git a/pl_examples/hydra_examples/conf/config.yaml b/pl_examples/hydra_examples/conf/config.yaml new file mode 100644 index 0000000000000..8ed42fa08ca64 --- /dev/null +++ b/pl_examples/hydra_examples/conf/config.yaml @@ -0,0 +1,11 @@ +defaults: + - data: mnist + - model: basic + - callbacks: basic + - profiler: null + - logger: null + - checkpoint: null + - early_stopping: null + - trainer: trainer + - scheduler: step + - opt: adam diff --git a/pl_examples/hydra_examples/conf/data/fashionmnist.yaml b/pl_examples/hydra_examples/conf/data/fashionmnist.yaml new file mode 100644 index 0000000000000..ad401af119ab6 --- /dev/null +++ b/pl_examples/hydra_examples/conf/data/fashionmnist.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +ds: + target: torchvision.datasets.FashionMNIST + params: + root: ${hydra:runtime.cwd}/datasets + download: true + train: null + +tf: + - tensor: + target: torchvision.transforms.ToTensor + - normal: + target: torchvision.transforms.Normalize + params: + mean: .5 + std: .2 + +dl: + target: torch.utils.data.DataLoader + params: + batch_size: 5 + shuffle: true + num_workers: 4 + pin_memory: False + drop_last: False + timeout: 0 diff --git a/pl_examples/hydra_examples/conf/data/kmnist.yaml b/pl_examples/hydra_examples/conf/data/kmnist.yaml new file mode 100644 index 0000000000000..9aaa5d1ff6110 --- /dev/null +++ b/pl_examples/hydra_examples/conf/data/kmnist.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +ds: + target: torchvision.datasets.KMNIST + params: + root: ${hydra:runtime.cwd}/datasets + download: true + train: null + +tf: + - tensor: + target: torchvision.transforms.ToTensor + - normal: + target: torchvision.transforms.Normalize + params: + mean: .5 + std: .2 + +dl: + target: torch.utils.data.DataLoader + params: + batch_size: 5 + shuffle: true + num_workers: 4 + pin_memory: False + drop_last: False + timeout: 0 \ No newline at end of file diff --git a/pl_examples/hydra_examples/conf/data/mnist.yaml b/pl_examples/hydra_examples/conf/data/mnist.yaml new file mode 100644 index 0000000000000..3975aed085ec2 --- /dev/null +++ b/pl_examples/hydra_examples/conf/data/mnist.yaml @@ -0,0 +1,27 @@ +# @package _group_ + +ds: + target: torchvision.datasets.MNIST + params: + root: ${hydra:runtime.cwd}/datasets + download: true + train: null + +tf: + - tensor: + target: torchvision.transforms.ToTensor + - normal: + target: torchvision.transforms.Normalize + params: + mean: .5 + std: .2 + +dl: + target: torch.utils.data.DataLoader + params: + batch_size: 5 + shuffle: true + num_workers: 4 + pin_memory: False + drop_last: False + timeout: 0 diff --git a/pl_examples/hydra_examples/conf/model/basic.yaml b/pl_examples/hydra_examples/conf/model/basic.yaml new file mode 100644 index 0000000000000..6a3182176f738 --- /dev/null +++ b/pl_examples/hydra_examples/conf/model/basic.yaml @@ -0,0 +1,7 @@ +# @package _group_ + +drop_prob: 0.2 +in_features: 784 +out_features: 10 +hidden_dim: 1000 +seed: 123 diff --git a/pl_examples/hydra_examples/conf/optimizer.py b/pl_examples/hydra_examples/conf/optimizer.py new file mode 100644 index 0000000000000..9dd82a5354b20 --- /dev/null +++ b/pl_examples/hydra_examples/conf/optimizer.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import Optional + +from hydra.core.config_store import ConfigStore +from hydra.types import ObjectConf + +cs = ConfigStore.instance() + + +@dataclass +class AdamConf: + betas: tuple = (0.9, 0.999) + lr: float = 1e-3 + eps: float = 1e-8 + weight_decay: float = 0 + amsgrad: bool = False + + +cs.store( + group="opt", name="adam", node=ObjectConf(target="torch.optim.Adam", params=AdamConf()), +) +cs.store( + group="opt", name="adamw", node=ObjectConf(target="torch.optim.AdamW", params=AdamConf()), +) + + +@dataclass +class AdamaxConf: + betas: tuple = (0.9, 0.999) + lr: float = 1e-3 + eps: float = 1e-8 + weight_decay: float = 0 + + +cs.store( + group="opt", name="adamax", node=ObjectConf(target="torch.optim.Adamax", params=AdamaxConf()), +) + + +@dataclass +class ASGDConf: + alpha: float = 0.75 + lr: float = 1e-3 + lambd: float = 1e-4 + t0: float = 1e6 + weight_decay: float = 0 + + +cs.store( + group="opt", name="asgd", node=ObjectConf(target="torch.optim.ASGD", params=ASGDConf()), +) + + +@dataclass +class LBFGSConf: + lr: float = 1 + max_iter: int = 20 + max_eval: int = 25 + tolerance_grad: float = 1e-5 + tolerance_change: float = 1e-9 + history_size: int = 100 + line_search_fn: Optional[str] = None + + +cs.store( + group="opt", name="lbfgs", node=ObjectConf(target="torch.optim.LBFGS", params=LBFGSConf()), +) + + +@dataclass +class RMSpropConf: + lr: float = 1e-2 + momentum: float = 0 + alpha: float = 0.99 + eps: float = 1e-8 + centered: bool = True + weight_decay: float = 0 + + +cs.store( + group="opt", name="rmsprop", node=ObjectConf(target="torch.optim.RMSprop", params=RMSpropConf()), +) + + +@dataclass +class RpropConf: + lr: float = 1e-2 + etas: tuple = (0.5, 1.2) + step_sizes: tuple = (1e-6, 50) + + +cs.store( + group="opt", name="rprop", node=ObjectConf(target="torch.optim.Rprop", params=RpropConf()), +) + + +@dataclass +class SGDConf: + lr: float = 1e-2 + momentum: float = 0 + weight_decay: float = 0 + dampening: float = 0 + nesterov: bool = False + + +cs.store( + group="opt", name="sgd", node=ObjectConf(target="torch.optim.SGD", params=SGDConf()), +) diff --git a/pl_examples/hydra_examples/conf/scheduler.py b/pl_examples/hydra_examples/conf/scheduler.py new file mode 100644 index 0000000000000..4081f4be611d7 --- /dev/null +++ b/pl_examples/hydra_examples/conf/scheduler.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass, field +from typing import Any, List, Optional + +from hydra.core.config_store import ConfigStore +from hydra.types import ObjectConf + +cs = ConfigStore.instance() + + +@dataclass +class CosineConf: + T_max: int = 100 + eta_min: float = 0 + last_epoch: int = -1 + + +cs.store( + group="scheduler", + name="cosine", + node=ObjectConf(target="torch.optim.lr_scheduler.CosineAnnealingLR", params=CosineConf()), +) + + +@dataclass +class CosineWarmConf: + T_0: int = 10 + T_mult: int = 1 + eta_min: float = 0 + last_epoch: int = -1 + + +cs.store( + group="scheduler", + name="cosinewarm", + node=ObjectConf(target="torch.optim.lr_scheduler.CosineAnnealingLR", params=CosineWarmConf()), +) + + +@dataclass +class CyclicConf: + base_lr: Any = 1e-3 + max_lr: Any = 1e-2 + step_size_up: int = 2000 + step_size_down: int = 2000 + mode: str = "triangular" + gamma: float = 1 + scale_fn: Optional[Any] = None + scal_mode: str = "cycle" + cycle_momentum: bool = True + base_momentum: Any = 0.8 + max_momentum: Any = 0.9 + last_epoch: int = -1 + + +cs.store( + group="scheduler", name="cyclic", node=ObjectConf(target="torch.optim.lr_scheduler.CyclicLR", params=CyclicConf()), +) + + +@dataclass +class ExponentialConf: + gamma: float = 1 + last_epoch: int = -1 + + +cs.store( + group="scheduler", + name="exponential", + node=ObjectConf(target="torch.optim.lr_scheduler.ExponentialLR", params=ExponentialConf()), +) + + +@dataclass +class RedPlatConf: + mode: str = "min" + factor: float = 0.1 + patience: int = 10 + verbose: bool = False + threshold: float = 1e-4 + threshold_mode: str = "rel" + cooldown: int = 0 + min_lr: Any = 0 + eps: float = 1e-8 + + +cs.store( + group="scheduler", + name="redplat", + node=ObjectConf(target="torch.optim.lr_scheduler.ReduceLROnPlateau", params=RedPlatConf()), +) + + +@dataclass +class MultiStepConf: + milestones: List = field(default_factory=lambda: [10, 20, 30, 40]) + gamma: float = 0.1 + last_epoch: int = -1 + + +cs.store( + group="scheduler", + name="multistep", + node=ObjectConf(target="torch.optim.lr_scheduler.MultiStepLR", params=MultiStepConf()), +) + + +@dataclass +class OneCycleConf: + max_lr: Any = 1e-2 + total_steps: int = 2000 + epochs: int = 200 + steps_per_epoch: int = 100 + pct_start: float = 0.3 + anneal_strategy: str = "cos" + cycle_momentum: bool = True + base_momentum: Any = 0.8 + max_momentum: Any = 0.9 + div_factor: float = 25 + final_div_factor: float = 1e4 + last_epoch: int = -1 + + +cs.store( + group="scheduler", + name="onecycle", + node=ObjectConf(target="torch.optim.lr_scheduler.OneCycleLR", params=OneCycleConf()), +) + + +@dataclass +class StepConf: + step_size: int = 20 + gamma: float = 0.1 + last_epoch: int = -1 + + +cs.store( + group="scheduler", name="step", node=ObjectConf(target="torch.optim.lr_scheduler.StepLR", params=StepConf()), +) diff --git a/pl_examples/hydra_examples/pl_template.py b/pl_examples/hydra_examples/pl_template.py new file mode 100644 index 0000000000000..f881466f75f8c --- /dev/null +++ b/pl_examples/hydra_examples/pl_template.py @@ -0,0 +1,48 @@ +""" +Pytorch Lightning training using Hydra for configuration +""" + +import hydra +import pl_examples.hydra_examples.user_config +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate + + +from pl_examples.models.hydra_config_model import LightningTemplateModel +from pytorch_lightning import Callback, seed_everything, Trainer + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg: DictConfig): + """ + Main training routine specific for this project + :param cfg: + """ + seed_everything(cfg.model.seed) + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + model = LightningTemplateModel(cfg.model, cfg.data, cfg.scheduler, cfg.opt) + # ------------------------ + # 2 INIT TRAINER + # ------------------------ + + callbacks = [instantiate(c) for c in cfg.callbacks.callbacks_list] if cfg.callbacks else [] + + trainer = Trainer( + **cfg.trainer, + logger=instantiate(cfg.logger), + profiler=instantiate(cfg.profiler), + checkpoint_callback=instantiate(cfg.checkpoint), + early_stop_callback=instantiate(cfg.early_stopping), + callbacks=callbacks, + ) + + # ------------------------ + # 3 START TRAINING + # ------------------------ + trainer.fit(model) + + +if __name__ == "__main__": + main() diff --git a/pl_examples/hydra_examples/user_config.py b/pl_examples/hydra_examples/user_config.py new file mode 100644 index 0000000000000..4a917493e8485 --- /dev/null +++ b/pl_examples/hydra_examples/user_config.py @@ -0,0 +1,56 @@ +from pytorch_lightning.trainer.trainer_conf import PLConfig +import pl_examples.hydra_examples.conf.optimizer +import pl_examples.hydra_examples.conf.scheduler +import hydra +from hydra.core.config_store import ConfigStore +from hydra.types import ObjectConf +from omegaconf import MISSING +from typing import Any, List +from dataclasses import dataclass +from pytorch_lightning import Callback + +cs = ConfigStore.instance() + +# Sample callback definition used in hydra yaml config +class MyPrintingCallback(Callback): + def on_init_start(self, trainer): + print("Starting to init trainer!") + + def on_init_end(self, trainer): + print("trainer is init now") + + def on_train_end(self, trainer, pl_module): + print("do something when training ends") + + +# Sample callback definition with param used in hydra yaml config +class MessageCallback(Callback): + def __init__(self, iter_num): + self.iter_num = iter_num + + def on_batch_start(self, trainer, pl_module): + if trainer.batch_idx == self.iter_num: + print(f"Iteration {self.iter_num}") + + +""" +Top Level used config which can be extended by a user. +For use in Pytorch Lightning we can extend the PLConfig +dataclass which has all the trainer settings. For further +config with type safety, we can extend this class and +add in other config groups. +""" + + +@dataclass +class UserConfig(PLConfig): + defaults: List[Any] = MISSING + data: Any = MISSING + model: Any = MISSING + scheduler: ObjectConf = MISSING + opt: ObjectConf = MISSING + callbacks: Any = None + + +# Stored as config node, for top level config used for type checking. +cs.store(name="config", node=UserConfig) diff --git a/pl_examples/models/hydra_config_model.py b/pl_examples/models/hydra_config_model.py new file mode 100644 index 0000000000000..09706fe109c84 --- /dev/null +++ b/pl_examples/models/hydra_config_model.py @@ -0,0 +1,122 @@ +""" +Example template for defining a Lightning Module with Hydra +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms + +from pytorch_lightning.core import LightningModule +import hydra +import logging + +# Hydra configures the Python logging subsystem automatically. +log = logging.getLogger(__name__) + + +class LightningTemplateModel(LightningModule): + def __init__(self, model, data, scheduler, opt) -> "LightningTemplateModel": + # init superclass + super().__init__() + self.save_hyperparameters() + self.model = model + self.data = data + self.opt = opt + self.scheduler = scheduler + self.c_d1 = nn.Linear(in_features=self.model.in_features, out_features=self.model.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.model.hidden_dim) + self.c_d1_drop = nn.Dropout(self.model.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.model.hidden_dim, out_features=self.model.out_features) + + self.example_input_array = torch.zeros(2, 1, 28, 28) + + def forward(self, x): + """ + No special modification required for Lightning, define it as you normally would + in the `nn.Module` in vanilla PyTorch. + """ + x = self.c_d1(x.view(x.size(0), -1)) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + x = self.c_d2(x) + return x + + def training_step(self, batch, batch_idx): + """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ + # forward pass + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + tensorboard_logs = {"train_loss": loss} + return {"loss": loss, "log": tensorboard_logs} + + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop with the data from the validation dataloader + passed in as `batch`. + """ + x, y = batch + y_hat = self(x) + val_loss = F.cross_entropy(y_hat, y) + labels_hat = torch.argmax(y_hat, dim=1) + n_correct_pred = torch.sum(y == labels_hat).item() + return {"val_loss": val_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)} + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + test_loss = F.cross_entropy(y_hat, y) + labels_hat = torch.argmax(y_hat, dim=1) + n_correct_pred = torch.sum(y == labels_hat).item() + return {"test_loss": test_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)} + + def validation_epoch_end(self, outputs): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + val_acc = sum([x["n_correct_pred"] for x in outputs]) / sum(x["n_pred"] for x in outputs) + tensorboard_logs = {"val_loss": avg_loss, "val_acc": val_acc} + return {"val_loss": avg_loss, "log": tensorboard_logs} + + def test_epoch_end(self, outputs): + avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean() + test_acc = sum([x["n_correct_pred"] for x in outputs]) / sum(x["n_pred"] for x in outputs) + tensorboard_logs = {"test_loss": avg_loss, "test_acc": test_acc} + return {"test_loss": avg_loss, "log": tensorboard_logs} + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + Return whatever optimizers and learning rate schedulers you want here. + At least one optimizer is required. + """ + optimizer = hydra.utils.instantiate(self.opt, params=self.parameters()) + scheduler = hydra.utils.instantiate(self.scheduler, optimizer=optimizer) + return [optimizer], [scheduler] + + def prepare_data(self): + transform = transforms.Compose([hydra.utils.instantiate(trans) for trans in self.data.tf]) + self.train_set = hydra.utils.instantiate(self.data.ds, transform=transform, train=True) + self.test_set = hydra.utils.instantiate(self.data.ds, transform=transform, train=False) + + def train_dataloader(self): + log.info("Training data loader called.") + return hydra.utils.instantiate(self.data.dl, dataset=self.train_set) + + def val_dataloader(self): + log.info("Validation data loader called.") + return hydra.utils.instantiate(self.data.dl, dataset=self.test_set) + + def test_dataloader(self): + log.info("Test data loader called.") + return hydra.utils.instantiate(self.data.dl, dataset=self.test_set) + diff --git a/pytorch_lightning/trainer/trainer_conf.py b/pytorch_lightning/trainer/trainer_conf.py new file mode 100644 index 0000000000000..795262810dc0c --- /dev/null +++ b/pytorch_lightning/trainer/trainer_conf.py @@ -0,0 +1,237 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from hydra.core.config_store import ConfigStore +from hydra.types import ObjectConf +from hydra.utils import instantiate +from omegaconf import MISSING, DictConfig + +cs = ConfigStore.instance() + + +@dataclass +class LightningTrainerConf: + default_root_dir: Optional[str] = None + gradient_clip_val: float = 0 + process_position: int = 0 + num_nodes: int = 1 + num_processes: int = 1 + gpus: Optional[Any] = None + auto_select_gpus: bool = False + tpu_cores: Optional[Any] = None + log_gpu_memory: Optional[str] = None + progress_bar_refresh_rate: int = 1 + overfit_batches: float = 0.0 + track_grad_norm: Any = -1 + check_val_every_n_epoch: int = 1 + fast_dev_run: bool = False + accumulate_grad_batches: Any = 1 + max_epochs: int = 1000 + min_epochs: int = 1 + max_steps: Optional[int] = None + min_steps: Optional[int] = None + limit_train_batches: float = 1.0 + limit_val_batches: float = 1.0 + limit_test_batches: float = 1.0 + val_check_interval: float = 1.0 + log_save_interval: int = 100 + row_log_interval: int = 50 + distributed_backend: Optional[str] = None + precision: int = 32 + print_nan_grads: bool = False + weights_summary: Optional[str] = "top" + weights_save_path: Optional[str] = None + num_sanity_val_steps: int = 2 + truncated_bptt_steps: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + benchmark: bool = False + deterministic: bool = False + reload_dataloaders_every_epoch: bool = False + auto_lr_find: Any = False + replace_sampler_ddp: bool = True + terminate_on_nan: bool = False + auto_scale_batch_size: Any = False + prepare_data_per_node: bool = True + amp_level: str = "O1" + num_tpu_cores: Optional[int] = None + + +cs.store(group="trainer", name="trainer", node=LightningTrainerConf) + + +@dataclass +class ModelCheckpointConf: + filepath: Optional[str] = None + monitor: str = "val_loss" + verbose: bool = False + save_last: bool = False + save_top_k: int = 1 + save_weights_only: bool = False + mode: str = "auto" + period: int = 1 + prefix: str = "" + + +cs.store( + group="checkpoint", + name="modelckpt", + node=ObjectConf(target="pytorch_lightning.callbacks.ModelCheckpoint", params=ModelCheckpointConf()), +) + + +@dataclass +class EarlyStoppingConf: + monitor: str = "val_loss" + verbose: bool = False + mode: str = "auto" + patience: int = 3 + strict: bool = True + min_delta: float = 0.0 + + +cs.store( + group="early_stopping", + name="earlystop", + node=ObjectConf(target="pytorch_lightning.callbacks.EarlyStopping", params=EarlyStoppingConf()), +) + + +@dataclass +class SimpleProfilerConf: + output_filename: Optional[str] = None + + +@dataclass +class AdvancedProfilerConf: + output_filename: Optional[str] = None + line_count_restriction: float = 1.0 + + +cs.store( + group="profiler", + name="simple", + node=ObjectConf(target="pytorch_lightning.profiler.SimpleProfiler", params=SimpleProfilerConf()), +) + +cs.store( + group="profiler", + name="advanced", + node=ObjectConf(target="pytorch_lightning.profiler.AdvancedProfiler", params=AdvancedProfilerConf()), +) + + +@dataclass +class CometLoggerConf: + api_key: Optional[str] = None + save_dir: Optional[str] = None + workspace: Optional[str] = None + project_name: Optional[str] = None + rest_api_key: Optional[str] = None + experiment_name: Optional[str] = None + experiment_key: Optional[str] = None + + +cs.store( + group="logger", + name="comet", + node=ObjectConf(target="pytorch_lightning.loggers.comet.CometLogger", params=CometLoggerConf()), +) + + +@dataclass +class MLFlowLoggerConf: + experiment_name: str = "default" + tracking_uri: Optional[str] = None + tags: Optional[Dict[str, Any]] = None + save_dir: Optional[str] = None + + +cs.store( + group="logger", + name="mlflow", + node=ObjectConf(target="pytorch_lightning.loggers.mlflow.MLFlowLogger", params=MLFlowLoggerConf()), +) + + +@dataclass +class NeptuneLoggerConf: + api_key: Optional[str] = None + project_name: Optional[str] = None + close_after_fit: Optional[bool] = True + offline_mode: bool = False + experiment_name: Optional[str] = None + upload_source_files: Optional[List[str]] = None + params: Optional[Dict[str, Any]] = None + properties: Optional[Dict[str, Any]] = None + tags: Optional[List[str]] = None + + +cs.store( + group="logger", + name="neptune", + node=ObjectConf(target="pytorch_lightning.loggers.neptune.NeptuneLogger", params=NeptuneLoggerConf()), +) + + +@dataclass +class TensorboardLoggerConf: + save_dir: str = "" + name: Optional[str] = "default" + version: Any = None + + +cs.store( + group="logger", + name="tensorboard", + node=ObjectConf(target="pytorch_lightning.loggers.tensorboard.TensorBoardLogger", params=TensorboardLoggerConf()), +) + + +@dataclass +class TestTubeLoggerConf: + save_dir: str = "" + name: str = "default" + description: Optional[str] = None + debug: bool = False + version: Optional[int] = None + create_git_tag: bool = False + + +cs.store( + group="logger", + name="testtube", + node=ObjectConf(target="pytorch_lightning.loggers.test_tube.TestTubeLogger", params=TestTubeLoggerConf()), +) + + +@dataclass +class WandbConf: + name: Optional[str] = None + save_dir: Optional[str] = None + offline: bool = False + id: Optional[str] = None + anonymous: bool = False + version: Optional[str] = None + project: Optional[str] = None + tags: Optional[List[str]] = None + log_model: bool = False + experiment = None + entity = None + group: Optional[str] = None + + +cs.store( + group="logger", + name="wandb", + node=ObjectConf(target="pytorch_lightning.loggers.wandb.WandbLogger", params=WandbConf()), +) + + +@dataclass +class PLConfig(DictConfig): + logger: Optional[ObjectConf] = None + profiler: Optional[ObjectConf] = None + checkpoint: Optional[ObjectConf] = None + early_stopping: Optional[ObjectConf] = None + trainer: LightningTrainerConf = MISSING + From 7429e01b44e86606d848032a3ae1319fe1f81c25 Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Sun, 19 Jul 2020 08:32:17 -0400 Subject: [PATCH 2/6] change out links --- pl_examples/hydra_examples/README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_examples/hydra_examples/README.md b/pl_examples/hydra_examples/README.md index 9509a55b25dbb..0a2913dafeca0 100644 --- a/pl_examples/hydra_examples/README.md +++ b/pl_examples/hydra_examples/README.md @@ -1,13 +1,13 @@ ## Hydra Pytorch Lightning Example -This directory consists of an example of configuring Pytorch Lightning with [Hydra](https://hydra.cc/). Hydra is a tool that allows for the easy configuration of complex applications. -The core of this directory consists of a set of structured configs used for pytorch lightining, which are stored under the `from pytorch_lightning.trainer.trainer_conf import PLConfig`. Within the PL config there are 5 cofigurations: 1) Trainer Configuration, 2) Profiler Configuration, 3) Early Stopping Configuration, 4) Logger Configuration and 5) Checkpoint Configuration. All of these are basically mirrors of the arguments that make up these objects. These configuration are used to instantiate the objects using Hydras instantiation utility. +This directory consists of an example of configuring Pytorch Lightning with [Hydra](https://hydra.cc/). Hydra is a tool that allows for the easy configuration of complex applications. +The core of this directory consists of a set of structured configs used for pytorch lightining, which are stored under the `from pytorch_lightning.trainer.trainer_conf import PLConfig`. Within the PL config there are 5 cofigurations: 1) Trainer Configuration, 2) Profiler Configuration, 3) Early Stopping Configuration, 4) Logger Configuration and 5) Checkpoint Configuration. All of these are basically mirrors of the arguments that make up these objects. These configuration are used to instantiate the objects using Hydras instantiation utility. -Aside from the PyTorch Lightning configuration we have included a few other important configurations. Optimizer and Scheduler are easy off-the-shelf configurations for configuring your optimizer and learning rate scheduler. You can add them to your config defaults list as needed and use them to configure these objects. Additionally, we provide the arch and data configurations for changing model and data hyperparameters. +Aside from the PyTorch Lightning configuration we have included a few other important configurations. Optimizer and Scheduler are easy off-the-shelf configurations for configuring your optimizer and learning rate scheduler. You can add them to your config defaults list as needed and use them to configure these objects. Additionally, we provide the arch and data configurations for changing model and data hyperparameters. -All of the above hyperparameters are configured in the config.yaml file which contains the top level configuration for all these configurations. Under this file is a defaults list which highlights for each of these Hydra groups what is the default configuration. Beyond this configuration file, all of the parameters defined can be overriden via the command line. +All of the above hyperparameters are configured in the config.yaml file which contains the top level configuration for all these configurations. Under this file is a defaults list which highlights for each of these Hydra groups what is the default configuration. Beyond this configuration file, all of the parameters defined can be overriden via the command line. -Additionally, for type safety we highlight in our file `user_config.py` an example of extending the `PLConfig` data class with a user configuration. Hence, we can get the benefits of type safety for our entire config.yaml. For further examples of this, [checkout](https://hydra.cc/docs/next/tutorials/structured_config/intro). +Additionally, for type safety we highlight in our file `user_config.py` an example of extending the `PLConfig` data class with a user configuration. Hence, we can get the benefits of type safety for our entire config.yaml. Please read through the [basic tutorial](https://hydra.cc/docs/next/tutorials/basic/your_first_app/simple_cli) and [structured configuration tutorial](https://hydra.cc/docs/next/tutorials/structured_config/intro) for more information on using Hydra. ### Tensorboard Visualization From a5a764eb84981f0cff5552410f99ea994bf72ffa Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Sun, 19 Jul 2020 08:39:05 -0400 Subject: [PATCH 3/6] conf consistency --- pl_examples/hydra_examples/user_config.py | 6 +++--- pytorch_lightning/trainer/trainer_conf.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_examples/hydra_examples/user_config.py b/pl_examples/hydra_examples/user_config.py index 4a917493e8485..6054dac71247a 100644 --- a/pl_examples/hydra_examples/user_config.py +++ b/pl_examples/hydra_examples/user_config.py @@ -1,4 +1,4 @@ -from pytorch_lightning.trainer.trainer_conf import PLConfig +from pytorch_lightning.trainer.trainer_conf import PLConf import pl_examples.hydra_examples.conf.optimizer import pl_examples.hydra_examples.conf.scheduler import hydra @@ -43,7 +43,7 @@ def on_batch_start(self, trainer, pl_module): @dataclass -class UserConfig(PLConfig): +class UserConf(PLConf): defaults: List[Any] = MISSING data: Any = MISSING model: Any = MISSING @@ -53,4 +53,4 @@ class UserConfig(PLConfig): # Stored as config node, for top level config used for type checking. -cs.store(name="config", node=UserConfig) +cs.store(name="config", node=UserConf) diff --git a/pytorch_lightning/trainer/trainer_conf.py b/pytorch_lightning/trainer/trainer_conf.py index 795262810dc0c..36d0d5af18a36 100644 --- a/pytorch_lightning/trainer/trainer_conf.py +++ b/pytorch_lightning/trainer/trainer_conf.py @@ -228,7 +228,7 @@ class WandbConf: @dataclass -class PLConfig(DictConfig): +class PLConf(DictConfig): logger: Optional[ObjectConf] = None profiler: Optional[ObjectConf] = None checkpoint: Optional[ObjectConf] = None From 739ae2f3d2d14fb2ffbb3b9e65e981ae34f6f0e3 Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Sun, 19 Jul 2020 08:46:46 -0400 Subject: [PATCH 4/6] support simpler callbacks --- .../hydra_examples/conf/callbacks/basic.yaml | 15 ++++----------- pl_examples/hydra_examples/pl_template.py | 4 ++-- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/pl_examples/hydra_examples/conf/callbacks/basic.yaml b/pl_examples/hydra_examples/conf/callbacks/basic.yaml index 35f8057b836cc..2442912736d08 100644 --- a/pl_examples/hydra_examples/conf/callbacks/basic.yaml +++ b/pl_examples/hydra_examples/conf/callbacks/basic.yaml @@ -1,13 +1,6 @@ # @package _group_ - -functions: - print: - target: pl_examples.hydra_examples.user_config.MyPrintingCallback - message: - target: pl_examples.hydra_examples.user_config.MessageCallback +callbacks: + - target: pl_examples.hydra_examples.user_config.MyPrintingCallback + - target: pl_examples.hydra_examples.user_config.MessageCallback params: - iter_num: 12 - -callbacks_list: - - ${callbacks.functions.print} - - ${callbacks.functions.message} + iter_num: 12 \ No newline at end of file diff --git a/pl_examples/hydra_examples/pl_template.py b/pl_examples/hydra_examples/pl_template.py index f881466f75f8c..b0d191ffa52b9 100644 --- a/pl_examples/hydra_examples/pl_template.py +++ b/pl_examples/hydra_examples/pl_template.py @@ -26,8 +26,8 @@ def main(cfg: DictConfig): # ------------------------ # 2 INIT TRAINER # ------------------------ - - callbacks = [instantiate(c) for c in cfg.callbacks.callbacks_list] if cfg.callbacks else [] + + callbacks = [instantiate(c) for c in cfg.callbacks.callbacks] if cfg.callbacks else [] trainer = Trainer( **cfg.trainer, From ab71dd16efe0a89bfeea7770ae29a6c40a7d73da Mon Sep 17 00:00:00 2001 From: Anthony Bisulco Date: Sun, 19 Jul 2020 08:58:09 -0400 Subject: [PATCH 5/6] callbacks is in trainer config --- pl_examples/hydra_examples/user_config.py | 2 +- pytorch_lightning/trainer/trainer_conf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_examples/hydra_examples/user_config.py b/pl_examples/hydra_examples/user_config.py index 6054dac71247a..92d3da59cf302 100644 --- a/pl_examples/hydra_examples/user_config.py +++ b/pl_examples/hydra_examples/user_config.py @@ -49,7 +49,7 @@ class UserConf(PLConf): model: Any = MISSING scheduler: ObjectConf = MISSING opt: ObjectConf = MISSING - callbacks: Any = None + # Stored as config node, for top level config used for type checking. diff --git a/pytorch_lightning/trainer/trainer_conf.py b/pytorch_lightning/trainer/trainer_conf.py index 36d0d5af18a36..7bb4d10b8a00b 100644 --- a/pytorch_lightning/trainer/trainer_conf.py +++ b/pytorch_lightning/trainer/trainer_conf.py @@ -234,4 +234,4 @@ class PLConf(DictConfig): checkpoint: Optional[ObjectConf] = None early_stopping: Optional[ObjectConf] = None trainer: LightningTrainerConf = MISSING - + callbacks: Any = None From 85181b423cc91c218e2f6661523e5592c5eb3966 Mon Sep 17 00:00:00 2001 From: romesc Date: Mon, 20 Jul 2020 22:25:37 -0700 Subject: [PATCH 6/6] Add simplified example for novice users. --- .../conf_simple/config_simple.yaml | 9 +++ .../conf_simple/model/basic.yaml | 7 +++ .../hydra_examples/conf_simple/trainer.py | 59 +++++++++++++++++++ .../hydra_examples/pl_template_simple.py | 44 ++++++++++++++ 4 files changed, 119 insertions(+) create mode 100644 pl_examples/hydra_examples/conf_simple/config_simple.yaml create mode 100644 pl_examples/hydra_examples/conf_simple/model/basic.yaml create mode 100644 pl_examples/hydra_examples/conf_simple/trainer.py create mode 100644 pl_examples/hydra_examples/pl_template_simple.py diff --git a/pl_examples/hydra_examples/conf_simple/config_simple.yaml b/pl_examples/hydra_examples/conf_simple/config_simple.yaml new file mode 100644 index 0000000000000..13ce795c44bd7 --- /dev/null +++ b/pl_examples/hydra_examples/conf_simple/config_simple.yaml @@ -0,0 +1,9 @@ +defaults: + - model: basic + - trainer: trainer + +misc: + data_root: ${hydra:runtime.cwd}/data + batch_size: 128 + learning_rate: 0.001 + num_workers: 4 diff --git a/pl_examples/hydra_examples/conf_simple/model/basic.yaml b/pl_examples/hydra_examples/conf_simple/model/basic.yaml new file mode 100644 index 0000000000000..89d66f3d498f4 --- /dev/null +++ b/pl_examples/hydra_examples/conf_simple/model/basic.yaml @@ -0,0 +1,7 @@ +# @package _group_ + +drop_prob: 0.1 +in_features: 784 +out_features: 10 +hidden_dim: 1000 +seed: 123 diff --git a/pl_examples/hydra_examples/conf_simple/trainer.py b/pl_examples/hydra_examples/conf_simple/trainer.py new file mode 100644 index 0000000000000..d1b2e44dfc132 --- /dev/null +++ b/pl_examples/hydra_examples/conf_simple/trainer.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from hydra.core.config_store import ConfigStore +from hydra.types import ObjectConf +from hydra.utils import instantiate +from omegaconf import MISSING, DictConfig + +cs = ConfigStore.instance() + + +@dataclass +class LightningTrainerConf: + default_root_dir: Optional[str] = None + gradient_clip_val: float = 0 + process_position: int = 0 + num_nodes: int = 1 + num_processes: int = 1 + gpus: Optional[Any] = None + auto_select_gpus: bool = False + tpu_cores: Optional[Any] = None + log_gpu_memory: Optional[str] = None + progress_bar_refresh_rate: int = 1 + overfit_batches: float = 0.0 + track_grad_norm: Any = -1 + check_val_every_n_epoch: int = 1 + fast_dev_run: bool = False + accumulate_grad_batches: Any = 1 + max_epochs: int = 1000 + min_epochs: int = 1 + max_steps: Optional[int] = None + min_steps: Optional[int] = None + limit_train_batches: float = 1.0 + limit_val_batches: float = 1.0 + limit_test_batches: float = 1.0 + val_check_interval: float = 1.0 + log_save_interval: int = 100 + row_log_interval: int = 50 + distributed_backend: Optional[str] = None + precision: int = 32 + print_nan_grads: bool = False + weights_summary: Optional[str] = "top" + weights_save_path: Optional[str] = None + num_sanity_val_steps: int = 2 + truncated_bptt_steps: Optional[int] = None + resume_from_checkpoint: Optional[str] = None + benchmark: bool = False + deterministic: bool = False + reload_dataloaders_every_epoch: bool = False + auto_lr_find: Any = False + replace_sampler_ddp: bool = True + terminate_on_nan: bool = False + auto_scale_batch_size: Any = False + prepare_data_per_node: bool = True + amp_level: str = "O1" + num_tpu_cores: Optional[int] = None + + +cs.store(group="trainer", name="trainer", node=LightningTrainerConf) diff --git a/pl_examples/hydra_examples/pl_template_simple.py b/pl_examples/hydra_examples/pl_template_simple.py new file mode 100644 index 0000000000000..af2b9b1918754 --- /dev/null +++ b/pl_examples/hydra_examples/pl_template_simple.py @@ -0,0 +1,44 @@ +""" +Pytorch Lightning training using Hydra for configuration +""" +import hydra +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import seed_everything, Trainer + +# Hydra defaults from non-yaml config stores (trainer, optimizer, scheduler) +import pl_examples.hydra_examples.conf_simple.trainer +#import pl_examples.hydra_examples.conf.optimizer +#import pl_examples.hydra_examples.conf.scheduler + +# Original lightning template +from pl_examples.models.lightning_template import LightningTemplateModel + + +@hydra.main(config_path="conf_simple", config_name="config_simple") +def main(cfg: DictConfig): + """ + Main training routine specific for this project + :param cfg: + """ + + print(cfg.pretty()) + seed_everything(cfg.model.seed) + + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + model = LightningTemplateModel(**cfg.model,**cfg.misc) + + # ------------------------ + # 2 INIT TRAINER + # ------------------------ + trainer = Trainer(**cfg.trainer) + + # ------------------------ + # 3 START TRAINING + # ------------------------ + trainer.fit(model) + + +if __name__ == "__main__": + main()