Skip to content

Commit

Permalink
No auto load weights (Lightning-AI#985)
Browse files Browse the repository at this point in the history
* remove autoload

* remove autoload

* added weights loading docs

* checkpoint loading saving docs

* checkpoint loading saving docs

* checkpoint loading saving docs

* docs (Lightning-AI#1010)

* remove autoload

* remove autoload

* added weights loading docs

* checkpoint loading saving docs

* checkpoint loading saving docs

* checkpoint loading saving docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs
  • Loading branch information
williamFalcon authored and tullie committed Apr 3, 2020
1 parent 470fbd0 commit d9fbcfe
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 325 deletions.
80 changes: 0 additions & 80 deletions docs/source/checkpointing.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ PyTorch-Lightning Documentation
:caption: Common Use Cases

apex
checkpointing
slurm
debugging
experiment_logging
Expand All @@ -54,6 +53,7 @@ PyTorch-Lightning Documentation
fast_training
hooks
multi_gpu
weights_loading
single_gpu
sequences
training_tricks
Expand Down
76 changes: 76 additions & 0 deletions docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
Saving and loading weights
==========================

Lightning can automate saving and loading checkpoints.

Checkpoint saving
-----------------

Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in:

.. code-block:: python
Trainer(default_save_path='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.

.. code-block:: python
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Or disable it by passing

.. code-block:: python
trainer = Trainer(checkpoint_callback=False)
The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.

.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.

.. code-block:: python
:emphasize-lines: 8
from argparse import Namespace
# usually these come from command line args
args = Namespace(**{'learning_rate':0.001})
# define you module to have hparams as the first arg
# this means your checkpoint will have everything that went into making
# this model (in this case, learning rate)
class MyLightningModule(pl.LightningModule):
def __init__(self, hparams, ...):
self.hparams = hparams
Checkpoint Loading
------------------

You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).

.. code-block:: python
model = MyLightingModule.load_from_checkpoint(PATH)
model.eval()
y_hat = model(x)
A LightningModule is no different than a nn.Module. This means you can load it and use it for
predictions as you would a nn.Module.
11 changes: 1 addition & 10 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def configure_checkpoint_callback(self):
self.weights_save_path = self.default_save_path

def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True:
if early_stop_callback is True or None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
Expand All @@ -71,15 +71,6 @@ def configure_early_stopping(self, early_stop_callback):
mode='min'
)
self.enable_early_stop = True
elif early_stop_callback is None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def run_evaluation(self, test_mode: bool = False):
# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)

# log results of test
if test_mode:
model.print(prog_bar_metrics)

# log metrics
self.log_metrics(log_metrics, {})

Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,8 @@ def run_pretrain_routine(self, model: LightningModule):
self.register_slurm_signal_handlers()

# print model summary
if self.proc_rank == 0 and self.weights_summary is not None:
# TODO: remove self.testing condition because model.summarize() is wiping out the weights
if self.proc_rank == 0 and self.weights_summary is not None and not self.testing:
if self.weights_summary in ['full', 'top']:
ref_model.summarize(mode=self.weights_summary)
else:
Expand All @@ -1116,7 +1117,7 @@ def run_pretrain_routine(self, model: LightningModule):
# when testing requested only run test and return
if self.testing:
# only load test dataloader for testing
self.reset_test_dataloader(ref_model)
# self.reset_test_dataloader(ref_model)
self.run_evaluation(test_mode=True)
return

Expand Down Expand Up @@ -1189,8 +1190,10 @@ def test(self, model: Optional[LightningModule] = None):
"""
self.testing = True
if model is not None:
self.model = model
self.fit(model)
self.run_evaluation(test_mode=True)
else:
self.run_evaluation(test_mode=True)


class _PatchDataLoader(object):
Expand Down
Loading

0 comments on commit d9fbcfe

Please sign in to comment.