-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Clean up Argparse interface with trainer (#1606)
* fixed distutil parsing * fixed distutil parsing * Apply suggestions from code review * log * fixed distutil parsing * fixed distutil parsing * fixed distutil parsing * fixed distutil parsing * doctest * fixed hparams section * fixed hparams section * fixed hparams section * formatting Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: J. Borovec <[email protected]>
- Loading branch information
1 parent
13bf772
commit 4755ded
Showing
7 changed files
with
141 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,36 +3,111 @@ Hyperparameters | |
Lightning has utilities to interact seamlessly with the command line ArgumentParser | ||
and plays well with the hyperparameter optimization framework of your choice. | ||
|
||
LightiningModule hparams | ||
ArgumentParser | ||
^^^^^^^^^^^^^^ | ||
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser | ||
|
||
.. code-block:: python | ||
from argparse import ArgumentParser | ||
parser = ArgumentParser() | ||
parser.add_argument('--layer_1_dim', type=int, default=128) | ||
args = parser.parse_args() | ||
This allows you to call your program like so: | ||
|
||
.. code-block:: bash | ||
python trainer.py --layer_1_dim 64 | ||
Argparser Best Practices | ||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||
It is best practice to layer your arguments in three sections. | ||
|
||
Normally, we don't hard-code the values to a model. We usually use the command line to | ||
modify the network. The `Trainer` can add all the available options to an ArgumentParser. | ||
1. Trainer args (gpus, num_nodes, etc...) | ||
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...) | ||
3. Program arguments (data_path, cluster_email, etc...) | ||
|
||
We can do this as follows. First, in your LightningModule, define the arguments | ||
specific to that module. Remember that data splits or data paths may also be specific to | ||
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10). | ||
|
||
.. code-block:: python | ||
class LitModel(LightningModule): | ||
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument('--encoder_layers', type=int, default=12) | ||
parser.add_argument('--data_path', type=str, default='/some/path') | ||
return parser | ||
Now in your main trainer file, add the Trainer args, the program args, and add the model args | ||
|
||
.. code-block:: python | ||
# ---------------- | ||
# trainer_main.py | ||
# ---------------- | ||
from argparse import ArgumentParser | ||
parser = ArgumentParser() | ||
# parametrize the network | ||
parser.add_argument('--layer_1_dim', type=int, default=128) | ||
parser.add_argument('--layer_2_dim', type=int, default=256) | ||
parser.add_argument('--batch_size', type=int, default=64) | ||
# add PROGRAM level args | ||
parser.add_argument('--conda_env', type=str, default='some_name') | ||
parser.add_argument('--notification_email', type=str, default='[email protected]') | ||
# add model specific args | ||
parser = LitModel.add_model_specific_args(parser) | ||
# add all the available options to the trainer | ||
# add all the available trainer options to argparse | ||
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli | ||
parser = pl.Trainer.add_argparse_args(parser) | ||
args = parser.parse_args() | ||
hparams = parser.parse_args() | ||
Now we can parametrize the LightningModule. | ||
Now you can call run your program like so | ||
|
||
.. code-block:: bash | ||
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12 | ||
Finally, make sure to start the training like so: | ||
|
||
.. code-block:: bash | ||
hparams = parser.parse_args() | ||
# YES | ||
model = LitModel(hparams) | ||
# NO | ||
# model = LitModel(learning_rate=hparams.learning_rate, ...) | ||
# YES | ||
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...) | ||
# NO | ||
trainer = Trainer(gpus=hparams.gpus, ...) | ||
LightiningModule hparams | ||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
Normally, we don't hard-code the values to a model. We usually use the command line to | ||
modify the network and read those values in the LightningModule | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 5,6,7,12,14 | ||
class LitMNIST(pl.LightningModule): | ||
def __init__(self, hparams): | ||
super().__init__() | ||
# do this to save all arguments in any logger (tensorboard) | ||
self.hparams = hparams | ||
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) | ||
|
@@ -49,86 +124,44 @@ Now we can parametrize the LightningModule. | |
def configure_optimizers(self): | ||
return Adam(self.parameters(), lr=self.hparams.learning_rate) | ||
hparams = parse_args() | ||
model = LitMNIST(hparams) | ||
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your | ||
model using those hparams exactly. | ||
parser.add_argument('--layer_1_dim', type=int, default=128) | ||
parser.add_argument('--layer_2_dim', type=int, default=256) | ||
parser.add_argument('--batch_size', type=int, default=64) | ||
parser.add_argument('--learning_rate', type=float, default=0.002) | ||
return parser | ||
And we can also add all the flags available in the Trainer to the Argparser. | ||
Now pass in the params when you init your model | ||
|
||
.. code-block:: python | ||
# add all the available Trainer options to the ArgParser | ||
parser = pl.Trainer.add_argparse_args(parser) | ||
args = parser.parse_args() | ||
And now you can start your program with | ||
hparams = parse_args() | ||
model = LitMNIST(hparams) | ||
.. code-block:: bash | ||
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule. | ||
This does two things: | ||
|
||
# now you can use any trainer flag | ||
$ python main.py --num_nodes 2 --gpus 8 | ||
1. It adds them automatically to tensorboard logs under the hparams tab. | ||
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly. | ||
|
||
Trainer args | ||
^^^^^^^^^^^^ | ||
|
||
It also gets annoying to map each argument into the Argparser. Luckily we have | ||
a default parser | ||
To recap, add ALL possible trainer flags to the argparser and init the Trainer this way | ||
|
||
.. code-block:: python | ||
parser = ArgumentParser() | ||
# add all options available in the trainer such as (max_epochs, etc...) | ||
parser = Trainer.add_argparse_args(parser) | ||
hparams = parser.parse_args() | ||
We set up the main training entry point file like this: | ||
|
||
.. code-block:: python | ||
def main(args): | ||
model = LitMNIST(hparams=args) | ||
trainer = Trainer(max_epochs=args.max_epochs) | ||
trainer.fit(model) | ||
trainer = Trainer.from_argparse_args(hparams) | ||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
# or if you need to pass in callbacks | ||
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...]) | ||
# adds all the trainer options as default arguments (like max_epochs) | ||
parser = Trainer.add_argparse_args(parser) | ||
# parametrize the network | ||
parser.add_argument('--layer_1_dim', type=int, default=128) | ||
parser.add_argument('--layer_1_dim', type=int, default=256) | ||
parser.add_argument('--batch_size', type=int, default=64) | ||
args = parser.parse_args() | ||
# train | ||
main(args) | ||
And now we can train like this: | ||
|
||
.. code-block:: bash | ||
$ python main.py --layer_1_dim 128 --layer_2_dim 256 --batch_size 64 --max_epochs 64 | ||
But it would also be nice to pass in any arbitrary argument to the trainer. | ||
We can do it by changing how we init the trainer. | ||
|
||
.. code-block:: python | ||
def main(args): | ||
model = LitMNIST(hparams=args) | ||
# makes all trainer options available from the command line | ||
trainer = Trainer.from_argparse_args(args) | ||
and now we can do this: | ||
|
||
.. code-block:: bash | ||
$ python main.py --gpus 1 --min_epochs 12 --max_epochs 64 --arbitrary_trainer_arg some_value | ||
Multiple Lightning Modules | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
@@ -173,7 +206,7 @@ Now we can allow each model to inject the arguments it needs in the main.py | |
model = LitMNIST(hparams=args) | ||
model = LitMNIST(hparams=args) | ||
trainer = Trainer(max_epochs=args.max_epochs) | ||
trainer = Trainer.from_argparse_args(args) | ||
trainer.fit(model) | ||
if __name__ == '__main__': | ||
|
@@ -182,6 +215,8 @@ Now we can allow each model to inject the arguments it needs in the main.py | |
# figure out which model to use | ||
parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist') | ||
# THIS LINE IS KEY TO PULL THE MODEL NAME | ||
temp_args = parser.parse_known_args() | ||
# let the model add what it wants | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
def strtobool(val): | ||
"""Convert a string representation of truth to true (1) or false (0). | ||
Copied from the python implementation distutils.utils.strtobool | ||
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values | ||
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if | ||
'val' is anything else. | ||
>>> strtobool('YES') | ||
1 | ||
>>> strtobool('FALSE') | ||
0 | ||
""" | ||
val = val.lower() | ||
if val in ('y', 'yes', 't', 'true', 'on', '1'): | ||
return 1 | ||
elif val in ('n', 'no', 'f', 'false', 'off', '0'): | ||
return 0 | ||
else: | ||
raise ValueError(f'invalid truth value {val}') |