From f943441222d4833f373a07aeaf2808d2bc4a59be Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 22 Sep 2021 19:40:14 +0100 Subject: [PATCH] [Doc] Add learn2learn integrations documentation (#788) Co-authored-by: Ethan Harris --- docs/source/index.rst | 1 + docs/source/integrations/learn2learn.rst | 81 +++++++++++++++++++ .../image_classification_meta_learning.py | 38 --------- .../image_classification_imagenette_mini.py | 5 +- 4 files changed, 84 insertions(+), 41 deletions(-) create mode 100644 docs/source/integrations/learn2learn.rst delete mode 100644 flash_examples/image_classification_meta_learning.py diff --git a/docs/source/index.rst b/docs/source/index.rst index dc5836be15..be643f9372 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -90,6 +90,7 @@ Lightning Flash integrations/providers integrations/fiftyone + integrations/learn2learn integrations/icevision .. toctree:: diff --git a/docs/source/integrations/learn2learn.rst b/docs/source/integrations/learn2learn.rst new file mode 100644 index 0000000000..18ae188a0a --- /dev/null +++ b/docs/source/integrations/learn2learn.rst @@ -0,0 +1,81 @@ +.. _learn2learn: + +########### +Learn2Learn +########### + +`Learn2Learn `__ is a software library for meta-learning research by `Sébastien M. R. Arnold and al.` (Aug 2020) + +.. raw:: html + +
+ +
+
+ + + +What is Meta-Learning and why you should care? +---------------------------------------------- + +Humans can distinguish between new objects with little or no training data, +However, machine learning models often require thousands, millions, billions of annotated data samples +to achieve good performance while extrapolating their learned knowledge on unseen objects. + +A machine learning model which could learn or learn to learn from only few new samples (K-shot learning) would have tremendous applications +once deployed in production. +In an extreme case, a model performing 1-shot or 0-shot learning could be the source of new kind of AI applications. + +Meta-Learning is a sub-field of AI dedicated to the study of few-shot learning algorithms. +This is often characterized as teaching deep learning models to learn with only a few labeled data. +The goal is to repeatedly learn from K-shot examples during training that match the structure of the final K-shot used in production. +It is important to note that the K-shot example seen in production are very likely to be completely out-of-distribution with new objects. + + +How does Meta-Learning work? +---------------------------- + +In meta-learning, the model is trained over multiple meta tasks. +A meta task is the smallest unit of data and it represents the data available to the model once in its deployment environment. +By doing so, we can optimise the model and get higher results. + +.. raw:: html + +
+ +
+
+ +For image classification, a meta task is comprised of shot + query elements for each class. +The shots samples are used to adapt the parameters and the queries ones to update the original model weights. +The classes used in the validation and testing shouldn't be present within the training dataset, +as the goal is to optimise the model performance on out-of-distribution (OOD) data with little label data. + +When training the model with the meta-learning algorithm, +the model will average its gradients over meta_batch_size meta tasks before performing an optimizer step. +Traditionally, an meta epoch is composed of multiple meta batch. + +Use Meta-Learning with Flash +---------------------------- + +With its integration within Flash, Meta Learning has never been simpler. +Flash takes care of all the hard work: the tasks sampling, meta optimizer update, distributed training, etc... + +.. note:: + + The users requires to provide a training dataset and testing dataset with no overlapping classes. + Flash doesn't support this feature out-of-the box. + +Once done, the users are left to play the hyper-parameters associated with the meta-learning algorithm. + +Here is an example using `miniImageNet dataset `_ containing 100 classes divided into 64 training, 16 validation, and 20 test classes. + +.. literalinclude:: ../../../flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py + :language: python + :lines: 15- + + +You can read their paper `Learn2Learn: A Library for Meta-Learning Research `_. + +And don't forget to cite `Learn2Learn `__ repository in your academic publications. +Find their Biblex on their repository. diff --git a/flash_examples/image_classification_meta_learning.py b/flash_examples/image_classification_meta_learning.py deleted file mode 100644 index 510fe634c9..0000000000 --- a/flash_examples/image_classification_meta_learning.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import flash -from flash.core.data.utils import download_data -from flash.image import ImageClassificationData, ImageClassifier - -# 1. Create the DataModule -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") - -datamodule = ImageClassificationData.from_folders( - train_folder="data/hymenoptera_data/train/", - val_folder="data/hymenoptera_data/val/", -) - -# 2. Build the task -model = ImageClassifier( - backbone="resnet18", - training_strategy="prototypicalnetworks", - training_strategy_kwargs={"ways": datamodule.num_classes, "shots": 4, "meta_batch_size": 10}, -) - -# 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=1) -trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") - -# 5. Save the model! -trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index 5a45199bad..38bd6c2e7e 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -77,10 +77,7 @@ model = ImageClassifier( backbone="resnet18", - pretrained=False, training_strategy="prototypicalnetworks", - optimizer=torch.optim.Adam, - optimizer_kwargs={"lr": 0.001}, training_strategy_kwargs={ "epoch_length": 10 * 16, "meta_batch_size": 4, @@ -92,6 +89,8 @@ "test_shots": 1, "test_queries": 15, }, + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": 0.001}, ) trainer = flash.Trainer(