This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,754 additions
and
1,187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Changelog | ||
|
||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). | ||
|
||
|
||
## [0.1.0] - 01/02/2021 | ||
|
||
### Added | ||
|
||
- Added flash_notebook examples ([#9](https://github.com/PyTorchLightning/pytorch-lightning/pull/9)) | ||
|
||
### Changed | ||
|
||
### Fixed | ||
|
||
|
||
### Removed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "purple-muscle", | ||
"metadata": {}, | ||
"source": [ | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PyTorchLightning/lightning-flash/blob/master/flash_notebooks/finetuning/image_classification.ipynb)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "usual-israeli", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook, we'll go over the basics of lightning Flash by finetuning an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", | ||
"\n", | ||
"# Finetuning\n", | ||
"\n", | ||
"Finetuning consists of four steps:\n", | ||
" \n", | ||
" - 1. Train a source neural network model on a source dataset. For computer vision, it is traditionally the [ImageNet dataset](http://www.image-net.org/search?q=cat). As training is costly, library such as [Torchvion](https://pytorch.org/docs/stable/torchvision/index.html) library supports popular pre-trainer model architectures . In this notebook, we will be using their [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).\n", | ||
" \n", | ||
" - 2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone\n", | ||
" \n", | ||
" - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", | ||
" \n", | ||
" - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. At training start, the backbone will be frozen, meaning its parameters won't be updated. Only the model head will be trained to properly distinguish ants and bees. On reaching first finetuning milestone, the backbone latest layers will be unfrozen and start to be trained. On reaching the second finetuning milestone, the remaining layers of the backend will be unfrozen and the entire model will be trained. In Flash, `trainer.finetune(..., unfreeze_milestones=(first_milestone, second_milestone))`.\n", | ||
"\n", | ||
" \n", | ||
"\n", | ||
"---\n", | ||
" - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", | ||
" - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", | ||
" - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", | ||
" - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "sapphire-counter", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%capture\n", | ||
"! pip install lightning-flash" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "chubby-incidence", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import flash\n", | ||
"from flash.core.data import download_data\n", | ||
"from flash.vision import ImageClassificationData, ImageClassifier" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "central-netscape", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Download data\n", | ||
"The data are downloaded from a URL, and save in a 'data' directory." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "through-munich", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "chief-footwear", | ||
"metadata": {}, | ||
"source": [ | ||
"<h2>2. Load the data</h2>\n", | ||
"\n", | ||
"Flash Tasks have built-in DataModules that you can abuse to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", | ||
"Creates a ImageClassificationData object from folders of images arranged in this way:</h4>\n", | ||
"\n", | ||
"\n", | ||
" train/dog/xxx.png\n", | ||
" train/dog/xxy.png\n", | ||
" train/dog/xxz.png\n", | ||
" train/cat/123.png\n", | ||
" train/cat/nsdf3.png\n", | ||
" train/cat/asd932.png\n", | ||
"\n", | ||
"\n", | ||
"Note: Each sub-folder content will be considered as a new class." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "helpful-glass", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"datamodule = ImageClassificationData.from_folders(\n", | ||
" train_folder=\"data/hymenoptera_data/train/\",\n", | ||
" valid_folder=\"data/hymenoptera_data/val/\",\n", | ||
" test_folder=\"data/hymenoptera_data/test/\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "extraordinary-tablet", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3. Build the model\n", | ||
"Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.\n", | ||
"For [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images, ``datamodule.num_classes`` will be 2.\n", | ||
"Backbone can easily be changed with `ImageClassifier(backbone=\"resnet50\")` or you could provide your own `ImageClassifier(backbone=my_backbone)`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "adjusted-acrobat", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = ImageClassifier(num_classes=datamodule.num_classes)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "sweet-pottery", | ||
"metadata": {}, | ||
"source": [ | ||
"### 4. Create the trainer. Run once on data\n", | ||
"\n", | ||
"The trainer object can be used for training or fine-tuning tasks on new sets of data. \n", | ||
"\n", | ||
"You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.\n", | ||
"\n", | ||
"For more details, read the [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).\n", | ||
"\n", | ||
"In this demo, we will limit the fine-tuning to run just one epoch using max_epochs=2." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "molecular-string", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer = flash.Trainer(max_epochs=3)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "criminal-string", | ||
"metadata": {}, | ||
"source": [ | ||
"### 5. Finetune the model\n", | ||
"The `unfreeze_milestones=(0, 1)` will unfreeze the latest layers of the backbone on epoch `0` and the rest of the backbone on epoch `1`. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "documentary-donna", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "civic-wednesday", | ||
"metadata": {}, | ||
"source": [ | ||
"### 6. Test the model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "public-regard", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.test()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "above-dietary", | ||
"metadata": {}, | ||
"source": [ | ||
"### 7. Save it!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "canadian-nudist", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer.save_checkpoint(\"image_classification_model.pt\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "worthy-february", | ||
"metadata": {}, | ||
"source": [ | ||
"<code style=\"color:#792ee5;\">\n", | ||
" <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n", | ||
"</code>\n", | ||
"\n", | ||
"Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", | ||
"\n", | ||
"### Help us build Flash by adding support for new data-types and new tasks.\n", | ||
"Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", | ||
"If you are interested, please open a PR with your contributions !!! \n", | ||
"\n", | ||
"\n", | ||
"### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", | ||
"The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", | ||
"\n", | ||
"* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", | ||
"\n", | ||
"### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", | ||
"The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", | ||
"\n", | ||
"### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
"Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", | ||
"\n", | ||
"* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
"\n", | ||
"### Contributions !\n", | ||
"The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", | ||
"\n", | ||
"* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
"* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
"* You can also contribute your own notebooks with useful examples !\n", | ||
"\n", | ||
"### Great thanks from the entire Pytorch Lightning Team for your interest !\n", | ||
"\n", | ||
"<img src=\"https://github.com/PyTorchLightning/lightning-flash/blob/master/docs/source/_images/flash_logo.png?raw=true\" width=\"800\" height=\"200\" />" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.