Skip to content

Commit

Permalink
change Checkpoint callback's save_best_only to save_top_k (#128)
Browse files Browse the repository at this point in the history
* docs: enable syntax highlight

* feat: change Checkpoint callback's `save_best_only` to `save_top_k`

fix #70

* docs: update docs for save_top_k

* revert other files

* style: lint for travis-ci

* fix typo

* make flake8 happy

* update according to review

* add tests

* rename func to private

* add doc on `save_top_k == 0`

* make flake8 happy

* update according to PR comments

* change some f-strings

* Update pt_callbacks.py

* Update test_models.py

* update options

* create folders

* Update test_models.py

* change epoch num

* support calling multiple times, add docs and tests

* update docs

* roll back changes in earlystopping

* clean test files

* make flake8 happy

* fix epoch number

* update tests about epoch numbers

* clean debugging code

* fix testing utils codes

* fix testing utils codes

* fix testing utils codes

* fix testing utils codes

* change save_dir to tests/tests according to previous lines

* remove unused overwrite option

* make flake8 happy

* change var name as per review

* make flake8 happy

* update property name to work on master

* elaborate in the docs

* update docs as per review

* revert previous commit

accidentally pressed wrong button when solving conflicts
  • Loading branch information
Ir1d authored and williamFalcon committed Nov 19, 2019
1 parent 619143a commit 7324dd9
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 70 deletions.
45 changes: 31 additions & 14 deletions docs/Trainer/Checkpointing.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Lightning can automate saving and loading checkpoints.

---

### Model saving
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in :
Expand All @@ -10,13 +11,13 @@ Trainer(default_save_path='/your/path/to/save/checkpoints')

To modify the behavior of checkpointing pass in your own callback.

