|
| 1 | +# Configuration System |
| 2 | + |
| 3 | +Lighter is a configuration-centric framework where the config. is used for setting up the machine learning workflow from model architecture selection, loss function, optimizer, dataset preparation and running the training/evaluation/inference process. |
| 4 | + |
| 5 | +Our configuration system is heavily based on MONAI bundle parser but with a standardized structure. For every configuration, we expect several items to be mandatorily defined. |
| 6 | + |
| 7 | +Let us take a simple example config to dig deeper into the configuration system of Lighter. You can go through the config and click on the + for more information about specific concepts. |
| 8 | + |
| 9 | +<div class="annotate" markdown> |
| 10 | + |
| 11 | +```yaml title="cifar10.yaml" |
| 12 | + |
| 13 | +trainer: |
| 14 | + _target_ (1): pytorch_lightning.Trainer |
| 15 | + max_epochs (2): 100 |
| 16 | + |
| 17 | +system: |
| 18 | + _target_: lighter.LighterSystem |
| 19 | + batch_size: 512 |
| 20 | + |
| 21 | + model: |
| 22 | + _target_: torchvision.models.resnet18 |
| 23 | + num_classes: 10 |
| 24 | + |
| 25 | + criterion: |
| 26 | + _target_: torch.nn.CrossEntropyLoss |
| 27 | + |
| 28 | + optimizer: |
| 29 | + _target_: torch.optim.Adam |
| 30 | + params: "$@system#model.parameters()" (3) |
| 31 | + lr: 0.001 |
| 32 | + |
| 33 | + datasets: |
| 34 | + train: |
| 35 | + _target_: torchvision.datasets.CIFAR10 |
| 36 | + download: True |
| 37 | + root: .datasets |
| 38 | + train: True |
| 39 | + transform: |
| 40 | + _target_: torchvision.transforms.Compose |
| 41 | + transforms: (4) |
| 42 | + - _target_: torchvision.transforms.ToTensor |
| 43 | + - _target_: torchvision.transforms.Normalize |
| 44 | + mean: [0.5, 0.5, 0.5] |
| 45 | + std: [0.5, 0.5, 0.5] |
| 46 | + |
| 47 | +``` |
| 48 | +</div> |
| 49 | +1. `_target_` is a special reserved keyword that initializes a python object from the provided text. In this case, a `Trainer` object from the `pytorch_lightning` library is initialized |
| 50 | +2. `max_epochs` is an argument of the `Trainer` class which is passed through this format. Any argument for the class can be passed similarly. |
| 51 | +3. `$@` is a combination of `$` which evaluates a python expression and `@` which references a python object. In this case we first reference the model with `@model` which is the `torchvision.models.resnet18` defined earlier and then access its parameters using `[email protected]()` |
| 52 | +4. YAML allows passing a list in the format below where each `_target_` specifices a transform that is added to the list of transforms in `Compose`. The `torchvision.datasets.CIFAR10` accepts these with a `transform` argument and applies them to each item. |
| 53 | + |
| 54 | +## Configuration Concepts |
| 55 | +As seen in the [Quickstart](./quickstart.md), Lighter has two main components: |
| 56 | + |
| 57 | +### Trainer Setup |
| 58 | +```yaml |
| 59 | + trainer: |
| 60 | + _target_: pytorch_lightning.Trainer # (1)! |
| 61 | + max_epochs: 100 |
| 62 | +``` |
| 63 | +
|
| 64 | +The trainer object (`pytorch_lightning.Trainer`) is initialized through the `_target_` key. For more info on `_target_` and special keys, click [here](#special-syntax-and-keywords) |
| 65 | + |
| 66 | +The `max_epochs` is an argument provided to the `pytorch_lightning.Trainer` object during its instantiation. All arguments that are accepted during instantiation can be provided similarly. |
| 67 | + |
| 68 | +### LighterSystem Configuration |
| 69 | +While Lighter borrows the Trainer from Pytorch Lightning, LighterSystem is a custom component unique to Lighter that draws on several concepts of PL such as LightningModule to provide a simple way to capture all the integral elements of a deep learning system. |
| 70 | + |
| 71 | +Concepts encapsulated by LighterSystem include, |
| 72 | + |
| 73 | +#### Model definition |
| 74 | +The `torchvision` library is installed by default in Lighter and therefore, you can choose different torchvision models here. We also have `monai` packaged with Lighter, so if you are looking to use a ResNet, all you need to modify to fit this new model in your config is, |
| 75 | + |
| 76 | +=== "Torchvision ResNet18" |
| 77 | + |
| 78 | + ```yaml |
| 79 | + LighterSystem: |
| 80 | + ... |
| 81 | + model: |
| 82 | + _target_: torchvision.models.resnet18 |
| 83 | + num_classes: 10 |
| 84 | + ... |
| 85 | + ``` |
| 86 | + |
| 87 | +=== "MONAI ResNet50" |
| 88 | + |
| 89 | + ```yaml |
| 90 | + LighterSystem: |
| 91 | + ... |
| 92 | + model: |
| 93 | + _target_: monai.networks.nets.resnet50 |
| 94 | + num_classes: 10 |
| 95 | + spatial_dims: 2 |
| 96 | + ... |
| 97 | + ``` |
| 98 | + |
| 99 | +=== "MONAI 3DResNet50" |
| 100 | + |
| 101 | + ```yaml |
| 102 | + LighterSystem: |
| 103 | + ... |
| 104 | + model: |
| 105 | + _target_: monai.networks.nets.resnet50 |
| 106 | + num_classes: 10 |
| 107 | + spatial_dims: 3 |
| 108 | + ... |
| 109 | + ``` |
| 110 | + |
| 111 | +<br/> |
| 112 | +#### Criterion/Loss |
| 113 | + |
| 114 | +Similar to overriding models, when exploring different loss types in Lighter, you can easily switch between various loss functions provided by libraries such as `torch` and `monai`. This flexibility allows you to experiment with different approaches to optimize your model's performance without changing code!! Below are some examples of how you can modify the criterion section in your configuration file to use different loss functions. |
| 115 | + |
| 116 | +=== "CrossEntropyLoss" |
| 117 | + ```yaml |
| 118 | + LighterSystem: |
| 119 | + ... |
| 120 | + criterion: |
| 121 | + _target_: torch.nn.CrossEntropyLoss |
| 122 | + ... |
| 123 | + ``` |
| 124 | + |
| 125 | +=== "MONAI's Dice Loss" |
| 126 | + ```yaml |
| 127 | + LighterSystem: |
| 128 | + ... |
| 129 | + criterion: |
| 130 | + _target_: monai.losses.DiceLoss |
| 131 | + ... |
| 132 | + ``` |
| 133 | + |
| 134 | +<br/> |
| 135 | +#### Optimizer |
| 136 | + |
| 137 | +Same as above, you can experiment with different optimizer parameters. Model parameters are directly passed to the optimizer in `params` argument. |
| 138 | +```yaml hl_lines="5" |
| 139 | +LighterSystem: |
| 140 | + ... |
| 141 | + optimizer: |
| 142 | + _target_: torch.optim.Adam |
| 143 | + params: "$@system#model.parameters()" |
| 144 | + lr: 0.001 |
| 145 | + ... |
| 146 | +``` |
| 147 | + |
| 148 | + You can also define a scheduler for the optimizer as below, |
| 149 | +```yaml hl_lines="10" |
| 150 | +LighterSystem: |
| 151 | + ... |
| 152 | + optimizer: |
| 153 | + _target_: torch.optim.Adam |
| 154 | + params: "$@system#model.parameters()" |
| 155 | + lr: 0.001 |
| 156 | +
|
| 157 | + scheduler: |
| 158 | + _target_: torch.optim.lr_scheduler.CosineAnnealingLR |
| 159 | + optimizer: "@system#optimizer" |
| 160 | + eta_min: 1.0e-06 |
| 161 | + T_max: "%trainer#max_epochs" |
| 162 | +
|
| 163 | + ... |
| 164 | +``` |
| 165 | +Here, the optimizer is passed to the scheduler with the `optimizer` argument. `%trainer#max_epochs` is also passed to the scheduler where it fetches `max_epochs` from the Trainer class. |
| 166 | + |
| 167 | +<br/> |
| 168 | +#### Datasets |
| 169 | + |
| 170 | +The most commonly changed part of the config is often the datasets as common workflows involve training/inferring on your own dataset. We provide a `datasets` key with `train`, `val`, `test` and `predict` keys that generate dataloaders for each of the different workflows provided by pytorch lightning. These are described in detail [here](./workflows.md) |
| 171 | + |
| 172 | +<div class="annotate" markdown> |
| 173 | + |
| 174 | +```yaml |
| 175 | +LighterSystem: |
| 176 | + ... |
| 177 | + datasets: |
| 178 | + train: |
| 179 | + _target_: torchvision.datasets.CIFAR10 (1) |
| 180 | + download: True |
| 181 | + root: .datasets |
| 182 | + train: True |
| 183 | + transform: (2) |
| 184 | + _target_: torchvision.transforms.Compose |
| 185 | + transforms: |
| 186 | + - _target_: torchvision.transforms.ToTensor |
| 187 | + - _target_: torchvision.transforms.Normalize |
| 188 | + mean: [0.5, 0.5, 0.5] |
| 189 | + std: [0.5, 0.5, 0.5] |
| 190 | + ... |
| 191 | +``` |
| 192 | + |
| 193 | +</div> |
| 194 | +1. Define your own dataset class here or use several existing dataset clases. Read more about [this](./projects.md) |
| 195 | +2. Transforms can be applied to each element of the dataset by initialization a `Compose` object and providing it a list of transforms. This is often the best way to adapt constraints for your data. |
| 196 | + |
| 197 | +### Special Syntax and Keywords |
| 198 | +- `_target_`: Indicates the Python class to instantiate. If a function is provided, a partial function is created. Any configuration key set with `_target_` will map to a python object. |
| 199 | +- **@**: References another configuration value. Using this syntax, keys mapped to python objects can be accessed. For instance, the learning rate of an optimizer, `optimizer` instianted to `torch.optim.Adam` using `_target_` can be accessed using `@model#lr` where `lr` is an attribute of the `torch.optim.Adam` class. |
| 200 | +- **$**: Used for evaluating Python expressions. |
| 201 | +- **%**: Macro for textual replacement in the configuration. |
0 commit comments