From f2b08da8b91627ca056056fea81aebc2db688187 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 4 Nov 2021 12:44:16 +0000 Subject: [PATCH] Version pretrained models (#931) --- docs/source/general/predictions.rst | 12 ++- docs/source/quickstart.rst | 2 +- .../fiftyone/image_classification.py | 4 +- .../image_classification_fiftyone_datasets.py | 4 +- .../image_classification/inference_server.py | 4 +- .../semantic_segmentation/inference_server.py | 2 +- .../speech_recognition/inference_server.py | 4 +- .../serve/summarization/inference_server.py | 4 +- .../inference_server.py | 4 +- .../text_classification/inference_server.py | 2 +- .../serve/translation/inference_server.py | 2 +- flash_notebooks/tabular_classification.ipynb | 2 +- tests/core/test_model.py | 74 ++++++++++++++----- 13 files changed, 87 insertions(+), 33 deletions(-) diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index 4bd260db99..88d7e5cd9d 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -23,7 +23,9 @@ You can pass in a sample of data (image file path, a string of text, etc) to the download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + ) # 3. Predict whether the image contains an ant or a bee predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") @@ -43,7 +45,9 @@ Predict on a csv file download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the model from a checkpoint - model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt") + model = TabularClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt" + ) # 3. Generate predictions from a csv file! Who would survive? predictions = model.predict("data/titanic/titanic.csv") @@ -69,7 +73,9 @@ reference below). download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" + ) # 3. Attach the Serializer model.serializer = Probabilities() diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 36c3ba3475..c82eaa0faf 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -91,7 +91,7 @@ Here's an example of inference: from flash.text import TextClassifier # 1. Init the finetuned task from URL - model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") + model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt") # 2. Perform inference from list of sequences predictions = model.predict( diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py index b1f5fb56cf..1d539b4eaf 100644 --- a/flash_examples/integrations/fiftyone/image_classification.py +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -53,7 +53,9 @@ trainer.save_checkpoint("image_classification_model.pt") # 4 Predict from checkpoint -model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" +) model.serializer = FiftyOneLabels(return_filepath=True) # output FiftyOne format predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py index 9ef31609d5..b7c12f79ca 100644 --- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -66,7 +66,9 @@ trainer.save_checkpoint("image_classification_model.pt") # 5 Predict from checkpoint on data with ground truth -model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" +) model.serializer = FiftyOneLabels(return_filepath=False) # output FiftyOne format datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset) predictions = trainer.predict(model, datamodule=datamodule) diff --git a/flash_examples/serve/image_classification/inference_server.py b/flash_examples/serve/image_classification/inference_server.py index 95dbc66200..a20e147c97 100644 --- a/flash_examples/serve/image_classification/inference_server.py +++ b/flash_examples/serve/image_classification/inference_server.py @@ -13,5 +13,7 @@ # limitations under the License. from flash.image import ImageClassifier -model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model = ImageClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt" +) model.serve() diff --git a/flash_examples/serve/semantic_segmentation/inference_server.py b/flash_examples/serve/semantic_segmentation/inference_server.py index 78d51a0c0e..140b3c6c34 100644 --- a/flash_examples/serve/semantic_segmentation/inference_server.py +++ b/flash_examples/serve/semantic_segmentation/inference_server.py @@ -15,7 +15,7 @@ from flash.image.segmentation.serialization import SegmentationLabels model = SemanticSegmentation.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" + "https://flash-weights.s3.amazonaws.com/0.5.2/semantic_segmentation_model.pt" ) model.serializer = SegmentationLabels(visualize=False) model.serve() diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py index bbc4479624..34e21ca319 100644 --- a/flash_examples/serve/speech_recognition/inference_server.py +++ b/flash_examples/serve/speech_recognition/inference_server.py @@ -13,5 +13,7 @@ # limitations under the License. from flash.audio import SpeechRecognition -model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt") +model = SpeechRecognition.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/speech_recognition_model.pt" +) model.serve() diff --git a/flash_examples/serve/summarization/inference_server.py b/flash_examples/serve/summarization/inference_server.py index 1d0270e4a8..8dea17bd40 100644 --- a/flash_examples/serve/summarization/inference_server.py +++ b/flash_examples/serve/summarization/inference_server.py @@ -13,5 +13,7 @@ # limitations under the License. from flash.text import SummarizationTask -model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") +model = SummarizationTask.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/summarization_model_xsum.pt" +) model.serve() diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index 4b58b8f691..08975b6cfb 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -14,6 +14,8 @@ from flash.core.classification import Labels from flash.tabular import TabularClassifier -model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") +model = TabularClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt" +) model.serializer = Labels(["Did not survive", "Survived"]) model.serve() diff --git a/flash_examples/serve/text_classification/inference_server.py b/flash_examples/serve/text_classification/inference_server.py index 37a952c906..ad8ff098f0 100644 --- a/flash_examples/serve/text_classification/inference_server.py +++ b/flash_examples/serve/text_classification/inference_server.py @@ -13,5 +13,5 @@ # limitations under the License. from flash.text import TextClassifier -model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") +model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt") model.serve() diff --git a/flash_examples/serve/translation/inference_server.py b/flash_examples/serve/translation/inference_server.py index f8f9c8dbce..0c9ed2f894 100644 --- a/flash_examples/serve/translation/inference_server.py +++ b/flash_examples/serve/translation/inference_server.py @@ -13,5 +13,5 @@ # limitations under the License. from flash.text import TranslationTask -model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") +model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/translation_model_en_ro.pt") model.serve() diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 6e7fbac057..ac8ec85242 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -222,7 +222,7 @@ "outputs": [], "source": [ "model = TabularClassifier.load_from_checkpoint(\n", - " \"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt\")" + " \"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt\")" ] }, { diff --git a/tests/core/test_model.py b/tests/core/test_model.py index f31bba3e70..6a885fc3b6 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -32,18 +32,15 @@ from torchmetrics import Accuracy import flash +from flash.audio import SpeechRecognition from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image -from flash.image import ImageClassificationData, ImageClassifier -from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING - -if _TABULAR_AVAILABLE: - from flash.tabular import TabularClassifier -else: - TabularClassifier = None - +from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image +from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation +from flash.tabular import TabularClassifier +from flash.text import SummarizationTask, TextClassifier, TranslationTask +from tests.helpers.utils import _AUDIO_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -251,23 +248,62 @@ def test_task_datapipeline_save(tmpdir): @pytest.mark.parametrize( ["cls", "filename"], [ - # needs to be updated. - # pytest.param( - # ImageClassifier, - # "image_classification_model.pt", - # marks=pytest.mark.skipif( - # not _IMAGE_TESTING, - # reason="image packages aren't installed", - # ), - # ), + pytest.param( + ImageClassifier, + "0.5.2/image_classification_model.pt", + marks=pytest.mark.skipif( + not _IMAGE_TESTING, + reason="image packages aren't installed", + ), + ), + pytest.param( + SemanticSegmentation, + "0.5.2/semantic_segmentation_model.pt", + marks=pytest.mark.skipif( + not _IMAGE_TESTING, + reason="image packages aren't installed", + ), + ), + pytest.param( + SpeechRecognition, + "0.5.2/speech_recognition_model.pt", + marks=pytest.mark.skipif( + not _AUDIO_TESTING, + reason="audio packages aren't installed", + ), + ), pytest.param( TabularClassifier, - "tabular_classification_model.pt", + "0.5.2/tabular_classification_model.pt", marks=pytest.mark.skipif( not _TABULAR_TESTING, reason="tabular packages aren't installed", ), ), + pytest.param( + TextClassifier, + "0.5.2/text_classification_model.pt", + marks=pytest.mark.skipif( + not _TEXT_TESTING, + reason="text packages aren't installed", + ), + ), + pytest.param( + SummarizationTask, + "0.5.2/summarization_model_xsum.pt", + marks=pytest.mark.skipif( + not _TEXT_TESTING, + reason="text packages aren't installed", + ), + ), + pytest.param( + TranslationTask, + "0.5.2/translation_model_en_ro.pt", + marks=pytest.mark.skipif( + not _TEXT_TESTING, + reason="text packages aren't installed", + ), + ), ], ) def test_model_download(tmpdir, cls, filename):