From 3d557bd62b9ee4580cee7eb4c686f13db62b7cdb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 13:36:16 +0000 Subject: [PATCH 01/15] improve finetuning --- flash/core/finetuning.py | 153 ++++++++++++++++++ flash/core/model.py | 6 +- flash/core/trainer.py | 95 ++++------- .../finetuning/image_classification.py | 4 +- .../finetuning/text_classification.py | 2 +- tests/core/test_finetuning.py | 41 +++++ 6 files changed, 237 insertions(+), 64 deletions(-) create mode 100644 flash/core/finetuning.py create mode 100644 tests/core/test_finetuning.py diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py new file mode 100644 index 0000000000..11fdb0b407 --- /dev/null +++ b/flash/core/finetuning.py @@ -0,0 +1,153 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import List, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.optim import Optimizer + +_EXCLUDE_PARAMTERS = ["self", "args", "kwargs"] + + +class NeverFreeze(BaseFinetuning): + pass + + +class NeverUnFreeze(BaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names + self.train_bn = train_bn + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + for attr_name in self.attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use NeverUnFreeze your model must have a {attr} attribute") + self.freeze(module=attr, train_bn=self.train_bn) + + +class FreezeUnFreeze(NeverUnFreeze): + + def __init__( + self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 + ): + super().__init__(attr_names, train_bn) + self.unfreeze_at_epoch = unfreeze_at_epoch + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + if epoch == self.unfreeze_at_epoch: + modules = [] + for attr_name in self.attr_names: + modules.append(getattr(pl_module, attr_name)) + + self.unfreeze_and_add_param_group( + module=modules, + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +# NOTE: copied from: +# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 +class MilestonesFinetuning(BaseFinetuning): + + def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): + self.unfreeze_milestones = unfreeze_milestones + self.train_bn = train_bn + self.num_layers = num_layers + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + # TODO: might need some config to say which attribute is model + # maybe something like: + # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) + # where self.feature_attr can be "backbone" or "feature_extractor", etc. + # (configured in init) + assert hasattr(pl_module, "backbone"), "To use MilestonesFinetuning your model must have a backbone attribute" + self.freeze(module=pl_module.backbone, train_bn=self.train_bn) + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + backbone_modules = list(pl_module.backbone.modules()) + if epoch == self.unfreeze_milestones[0]: + # unfreeze 5 last layers + # TODO last N layers should be parameter + self.unfreeze_and_add_param_group( + module=backbone_modules[-self.num_layers:], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + elif epoch == self.unfreeze_milestones[1]: + # unfreeze remaining layers + # TODO last N layers should be parameter + self.unfreeze_and_add_param_group( + module=backbone_modules[:-self.num_layers], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +def instantiate_cls(cls, kwargs): + parameters = list(inspect.signature(cls.__init__).parameters.keys()) + parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] + cls_kwargs = {} + for p in parameters: + if p in kwargs: + cls_kwargs[p] = kwargs.pop(p) + if len(kwargs) > 0: + raise MisconfigurationException(f"Available parameters are: {parameters}. Found {kwargs} left") + return cls(**cls_kwargs) + + +_DEFAULTS_FINETUNE_STRATEGIES = { + "never_freeze": NeverFreeze, + "never_unfreeze": NeverUnFreeze, + "freeze_unfreeze": FreezeUnFreeze, + "unfreeze_milestones": MilestonesFinetuning +} + + +def instantiate_default_finetuning_callbacks(kwargs): + finetune_strategy = kwargs.pop("finetune_strategy", None) + if isinstance(finetune_strategy, str): + finetune_strategy = finetune_strategy.lower() + if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: + return [instantiate_cls(_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy], kwargs)] + else: + msg = "\n Extra arguments can be: \n" + for n, cls in _DEFAULTS_FINETUNE_STRATEGIES.items(): + parameters = list(inspect.signature(cls.__init__).parameters.keys()) + parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] + msg += f"{n}: {parameters} \n" + raise MisconfigurationException( + f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f"{msg}" + f". Found {finetune_strategy}" + ) + return [] diff --git a/flash/core/model.py b/flash/core/model.py index 51b1a87d12..0bc675dd14 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch from torch import nn from flash.core.data import DataModule, DataPipeline +from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.utils import get_callable_dict @@ -150,3 +151,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline + + def configure_finetune_callbacks(self, **kwargs) -> List: + return instantiate_default_finetuning_callbacks(kwargs) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 7ce9f8cf75..ebfd1a2042 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -15,55 +15,11 @@ from typing import List, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.callbacks import BaseFinetuning -from torch.optim import Optimizer +from pytorch_lightning.callbacks import BaseFinetuning, Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader - -# NOTE: copied from: -# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 -class MilestonesFinetuningCallback(BaseFinetuning): - - def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True): - self.milestones = milestones - self.train_bn = train_bn - - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - # TODO: might need some config to say which attribute is model - # maybe something like: - # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) - # where self.feature_attr can be "backbone" or "feature_extractor", etc. - # (configured in init) - assert hasattr( - pl_module, "backbone" - ), "To use MilestonesFinetuningCallback your model must have a backbone attribute" - self.freeze(module=pl_module.backbone, train_bn=self.train_bn) - - def finetunning_function( - self, - pl_module: pl.LightningModule, - epoch: int, - optimizer: Optimizer, - opt_idx: int, - ) -> None: - backbone_modules = list(pl_module.backbone.modules()) - if epoch == self.milestones[0]: - # unfreeze 5 last layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[-5:], - optimizer=optimizer, - train_bn=self.train_bn, - ) - - elif epoch == self.milestones[1]: - # unfreeze remaining layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[:-5], - optimizer=optimizer, - train_bn=self.train_bn, - ) +from flash.core.model import Task class Trainer(pl.Trainer): @@ -96,11 +52,12 @@ def fit( def finetune( self, - model: pl.LightningModule, + model: Task, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, - unfreeze_milestones: tuple = (5, 10), + finetune_strategy: Optional[Union[str, Callback]] = None, + **callbacks_kwargs, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -117,18 +74,36 @@ def finetune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5 - layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will - be unfrozen. + finetune_strategy: Should either be a string or a finetuning callback subclassing + ``pytorch_lightning.callbacks.BaseFinetuning``. + + callbacks_kwargs: Those arguments will be provided to `model.configure_finetune_callbacks` + to instantiante your own finetuning callbacks. """ - if hasattr(model, "backbone"): - # TODO: if we find a finetuning callback in the trainer should we change it? - # or should we warn the user? - if not any(isinstance(c, BaseFinetuning) for c in self.callbacks): - # TODO: should pass config from arguments - self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones)) - else: - warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally") + if isinstance(finetune_strategy, Callback) and not isinstance(finetune_strategy, BaseFinetuning): + raise Exception("finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback") + self._resolve_callbacks(model, finetune_strategy, **callbacks_kwargs) return super().fit(model, train_dataloader, val_dataloaders, datamodule) + + def _resolve_callbacks(self, model, finetune_strategy, **callbacks_kwargs): + if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: + raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") + # provided callbacks are higher priorities than model callbacks. + callbacks = self.callbacks + if isinstance(finetune_strategy, str): + callbacks_kwargs["finetune_strategy"] = finetune_strategy + else: + callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) + self.callbacks = self._merge_callbacks(callbacks, model.configure_finetune_callbacks(**callbacks_kwargs)) + + @staticmethod + def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: + if len(new_callbacks): + return current_callbacks + new_callbacks_types = set(type(c) for c in new_callbacks) + current_callbacks_types = set(type(c) for c in current_callbacks) + override_types = new_callbacks_types.intersection(current_callbacks_types) + new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types) + return new_callbacks diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index ef8dac9aa8..3e64a4867c 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -18,10 +18,10 @@ model = ImageClassifier(num_classes=datamodule.num_classes) # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) + trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze', unfreeze_at_epoch=1) # 6. Test the model trainer.test() diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 2c0ff4b3f9..445cbb4328 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -24,7 +24,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='never_freeze') # 6. Test model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py new file mode 100644 index 0000000000..de49b77f7b --- /dev/null +++ b/tests/core/test_finetuning.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.nn import functional as F + +from flash import ClassificationTask, Trainer +from flash.core.finetuning import NeverFreeze +from tests.core.test_model import DummyDataset + + +@pytest.mark.parametrize( + "finetune_strategy", + ['never_freeze', 'never_unfreeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] +) +def test_finetuning(tmpdir: str, finetune_strategy): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + val_dl = torch.utils.data.DataLoader(DummyDataset()) + task = ClassificationTask(model, F.nll_loss) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + if finetune_strategy == "cls": + finetune_strategy = NeverFreeze() + if finetune_strategy == 'chocolat': + with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + else: + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) From 72421021bb8eae7af114a1fc429228c8c4bda8a3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 13:39:00 +0000 Subject: [PATCH 02/15] update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a589789b7..57ab7ffc1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) + +- Added `3 BaseFinetuning Callbacks` and `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) + ### Changed ### Fixed -### Removed \ No newline at end of file +### Removed From fe83588acfc2d64d82087bd9a772cae01ef85149 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 14:19:26 +0000 Subject: [PATCH 03/15] update on comments --- flash/core/finetuning.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 11fdb0b407..cc84ff1c44 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -20,14 +20,14 @@ from torch import nn from torch.optim import Optimizer -_EXCLUDE_PARAMTERS = ["self", "args", "kwargs"] +_EXCLUDE_PARAMTERS = ("self", "args", "kwargs") class NeverFreeze(BaseFinetuning): pass -class NeverUnFreeze(BaseFinetuning): +class NeverUnfreeze(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names @@ -37,11 +37,11 @@ def freeze_before_training(self, pl_module: pl.LightningModule) -> None: for attr_name in self.attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use NeverUnFreeze your model must have a {attr} attribute") + MisconfigurationException("To use NeverUnfreeze your model must have a {attr} attribute") self.freeze(module=attr, train_bn=self.train_bn) -class FreezeUnFreeze(NeverUnFreeze): +class FreezeUnFreeze(NeverUnfreeze): def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 @@ -68,8 +68,6 @@ def finetunning_function( ) -# NOTE: copied from: -# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 class MilestonesFinetuning(BaseFinetuning): def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): @@ -127,7 +125,7 @@ def instantiate_cls(cls, kwargs): _DEFAULTS_FINETUNE_STRATEGIES = { "never_freeze": NeverFreeze, - "never_unfreeze": NeverUnFreeze, + "never_unfreeze": NeverUnfreeze, "freeze_unfreeze": FreezeUnFreeze, "unfreeze_milestones": MilestonesFinetuning } From ed4cf423a3944857be2a5450b73176f22d3a814b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 14:20:08 +0000 Subject: [PATCH 04/15] typo --- flash/core/finetuning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index cc84ff1c44..7ac2df6271 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -41,7 +41,7 @@ def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze(module=attr, train_bn=self.train_bn) -class FreezeUnFreeze(NeverUnfreeze): +class FreezeUnfreeze(NeverUnfreeze): def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 @@ -126,7 +126,7 @@ def instantiate_cls(cls, kwargs): _DEFAULTS_FINETUNE_STRATEGIES = { "never_freeze": NeverFreeze, "never_unfreeze": NeverUnfreeze, - "freeze_unfreeze": FreezeUnFreeze, + "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": MilestonesFinetuning } From 80d2b96279a7f7ed940abc72aef5ff7ce7191027 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:05:10 +0000 Subject: [PATCH 05/15] update on comments --- CHANGELOG.md | 4 +- flash/core/finetuning.py | 83 +++++++++---------- flash/core/model.py | 4 - flash/core/trainer.py | 28 ++++--- .../finetuning/image_classification.py | 2 +- tests/core/test_finetuning.py | 10 ++- 6 files changed, 62 insertions(+), 69 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57ab7ffc1f..5d0695869b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `3 BaseFinetuning Callbacks` and `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) + ### Changed + ### Fixed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 7ac2df6271..9f9e55e09b 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -23,31 +23,40 @@ _EXCLUDE_PARAMTERS = ("self", "args", "kwargs") -class NeverFreeze(BaseFinetuning): +class NoFreeze(BaseFinetuning): pass -class NeverUnfreeze(BaseFinetuning): +def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + for attr_name in attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use Freeze your model must have a {attr} attribute") + BaseFinetuning.freeze(module=attr, train_bn=train_bn) + + +class FlashBaseBaseFinetuning(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - for attr_name in self.attr_names: - attr = getattr(pl_module, attr_name, None) - if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use NeverUnfreeze your model must have a {attr} attribute") - self.freeze(module=attr, train_bn=self.train_bn) + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) -class FreezeUnfreeze(NeverUnfreeze): +class Freeze(FlashBaseBaseFinetuning): + pass - def __init__( - self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_at_epoch: int = 10 - ): + +class FreezeUnfreeze(FlashBaseBaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) - self.unfreeze_at_epoch = unfreeze_at_epoch + self.unfreeze_epoch = unfreeze_epoch + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -56,7 +65,7 @@ def finetunning_function( optimizer: Optimizer, opt_idx: int, ) -> None: - if epoch == self.unfreeze_at_epoch: + if epoch == self.unfreeze_epoch: modules = [] for attr_name in self.attr_names: modules.append(getattr(pl_module, attr_name)) @@ -68,21 +77,22 @@ def finetunning_function( ) -class MilestonesFinetuning(BaseFinetuning): +class MilestonesFinetuning(FlashBaseBaseFinetuning): - def __init__(self, unfreeze_milestones: tuple = (5, 10), train_bn: bool = True, num_layers: int = 5): + def __init__( + self, + attr_names: Union[str, List[str]] = "backbone", + train_bn: bool = True, + unfreeze_milestones: tuple = (5, 10), + num_layers: int = 5 + ): self.unfreeze_milestones = unfreeze_milestones - self.train_bn = train_bn self.num_layers = num_layers + super().__init__(attr_names, train_bn) + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - # TODO: might need some config to say which attribute is model - # maybe something like: - # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) - # where self.feature_attr can be "backbone" or "feature_extractor", etc. - # (configured in init) - assert hasattr(pl_module, "backbone"), "To use MilestonesFinetuning your model must have a backbone attribute" - self.freeze(module=pl_module.backbone, train_bn=self.train_bn) + freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -111,41 +121,22 @@ def finetunning_function( ) -def instantiate_cls(cls, kwargs): - parameters = list(inspect.signature(cls.__init__).parameters.keys()) - parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] - cls_kwargs = {} - for p in parameters: - if p in kwargs: - cls_kwargs[p] = kwargs.pop(p) - if len(kwargs) > 0: - raise MisconfigurationException(f"Available parameters are: {parameters}. Found {kwargs} left") - return cls(**cls_kwargs) - - _DEFAULTS_FINETUNE_STRATEGIES = { - "never_freeze": NeverFreeze, - "never_unfreeze": NeverUnfreeze, + "no_freeze": NoFreeze, + "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": MilestonesFinetuning } -def instantiate_default_finetuning_callbacks(kwargs): - finetune_strategy = kwargs.pop("finetune_strategy", None) +def instantiate_default_finetuning_callbacks(finetune_strategy): if isinstance(finetune_strategy, str): finetune_strategy = finetune_strategy.lower() if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: - return [instantiate_cls(_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy], kwargs)] + return [_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy]()] else: - msg = "\n Extra arguments can be: \n" - for n, cls in _DEFAULTS_FINETUNE_STRATEGIES.items(): - parameters = list(inspect.signature(cls.__init__).parameters.keys()) - parameters = [p for p in parameters if p not in _EXCLUDE_PARAMTERS] - msg += f"{n}: {parameters} \n" raise MisconfigurationException( f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f"{msg}" f". Found {finetune_strategy}" ) return [] diff --git a/flash/core/model.py b/flash/core/model.py index 0bc675dd14..40b7f28be7 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -18,7 +18,6 @@ from torch import nn from flash.core.data import DataModule, DataPipeline -from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.utils import get_callable_dict @@ -151,6 +150,3 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline - - def configure_finetune_callbacks(self, **kwargs) -> List: - return instantiate_default_finetuning_callbacks(kwargs) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index ebfd1a2042..9e7fecb57c 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -19,6 +19,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader +from flash.core.finetuning import instantiate_default_finetuning_callbacks from flash.core.model import Task @@ -57,7 +58,6 @@ def finetune( val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, finetune_strategy: Optional[Union[str, Callback]] = None, - **callbacks_kwargs, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -76,27 +76,29 @@ def finetune( finetune_strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. - - callbacks_kwargs: Those arguments will be provided to `model.configure_finetune_callbacks` - to instantiante your own finetuning callbacks. + Currently default strategies can be create with strings such as: + * ``no_freeze``, + * ``freeze`` + * ``freeze_unfreeze`` + * ``unfreeze_milestones`` """ - if isinstance(finetune_strategy, Callback) and not isinstance(finetune_strategy, BaseFinetuning): - raise Exception("finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback") + if not isinstance(finetune_strategy, (BaseFinetuning, str)): + raise MisconfigurationException( + "finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + ) - self._resolve_callbacks(model, finetune_strategy, **callbacks_kwargs) + self._resolve_callbacks(finetune_strategy) return super().fit(model, train_dataloader, val_dataloaders, datamodule) - def _resolve_callbacks(self, model, finetune_strategy, **callbacks_kwargs): + def _resolve_callbacks(self, finetune_strategy): if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # provided callbacks are higher priorities than model callbacks. + # todo: change to ``configure_callbacks`` when callbacks = self.callbacks if isinstance(finetune_strategy, str): - callbacks_kwargs["finetune_strategy"] = finetune_strategy - else: - callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) - self.callbacks = self._merge_callbacks(callbacks, model.configure_finetune_callbacks(**callbacks_kwargs)) + finetune_strategy = instantiate_default_finetuning_callbacks(finetune_strategy) + self.callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 3e64a4867c..275d2d58ea 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -21,7 +21,7 @@ trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze', unfreeze_at_epoch=1) + trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze') # 6. Test the model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index de49b77f7b..91667985a6 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -18,13 +18,12 @@ from torch.nn import functional as F from flash import ClassificationTask, Trainer -from flash.core.finetuning import NeverFreeze +from flash.core.finetuning import NoFreeze from tests.core.test_model import DummyDataset @pytest.mark.parametrize( - "finetune_strategy", - ['never_freeze', 'never_unfreeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "finetune_strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] ) def test_finetuning(tmpdir: str, finetune_strategy): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) @@ -33,9 +32,12 @@ def test_finetuning(tmpdir: str, finetune_strategy): task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if finetune_strategy == "cls": - finetune_strategy = NeverFreeze() + finetune_strategy = NoFreeze() if finetune_strategy == 'chocolat': with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + elif finetune_strategy is None: + with pytest.raises(MisconfigurationException, match="finetune_strategy should"): + trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) else: trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) From d17df45b12b7e55ea216c7aff0dc7b9312c48419 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:09:51 +0000 Subject: [PATCH 06/15] update on comments --- flash/core/finetuning.py | 14 ++++++------ flash/core/trainer.py | 22 +++++++++---------- .../finetuning/image_classification.py | 3 ++- .../finetuning/text_classification.py | 2 +- tests/core/test_finetuning.py | 22 +++++++++---------- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 9f9e55e09b..f49356d8e8 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -129,14 +129,14 @@ def finetunning_function( } -def instantiate_default_finetuning_callbacks(finetune_strategy): - if isinstance(finetune_strategy, str): - finetune_strategy = finetune_strategy.lower() - if finetune_strategy in _DEFAULTS_FINETUNE_STRATEGIES: - return [_DEFAULTS_FINETUNE_STRATEGIES[finetune_strategy]()] +def instantiate_default_finetuning_callbacks(strategy): + if isinstance(strategy, str): + strategy = strategy.lower() + if strategy in _DEFAULTS_FINETUNE_STRATEGIES: + return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()] else: raise MisconfigurationException( - f"finetune_strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f". Found {finetune_strategy}" + f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f". Found {strategy}" ) return [] diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 9e7fecb57c..9a1491f537 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -57,7 +57,7 @@ def finetune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, - finetune_strategy: Optional[Union[str, Callback]] = None, + strategy: Optional[Union[str, BaseFinetuning]] = None, ): r""" Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers @@ -74,7 +74,7 @@ def finetune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - finetune_strategy: Should either be a string or a finetuning callback subclassing + strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. Currently default strategies can be create with strings such as: * ``no_freeze``, @@ -83,22 +83,22 @@ def finetune( * ``unfreeze_milestones`` """ - if not isinstance(finetune_strategy, (BaseFinetuning, str)): + if not isinstance(strategy, (BaseFinetuning, str)): raise MisconfigurationException( - "finetune_strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" ) - self._resolve_callbacks(finetune_strategy) + self._resolve_callbacks(strategy) return super().fit(model, train_dataloader, val_dataloaders, datamodule) - def _resolve_callbacks(self, finetune_strategy): - if sum((isinstance(c, BaseFinetuning) for c in [finetune_strategy])) > 1: + def _resolve_callbacks(self, strategy): + if sum((isinstance(c, BaseFinetuning) for c in [strategy])) > 1: raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # todo: change to ``configure_callbacks`` when + # todo: change to ``configure_callbacks`` when merged to Lightning. callbacks = self.callbacks - if isinstance(finetune_strategy, str): - finetune_strategy = instantiate_default_finetuning_callbacks(finetune_strategy) - self.callbacks = self._merge_callbacks(callbacks, [finetune_strategy]) + if isinstance(strategy, str): + strategy = instantiate_default_finetuning_callbacks(strategy) + self.callbacks = self._merge_callbacks(callbacks, [strategy]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 275d2d58ea..7643e85131 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -1,5 +1,6 @@ import flash from flash.core.data import download_data +from flash.core.finetuning import FreezeUnfreeze from flash.vision import ImageClassificationData, ImageClassifier if __name__ == "__main__": @@ -21,7 +22,7 @@ trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='freeze_unfreeze') + trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) # 6. Test the model trainer.test() diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 445cbb4328..dd07f46bf9 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -24,7 +24,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, finetune_strategy='never_freeze') + trainer.finetune(model, datamodule=datamodule, strategy='freeze') # 6. Test model trainer.test() diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 91667985a6..1e46385584 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -23,21 +23,21 @@ @pytest.mark.parametrize( - "finetune_strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] ) -def test_finetuning(tmpdir: str, finetune_strategy): +def test_finetuning(tmpdir: str, strategy): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - if finetune_strategy == "cls": - finetune_strategy = NoFreeze() - if finetune_strategy == 'chocolat': - with pytest.raises(MisconfigurationException, match="finetune_strategy should be within"): - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) - elif finetune_strategy is None: - with pytest.raises(MisconfigurationException, match="finetune_strategy should"): - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + if strategy == "cls": + strategy = NoFreeze() + if strategy == 'chocolat': + with pytest.raises(MisconfigurationException, match="strategy should be within"): + trainer.finetune(task, train_dl, val_dl, strategy=strategy) + elif strategy is None: + with pytest.raises(MisconfigurationException, match="strategy should"): + trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: - trainer.finetune(task, train_dl, val_dl, finetune_strategy=finetune_strategy) + trainer.finetune(task, train_dl, val_dl, strategy=strategy) From b4bffaf5e0a1b59b4c675a974d23e003ba2e351d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:16:21 +0000 Subject: [PATCH 07/15] update finetuning --- CHANGELOG.md | 2 +- flash/core/finetuning.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d0695869b..feaf4c5f64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `configure_finetuning_callbacks` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) ### Changed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index f49356d8e8..e3ed6701ae 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -77,7 +77,7 @@ def finetunning_function( ) -class MilestonesFinetuning(FlashBaseBaseFinetuning): +class UnfreezeMilestones(FlashBaseBaseFinetuning): def __init__( self, @@ -125,7 +125,7 @@ def finetunning_function( "no_freeze": NoFreeze, "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, - "unfreeze_milestones": MilestonesFinetuning + "unfreeze_milestones": UnfreezeMilestones } From 2b955f1fda5a58a2fe627a0f2763e1f7c8237e6e Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:17:08 +0000 Subject: [PATCH 08/15] typo --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index feaf4c5f64..dd92c849c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) -- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` ([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) ### Changed From cb6a905de6b8f9dbdcb279e6396c9556c853756c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 16:19:05 +0000 Subject: [PATCH 09/15] update --- flash/core/finetuning.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index e3ed6701ae..ff569a7835 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -27,14 +27,6 @@ class NoFreeze(BaseFinetuning): pass -def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): - for attr_name in attr_names: - attr = getattr(pl_module, attr_name, None) - if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use Freeze your model must have a {attr} attribute") - BaseFinetuning.freeze(module=attr, train_bn=train_bn) - - class FlashBaseBaseFinetuning(BaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): @@ -42,7 +34,15 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo self.train_bn = train_bn def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + + @staticmethod + def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + for attr_name in attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException("To use Freeze your model must have a {attr} attribute") + BaseFinetuning.freeze(module=attr, train_bn=train_bn) class Freeze(FlashBaseBaseFinetuning): @@ -56,7 +56,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo self.unfreeze_epoch = unfreeze_epoch def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, @@ -92,7 +92,7 @@ def __init__( super().__init__(attr_names, train_bn) def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) def finetunning_function( self, From 6ab9daf7e7b83872fb0eb8828d3279ca49285313 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:01:40 +0000 Subject: [PATCH 10/15] update --- CHANGELOG.md | 4 +-- flash/core/finetuning.py | 44 ++++++++++++++++----------- flash/core/model.py | 5 +++- flash/core/trainer.py | 56 +++++++++++++++++++++++------------ tests/core/test_finetuning.py | 7 ++--- tests/core/test_trainer.py | 2 +- 6 files changed, 72 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd92c849c0..87b983c22e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) +- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9)) -- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/pytorch-lightning/pull/39)) +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39)) ### Changed diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index ff569a7835..5e6298b03b 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim import Optimizer @@ -23,13 +24,22 @@ _EXCLUDE_PARAMTERS = ("self", "args", "kwargs") -class NoFreeze(BaseFinetuning): - pass +class FlashBaseFinetuning(BaseFinetuning): + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + r""" -class FlashBaseBaseFinetuning(BaseFinetuning): + FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. + + Override ``finetunning_function`` to put your unfreeze logic. + + Args: + attr_names: Name(s) of the module attributes of the model to be frozen. + + train_bn: Wether to train Batch Norm layer + + """ - def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names self.train_bn = train_bn @@ -41,15 +51,11 @@ def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = T for attr_name in attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): - MisconfigurationException("To use Freeze your model must have a {attr} attribute") + MisconfigurationException(f"Your model must have a {attr} attribute") BaseFinetuning.freeze(module=attr, train_bn=train_bn) -class Freeze(FlashBaseBaseFinetuning): - pass - - -class FreezeUnfreeze(FlashBaseBaseFinetuning): +class FreezeUnfreeze(FlashBaseFinetuning): def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) @@ -77,7 +83,7 @@ def finetunning_function( ) -class UnfreezeMilestones(FlashBaseBaseFinetuning): +class UnfreezeMilestones(FlashBaseFinetuning): def __init__( self, @@ -122,21 +128,23 @@ def finetunning_function( _DEFAULTS_FINETUNE_STRATEGIES = { - "no_freeze": NoFreeze, - "freeze": Freeze, + "no_freeze": BaseFinetuning, + "freeze": FlashBaseFinetuning, "freeze_unfreeze": FreezeUnfreeze, "unfreeze_milestones": UnfreezeMilestones } def instantiate_default_finetuning_callbacks(strategy): + if strategy is None: + strategy = "no_freeze" + rank_zero_warn("strategy is None. Setting strategy to `no_freeze` by default.", UserWarning) if isinstance(strategy, str): strategy = strategy.lower() if strategy in _DEFAULTS_FINETUNE_STRATEGIES: return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()] - else: - raise MisconfigurationException( - f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" - f". Found {strategy}" - ) + raise MisconfigurationException( + f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f". Found {strategy}" + ) return [] diff --git a/flash/core/model.py b/flash/core/model.py index 40b7f28be7..3607878ac8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch @@ -150,3 +150,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline + + def configure_finetune_callback(self): + return [] diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 9a1491f537..75c74f6190 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -15,12 +15,13 @@ from typing import List, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.callbacks import BaseFinetuning, Callback +from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader -from flash.core.finetuning import instantiate_default_finetuning_callbacks -from flash.core.model import Task +from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks class Trainer(pl.Trainer): @@ -53,7 +54,7 @@ def fit( def finetune( self, - model: Task, + model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, @@ -76,29 +77,46 @@ def finetune( strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. - Currently default strategies can be create with strings such as: + Currently, default strategies can be enabled with these strings: * ``no_freeze``, - * ``freeze`` - * ``freeze_unfreeze`` + * ``freeze``, + * ``freeze_unfreeze``, * ``unfreeze_milestones`` """ - if not isinstance(strategy, (BaseFinetuning, str)): + self._resolve_callbacks(model, strategy) + return super().fit(model, train_dataloader, val_dataloaders, datamodule) + + def _resolve_callbacks(self, model, strategy): + if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): raise MisconfigurationException( - "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning`` Callback or a str" + "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" + f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}" ) - self._resolve_callbacks(strategy) - return super().fit(model, train_dataloader, val_dataloaders, datamodule) - - def _resolve_callbacks(self, strategy): - if sum((isinstance(c, BaseFinetuning) for c in [strategy])) > 1: - raise MisconfigurationException("Only 1 callback subclassing `BaseFinetuning` should be provided.") - # todo: change to ``configure_callbacks`` when merged to Lightning. callbacks = self.callbacks - if isinstance(strategy, str): - strategy = instantiate_default_finetuning_callbacks(strategy) - self.callbacks = self._merge_callbacks(callbacks, [strategy]) + + if isinstance(strategy, BaseFinetuning): + callback = strategy + else: + # todo: change to ``configure_callbacks`` when merged to Lightning. + model_callback = model.configure_finetune_callback() + if len(model_callback) > 1: + raise MisconfigurationException( + f"{model} configure_finetune_callback should create a list with only 1 callback" + ) + if len(model_callback) == 1: + if strategy is not None: + rank_zero_warn( + "The model contains a default finetune callback. " + f"The provided {strategy} will be overriden. " + "HINT: Provide a `BaseFinetuning callback as strategy to be prioritized. ", UserWarning + ) + callback = [model_callback] + else: + callback = instantiate_default_finetuning_callbacks(strategy) + + self.callbacks = self._merge_callbacks(callbacks, [callback]) @staticmethod def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 1e46385584..e4062eac59 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -18,7 +18,7 @@ from torch.nn import functional as F from flash import ClassificationTask, Trainer -from flash.core.finetuning import NoFreeze +from flash.core.finetuning import FlashBaseFinetuning from tests.core.test_model import DummyDataset @@ -32,12 +32,9 @@ def test_finetuning(tmpdir: str, strategy): task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": - strategy = NoFreeze() + strategy = FlashBaseFinetuning() if strategy == 'chocolat': with pytest.raises(MisconfigurationException, match="strategy should be within"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) - elif strategy is None: - with pytest.raises(MisconfigurationException, match="strategy should"): - trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: trainer.finetune(task, train_dl, val_dl, strategy=strategy) diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index f872de5e55..e0f63f19f2 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -51,5 +51,5 @@ def test_task_finetune(tmpdir: str): val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - result = trainer.finetune(task, train_dl, val_dl, unfreeze_milestones=(0, 0)) + result = trainer.finetune(task, train_dl, val_dl) assert result From 583ce0cb3ca01d0967a55294e421c2ea02765728 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:15:34 +0000 Subject: [PATCH 11/15] update notebooks --- flash/core/trainer.py | 23 +++++---- .../finetuning/image_classification.ipynb | 47 ++++++++--------- .../finetuning/text_classification.ipynb | 50 ++++++++++--------- 3 files changed, 64 insertions(+), 56 deletions(-) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 75c74f6190..12b801612f 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -88,14 +88,15 @@ def finetune( return super().fit(model, train_dataloader, val_dataloaders, datamodule) def _resolve_callbacks(self, model, strategy): + """ + This function is used to select the `BaseFinetuning` to be used for finetuning. + """ if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): raise MisconfigurationException( "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}" ) - callbacks = self.callbacks - if isinstance(strategy, BaseFinetuning): callback = strategy else: @@ -110,20 +111,24 @@ def _resolve_callbacks(self, model, strategy): rank_zero_warn( "The model contains a default finetune callback. " f"The provided {strategy} will be overriden. " - "HINT: Provide a `BaseFinetuning callback as strategy to be prioritized. ", UserWarning + "HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning ) callback = [model_callback] else: callback = instantiate_default_finetuning_callbacks(strategy) - self.callbacks = self._merge_callbacks(callbacks, [callback]) + self.callbacks = self._merge_callbacks(self.callbacks, [callback]) @staticmethod - def _merge_callbacks(current_callbacks: List, new_callbacks: List) -> List: + def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: + """ + This function keeps only 1 instance of each callback type, + extending new_callbacks with old_callbacks + """ if len(new_callbacks): - return current_callbacks + return old_callbacks new_callbacks_types = set(type(c) for c in new_callbacks) - current_callbacks_types = set(type(c) for c in current_callbacks) - override_types = new_callbacks_types.intersection(current_callbacks_types) - new_callbacks.extend(c for c in current_callbacks if type(c) not in override_types) + old_callbacks_types = set(type(c) for c in old_callbacks) + override_types = new_callbacks_types.intersection(old_callbacks_types) + new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) return new_callbacks diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/finetuning/image_classification.ipynb index 1ee71ec15e..1959e99df7 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/finetuning/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "dominican-savings", + "id": "thousand-manufacturer", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "threaded-coffee", + "id": "smoking-probe", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetuning an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", @@ -27,7 +27,9 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to properly distinguish ants and bees. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + " \n", + " \n", "\n", " \n", "\n", @@ -41,7 +43,7 @@ { "cell_type": "code", "execution_count": null, - "id": "handmade-timing", + "id": "thermal-fraction", "metadata": {}, "outputs": [], "source": [ @@ -52,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "through-edwards", + "id": "cognitive-haven", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "hybrid-adapter", + "id": "afraid-straight", "metadata": {}, "source": [ "## 1. Download data\n", @@ -73,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "amateur-disposal", + "id": "advisory-narrow", "metadata": {}, "outputs": [], "source": [ @@ -82,7 +84,7 @@ }, { "cell_type": "markdown", - "id": "front-metallic", + "id": "trying-group", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "hazardous-means", + "id": "stuck-composition", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "defined-mouse", + "id": "irish-scenario", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -130,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "internal-playback", + "id": "opening-nomination", "metadata": {}, "outputs": [], "source": [ @@ -139,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "million-tower", + "id": "breathing-element", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -156,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "centered-paris", + "id": "earlier-jordan", "metadata": {}, "outputs": [], "source": [ @@ -165,26 +167,25 @@ }, { "cell_type": "markdown", - "id": "special-fence", + "id": "extreme-scene", "metadata": {}, "source": [ - "### 5. Finetune the model\n", - "The `unfreeze_milestones=(0, 1)` will unfreeze the latest layers of the backbone on epoch `0` and the rest of the backbone on epoch `1`. " + "### 5. Finetune the model" ] }, { "cell_type": "code", "execution_count": null, - "id": "local-taylor", + "id": "tired-underground", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze_unfreeze\")" ] }, { "cell_type": "markdown", - "id": "municipal-kentucky", + "id": "smooth-european", "metadata": {}, "source": [ "### 6. Test the model" @@ -193,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "simplified-bundle", + "id": "sexual-tender", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "familiar-territory", + "id": "athletic-nutrition", "metadata": {}, "source": [ "### 7. Save it!" @@ -211,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "injured-mineral", + "id": "pleasant-canon", "metadata": {}, "outputs": [], "source": [ @@ -220,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "tested-experience", + "id": "incident-basket", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/finetuning/text_classification.ipynb index 8079575e0e..28cb748793 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/finetuning/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "digital-quilt", + "id": "prerequisite-straight", "metadata": {}, "source": [ "
\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "empty-request", + "id": "coastal-bible", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as IMDB Dataset to learn to predict the associated sentiment of movie reviews. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to between negative and positive reviews. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `unfreeze_milestones` controls those milestone and be used as such `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "another-might", + "id": "sharp-techno", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ideal-summary", + "id": "posted-blair", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "straight-commission", + "id": "double-swedish", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "classical-snake", + "id": "outside-garlic", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "lined-standing", + "id": "tired-lender", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "endangered-heavy", + "id": "daily-marijuana", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "posted-chosen", + "id": "standing-commons", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "cognitive-compact", + "id": "fantastic-mortality", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "underlying-liberia", + "id": "prompt-azerbaijan", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "democratic-interaction", + "id": "cubic-crystal", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "adopted-caution", + "id": "mineral-phrase", "metadata": {}, "outputs": [], "source": [ @@ -169,29 +169,31 @@ }, { "cell_type": "markdown", - "id": "integral-access", + "id": "brown-scoop", "metadata": { "jupyter": { "outputs_hidden": true } }, "source": [ - "### 5. Fine-tune the model" + "### 5. Fine-tune the model\n", + "\n", + "The backbone won't be freezed and the entire model will be finetuned on the imdb dataset " ] }, { "cell_type": "code", "execution_count": null, - "id": "enormous-botswana", + "id": "reliable-hampshire", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"no_freeze\")" ] }, { "cell_type": "markdown", - "id": "cellular-baking", + "id": "unlimited-duplicate", "metadata": { "jupyter": { "outputs_hidden": true @@ -204,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "demanding-headline", + "id": "federal-quarter", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "charged-investigator", + "id": "defensive-committee", "metadata": { "jupyter": { "outputs_hidden": true @@ -226,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "early-ridge", + "id": "disciplinary-background", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "detailed-direction", + "id": "increased-filter", "metadata": {}, "source": [ "\n", @@ -296,4 +298,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 6bf0064ab3b0184e864c4827d339f9f16aba19b8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:18:18 +0000 Subject: [PATCH 12/15] update typo --- flash/core/finetuning.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 5e6298b03b..d94de89fd1 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import List, Union import pytorch_lightning as pl @@ -21,8 +20,6 @@ from torch import nn from torch.optim import Optimizer -_EXCLUDE_PARAMTERS = ("self", "args", "kwargs") - class FlashBaseFinetuning(BaseFinetuning): @@ -46,13 +43,12 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - @staticmethod - def freeze_using_attr_names(pl_module, attr_names: List[str], train_bn: bool = True): + def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bool = True): for attr_name in attr_names: attr = getattr(pl_module, attr_name, None) if attr is None or not isinstance(attr, nn.Module): MisconfigurationException(f"Your model must have a {attr} attribute") - BaseFinetuning.freeze(module=attr, train_bn=train_bn) + self.freeze(module=attr, train_bn=train_bn) class FreezeUnfreeze(FlashBaseFinetuning): From 504ddfeab5c584099067414dc43e854235db40f4 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 17:40:53 +0000 Subject: [PATCH 13/15] Update flash_notebooks/finetuning/image_classification.ipynb MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- flash_notebooks/finetuning/image_classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/finetuning/image_classification.ipynb index 1959e99df7..4cd82ec404 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/finetuning/image_classification.ipynb @@ -27,7 +27,7 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", " \n", " \n", "\n", From 2f37a368ef996970319c9b25b3e1d2d113657a7d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:45:29 +0000 Subject: [PATCH 14/15] resolve comments --- README.md | 16 +++---- flash/core/finetuning.py | 15 +------ flash/core/trainer.py | 10 +++-- .../finetuning/image_classification.py | 2 +- .../finetuning/text_classification.ipynb | 42 +++++++++---------- 5 files changed, 38 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index dfad959328..6253827031 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ model = ImageClassifier(num_classes=datamodule.num_classes) trainer = flash.Trainer(max_epochs=1) # 5. Finetune the model -trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 7. Save it! trainer.save_checkpoint("image_classification_model.pt") @@ -151,13 +151,13 @@ Flash is built as a collection of community-built tasks. A task is highly opinio ### Example 1: Image classification Flash has an ImageClassification task to tackle any image classification problem. - +
View example To illustrate, Let's say we wanted to develop a model that could classify between ants and bees. - + - + Here we classify ants vs bees. ```python @@ -208,7 +208,7 @@ Flash has a TextClassification task to tackle any text classification problem.
View example To illustrate, say you wanted to classify movie reviews as positive or negative. - + ```python import flash from flash import download_data @@ -261,9 +261,9 @@ Flash has a TabularClassification task to tackle any tabular classification prob
View example - - To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. - + + To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. + ```python from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall import flash diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index d94de89fd1..c68f52fde0 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -57,9 +57,6 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo super().__init__(attr_names, train_bn) self.unfreeze_epoch = unfreeze_epoch - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - def finetunning_function( self, pl_module: pl.LightningModule, @@ -68,10 +65,7 @@ def finetunning_function( opt_idx: int, ) -> None: if epoch == self.unfreeze_epoch: - modules = [] - for attr_name in self.attr_names: - modules.append(getattr(pl_module, attr_name)) - + modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names] self.unfreeze_and_add_param_group( module=modules, optimizer=optimizer, @@ -93,9 +87,6 @@ def __init__( super().__init__(attr_names, train_bn) - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) - def finetunning_function( self, pl_module: pl.LightningModule, @@ -105,8 +96,7 @@ def finetunning_function( ) -> None: backbone_modules = list(pl_module.backbone.modules()) if epoch == self.unfreeze_milestones[0]: - # unfreeze 5 last layers - # TODO last N layers should be parameter + # unfreeze num_layers last layers self.unfreeze_and_add_param_group( module=backbone_modules[-self.num_layers:], optimizer=optimizer, @@ -115,7 +105,6 @@ def finetunning_function( elif epoch == self.unfreeze_milestones[1]: # unfreeze remaining layers - # TODO last N layers should be parameter self.unfreeze_and_add_param_group( module=backbone_modules[:-self.num_layers], optimizer=optimizer, diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 12b801612f..e570d4ae2d 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -61,6 +61,7 @@ def finetune( strategy: Optional[Union[str, BaseFinetuning]] = None, ): r""" + Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training. @@ -77,11 +78,12 @@ def finetune( strategy: Should either be a string or a finetuning callback subclassing ``pytorch_lightning.callbacks.BaseFinetuning``. + Currently, default strategies can be enabled with these strings: - * ``no_freeze``, - * ``freeze``, - * ``freeze_unfreeze``, - * ``unfreeze_milestones`` + - ``no_freeze``, + - ``freeze``, + - ``freeze_unfreeze``, + - ``unfreeze_milestones`` """ self._resolve_callbacks(model, strategy) diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index 7643e85131..b5202c1661 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -18,7 +18,7 @@ # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) - # 4. Create the trainer. Run once on data + # 4. Create the trainer. Run twice on data trainer = flash.Trainer(max_epochs=2) # 5. Train the model diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/finetuning/text_classification.ipynb index 28cb748793..34411232aa 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/finetuning/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "prerequisite-straight", + "id": "optical-barrel", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "coastal-bible", + "id": "rolled-scoop", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `from pytorch_lightning.callbacks import BaseFinetuning`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "sharp-techno", + "id": "pleasant-benchmark", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "posted-blair", + "id": "suspended-announcement", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "double-swedish", + "id": "appreciated-internship", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "outside-garlic", + "id": "excessive-private", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "tired-lender", + "id": "noted-father", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "daily-marijuana", + "id": "naval-rogers", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "standing-commons", + "id": "monetary-album", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "fantastic-mortality", + "id": "published-vision", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "prompt-azerbaijan", + "id": "focused-claim", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "cubic-crystal", + "id": "primary-battery", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "mineral-phrase", + "id": "great-austria", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "brown-scoop", + "id": "corporate-sequence", "metadata": { "jupyter": { "outputs_hidden": true @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "reliable-hampshire", + "id": "opponent-visit", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "unlimited-duplicate", + "id": "sunrise-questionnaire", "metadata": { "jupyter": { "outputs_hidden": true @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "federal-quarter", + "id": "certain-pizza", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "defensive-committee", + "id": "loose-march", "metadata": { "jupyter": { "outputs_hidden": true @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "disciplinary-background", + "id": "loose-culture", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "increased-filter", + "id": "quarterly-dominican", "metadata": {}, "source": [ "\n", From 7ccfd814d043e7beed668dce3642159c87aaa649 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:04:54 +0000 Subject: [PATCH 15/15] remove set -e --- .github/workflows/ci-notebook.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml index 97b57cb580..daa5ec4e40 100644 --- a/.github/workflows/ci-notebook.yml +++ b/.github/workflows/ci-notebook.yml @@ -57,15 +57,14 @@ jobs: with: path: flash_examples/predict # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file - key: flash-datasets_predict + key: flash-datasets_predict - name: Run Notebooks run: | - set -e jupyter nbconvert --to script flash_notebooks/finetuning/tabular_classification.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_image.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_tabular.ipynb ipython flash_notebooks/finetuning/tabular_classification.py ipython flash_notebooks/predict/classify_image.py - ipython flash_notebooks/predict/classify_tabular.py \ No newline at end of file + ipython flash_notebooks/predict/classify_tabular.py