diff --git a/README.md b/README.md index b1590b77bb..1aa4f2ea0d 100644 --- a/README.md +++ b/README.md @@ -92,19 +92,17 @@ in AI research embedded into each task so you don't have to be a deep learning P ### Predictions ```python + # import our libraries -from flash.text import TextClassifier +from flash.text import TranslationTask # 1. Load finetuned task -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") +model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") -# 2. Perform inference from list of sequences +# 2. Translate a few sentences! predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado." - "Very, very afraid" - "This guy has done a great job with this movie!", + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", ]) print(predictions) ``` @@ -114,6 +112,7 @@ print(predictions) First, finetune: ```python +# import our libraries import flash from flash import download_data from flash.vision import ImageClassificationData, ImageClassifier @@ -157,109 +156,72 @@ print(predictions) ## Tasks Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods. -### Example 1: Image classification -Flash has an ImageClassification task to tackle any image classification problem. +### Example 1: Image embedding +Flash has an Image embedding task to encodes an image into a vector of image features which can be used for anything like clustering, similarity search or classification.
View example - To illustrate, Let's say we wanted to develop a model that could classify between ants and bees. - - - - Here we classify ants vs bees. ```python - import flash - from flash import download_data - from flash.vision import ImageClassificationData, ImageClassifier + # import our libraries + import torch + + from flash.core.data import download_data + from flash.vision import ImageEmbedder # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') - # 2. Load the data - datamodule = ImageClassificationData.from_folders( - train_folder="data/hymenoptera_data/train/", - valid_folder="data/hymenoptera_data/val/", - test_folder="data/hymenoptera_data/test/", - ) - - # 3. Build the model - model = ImageClassifier(num_classes=datamodule.num_classes) - - # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) - - # 5. Train the model - trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze") + # 2. Create an ImageEmbedder with resnet50 trained on imagenet. + embedder = ImageEmbedder(backbone="resnet50", embedding_dim=128) - # 6. Test the model - trainer.test() - - # 7. Predict! - predictions = model.predict([ - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", - ]) - print(predictions) - ``` + # 3. Generate an embedding from an image path. + embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') - To run the example: - ```bash - python flash_examples/finetuning/image_classifier.py + # 4. Print embeddings shape + print(embeddings.shape) ```
-### Example 2: Text Classification -Flash has a TextClassification task to tackle any text classification problem. +### Example 2: Text Summarization +Flash has a Summarization task to sum up text from a larger article into a short description.
View example - To illustrate, say you wanted to classify movie reviews as positive or negative. ```python + # import our libraries import flash from flash import download_data - from flash.text import TextClassificationData, TextClassifier + from flash.text import SummarizationData, SummarizationTask # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') # 2. Load the data - datamodule = TextClassificationData.from_files( - train_file="data/imdb/train.csv", - valid_file="data/imdb/valid.csv", - test_file="data/imdb/test.csv", - input="review", - target="sentiment", - batch_size=512 + datamodule = SummarizationData.from_files( + train_file="data/xsum/train.csv", + valid_file="data/xsum/valid.csv", + test_file="data/xsum/test.csv", + input="input", + target="target" ) # 3. Build the model - model = TextClassifier(num_classes=datamodule.num_classes) + model = SummarizationTask() # 4. Create the trainer. Run once on data - trainer = flash.Trainer(max_epochs=1) + trainer = flash.Trainer(max_epochs=1, gpus=1, precision=16) # 5. Fine-tune the model - trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze") + trainer.finetune(model, datamodule=datamodule) # 6. Test model trainer.test() - - # 7. Classify a few sentences! How was the movie? - predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado." - "Very, very afraid" - "This guy has done a great job with this movie!", - ]) - print(predictions) ``` To run the example: ```bash - python flash_examples/finetuning/classify_text.py + python flash_examples/finetuning/summarization.py ```
@@ -273,6 +235,7 @@ Flash has a TabularClassification task to tackle any tabular classification prob To illustrate, say we want to build a model to predict if a passenger survived on the Titanic. ```python + # import our libraries from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall import flash from flash import download_data