Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge 19caecd into 5aecfad
Browse files Browse the repository at this point in the history
  • Loading branch information
edenlightning authored Feb 8, 2021
2 parents 5aecfad + 19caecd commit a945331
Showing 1 changed file with 36 additions and 73 deletions.
109 changes: 36 additions & 73 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,17 @@ in AI research embedded into each task so you don't have to be a deep learning P
### Predictions

```python

# import our libraries
from flash.text import TextClassifier
from flash.text import TranslationTask

# 1. Load finetuned task
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")

# 2. Perform inference from list of sequences
# 2. Translate a few sentences!
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
"BBC News went to meet one of the project's first graduates.",
"A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
])
print(predictions)
```
Expand All @@ -114,6 +112,7 @@ print(predictions)
First, finetune:

```python
# import our libraries
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
Expand Down Expand Up @@ -157,109 +156,72 @@ print(predictions)
## Tasks
Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods.

### Example 1: Image classification
Flash has an ImageClassification task to tackle any image classification problem.
### Example 1: Image embedding
Flash has an Image embedding task to encodes an image into a vector of image features which can be used for anything like clustering, similarity search or classification.

<details>
<summary>View example</summary>
To illustrate, Let's say we wanted to develop a model that could classify between ants and bees.

<img src="https://pl-flash-data.s3.amazonaws.com/images/ant_bee.png" width="300px">

Here we classify ants vs bees.

```python
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# import our libraries
import torch

from flash.core.data import download_data
from flash.vision import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")
# 2. Create an ImageEmbedder with resnet50 trained on imagenet.
embedder = ImageEmbedder(backbone="resnet50", embedding_dim=128)

# 6. Test the model
trainer.test()

# 7. Predict!
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)
```
# 3. Generate an embedding from an image path.
embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg')

To run the example:
```bash
python flash_examples/finetuning/image_classifier.py
# 4. Print embeddings shape
print(embeddings.shape)
```
</details>

### Example 2: Text Classification
Flash has a TextClassification task to tackle any text classification problem.
### Example 2: Text Summarization
Flash has a Summarization task to sum up text from a larger article into a short description.

<details>
<summary>View example</summary>
To illustrate, say you wanted to classify movie reviews as positive or negative.

```python
# import our libraries
import flash
from flash import download_data
from flash.text import TextClassificationData, TextClassifier
from flash.text import SummarizationData, SummarizationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment",
batch_size=512
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)
model = SummarizationTask()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
trainer = flash.Trainer(max_epochs=1, gpus=1, precision=16)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")
trainer.finetune(model, datamodule=datamodule)

# 6. Test model
trainer.test()

# 7. Classify a few sentences! How was the movie?
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
```
To run the example:
```bash
python flash_examples/finetuning/classify_text.py
python flash_examples/finetuning/summarization.py
```
</details>

Expand All @@ -273,6 +235,7 @@ Flash has a TabularClassification task to tackle any tabular classification prob
To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.

```python
# import our libraries
from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
Expand Down

0 comments on commit a945331

Please sign in to comment.