Skip to content

Commit

Permalink
Merge load functions (#995)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Use callable object for patching dataloaders (#971)

* Use callable object for patching dataloaders

* Add test for ddp with dataloaders passed to fit()

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <[email protected]>

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <[email protected]>

Co-authored-by: Jirka Borovec <[email protected]>

* merge load functions

* update tests

* fix documentation warnings

* fix line too long

* fix line too long

* print deprecation warning

Co-Authored-By: Jirka Borovec <[email protected]>

* move tags_csv argument to end of signature

* fix typo, update version numbers

* fix line too long

* add typing as requested

* update changelog

Co-authored-by: William Falcon <[email protected]>
Co-authored-by: Sho Arora <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Mar 3, 2020
1 parent f862d9f commit 5458d05
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 74 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Deprecated

- None
- Deprecated `LightningModule.load_from_metrics` in favour of `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))

### Removed

Expand Down
122 changes: 56 additions & 66 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Optional, Union, Dict, Callable

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -1090,77 +1091,35 @@ def val_dataloader(self):
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
r"""
You should use `load_from_checkpoint` instead!
However, if your .ckpt weights don't have the hyperparameters saved, use this method to pass
in a .csv with the hparams you'd like to use. These will be converted into a argparse.Namespace
and passed into your LightningModule for use.
Args:
weights_path (str): Path to a PyTorch checkpoint
tags_csv (str): Path to a .csv with two columns (key, value) as in this
Example::
key,value
drop_prob,0.2
batch_size,32
map_location (dict | str | torch.device | function):
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup
(example: {'cuda:1':'cuda:0'}).
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
Return:
LightningModule with loaded weights and hyperparameters (if available).
Example
-------
.. code-block:: python
pretrained_model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv',
on_gpu=True,
map_location=None
)
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
Warning:
Deprecated in version 0.7.0.
You should use `load_from_checkpoint` instead.
Will be removed in v0.9.0.
"""

hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)

if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)

# add the hparams from csv file to checkpoint
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model
warnings.warn(
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
)
return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location)

@classmethod
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
def load_from_checkpoint(
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
) -> 'LightningModule':
r"""
Primary way of loading model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace or dictionary of hyperparameters
it stores the hyperparameters in the checkpoint if you initialized your LightningModule
with an argument called `hparams` which is a Namespace (output of using argparse
to parse command line arguments) or dictionary of hyperparameters.
Example
-------
.. code-block:: python
# --------------
# Case 1
# when using Namespace (output of using Argparse to parse command line arguments)
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})
Expand All @@ -1171,12 +1130,25 @@ def __init__(self, hparams):
self.learning_rate = hparams.learning_rate
Args:
checkpoint_path (str): Path to checkpoint.
map_location (dict | str | torch.device | function):
checkpoint_path: Path to checkpoint.
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in
`torch.load <https://pytorch.org/docs/stable/torch.html#torch.load>`_.
tags_csv: Optional path to a .csv file with two columns (key, value)
as in this example::
key,value
drop_prob,0.2
batch_size,32
You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a argparse.Namespace and passed into your
LightningModule for use.
Return:
LightningModule with loaded weights and hyperparameters (if available).
Expand All @@ -1185,20 +1157,38 @@ def __init__(self, hparams):
-------
.. code-block:: python
# load weights without mapping
# load weights without mapping ...
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
# load weights mapping all weights from GPU 1 to GPU 0
# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt', map_location=map_location)
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
map_location=map_location
)
"""
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
)
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

if tags_csv is not None:
# add the hparams from csv file to checkpoint
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)

model = cls._load_model_state(checkpoint)
return model

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,21 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
5 changes: 4 additions & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def load_model(exp, root_weights_dir, module_class=LightningTemplateModel, path_
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]
weights_dir = os.path.join(root_weights_dir, checkpoints[0])

trained_model = module_class.load_from_checkpoint(weights_dir)
trained_model = module_class.load_from_checkpoint(
checkpoint_path=weights_dir,
tags_csv=tags_path
)

assert trained_model is not None, 'loading model failed'

Expand Down
6 changes: 4 additions & 2 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def test_model_saving_loading(tmpdir):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()

# make prediction
Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down Expand Up @@ -99,8 +101,10 @@ class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModel
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path)
model_2 = LightningTestModel.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
)
model_2.eval()


Expand Down

0 comments on commit 5458d05

Please sign in to comment.