From edccf3be226ab71c3a80037fc7b81e5e2322f60f Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 18:31:08 +0000 Subject: [PATCH] improve finetuning (#39) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * improve finetuning * update changelog * update on comments * typo * update on comments * update on comments * update finetuning * typo * update * update * update notebooks * update typo * Update flash_notebooks/finetuning/image_classification.ipynb Co-authored-by: Carlos Mocholí * resolve comments * remove set -e Co-authored-by: Carlos Mocholí --- .github/workflows/ci-notebook.yml | 5 +- CHANGELOG.md | 9 +- README.md | 16 +-- flash/core/finetuning.py | 135 ++++++++++++++++++ flash/core/model.py | 3 + flash/core/trainer.py | 122 ++++++++-------- .../finetuning/image_classification.py | 7 +- .../finetuning/text_classification.py | 2 +- .../finetuning/image_classification.ipynb | 47 +++--- .../finetuning/text_classification.ipynb | 50 +++---- tests/core/test_finetuning.py | 40 ++++++ tests/core/test_trainer.py | 2 +- 12 files changed, 313 insertions(+), 125 deletions(-) create mode 100644 flash/core/finetuning.py create mode 100644 tests/core/test_finetuning.py diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml index 97b57cb580..daa5ec4e40 100644 --- a/.github/workflows/ci-notebook.yml +++ b/.github/workflows/ci-notebook.yml @@ -57,15 +57,14 @@ jobs: with: path: flash_examples/predict # This path is specific to Ubuntu # Look to see if there is a cache hit for the corresponding requirements file - key: flash-datasets_predict + key: flash-datasets_predict - name: Run Notebooks run: | - set -e jupyter nbconvert --to script flash_notebooks/finetuning/tabular_classification.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_image.ipynb jupyter nbconvert --to script flash_notebooks/predict/classify_tabular.ipynb ipython flash_notebooks/finetuning/tabular_classification.py ipython flash_notebooks/predict/classify_image.py - ipython flash_notebooks/predict/classify_tabular.py \ No newline at end of file + ipython flash_notebooks/predict/classify_tabular.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a589789b7..87b983c22e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,11 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) +- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/lightning-flash/pull/9)) + + +- Added `strategy` to `trainer.finetune` with `NoFreeze`, `Freeze`, `FreezeUnfreeze`, `UnfreezeMilestones` Callbacks([#39](https://github.com/PyTorchLightning/lightning-flash/pull/39)) + ### Changed + ### Fixed -### Removed \ No newline at end of file +### Removed diff --git a/README.md b/README.md index dfad959328..6253827031 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ model = ImageClassifier(num_classes=datamodule.num_classes) trainer = flash.Trainer(max_epochs=1) # 5. Finetune the model -trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 7. Save it! trainer.save_checkpoint("image_classification_model.pt") @@ -151,13 +151,13 @@ Flash is built as a collection of community-built tasks. A task is highly opinio ### Example 1: Image classification Flash has an ImageClassification task to tackle any image classification problem. - +
View example To illustrate, Let's say we wanted to develop a model that could classify between ants and bees. - + - + Here we classify ants vs bees. ```python @@ -208,7 +208,7 @@ Flash has a TextClassification task to tackle any text classification problem.
View example To illustrate, say you wanted to classify movie reviews as positive or negative. - + ```python import flash from flash import download_data @@ -261,9 +261,9 @@ Flash has a TabularClassification task to tackle any tabular classification prob
View example - - To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. - + + To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. + ```python from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall import flash diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py new file mode 100644 index 0000000000..c68f52fde0 --- /dev/null +++ b/flash/core/finetuning.py @@ -0,0 +1,135 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import BaseFinetuning +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.optim import Optimizer + + +class FlashBaseFinetuning(BaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + r""" + + FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. + + Override ``finetunning_function`` to put your unfreeze logic. + + Args: + attr_names: Name(s) of the module attributes of the model to be frozen. + + train_bn: Wether to train Batch Norm layer + + """ + + self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names + self.train_bn = train_bn + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn) + + def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bool = True): + for attr_name in attr_names: + attr = getattr(pl_module, attr_name, None) + if attr is None or not isinstance(attr, nn.Module): + MisconfigurationException(f"Your model must have a {attr} attribute") + self.freeze(module=attr, train_bn=train_bn) + + +class FreezeUnfreeze(FlashBaseFinetuning): + + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): + super().__init__(attr_names, train_bn) + self.unfreeze_epoch = unfreeze_epoch + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + if epoch == self.unfreeze_epoch: + modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names] + self.unfreeze_and_add_param_group( + module=modules, + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +class UnfreezeMilestones(FlashBaseFinetuning): + + def __init__( + self, + attr_names: Union[str, List[str]] = "backbone", + train_bn: bool = True, + unfreeze_milestones: tuple = (5, 10), + num_layers: int = 5 + ): + self.unfreeze_milestones = unfreeze_milestones + self.num_layers = num_layers + + super().__init__(attr_names, train_bn) + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + backbone_modules = list(pl_module.backbone.modules()) + if epoch == self.unfreeze_milestones[0]: + # unfreeze num_layers last layers + self.unfreeze_and_add_param_group( + module=backbone_modules[-self.num_layers:], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + elif epoch == self.unfreeze_milestones[1]: + # unfreeze remaining layers + self.unfreeze_and_add_param_group( + module=backbone_modules[:-self.num_layers], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +_DEFAULTS_FINETUNE_STRATEGIES = { + "no_freeze": BaseFinetuning, + "freeze": FlashBaseFinetuning, + "freeze_unfreeze": FreezeUnfreeze, + "unfreeze_milestones": UnfreezeMilestones +} + + +def instantiate_default_finetuning_callbacks(strategy): + if strategy is None: + strategy = "no_freeze" + rank_zero_warn("strategy is None. Setting strategy to `no_freeze` by default.", UserWarning) + if isinstance(strategy, str): + strategy = strategy.lower() + if strategy in _DEFAULTS_FINETUNE_STRATEGIES: + return [_DEFAULTS_FINETUNE_STRATEGIES[strategy]()] + raise MisconfigurationException( + f"strategy should be within {list(_DEFAULTS_FINETUNE_STRATEGIES)}" + f". Found {strategy}" + ) + return [] diff --git a/flash/core/model.py b/flash/core/model.py index 51b1a87d12..3607878ac8 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -150,3 +150,6 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: checkpoint["pipeline"] = self.data_pipeline + + def configure_finetune_callback(self): + return [] diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 7ce9f8cf75..e570d4ae2d 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -16,54 +16,12 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import BaseFinetuning -from torch.optim import Optimizer +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader - -# NOTE: copied from: -# https://github.com/PyTorchLightning/pytorch-lightning/blob/9d165f6f5655a44f1e5cd02ab36f21bc14e2a604/pl_examples/domain_templates/computer_vision_fine_tuning.py#L66 -class MilestonesFinetuningCallback(BaseFinetuning): - - def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True): - self.milestones = milestones - self.train_bn = train_bn - - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - # TODO: might need some config to say which attribute is model - # maybe something like: - # self.freeze(module=pl_module.getattr(self.feature_attr), train_bn=self.train_bn) - # where self.feature_attr can be "backbone" or "feature_extractor", etc. - # (configured in init) - assert hasattr( - pl_module, "backbone" - ), "To use MilestonesFinetuningCallback your model must have a backbone attribute" - self.freeze(module=pl_module.backbone, train_bn=self.train_bn) - - def finetunning_function( - self, - pl_module: pl.LightningModule, - epoch: int, - optimizer: Optimizer, - opt_idx: int, - ) -> None: - backbone_modules = list(pl_module.backbone.modules()) - if epoch == self.milestones[0]: - # unfreeze 5 last layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[-5:], - optimizer=optimizer, - train_bn=self.train_bn, - ) - - elif epoch == self.milestones[1]: - # unfreeze remaining layers - # TODO last N layers should be parameter - self.unfreeze_and_add_param_group( - module=backbone_modules[:-5], - optimizer=optimizer, - train_bn=self.train_bn, - ) +from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks class Trainer(pl.Trainer): @@ -96,13 +54,14 @@ def fit( def finetune( self, - model: pl.LightningModule, + model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[pl.LightningDataModule] = None, - unfreeze_milestones: tuple = (5, 10), + strategy: Optional[Union[str, BaseFinetuning]] = None, ): r""" + Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training. @@ -117,18 +76,61 @@ def finetune( val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - unfreeze_milestones: A tuple of two integers. First value marks the epoch in which the last 5 - layers of the backbone will be unfrozen. The second value marks the epoch in which the full backbone will - be unfrozen. + strategy: Should either be a string or a finetuning callback subclassing + ``pytorch_lightning.callbacks.BaseFinetuning``. - """ - if hasattr(model, "backbone"): - # TODO: if we find a finetuning callback in the trainer should we change it? - # or should we warn the user? - if not any(isinstance(c, BaseFinetuning) for c in self.callbacks): - # TODO: should pass config from arguments - self.callbacks.append(MilestonesFinetuningCallback(milestones=unfreeze_milestones)) - else: - warnings.warn("Warning: model does not have a 'backbone' attribute, will train normally") + Currently, default strategies can be enabled with these strings: + - ``no_freeze``, + - ``freeze``, + - ``freeze_unfreeze``, + - ``unfreeze_milestones`` + """ + self._resolve_callbacks(model, strategy) return super().fit(model, train_dataloader, val_dataloaders, datamodule) + + def _resolve_callbacks(self, model, strategy): + """ + This function is used to select the `BaseFinetuning` to be used for finetuning. + """ + if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): + raise MisconfigurationException( + "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" + f"callback or a str within {list(_DEFAULTS_FINETUNE_STRATEGIES.keys())}" + ) + + if isinstance(strategy, BaseFinetuning): + callback = strategy + else: + # todo: change to ``configure_callbacks`` when merged to Lightning. + model_callback = model.configure_finetune_callback() + if len(model_callback) > 1: + raise MisconfigurationException( + f"{model} configure_finetune_callback should create a list with only 1 callback" + ) + if len(model_callback) == 1: + if strategy is not None: + rank_zero_warn( + "The model contains a default finetune callback. " + f"The provided {strategy} will be overriden. " + "HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning + ) + callback = [model_callback] + else: + callback = instantiate_default_finetuning_callbacks(strategy) + + self.callbacks = self._merge_callbacks(self.callbacks, [callback]) + + @staticmethod + def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: + """ + This function keeps only 1 instance of each callback type, + extending new_callbacks with old_callbacks + """ + if len(new_callbacks): + return old_callbacks + new_callbacks_types = set(type(c) for c in new_callbacks) + old_callbacks_types = set(type(c) for c in old_callbacks) + override_types = new_callbacks_types.intersection(old_callbacks_types) + new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) + return new_callbacks diff --git a/flash_examples/finetuning/image_classification.py b/flash_examples/finetuning/image_classification.py index ef8dac9aa8..b5202c1661 100644 --- a/flash_examples/finetuning/image_classification.py +++ b/flash_examples/finetuning/image_classification.py @@ -1,5 +1,6 @@ import flash from flash.core.data import download_data +from flash.core.finetuning import FreezeUnfreeze from flash.vision import ImageClassificationData, ImageClassifier if __name__ == "__main__": @@ -17,11 +18,11 @@ # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) - # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) + # 4. Create the trainer. Run twice on data + trainer = flash.Trainer(max_epochs=2) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) # 6. Test the model trainer.test() diff --git a/flash_examples/finetuning/text_classification.py b/flash_examples/finetuning/text_classification.py index 2c0ff4b3f9..dd07f46bf9 100644 --- a/flash_examples/finetuning/text_classification.py +++ b/flash_examples/finetuning/text_classification.py @@ -24,7 +24,7 @@ trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule, strategy='freeze') # 6. Test model trainer.test() diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/finetuning/image_classification.ipynb index 1ee71ec15e..4cd82ec404 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/finetuning/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "dominican-savings", + "id": "thousand-manufacturer", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "threaded-coffee", + "id": "smoking-probe", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetuning an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", @@ -27,7 +27,9 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to properly distinguish ants and bees. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", + " \n", + " \n", "\n", " \n", "\n", @@ -41,7 +43,7 @@ { "cell_type": "code", "execution_count": null, - "id": "handmade-timing", + "id": "thermal-fraction", "metadata": {}, "outputs": [], "source": [ @@ -52,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "through-edwards", + "id": "cognitive-haven", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "hybrid-adapter", + "id": "afraid-straight", "metadata": {}, "source": [ "## 1. Download data\n", @@ -73,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "amateur-disposal", + "id": "advisory-narrow", "metadata": {}, "outputs": [], "source": [ @@ -82,7 +84,7 @@ }, { "cell_type": "markdown", - "id": "front-metallic", + "id": "trying-group", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "hazardous-means", + "id": "stuck-composition", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "defined-mouse", + "id": "irish-scenario", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -130,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "internal-playback", + "id": "opening-nomination", "metadata": {}, "outputs": [], "source": [ @@ -139,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "million-tower", + "id": "breathing-element", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -156,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "centered-paris", + "id": "earlier-jordan", "metadata": {}, "outputs": [], "source": [ @@ -165,26 +167,25 @@ }, { "cell_type": "markdown", - "id": "special-fence", + "id": "extreme-scene", "metadata": {}, "source": [ - "### 5. Finetune the model\n", - "The `unfreeze_milestones=(0, 1)` will unfreeze the latest layers of the backbone on epoch `0` and the rest of the backbone on epoch `1`. " + "### 5. Finetune the model" ] }, { "cell_type": "code", "execution_count": null, - "id": "local-taylor", + "id": "tired-underground", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"freeze_unfreeze\")" ] }, { "cell_type": "markdown", - "id": "municipal-kentucky", + "id": "smooth-european", "metadata": {}, "source": [ "### 6. Test the model" @@ -193,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "simplified-bundle", + "id": "sexual-tender", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "familiar-territory", + "id": "athletic-nutrition", "metadata": {}, "source": [ "### 7. Save it!" @@ -211,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "injured-mineral", + "id": "pleasant-canon", "metadata": {}, "outputs": [], "source": [ @@ -220,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "tested-experience", + "id": "incident-basket", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/finetuning/text_classification.ipynb index 8079575e0e..34411232aa 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/finetuning/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "digital-quilt", + "id": "optical-barrel", "metadata": {}, "source": [ "
\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "empty-request", + "id": "rolled-scoop", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as IMDB Dataset to learn to predict the associated sentiment of movie reviews. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to between negative and positive reviews. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `unfreeze_milestones` controls those milestone and be used as such `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "another-might", + "id": "pleasant-benchmark", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ideal-summary", + "id": "suspended-announcement", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "straight-commission", + "id": "appreciated-internship", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "classical-snake", + "id": "excessive-private", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "lined-standing", + "id": "noted-father", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "endangered-heavy", + "id": "naval-rogers", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "posted-chosen", + "id": "monetary-album", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "cognitive-compact", + "id": "published-vision", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "underlying-liberia", + "id": "focused-claim", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "democratic-interaction", + "id": "primary-battery", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "adopted-caution", + "id": "great-austria", "metadata": {}, "outputs": [], "source": [ @@ -169,29 +169,31 @@ }, { "cell_type": "markdown", - "id": "integral-access", + "id": "corporate-sequence", "metadata": { "jupyter": { "outputs_hidden": true } }, "source": [ - "### 5. Fine-tune the model" + "### 5. Fine-tune the model\n", + "\n", + "The backbone won't be freezed and the entire model will be finetuned on the imdb dataset " ] }, { "cell_type": "code", "execution_count": null, - "id": "enormous-botswana", + "id": "opponent-visit", "metadata": {}, "outputs": [], "source": [ - "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + "trainer.finetune(model, datamodule=datamodule, strategy=\"no_freeze\")" ] }, { "cell_type": "markdown", - "id": "cellular-baking", + "id": "sunrise-questionnaire", "metadata": { "jupyter": { "outputs_hidden": true @@ -204,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "demanding-headline", + "id": "certain-pizza", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "charged-investigator", + "id": "loose-march", "metadata": { "jupyter": { "outputs_hidden": true @@ -226,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "early-ridge", + "id": "loose-culture", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "detailed-direction", + "id": "quarterly-dominican", "metadata": {}, "source": [ "\n", @@ -296,4 +298,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py new file mode 100644 index 0000000000..e4062eac59 --- /dev/null +++ b/tests/core/test_finetuning.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn +from torch.nn import functional as F + +from flash import ClassificationTask, Trainer +from flash.core.finetuning import FlashBaseFinetuning +from tests.core.test_model import DummyDataset + + +@pytest.mark.parametrize( + "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] +) +def test_finetuning(tmpdir: str, strategy): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + val_dl = torch.utils.data.DataLoader(DummyDataset()) + task = ClassificationTask(model, F.nll_loss) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) + if strategy == "cls": + strategy = FlashBaseFinetuning() + if strategy == 'chocolat': + with pytest.raises(MisconfigurationException, match="strategy should be within"): + trainer.finetune(task, train_dl, val_dl, strategy=strategy) + else: + trainer.finetune(task, train_dl, val_dl, strategy=strategy) diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index f872de5e55..e0f63f19f2 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -51,5 +51,5 @@ def test_task_finetune(tmpdir: str): val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - result = trainer.finetune(task, train_dl, val_dl, unfreeze_milestones=(0, 0)) + result = trainer.finetune(task, train_dl, val_dl) assert result