Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

improve finetuning #39

Merged
merged 17 commits into from
Feb 1, 2021
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9))


- Added `3 BaseFinetuning Callbacks` and `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39))

### Changed

### Fixed


### Removed
### Removed
153 changes: 153 additions & 0 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim import Optimizer

_EXCLUDE_PARAMTERS = ["self", "args", "kwargs"]
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class NeverFreeze(BaseFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
pass


class NeverUnFreeze(BaseFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True):
self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for attr_name in self.attr_names:
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException("To use NeverUnFreeze your model must have a {attr} attribute")
self.freeze(module=attr, train_bn=self.train_bn)


class FreezeUnFreeze(NeverUnFreeze):

def __init__(
self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(attr_names, train_bn)
self.unfreeze_at_epoch = unfreeze_at_epoch

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
if epoch == self.unfreeze_at_epoch:
modules = []
for attr_name in self.attr_names:
modules.append(getattr(pl_module, attr_name))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self.unfreeze_and_add_param_group(
module=modules,
optimizer=optimizer,
train_bn=self.train_bn,
)


# NOTE: copied from:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66
tchaton marked this conversation as resolved.
Show resolved Hide resolved
class MilestonesFinetuning(BaseFinetuning):

def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5):
self.unfreeze_milestones = unfreeze_milestones
self.train_bn = train_bn
self.num_layers = num_layers

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
# TODO: might need some config to say which attribute is model
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# maybe something like:
# self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn)
# where self.feature_attr can be "backbone" or "feature_extractor", etc.
# (configured in init)
assert hasattr(pl_module, "backbone"), "To use MilestonesFinetuning your model must have a backbone attribute"
self.freeze(module=pl_module.backbone, train_bn=self.train_bn)

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.unfreeze_milestones[0]:
# unfreeze 5 last layers
# TODO last N layers should be parameter
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.unfreeze_and_add_param_group(
module=backbone_modules[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.unfreeze_milestones[1]:
# unfreeze remaining layers
# TODO last N layers should be parameter
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.unfreeze_and_add_param_group(
module=backbone_modules[:-self.num_layers],
optimizer=optimizer,
train_bn=self.train_bn,
)


def instantiate_cls(cls, kwargs):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
parameters = list(inspect.signature(cls.__init__).parameters.keys())
parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS]
cls_kwargs = {}
for p in parameters:
if p in kwargs:
cls_kwargs[p] = kwargs.pop(p)
if len(kwargs) > 0:
raise MisconfigurationException(f"Available parameters are: {parameters}. Found {kwargs} left")
return cls(**cls_kwargs)


_DEFAULTS_FINETUNE_STRATEGIES = {
"never_freeze": NeverFreeze,
"never_unfreeze": NeverUnFreeze,
"freeze_unfreeze": FreezeUnFreeze,
"unfreeze_milestones": MilestonesFinetuning
}


def instantiate_default_finetuning_callbacks(kwargs):
finetune_strategy = kwargs.pop("finetune_strategy", None)
if isinstance(finetune_strategy, str):
finetune_strategy = finetune_strategy.lower()
if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES:
return [instantiate_cls(_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy], kwargs)]
else:
msg = "\n Extra arguments can be: \n"
for n, cls in _DEFAULTS_FINETUNE_STRATEGIES.items():
parameters = list(inspect.signature(cls.__init__).parameters.keys())
parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS]
msg += f"{n}: {parameters} \n"
raise MisconfigurationException(
f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}"
f"{msg}"
f". Found {finetune_strategy}"
)
return []
6 changes: 5 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
tchaton marked this conversation as resolved.
Show resolved Hide resolved

import pytorch_lightning as pl
import torch
from torch import nn

from flash.core.data import DataModule, DataPipeline
from flash.core.finetuning import instantiate_default_finetuning_callbacks
from flash.core.utils import get_callable_dict


Expand Down Expand Up @@ -150,3 +151,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["pipeline"] = self.data_pipeline

def configure_finetune_callbacks(self, **kwargs) -> List:
return instantiate_default_finetuning_callbacks(kwargs)
95 changes: 35 additions & 60 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,11 @@
from typing import List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import BaseFinetuning
from torch.optim import Optimizer
from pytorch_lightning.callbacks import BaseFinetuning, Callback
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader


# NOTE: copied from:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66
class MilestonesFinetuningCallback(BaseFinetuning):

def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
self.milestones = milestones
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
# TODO: might need some config to say which attribute is model
# maybe something like:
# self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn)
# where self.feature_attr can be "backbone" or "feature_extractor", etc.
# (configured in init)
assert hasattr(
pl_module, "backbone"
), "To use MilestonesFinetuningCallback your model must have a backbone attribute"
self.freeze(module=pl_module.backbone, train_bn=self.train_bn)

def finetunning_function(
self,
pl_module: pl.LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
) -> None:
backbone_modules = list(pl_module.backbone.modules())
if epoch == self.milestones[0]:
# unfreeze 5 last layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[-5:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.milestones[1]:
# unfreeze remaining layers
# TODO last N layers should be parameter
self.unfreeze_and_add_param_group(
module=backbone_modules[:-5],
optimizer=optimizer,
train_bn=self.train_bn,
)
from flash.core.model import Task


class Trainer(pl.Trainer):
Expand Down Expand Up @@ -96,11 +52,12 @@ def fit(

def finetune(
self,
model: pl.LightningModule,
model: Task,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[pl.LightningDataModule] = None,
unfreeze_milestones: tuple = (5, 10),
finetune_strategy: Optional[Union[str, Callback]] = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
**callbacks_kwargs,
):
r"""
Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers
Expand All @@ -117,18 +74,36 @@ def finetune(
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5
layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will
be unfrozen.
finetune_strategy: Should either be a string or a finetuning callback subclassing
``pytorch_lightning.callbacks.BaseFinetuning``.

callbacks_kwargs: Those arguments will be provided to `model.configure_finetune_callbacks`
to instantiante your own finetuning callbacks.

"""
if hasattr(model, "backbone"):
# TODO: if we find a finetuning callback in the trainer should we change it?
# or should we warn the user?
if not any(isinstance(c, BaseFinetuning) for c in self.callbacks):
# TODO: should pass config from arguments
self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones))
else:
warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally")
if isinstance(finetune_strategy, Callback) and not isinstance(finetune_strategy, BaseFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise Exception("finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback")

self._resolve_callbacks(model, finetune_strategy, **callbacks_kwargs)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)

def _resolve_callbacks(self, model, finetune_strategy, **callbacks_kwargs):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1:
raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.")
# provided callbacks are higher priorities than model callbacks.
callbacks = self.callbacks
if isinstance(finetune_strategy, str):
callbacks_kwargs["finetune_strategy"] = finetune_strategy
else:
callbacks = self._merge_callbacks(callbacks, [finetune_strategy])
self.callbacks = self._merge_callbacks(callbacks, model.configure_finetune_callbacks(**callbacks_kwargs))

@staticmethod
def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List:
if len(new_callbacks):
return current_callbacks
new_callbacks_types = set(type(c) for c in new_callbacks)
current_callbacks_types = set(type(c) for c in current_callbacks)
override_types = new_callbacks_types.intersection(current_callbacks_types)
new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types)
return new_callbacks
4 changes: 2 additions & 2 deletions flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=2)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze', unfreeze_at_epoch=1)

# 6. Test the model
trainer.test()
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))
trainer.finetune(model, datamodule=datamodule, finetune_strategy='never_freeze')

# 6. Test model
trainer.test()
Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F

from flash import ClassificationTask, Trainer
from flash.core.finetuning import NeverFreeze
from tests.core.test_model import DummyDataset


@pytest.mark.parametrize(
"finetune_strategy",
['never_freeze', 'never_unfreeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat']
)
def test_finetuning(tmpdir: str, finetune_strategy):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())
train_dl = torch.utils.data.DataLoader(DummyDataset())
val_dl = torch.utils.data.DataLoader(DummyDataset())
task = ClassificationTask(model, F.nll_loss)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
if finetune_strategy == "cls":
finetune_strategy = NeverFreeze()
if finetune_strategy == 'chocolat':
with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"):
trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy)
else:
trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy)