diff --git a/flash_notebooks/finetuning/finetuning_image_classification.ipynb b/flash_notebooks/finetuning/finetuning_image_classification.ipynb new file mode 100644 index 0000000000..119f621a81 --- /dev/null +++ b/flash_notebooks/finetuning/finetuning_image_classification.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "fatal-doctor", + "metadata": {}, + "outputs": [], + "source": [ + "import flash\n", + "from flash.core.data import download_data\n", + "from flash.vision import ImageClassificationData, ImageClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "important-underground", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "celtic-publisher", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "atmospheric-batch", + "metadata": {}, + "source": [ + "# 2. Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "broadband-massachusetts", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = ImageClassificationData.from_folders(\n", + " train_folder=\"data/hymenoptera_data/train/\",\n", + " valid_folder=\"data/hymenoptera_data/val/\",\n", + " test_folder=\"data/hymenoptera_data/test/\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "biblical-boundary", + "metadata": {}, + "source": [ + "# 3. Build the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "upper-staff", + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier(num_classes=datamodule.num_classes)" + ] + }, + { + "cell_type": "markdown", + "id": "sixth-lancaster", + "metadata": {}, + "source": [ + "# 4. Create the trainer. Run once on data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "generous-paraguay", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = flash.Trainer(max_epochs=1)" + ] + }, + { + "cell_type": "markdown", + "id": "physical-prophet", + "metadata": {}, + "source": [ + "# 5. Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "oriented-hudson", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "sexual-puzzle", + "metadata": {}, + "source": [ + "# 6. Test the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eight-russell", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "id": "ultimate-halloween", + "metadata": {}, + "source": [ + "# 7. Save it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "literary-establishment", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint(\"image_classification_model.pt\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flash_notebooks/finetuning/finetuning_tabular_classification.ipynb b/flash_notebooks/finetuning/finetuning_tabular_classification.ipynb new file mode 100644 index 0000000000..fc1d67dae4 --- /dev/null +++ b/flash_notebooks/finetuning/finetuning_tabular_classification.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "pacific-unemployment", + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall\n", + "\n", + "import flash\n", + "from flash.core.data import download_data\n", + "from flash.tabular import TabularClassifier, TabularData" + ] + }, + { + "cell_type": "markdown", + "id": "political-guitar", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "romance-penalty", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/titanic.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "atmospheric-justice", + "metadata": {}, + "source": [ + "# 2. Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "unauthorized-differential", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = TabularData.from_csv(\n", + " \"./data/titanic/titanic.csv\",\n", + " test_csv=\"./data/titanic/test.csv\",\n", + " categorical_input=[\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", + " numerical_input=[\"Fare\"],\n", + " target=\"Survived\",\n", + " val_size=0.25,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "burning-carroll", + "metadata": {}, + "source": [ + "# 3. Build the model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "coupled-hindu", + "metadata": {}, + "outputs": [], + "source": [ + "model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])" + ] + }, + { + "cell_type": "markdown", + "id": "civil-greeting", + "metadata": {}, + "source": [ + "# 4. Create the trainer. Run 10 times on data" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "processed-congo", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n" + ] + } + ], + "source": [ + "trainer = flash.Trainer(max_epochs=10)" + ] + }, + { + "cell_type": "markdown", + "id": "contained-sequence", + "metadata": {}, + "source": [ + "# 5. Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "changed-harvest", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params\n", + "----------------------------------------\n", + "0 | model | Sequential | 59.1 K\n", + "1 | metrics | ModuleDict | 0 \n", + "2 | embs | ModuleList | 15.1 K\n", + "3 | bn_num | BatchNorm1d | 2 \n", + "----------------------------------------\n", + "74.2 K Trainable params\n", + "0 Non-trainable params\n", + "74.2 K Total params\n", + "0.297 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation sanity check: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f826a8a8b79a4ec29a9bea8a216fd11c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.fit(model, datamodule=datamodule)" + ] + }, + { + "cell_type": "markdown", + "id": "defensive-studio", + "metadata": {}, + "source": [ + "# 6. Test model" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "injured-andrew", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f79f25a5991d4fb4be9e4653b25821c5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "DATALOADER:0 TEST RESULTS\n", + "{'test_accuracy': 0.7333333492279053,\n", + " 'test_cross_entropy': 0.6910470128059387,\n", + " 'test_precision': 0.7333333492279053,\n", + " 'test_recall': 0.7333333492279053}\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'test_accuracy': 0.7333333492279053,\n", + " 'test_precision': 0.7333333492279053,\n", + " 'test_recall': 0.7333333492279053,\n", + " 'test_cross_entropy': 0.6910470128059387}]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "id": "concrete-continent", + "metadata": {}, + "source": [ + "# 7. Save it!" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "opponent-chuck", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint(\"tabular_classification_model.pt\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flash_notebooks/finetuning/finetuning_text_classification.ipynb b/flash_notebooks/finetuning/finetuning_text_classification.ipynb new file mode 100644 index 0000000000..05836e701f --- /dev/null +++ b/flash_notebooks/finetuning/finetuning_text_classification.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "aggressive-accounting", + "metadata": {}, + "outputs": [], + "source": [ + "import flash\n", + "from flash.core.data import download_data\n", + "from flash.text import TextClassificationData, TextClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "complicated-pitch", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "temporal-lender", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/imdb.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "southern-junior", + "metadata": {}, + "source": [ + "# 2. Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "copyrighted-adjustment", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = TextClassificationData.from_files(\n", + " train_file=\"data/imdb/train.csv\",\n", + " valid_file=\"data/imdb/valid.csv\",\n", + " test_file=\"data/imdb/test.csv\",\n", + " input=\"review\",\n", + " target=\"sentiment\",\n", + " batch_size=512\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "prime-complexity", + "metadata": { + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# 3. Build the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "international-agriculture", + "metadata": {}, + "outputs": [], + "source": [ + "model = TextClassifier(num_classes=datamodule.num_classes)" + ] + }, + { + "cell_type": "markdown", + "id": "small-programmer", + "metadata": { + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# 4. Create the trainer. Run once on data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fallen-subcommittee", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = flash.Trainer(max_epochs=1)" + ] + }, + { + "cell_type": "markdown", + "id": "miniature-complexity", + "metadata": { + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# 5. Fine-tune the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "critical-palace", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "detected-return", + "metadata": { + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# 6. Test model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "american-claim", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "id": "quantitative-array", + "metadata": { + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# 7. Save it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aggregate-river", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint(\"text_classification_model.pt\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/general_model.py b/flash_notebooks/general_model.py similarity index 100% rename from notebooks/general_model.py rename to flash_notebooks/general_model.py diff --git a/notebooks/image-classification.ipynb b/flash_notebooks/image-classification.ipynb similarity index 100% rename from notebooks/image-classification.ipynb rename to flash_notebooks/image-classification.ipynb diff --git a/notebooks/image_classifier.py b/flash_notebooks/image_classifier.py similarity index 100% rename from notebooks/image_classifier.py rename to flash_notebooks/image_classifier.py diff --git a/flash_notebooks/predict/image_classification.ipynb b/flash_notebooks/predict/image_classification.ipynb new file mode 100644 index 0000000000..50a1b13348 --- /dev/null +++ b/flash_notebooks/predict/image_classification.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "unlike-brain", + "metadata": {}, + "outputs": [], + "source": [ + "from flash import Trainer\n", + "from flash.core.data import download_data\n", + "from flash.vision import ImageClassificationData, ImageClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "joined-rehabilitation", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "standing-schema", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "threaded-thomas", + "metadata": {}, + "source": [ + "# 2. Load the model from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dedicated-mauritius", + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/image_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "fleet-paraguay", + "metadata": {}, + "source": [ + "# 3a. Predict what's on a few images! ants or bees?" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "regional-cherry", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict([\n", + " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", + " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", + " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "under-character", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 1, 0]\n" + ] + } + ], + "source": [ + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "baking-messaging", + "metadata": {}, + "source": [ + "# 3b. Or generate predictions with a whole folder!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "respective-sensitivity", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = ImageClassificationData.from_folder(folder=\"data/hymenoptera_data/predict/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "leading-mailing", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: None, using: 0 TPU cores\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2ace19df45c34b39bc8a305fa14ba1e3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "predictions = Trainer().predict(model, datamodule=datamodule)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "imperial-lyric", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1]]]\n" + ] + } + ], + "source": [ + "print(predictions)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flash_notebooks/predict/tabular_classification.ipynb b/flash_notebooks/predict/tabular_classification.ipynb new file mode 100644 index 0000000000..d806e7ac81 --- /dev/null +++ b/flash_notebooks/predict/tabular_classification.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "attempted-identifier", + "metadata": {}, + "outputs": [], + "source": [ + "import flash\n", + "from flash.core.data import download_data\n", + "from flash.vision import ImageClassificationData, ImageClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "answering-toolbox", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sized-license", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "complete-memory", + "metadata": {}, + "source": [ + "# 2. Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "geographic-literacy", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = ImageClassificationData.from_folders(\n", + " train_folder=\"data/hymenoptera_data/train/\",\n", + " valid_folder=\"data/hymenoptera_data/val/\",\n", + " test_folder=\"data/hymenoptera_data/test/\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "sharp-spouse", + "metadata": {}, + "source": [ + "# 3. Build the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "following-mapping", + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier(num_classes=datamodule.num_classes)" + ] + }, + { + "cell_type": "markdown", + "id": "talented-arena", + "metadata": {}, + "source": [ + "# 4. Create the trainer. Run once on data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "pharmaceutical-acrobat", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = flash.Trainer(max_epochs=1)" + ] + }, + { + "cell_type": "markdown", + "id": "agreed-somewhere", + "metadata": {}, + "source": [ + "# 5. Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "juvenile-violation", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))" + ] + }, + { + "cell_type": "markdown", + "id": "weekly-swimming", + "metadata": {}, + "source": [ + "# 6. Test the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "viral-salem", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "id": "fluid-leave", + "metadata": {}, + "source": [ + "# 7. Save it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "alpine-testimony", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint(\"image_classification_model.pt\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flash_notebooks/predict/text_classification.ipynb b/flash_notebooks/predict/text_classification.ipynb new file mode 100644 index 0000000000..80ad1e345c --- /dev/null +++ b/flash_notebooks/predict/text_classification.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "residential-identification", + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_lightning import Trainer\n", + "\n", + "from flash.core.data import download_data\n", + "from flash.text import TextClassificationData, TextClassifier" + ] + }, + { + "cell_type": "markdown", + "id": "green-moderator", + "metadata": {}, + "source": [ + "# 1. Download the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fuzzy-boost", + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/imdb.zip\", 'data/')" + ] + }, + { + "cell_type": "markdown", + "id": "considerable-learning", + "metadata": {}, + "source": [ + "# 2. Load the model from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "academic-asset", + "metadata": {}, + "outputs": [], + "source": [ + "model = TextClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/text_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "clear-bowling", + "metadata": {}, + "source": [ + "# 2a. Classify a few sentences! How was the movie?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "impressive-disclaimer", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict([\n", + " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", + " \"The worst movie in the history of cinema.\",\n", + " \"I come from Bulgaria where it 's almost impossible to have a tornado.\"\n", + " \"Very, very afraid\"\n", + " \"This guy has done a great job with this movie!\",\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "streaming-denial", + "metadata": {}, + "outputs": [], + "source": [ + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "ignored-airfare", + "metadata": {}, + "source": [ + "# 2b. Or generate predictions from a sheet file!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "perfect-latin", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = TextClassificationData.from_file(\n", + " predict_file=\"data/imdb/predict.csv\",\n", + " input=\"review\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "julian-bargain", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = Trainer().predict(model, datamodule=datamodule)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "organized-illinois", + "metadata": {}, + "outputs": [], + "source": [ + "print(predictions)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/text-classification.ipynb b/flash_notebooks/text-classification.ipynb similarity index 100% rename from notebooks/text-classification.ipynb rename to flash_notebooks/text-classification.ipynb