From 48ec088835a7f877c1e14a8023626db6f3edba76 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 19:07:18 +0000 Subject: [PATCH] Notebooks update (#41) * improve finetuning * update changelog * update on comments * typo * update on comments * update on comments * update finetuning * typo * update * update * update notebooks * update typo * update notebooks * update ci --- .github/workflows/ci-notebook.yml | 10 +- flash_notebooks/generic_task.ipynb | 463 ++++++++++++++++-- .../image_classification.ipynb | 42 +- flash_notebooks/predict/classify_image.ipynb | 227 --------- .../predict/classify_tabular.ipynb | 184 ------- flash_notebooks/predict/classify_text.ipynb | 233 --------- .../tabular_classification.ipynb | 105 +++- .../text_classification.ipynb | 42 +- 8 files changed, 562 insertions(+), 744 deletions(-) rename flash_notebooks/{finetuning => }/image_classification.ipynb (92%) delete mode 100644 flash_notebooks/predict/classify_image.ipynb delete mode 100644 flash_notebooks/predict/classify_tabular.ipynb delete mode 100644 flash_notebooks/predict/classify_text.ipynb rename flash_notebooks/{finetuning => }/tabular_classification.ipynb (78%) rename flash_notebooks/{finetuning => }/text_classification.ipynb (94%) diff --git a/.github/workflows/ci-notebook.yml b/.github/workflows/ci-notebook.yml index daa5ec4e40..fce2cf21b8 100644 --- a/.github/workflows/ci-notebook.yml +++ b/.github/workflows/ci-notebook.yml @@ -61,10 +61,8 @@ jobs: - name: Run Notebooks run: | - jupyter nbconvert --to script flash_notebooks/finetuning/tabular_classification.ipynb - jupyter nbconvert --to script flash_notebooks/predict/classify_image.ipynb - jupyter nbconvert --to script flash_notebooks/predict/classify_tabular.ipynb + jupyter nbconvert --to script flash_notebooks/image_classification.ipynb + jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb - ipython flash_notebooks/finetuning/tabular_classification.py - ipython flash_notebooks/predict/classify_image.py - ipython flash_notebooks/predict/classify_tabular.py + ipython flash_notebooks/image_classification.py + ipython flash_notebooks/tabular_classification.py diff --git a/flash_notebooks/generic_task.ipynb b/flash_notebooks/generic_task.ipynb index f7886b39e0..f0133b429c 100644 --- a/flash_notebooks/generic_task.ipynb +++ b/flash_notebooks/generic_task.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "improved-texture", + "id": "unlike-price", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "integral-japanese", + "id": "requested-hostel", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by creating a ClassificationTask with a custom Convolutional Model and train it on [MNIST Dataset](http://yann.lecun.com/exdb/mnist/)\n", @@ -24,10 +24,18 @@ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" ] }, + { + "cell_type": "markdown", + "id": "historic-cowboy", + "metadata": {}, + "source": [ + "# Training" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "serial-penetration", + "id": "intermediate-rebecca", "metadata": {}, "outputs": [], "source": [ @@ -37,10 +45,18 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "tough-american", + "execution_count": 1, + "id": "proof-plenty", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PyTorch version 1.7.1 available.\n" + ] + } + ], "source": [ "import pytorch_lightning as pl\n", "from torch import nn, optim\n", @@ -52,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "regulated-memory", + "id": "smart-factor", "metadata": {}, "source": [ "### 1. Load a basic backbone" @@ -60,8 +76,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "external-indonesia", + "execution_count": 2, + "id": "entitled-opera", "metadata": {}, "outputs": [], "source": [ @@ -75,7 +91,7 @@ }, { "cell_type": "markdown", - "id": "artificial-carpet", + "id": "restricted-tooth", "metadata": {}, "source": [ "### 2. Load a dataset" @@ -83,17 +99,122 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "sensitive-inquiry", + "execution_count": 3, + "id": "polish-duncan", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cef56b96ec38400b8d8b2acadcd20f58", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7b13a282cce148f9ac5946aa949935ec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b79c71d03464940a48de53bda953f82", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6d1a06fc73f34a988a333f386d4fed67", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", + "Processing...\n", + "Done!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:141.)\n", + " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" + ] + } + ], "source": [ "dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "markdown", - "id": "hollow-deployment", + "id": "starting-edmonton", "metadata": {}, "source": [ "### 3. Split the data randomly" @@ -101,17 +222,17 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "weighted-mathematics", + "execution_count": 12, + "id": "indonesian-arlington", "metadata": {}, "outputs": [], "source": [ - "train, val, test = random_split(dataset, [50000, 5000, 5000])" + "train, val, test, predict = random_split(dataset, [50000, 5000, 4975, 25])" ] }, { "cell_type": "markdown", - "id": "changed-calculator", + "id": "configured-bones", "metadata": {}, "source": [ "### 4. Create the model" @@ -119,8 +240,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "hungarian-beads", + "execution_count": 5, + "id": "fleet-breast", "metadata": {}, "outputs": [], "source": [ @@ -129,7 +250,7 @@ }, { "cell_type": "markdown", - "id": "mature-trance", + "id": "vulnerable-shirt", "metadata": {}, "source": [ "### 5. Create the trainer" @@ -137,10 +258,19 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "designed-community", + "execution_count": 6, + "id": "assigned-bahamas", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n" + ] + } + ], "source": [ "trainer = pl.Trainer(\n", " max_epochs=10,\n", @@ -151,7 +281,7 @@ }, { "cell_type": "markdown", - "id": "wound-species", + "id": "unavailable-sodium", "metadata": {}, "source": [ "### 6. Train the model" @@ -159,17 +289,222 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "commercial-empty", + "execution_count": 7, + "id": "grave-complaint", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "---------------------------------------\n", + "0 | model | Sequential | 101 K \n", + "1 | metrics | ModuleDict | 0 \n", + "---------------------------------------\n", + "101 K Trainable params\n", + "0 Non-trainable params\n", + "101 K Total params\n", + "0.407 Total estimated model params size (MB)\n", + "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation sanity check: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "95bfb98e9b7b4b5ca6d8cb94f3825e15", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "trainer.fit(classifier, DataLoader(train), DataLoader(val))" ] }, { "cell_type": "markdown", - "id": "great-intent", + "id": "excellent-detail", "metadata": {}, "source": [ "### 7. Test the model" @@ -177,17 +512,79 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "extreme-principal", + "execution_count": 8, + "id": "sized-string", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0c7202f53c22480a8ec2b9bbdbdde631", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "DATALOADER:0 TEST RESULTS\n", + "{'test_cross_entropy': 1.5057185888290405}\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], "source": [ "results = trainer.test(classifier, test_dataloaders=DataLoader(test))" ] }, { "cell_type": "markdown", - "id": "bulgarian-cursor", + "id": "endless-contrary", + "metadata": {}, + "source": [ + "# Predicting" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "informative-arnold", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[9, 1, 8, 5, 5, 6, 6, 3, 3, 5, 5, 3, 8, 1, 2, 7, 3, 9, 8, 1, 4, 3, 8, 0, 3]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classifier.predict(predict)" + ] + }, + { + "cell_type": "markdown", + "id": "dimensional-breakfast", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/finetuning/image_classification.ipynb b/flash_notebooks/image_classification.ipynb similarity index 92% rename from flash_notebooks/finetuning/image_classification.ipynb rename to flash_notebooks/image_classification.ipynb index 4cd82ec404..ea39880573 100644 --- a/flash_notebooks/finetuning/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "thousand-manufacturer", + "id": "western-queue", "metadata": {}, "source": [ "\n", @@ -12,10 +12,10 @@ }, { "cell_type": "markdown", - "id": "smoking-probe", + "id": "democratic-alpha", "metadata": {}, "source": [ - "In this notebook, we'll go over the basics of lightning Flash by finetuning an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", + "In this notebook, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", "\n", "# Finetuning\n", "\n", @@ -27,7 +27,7 @@ " \n", " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", - " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", + " - 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.\n", " \n", " \n", "\n", @@ -43,7 +43,7 @@ { "cell_type": "code", "execution_count": null, - "id": "thermal-fraction", + "id": "parallel-integrity", "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cognitive-haven", + "id": "worth-wealth", "metadata": {}, "outputs": [], "source": [ @@ -65,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "afraid-straight", + "id": "frequent-memorial", "metadata": {}, "source": [ "## 1. Download data\n", @@ -75,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "advisory-narrow", + "id": "planned-greene", "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "markdown", - "id": "trying-group", + "id": "changed-perry", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -107,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "stuck-composition", + "id": "synthetic-hamburg", "metadata": {}, "outputs": [], "source": [ @@ -120,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "irish-scenario", + "id": "common-testing", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "opening-nomination", + "id": "religious-pasta", "metadata": {}, "outputs": [], "source": [ @@ -141,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "breathing-element", + "id": "accurate-thread", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -158,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "earlier-jordan", + "id": "rural-silly", "metadata": {}, "outputs": [], "source": [ @@ -167,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "extreme-scene", + "id": "trained-unemployment", "metadata": {}, "source": [ "### 5. Finetune the model" @@ -176,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "tired-underground", + "id": "bound-printer", "metadata": {}, "outputs": [], "source": [ @@ -185,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "smooth-european", + "id": "european-incentive", "metadata": {}, "source": [ "### 6. Test the model" @@ -194,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "sexual-tender", + "id": "ceramic-dress", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "athletic-nutrition", + "id": "cheap-residence", "metadata": {}, "source": [ "### 7. Save it!" @@ -212,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "pleasant-canon", + "id": "micro-favor", "metadata": {}, "outputs": [], "source": [ @@ -221,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "incident-basket", + "id": "associate-demonstration", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/predict/classify_image.ipynb b/flash_notebooks/predict/classify_image.ipynb deleted file mode 100644 index 6e91662d00..0000000000 --- a/flash_notebooks/predict/classify_image.ipynb +++ /dev/null @@ -1,227 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "renewable-brooks", - "metadata": {}, - "source": [ - "
\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "norwegian-ivory", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash for making predictions with ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", - " - Find finetuning notebook used to generate the weights \n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "rural-gabriel", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install lightning-flash" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "tight-villa", - "metadata": {}, - "outputs": [], - "source": [ - "from flash import Trainer\n", - "from flash.core.data import download_data\n", - "from flash.vision import ImageClassificationData, ImageClassifier" - ] - }, - { - "cell_type": "markdown", - "id": "prescribed-works", - "metadata": {}, - "source": [ - "### 1. Download the data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adequate-population", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "latest-designer", - "metadata": {}, - "source": [ - "### 2. Load the model from a checkpoint\n", - "\n", - "`ImageClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "liked-stomach", - "metadata": {}, - "outputs": [], - "source": [ - "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/image_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "funky-concentrate", - "metadata": {}, - "source": [ - "### 3a. Predict what's on a few images! ants or bees?\n", - "\n", - "`ImageClassifier.predict` supports a list of image paths to make an inference on." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "instructional-purse", - "metadata": {}, - "outputs": [], - "source": [ - "predictions = model.predict([\n", - " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", - " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", - " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "returning-richmond", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "competent-diagram", - "metadata": {}, - "source": [ - "### 3b. Or generate predictions with a whole folder!\n", - "\n", - "For scaling for inference on 32 gpus, it is as simple as `Trainer(num_gpus=32).predict(...)`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "numeric-torture", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = ImageClassificationData.from_folder(folder=\"data/hymenoptera_data/predict/\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "junior-blink", - "metadata": {}, - "outputs": [], - "source": [ - "predictions = Trainer().predict(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "third-charlotte", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "devoted-injection", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/flash_notebooks/predict/classify_tabular.ipynb b/flash_notebooks/predict/classify_tabular.ipynb deleted file mode 100644 index 9da5e6be09..0000000000 --- a/flash_notebooks/predict/classify_tabular.ipynb +++ /dev/null @@ -1,184 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "impaired-nightlife", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "developed-farming", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash for making predictions with TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", - " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", - " - Find finetuning notebook used to generate the weights \n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "linear-consent", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install lightning-flash" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "skilled-fifty", - "metadata": {}, - "outputs": [], - "source": [ - "from flash.core.data import download_data\n", - "from flash.tabular import TabularClassifier" - ] - }, - { - "cell_type": "markdown", - "id": "excessive-temperature", - "metadata": {}, - "source": [ - "### 1. Download the data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "younger-material", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/titanic.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "military-supplier", - "metadata": {}, - "source": [ - "### 2. Load the model from a checkpoint\n", - "\n", - "`TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "guilty-relation", - "metadata": {}, - "outputs": [], - "source": [ - "model = TabularClassifier.load_from_checkpoint(\n", - " \"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "cheap-regular", - "metadata": {}, - "source": [ - "### 3. Generate predictions from a sheet file! Who would survive?\n", - "\n", - "`TabularClassifier.predict` support both DataFrame and path to `.csv` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "forward-store", - "metadata": {}, - "outputs": [], - "source": [ - "predictions = model.predict(\"data/titanic/titanic.csv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "healthy-supervision", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "atlantic-catalyst", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/flash_notebooks/predict/classify_text.ipynb b/flash_notebooks/predict/classify_text.ipynb deleted file mode 100644 index d6c5b92eb0..0000000000 --- a/flash_notebooks/predict/classify_text.ipynb +++ /dev/null @@ -1,233 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "prime-swedish", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "rubber-north", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash for making predictions with TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).(https://www.kaggle.com/ajayrana/hymenoptera-data).\n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n", - " - Find finetuning notebook used to generate the weights \n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "floppy-horror", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install lightning-flash" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "western-stick", - "metadata": {}, - "outputs": [], - "source": [ - "from pytorch_lightning import Trainer\n", - "\n", - "from flash.core.data import download_data\n", - "from flash.text import TextClassificationData, TextClassifier" - ] - }, - { - "cell_type": "markdown", - "id": "cubic-directive", - "metadata": {}, - "source": [ - "### 1. Download the data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "partial-colonial", - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/imdb.zip\", 'data/')" - ] - }, - { - "cell_type": "markdown", - "id": "frozen-performance", - "metadata": {}, - "source": [ - "### 2. Load the model from a checkpoint\n", - "\n", - "`TextClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "continental-haven", - "metadata": {}, - "outputs": [], - "source": [ - "model = TextClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/text_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "id": "gorgeous-effect", - "metadata": {}, - "source": [ - "### 2a. Classify a few sentences! How was the movie?\n", - "\n", - "The model can perform sentimennt predictions directly from a list of sentences." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "silent-atmosphere", - "metadata": {}, - "outputs": [], - "source": [ - "predictions = model.predict([\n", - " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", - " \"The worst movie in the history of cinema.\",\n", - " \"I come from Bulgaria where it 's almost impossible to have a tornado.\"\n", - " \"Very, very afraid\"\n", - " \"This guy has done a great job with this movie!\",\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "earned-country", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "adolescent-taxation", - "metadata": {}, - "source": [ - "### 2b. Or generate predictions from a sheet file!\n", - "\n", - "For scaling for inference on 32 gpus, it is as simple as `Trainer(num_gpus=32).predict(...)`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "standing-stamp", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TextClassificationData.from_file(\n", - " predict_file=\"data/imdb/predict.csv\",\n", - " input=\"review\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "czech-essay", - "metadata": {}, - "outputs": [], - "source": [ - "predictions = Trainer().predict(model, datamodule=datamodule)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "impossible-sharp", - "metadata": {}, - "outputs": [], - "source": [ - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "metric-coaching", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/flash_notebooks/finetuning/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb similarity index 78% rename from flash_notebooks/finetuning/tabular_classification.ipynb rename to flash_notebooks/tabular_classification.ipynb index bdf57fd18b..4401e7e071 100644 --- a/flash_notebooks/finetuning/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "incorrect-shield", + "id": "stunning-event", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "subtle-rogers", + "id": "fatty-sudan", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n", @@ -24,10 +24,18 @@ " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" ] }, + { + "cell_type": "markdown", + "id": "organizational-savage", + "metadata": {}, + "source": [ + "# Training" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "early-matter", + "id": "affected-cuisine", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +46,7 @@ { "cell_type": "code", "execution_count": null, - "id": "limiting-beach", + "id": "editorial-contribution", "metadata": {}, "outputs": [], "source": [ @@ -51,7 +59,7 @@ }, { "cell_type": "markdown", - "id": "female-abortion", + "id": "parliamentary-cookbook", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -61,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "billion-rabbit", + "id": "black-removal", "metadata": {}, "outputs": [], "source": [ @@ -70,7 +78,7 @@ }, { "cell_type": "markdown", - "id": "documentary-tourism", + "id": "chubby-spell", "metadata": {}, "source": [ "### 2. Load the data\n", @@ -82,7 +90,7 @@ { "cell_type": "code", "execution_count": null, - "id": "specialized-wallet", + "id": "academic-prime", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +106,7 @@ }, { "cell_type": "markdown", - "id": "contrary-brook", + "id": "promotional-shelf", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -109,7 +117,7 @@ { "cell_type": "code", "execution_count": null, - "id": "liable-slave", + "id": "minus-easter", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +126,7 @@ }, { "cell_type": "markdown", - "id": "large-westminster", + "id": "handy-complex", "metadata": {}, "source": [ "### 4. Create the trainer. Run 10 times on data" @@ -127,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "whole-index", + "id": "liberal-smooth", "metadata": {}, "outputs": [], "source": [ @@ -136,7 +144,7 @@ }, { "cell_type": "markdown", - "id": "graduate-shopping", + "id": "israeli-interstate", "metadata": {}, "source": [ "### 5. Train the model" @@ -145,7 +153,7 @@ { "cell_type": "code", "execution_count": null, - "id": "weighted-gibraltar", + "id": "pleased-retail", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +162,7 @@ }, { "cell_type": "markdown", - "id": "helpful-christianity", + "id": "covered-completion", "metadata": {}, "source": [ "### 6. Test model" @@ -163,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "laughing-mention", + "id": "heard-ability", "metadata": {}, "outputs": [], "source": [ @@ -172,7 +180,7 @@ }, { "cell_type": "markdown", - "id": "failing-heaven", + "id": "humanitarian-geography", "metadata": {}, "source": [ "### 7. Save it!" @@ -181,7 +189,7 @@ { "cell_type": "code", "execution_count": null, - "id": "touched-burner", + "id": "ongoing-flooring", "metadata": {}, "outputs": [], "source": [ @@ -190,7 +198,66 @@ }, { "cell_type": "markdown", - "id": "previous-corporation", + "id": "dense-integral", + "metadata": {}, + "source": [ + "# Predicting" + ] + }, + { + "cell_type": "markdown", + "id": "conventional-accommodation", + "metadata": {}, + "source": [ + "### 8. Load the model from a checkpoint\n", + "\n", + "`TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "former-paradise", + "metadata": {}, + "outputs": [], + "source": [ + "model = TabularClassifier.load_from_checkpoint(\n", + " \"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "empirical-growing", + "metadata": {}, + "source": [ + "### 9. Generate predictions from a sheet file! Who would survive?\n", + "\n", + "`TabularClassifier.predict` support both DataFrame and path to `.csv` file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "residential-absorption", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict(\"data/titanic/titanic.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "trying-client", + "metadata": {}, + "outputs": [], + "source": [ + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "offensive-latest", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/finetuning/text_classification.ipynb b/flash_notebooks/text_classification.ipynb similarity index 94% rename from flash_notebooks/finetuning/text_classification.ipynb rename to flash_notebooks/text_classification.ipynb index 34411232aa..411886c941 100644 --- a/flash_notebooks/finetuning/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "optical-barrel", + "id": "satellite-bidding", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "rolled-scoop", + "id": "minute-father", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -31,7 +31,7 @@ "- 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head, will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", " \n", "\n", - "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is sub-classing `pytorch_lightning.callbacks.BaseFinetuning`.\n", + "- 4. Train the target model on a target dataset, such as Hymenoptera Dataset with ants and bees. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If a one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`.\n", "\n", "---\n", " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "pleasant-benchmark", + "id": "recent-footwear", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "suspended-announcement", + "id": "proprietary-sheriff", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "appreciated-internship", + "id": "signal-doctrine", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "excessive-private", + "id": "sweet-insight", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "noted-father", + "id": "prescribed-circuit", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "naval-rogers", + "id": "appointed-syndicate", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "monetary-album", + "id": "suffering-sacramento", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "published-vision", + "id": "leading-latitude", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "focused-claim", + "id": "educational-toner", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "primary-battery", + "id": "limiting-iceland", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "great-austria", + "id": "potential-hypothesis", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "corporate-sequence", + "id": "special-costume", "metadata": { "jupyter": { "outputs_hidden": true @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "opponent-visit", + "id": "lined-phoenix", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "sunrise-questionnaire", + "id": "derived-haven", "metadata": { "jupyter": { "outputs_hidden": true @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "certain-pizza", + "id": "bearing-israel", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "loose-march", + "id": "enormous-resort", "metadata": { "jupyter": { "outputs_hidden": true @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "loose-culture", + "id": "caroline-jewelry", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "quarterly-dominican", + "id": "arctic-directive", "metadata": {}, "source": [ "\n",