Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

improve finetuning #39

Merged
merged 17 commits into from
Feb 1, 2021
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9))


- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


### Changed


### Fixed


### Removed
### Removed
142 changes: 142 additions & 0 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim import Optimizer

_EXCLUDE_PARAMTERS = ("self", "args", "kwargs")


class NoFreeze(BaseFinetuning):
pass
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class FlashBaseBaseFinetuning(BaseFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True):
self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn)

@staticmethod
def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True):
for attr_name in attr_names:
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException("To use Freeze your model must have a {attr} attribute")
BaseFinetuning.freeze(module=attr, train_bn=train_bn)
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class Freeze(FlashBaseBaseFinetuning):
pass


class FreezeUnfreeze(FlashBaseBaseFinetuning):

def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10):
super().__init__(attr_names, train_bn)
self.unfreeze_epoch = unfreeze_epoch

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
if epoch == self.unfreeze_epoch:
modules = []
for attr_name in self.attr_names:
modules.append(getattr(pl_module, attr_name))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self.unfreeze_and_add_param_group(
module=modules,
optimizer=optimizer,
train_bn=self.train_bn,
)


class UnfreezeMilestones(FlashBaseBaseFinetuning):

def __init__(
self,
attr_names: Union[str, List[str]] = "backbone",
train_bn: bool = True,
unfreeze_milestones: tuple = (5, 10),
num_layers: int = 5
):
self.unfreeze_milestones = unfreeze_milestones
self.num_layers = num_layers

super().__init__(attr_names, train_bn)

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.unfreeze_milestones[0]:
# unfreeze 5 last layers
# TODO last N layers should be parameter
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.unfreeze_and_add_param_group(
module=backbone_modules[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.unfreeze_milestones[1]:
# unfreeze remaining layers
# TODO last N layers should be parameter
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.unfreeze_and_add_param_group(
module=backbone_modules[:-self.num_layers],
optimizer=optimizer,
train_bn=self.train_bn,
)


_DEFAULTS_FINETUNE_STRATEGIES = {
"no_freeze": NoFreeze,
"freeze": Freeze,
"freeze_unfreeze": FreezeUnfreeze,
"unfreeze_milestones": UnfreezeMilestones
}


def instantiate_default_finetuning_callbacks(strategy):
if isinstance(strategy, str):
strategy = strategy.lower()
if strategy in _DEFAULTS_FINETUNE_STRATEGIES:
return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()]
else:
raise MisconfigurationException(
f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}"
f". Found {strategy}"
)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return []
2 changes: 1 addition & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
tchaton marked this conversation as resolved.
Show resolved Hide resolved

import pytorch_lightning as pl
import torch
Expand Down
97 changes: 37 additions & 60 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,12 @@
from typing import List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from torch.optim import Optimizer
from pytorch_lightning.callbacks import BaseFinetuning, Callback
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader


# NOTE: copied from:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66
class MilestonesFinetuningCallback(BaseFinetuning):

def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
self.milestones = milestones
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
# TODO: might need some config to say which attribute is model
# maybe something like:
# self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn)
# where self.feature_attr can be "backbone" or "feature_extractor", etc.
# (configured in init)
assert hasattr(
pl_module, "backbone"
), "To use MilestonesFinetuningCallback your model must have a backbone attribute"
self.freeze(module=pl_module.backbone, train_bn=self.train_bn)

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.milestones[0]:
# unfreeze 5 last layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[-5:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.milestones[1]:
# unfreeze remaining layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[:-5],
optimizer=optimizer,
train_bn=self.train_bn,
)
from flash.core.finetuning import instantiate_default_finetuning_callbacks
from flash.core.model import Task


class Trainer(pl.Trainer):
Expand Down Expand Up @@ -96,11 +53,11 @@ def fit(

def finetune(
self,
model: pl.LightningModule,
model: Task,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[pl.LightningDataModule] = None,
unfreeze_milestones: tuple = (5, 10),
strategy: Optional[Union[str, BaseFinetuning]] = None,
):
r"""
Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers
Expand All @@ -117,18 +74,38 @@ def finetune(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5
layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will
be unfrozen.
strategy: Should either be a string or a finetuning callback subclassing
``pytorch_lightning.callbacks.BaseFinetuning``.
Currently default strategies can be create with strings such as:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
* ``no_freeze``,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
* ``freeze``
* ``freeze_unfreeze``
* ``unfreeze_milestones``

"""
if hasattr(model, "backbone"):
# TODO: if we find a finetuning callback in the trainer should we change it?
# or should we warn the user?
if not any(isinstance(c, BaseFinetuning) for c in self.callbacks):
# TODO: should pass config from arguments
self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones))
else:
warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally")
if not isinstance(strategy, (BaseFinetuning, str)):
raise MisconfigurationException(
"strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str"
)

self._resolve_callbacks(strategy)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)

def _resolve_callbacks(self, strategy):
if sum((isinstance(c, BaseFinetuning) for c in [strategy])) > 1:
raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.")
# todo: change to ``configure_callbacks`` when merged to Lightning.
callbacks = self.callbacks
if isinstance(strategy, str):
strategy = instantiate_default_finetuning_callbacks(strategy)
self.callbacks = self._merge_callbacks(callbacks, [strategy])

@staticmethod
def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List:
if len(new_callbacks):
return current_callbacks
new_callbacks_types = set(type(c) for c in new_callbacks)
current_callbacks_types = set(type(c) for c in current_callbacks)
override_types = new_callbacks_types.intersection(current_callbacks_types)
new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types)
return new_callbacks
5 changes: 3 additions & 2 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import flash
from flash.core.data import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":
Expand All @@ -18,10 +19,10 @@
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=2)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 6. Test the model
trainer.test()
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 6. Test model
trainer.test()
Expand Down
43 changes: 43 additions & 0 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F

from flash import ClassificationTask, Trainer
from flash.core.finetuning import NoFreeze
from tests.core.test_model import DummyDataset


@pytest.mark.parametrize(
"strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat']
)
def test_finetuning(tmpdir: str, strategy):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
if strategy == "cls":
strategy = NoFreeze()
if strategy == 'chocolat':
with pytest.raises(MisconfigurationException, match="strategy should be within"):
trainer.finetune(task, train_dl, val_dl, strategy=strategy)
elif strategy is None:
with pytest.raises(MisconfigurationException, match="strategy should"):
trainer.finetune(task, train_dl, val_dl, strategy=strategy)
else:
trainer.finetune(task, train_dl, val_dl, strategy=strategy)