From 253063de1dead21bb26384a46cd7c8280be93efe Mon Sep 17 00:00:00 2001
From: edenlightning <66261195+edenlightning@users.noreply.github.com>
Date: Mon, 8 Feb 2021 18:59:45 -0500
Subject: [PATCH] Update flash examples in README (#91)
---
README.md | 109 ++++++++++++++++++------------------------------------
1 file changed, 36 insertions(+), 73 deletions(-)
diff --git a/README.md b/README.md
index b1590b77bb..1aa4f2ea0d 100644
--- a/README.md
+++ b/README.md
@@ -92,19 +92,17 @@ in AI research embedded into each task so you don't have to be a deep learning P
### Predictions
```python
+
# import our libraries
-from flash.text import TextClassifier
+from flash.text import TranslationTask
# 1. Load finetuned task
-model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
+model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
-# 2. Perform inference from list of sequences
+# 2. Translate a few sentences!
predictions = model.predict([
- "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
- "The worst movie in the history of cinema.",
- "I come from Bulgaria where it 's almost impossible to have a tornado."
- "Very, very afraid"
- "This guy has done a great job with this movie!",
+ "BBC News went to meet one of the project's first graduates.",
+ "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
])
print(predictions)
```
@@ -114,6 +112,7 @@ print(predictions)
First, finetune:
```python
+# import our libraries
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
@@ -157,109 +156,72 @@ print(predictions)
## Tasks
Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods.
-### Example 1: Image classification
-Flash has an ImageClassification task to tackle any image classification problem.
+### Example 1: Image embedding
+Flash has an Image embedding task to encodes an image into a vector of image features which can be used for anything like clustering, similarity search or classification.
View example
- To illustrate, Let's say we wanted to develop a model that could classify between ants and bees.
-
-
-
- Here we classify ants vs bees.
```python
- import flash
- from flash import download_data
- from flash.vision import ImageClassificationData, ImageClassifier
+ # import our libraries
+ import torch
+
+ from flash.core.data import download_data
+ from flash.vision import ImageEmbedder
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
- # 2. Load the data
- datamodule = ImageClassificationData.from_folders(
- train_folder="data/hymenoptera_data/train/",
- valid_folder="data/hymenoptera_data/val/",
- test_folder="data/hymenoptera_data/test/",
- )
-
- # 3. Build the model
- model = ImageClassifier(num_classes=datamodule.num_classes)
-
- # 4. Create the trainer. Run once on data
- trainer = flash.Trainer(max_epochs=1)
-
- # 5. Train the model
- trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")
+ # 2. Create an ImageEmbedder with resnet50 trained on imagenet.
+ embedder = ImageEmbedder(backbone="resnet50", embedding_dim=128)
- # 6. Test the model
- trainer.test()
-
- # 7. Predict!
- predictions = model.predict([
- "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
- "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
- "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
- ])
- print(predictions)
- ```
+ # 3. Generate an embedding from an image path.
+ embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg')
- To run the example:
- ```bash
- python flash_examples/finetuning/image_classifier.py
+ # 4. Print embeddings shape
+ print(embeddings.shape)
```
-### Example 2: Text Classification
-Flash has a TextClassification task to tackle any text classification problem.
+### Example 2: Text Summarization
+Flash has a Summarization task to sum up text from a larger article into a short description.
View example
- To illustrate, say you wanted to classify movie reviews as positive or negative.
```python
+ # import our libraries
import flash
from flash import download_data
- from flash.text import TextClassificationData, TextClassifier
+ from flash.text import SummarizationData, SummarizationTask
# 1. Download the data
- download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
+ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')
# 2. Load the data
- datamodule = TextClassificationData.from_files(
- train_file="data/imdb/train.csv",
- valid_file="data/imdb/valid.csv",
- test_file="data/imdb/test.csv",
- input="review",
- target="sentiment",
- batch_size=512
+ datamodule = SummarizationData.from_files(
+ train_file="data/xsum/train.csv",
+ valid_file="data/xsum/valid.csv",
+ test_file="data/xsum/test.csv",
+ input="input",
+ target="target"
)
# 3. Build the model
- model = TextClassifier(num_classes=datamodule.num_classes)
+ model = SummarizationTask()
# 4. Create the trainer. Run once on data
- trainer = flash.Trainer(max_epochs=1)
+ trainer = flash.Trainer(max_epochs=1, gpus=1, precision=16)
# 5. Fine-tune the model
- trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")
+ trainer.finetune(model, datamodule=datamodule)
# 6. Test model
trainer.test()
-
- # 7. Classify a few sentences! How was the movie?
- predictions = model.predict([
- "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
- "The worst movie in the history of cinema.",
- "I come from Bulgaria where it 's almost impossible to have a tornado."
- "Very, very afraid"
- "This guy has done a great job with this movie!",
- ])
- print(predictions)
```
To run the example:
```bash
- python flash_examples/finetuning/classify_text.py
+ python flash_examples/finetuning/summarization.py
```
@@ -273,6 +235,7 @@ Flash has a TabularClassification task to tackle any tabular classification prob
To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.
```python
+ # import our libraries
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data