Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

improve finetuning #39

Merged
merged 17 commits into from
Feb 1, 2021
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9))
- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9))


- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39))


### Changed


### Fixed


### Removed
### Removed
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ model = ImageClassifier(num_classes=datamodule.num_classes)
trainer = flash.Trainer(max_epochs=1)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
Expand All @@ -151,13 +151,13 @@ Flash is built as a collection of community-built tasks. A task is highly opinio

### Example 1: Image classification
Flash has an ImageClassification task to tackle any image classification problem.

<details>
<summary>View example</summary>
To illustrate, Let's say we wanted to develop a model that could classify between ants and bees.

<img src="https://pl-flash-data.s3.amazonaws.com/images/ant_bee.png" width="300px">

Here we classify ants vs bees.

```python
Expand Down Expand Up @@ -208,7 +208,7 @@ Flash has a TextClassification task to tackle any text classification problem.
<details>
<summary>View example</summary>
To illustrate, say you wanted to classify movie reviews as positive or negative.

```python
import flash
from flash import download_data
Expand Down Expand Up @@ -261,9 +261,9 @@ Flash has a TabularClassification task to tackle any tabular classification prob

<details>
<summary>View example</summary>
To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.

To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.

```python
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
import flash
Expand Down
135 changes: 135 additions & 0 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim import Optimizer


class FlashBaseFinetuning(BaseFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True):
r"""

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

Override ``finetunning_function`` to put your unfreeze logic.

Args:
attr_names: Name(s) of the module attributes of the model to be frozen.

train_bn: Wether to train Batch Norm layer

"""

self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn)

def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bool = True):
for attr_name in attr_names:
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException(f"Your model must have a {attr} attribute")
self.freeze(module=attr, train_bn=train_bn)


class FreezeUnfreeze(FlashBaseFinetuning):

def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10):
super().__init__(attr_names, train_bn)
self.unfreeze_epoch = unfreeze_epoch

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
if epoch == self.unfreeze_epoch:
modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names]
self.unfreeze_and_add_param_group(
module=modules,
optimizer=optimizer,
train_bn=self.train_bn,
)


class UnfreezeMilestones(FlashBaseFinetuning):

def __init__(
self,
attr_names: Union[str, List[str]] = "backbone",
train_bn: bool = True,
unfreeze_milestones: tuple = (5, 10),
num_layers: int = 5
):
self.unfreeze_milestones = unfreeze_milestones
self.num_layers = num_layers

super().__init__(attr_names, train_bn)

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.unfreeze_milestones[0]:
# unfreeze num_layers last layers
self.unfreeze_and_add_param_group(
module=backbone_modules[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.unfreeze_milestones[1]:
# unfreeze remaining layers
self.unfreeze_and_add_param_group(
module=backbone_modules[:-self.num_layers],
optimizer=optimizer,
train_bn=self.train_bn,
)


_DEFAULTS_FINETUNE_STRATEGIES = {
"no_freeze": BaseFinetuning,
"freeze": FlashBaseFinetuning,
"freeze_unfreeze": FreezeUnfreeze,
"unfreeze_milestones": UnfreezeMilestones
}


def instantiate_default_finetuning_callbacks(strategy):
if strategy is None:
strategy = "no_freeze"
rank_zero_warn("strategy is None. Setting strategy to `no_freeze` by default.", UserWarning)
if isinstance(strategy, str):
strategy = strategy.lower()
if strategy in _DEFAULTS_FINETUNE_STRATEGIES:
return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()]
raise MisconfigurationException(
f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}"
f". Found {strategy}"
)
return []
3 changes: 3 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["pipeline"] = self.data_pipeline

def configure_finetune_callback(self):
return []
122 changes: 62 additions & 60 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,12 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader


# NOTE: copied from:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66
class MilestonesFinetuningCallback(BaseFinetuning):

def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
self.milestones = milestones
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
# TODO: might need some config to say which attribute is model
# maybe something like:
# self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn)
# where self.feature_attr can be "backbone" or "feature_extractor", etc.
# (configured in init)
assert hasattr(
pl_module, "backbone"
), "To use MilestonesFinetuningCallback your model must have a backbone attribute"
self.freeze(module=pl_module.backbone, train_bn=self.train_bn)

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.milestones[0]:
# unfreeze 5 last layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[-5:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.milestones[1]:
# unfreeze remaining layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[:-5],
optimizer=optimizer,
train_bn=self.train_bn,
)
from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks


class Trainer(pl.Trainer):
Expand Down Expand Up @@ -96,13 +54,14 @@ def fit(

def finetune(
self,
model: pl.LightningModule,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[pl.LightningDataModule] = None,
unfreeze_milestones: tuple = (5, 10),
strategy: Optional[Union[str, BaseFinetuning]] = None,
):
r"""

Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers
of the backbone throughout training layers of the backbone throughout training.

Expand All @@ -117,18 +76,61 @@ def finetune(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5
layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will
be unfrozen.
strategy: Should either be a string or a finetuning callback subclassing
``pytorch_lightning.callbacks.BaseFinetuning``.

"""
if hasattr(model, "backbone"):
# TODO: if we find a finetuning callback in the trainer should we change it?
# or should we warn the user?
if not any(isinstance(c, BaseFinetuning) for c in self.callbacks):
# TODO: should pass config from arguments
self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones))
else:
warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally")
Currently, default strategies can be enabled with these strings:
- ``no_freeze``,
- ``freeze``,
- ``freeze_unfreeze``,
- ``unfreeze_milestones``

"""
self._resolve_callbacks(model, strategy)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)

def _resolve_callbacks(self, model, strategy):
"""
This function is used to select the `BaseFinetuning` to be used for finetuning.
"""
if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)):
raise MisconfigurationException(
"strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``"
f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}"
)

if isinstance(strategy, BaseFinetuning):
callback = strategy
else:
# todo: change to ``configure_callbacks`` when merged to Lightning.
model_callback = model.configure_finetune_callback()
if len(model_callback) > 1:
raise MisconfigurationException(
f"{model} configure_finetune_callback should create a list with only 1 callback"
)
if len(model_callback) == 1:
if strategy is not None:
rank_zero_warn(
"The model contains a default finetune callback. "
f"The provided {strategy} will be overriden. "
"HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning
)
callback = [model_callback]
else:
callback = instantiate_default_finetuning_callbacks(strategy)

self.callbacks = self._merge_callbacks(self.callbacks, [callback])

@staticmethod
def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List:
"""
This function keeps only 1 instance of each callback type,
extending new_callbacks with old_callbacks
"""
if len(new_callbacks):
return old_callbacks
new_callbacks_types = set(type(c) for c in new_callbacks)
old_callbacks_types = set(type(c) for c in old_callbacks)
override_types = new_callbacks_types.intersection(old_callbacks_types)
new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types)
return new_callbacks
7 changes: 4 additions & 3 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import flash
from flash.core.data import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
Expand All @@ -17,11 +18,11 @@
# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=2)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 6. Test the model
trainer.test()
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 6. Test model
trainer.test()
Expand Down
Loading