``` {.python}
```{.python}
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min',
Expand All @@ -26,11 +27,26 @@ checkpoint_callback = ModelCheckpoint(
trainer = Trainer(checkpoint_callback=checkpoint_callback)
```

The `save_top_k` options works in the following ways:

| save_top_k | behavior |
| -------- | ----- |
| 0 | no models are saved |
| -1 | all models are saved |
| k >= 1 | the best k models are saved |


Also, if `save_top_k` >= 2 and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.

---
### Restoring training session

### Restoring training session

You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).

Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint.
``` {.python}
Expand All @@ -52,18 +68,19 @@ trainer = Trainer(
trainer.fit(model)
```

The trainer restores:
The trainer restores:

- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- Model weights

You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.
You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.

At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63):

At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63):
```python

self.global_step = checkpoint['global_step']
Expand All @@ -79,6 +96,6 @@ lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)

# uses the model you passed into trainer
# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
```
```
41 changes: 25 additions & 16 deletions docs/examples/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ In 99% of cases you want to just copy [one of the examples](https://github.com/w
wget https://raw.githubusercontent.com/williamFalcon/pytorch-lightning/master/pl_examples/new_project_templates/lightning_module_template.py
```

---
### Trainer Example
---

### Trainer Example

** \_\_main__ function**
** \_\_main\_\_ function**

Normally, we want to let the \_\_main__ function start the training.
Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a
chance to add hyperparameters.
Normally, we want to let the \_\_main\_\_ function start the training.
Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a
chance to add hyperparameters.

```{.python}
from test_tube import HyperOptArgumentParser
Expand All @@ -32,13 +33,15 @@ if __name__ == '__main__':
# train model
main(hyperparams)
```
**Main Function**

**Main Function**

The main function is your entry into the program. This is where you init your model, checkpoint directory, and launch the training.
The main function should have 3 arguments:
- hparams: a configuration of hyperparameters.
The main function should have 3 arguments:

- hparams: a configuration of hyperparameters.
- slurm_manager: Slurm cluster manager object (can be None)
- dict: for you to return any values you want (useful in meta-learning, otherwise set to _)
- dict: for you to return any values you want (useful in meta-learning, otherwise set to \_)

```python
def main(hparams, cluster, results_dict):
Expand All @@ -62,13 +65,15 @@ The __main__ function will start training on your **main** function. If you use
in hyper parameter optimization mode, this main function will get one set of hyperparameters. If you use it as a simple
argument parser you get the default arguments in the argument parser.

So, calling main(hyperparams) runs the model with the default argparse arguments.
So, calling main(hyperparams) runs the model with the default argparse arguments.

```{.python}
main(hyperparams)
```

---
#### CPU hyperparameter search

#### CPU hyperparameter search

```{.python}
# run a grid search over 20 hyperparameter combinations.
Expand All @@ -80,7 +85,9 @@ hyperparams.optimize_parallel_cpu(
```

---
#### Hyperparameter search on a single or multiple GPUs

#### Hyperparameter search on a single or multiple GPUs

```{.python}
# run a grid search over 20 hyperparameter combinations.
hyperparams.optimize_parallel_gpu(
Expand All @@ -92,8 +99,10 @@ hyperparams.optimize_parallel_gpu(
```

---
#### Hyperparameter search on a SLURM HPC cluster
```{.python}

#### Hyperparameter search on a SLURM HPC cluster

```{.python}
def optimize_on_cluster(hyperparams):
# enable cluster training
cluster = SlurmCluster(
Expand Down Expand Up @@ -126,6 +135,6 @@ def optimize_on_cluster(hyperparams):
job_name=job_display_name
)
# run cluster hyperparameter search
# run cluster hyperparameter search
optimize_on_cluster(hyperparams)
```
134 changes: 96 additions & 38 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,17 @@ class ModelCheckpoint(Callback):
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if `save_top_k == 0`, no models are saved.
if `save_top_k == -1`, all models are saved.
Please note that the monitors are checked every `period` epochs.
if `save_top_k >= 2` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
If `save_top_k != 0`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
Expand All @@ -176,27 +182,33 @@ class ModelCheckpoint(Callback):
"""

def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=True, save_weights_only=False,
save_top_k=1, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
if (
save_best_only and
save_top_k and
os.path.isdir(filepath) and
len(os.listdir(filepath)) > 0
):
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_best_only=True."
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)

self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
if not os.path.exists(filepath):
os.makedirs(filepath)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0

if mode not in ['auto', 'min', 'max']:
warnings.warn(
Expand All @@ -206,66 +218,112 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,

if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
self.kth_value = np.Inf
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
self.kth_value = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
self.kth_value = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.best = np.Inf
self.kth_value = np.Inf
self.mode = 'min'

def save_model(self, filepath, overwrite):
dirpath = '/'.join(filepath.split('/')[:-1])
def _del_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
os.makedirs(dirpath, exist_ok=True)

if overwrite:
for filename in os.listdir(dirpath):
if self.prefix in filename:
path_to_delete = os.path.join(dirpath, filename)
try:
shutil.rmtree(path_to_delete)
except OSError:
os.remove(path_to_delete)
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)

def _save_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(dirpath, exist_ok=True)

# delegate the saving to the model
self.save_function(filepath)

def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1)
if self.save_best_only:
self.epochs_since_last_check += 1

if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1

print(filepath)

if self.save_top_k != -1:
current = logs.get(self.monitor)

if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.check_monitor_top_k(current):

# remove kth
if len(self.best_k_models.keys()) == self.save_top_k:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
if self.mode == 'min':
self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
else:
self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]

if self.mode == 'min':
self.best = min(self.best_k_models.values())
else:
self.best = max(self.best_k_models.values())
if self.verbose > 0:
logging.info(
f'\nEpoch {epoch + 1:05d}: {self.monitor} improved'
f' from {self.best:0.5f} to {current:0.5f},'
f' saving model to {filepath}')
self.best = current
self.save_model(filepath, overwrite=True)
f'\nEpoch {epoch:05d}: {self.monitor} reached',
f'{current:0.5f} (best {self.best:0.5f}), saving model to',
f'{filepath} as top {self.save_top_k}')
self._save_model(filepath)

else:
if self.verbose > 0:
logging.info(
f'\nEpoch {epoch + 1:05d}: {self.monitor} did not improve')
f'\nEpoch {epoch:05d}: {self.monitor}',
f'was not in top {self.save_top_k}')

else:
if self.verbose > 0:
logging.info(f'\nEpoch {epoch + 1:05d}: saving model to {filepath}')
self.save_model(filepath, overwrite=False)
logging.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)


class GradientAccumulationScheduler(Callback):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_a_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_load_model_from_checkpoint():
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = LightningTestModel.load_from_checkpoint(
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
)

# test that hparams loaded correctly
Expand Down
Loading

0 comments on commit 7324dd9

Please sign in to comment.