Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Version pretrained models (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 4, 2021
1 parent 47d5d35 commit f2b08da
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 33 deletions.
12 changes: 9 additions & 3 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ You can pass in a sample of data (image file path, a string of text, etc) to the
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
# 3. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
Expand All @@ -43,7 +45,9 @@ Predict on a csv file
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt"
)
# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
Expand All @@ -69,7 +73,9 @@ reference below).
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
# 3. Attach the Serializer
model.serializer = Probabilities()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Here's an example of inference:
from flash.text import TextClassifier

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict(
Expand Down
4 changes: 3 additions & 1 deletion flash_examples/integrations/fiftyone/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
trainer.save_checkpoint("image_classification_model.pt")

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
model.serializer = FiftyOneLabels(return_filepath=True) # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions)) # flatten batches
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
trainer.save_checkpoint("image_classification_model.pt")

# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
model.serializer = FiftyOneLabels(return_filepath=False) # output FiftyOne format
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
predictions = trainer.predict(model, datamodule=datamodule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.
from flash.image import ImageClassifier

model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
"https://flash-weights.s3.amazonaws.com/0.5.2/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()
4 changes: 3 additions & 1 deletion flash_examples/serve/speech_recognition/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.
from flash.audio import SpeechRecognition

model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt")
model = SpeechRecognition.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/speech_recognition_model.pt"
)
model.serve()
4 changes: 3 additions & 1 deletion flash_examples/serve/summarization/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.
from flash.text import SummarizationTask

model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")
model = SummarizationTask.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/summarization_model_xsum.pt"
)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from flash.core.classification import Labels
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt"
)
model.serializer = Labels(["Did not survive", "Survived"])
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt")
model.serve()
2 changes: 1 addition & 1 deletion flash_examples/serve/translation/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/translation_model_en_ro.pt")
model.serve()
2 changes: 1 addition & 1 deletion flash_notebooks/tabular_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@
"outputs": [],
"source": [
"model = TabularClassifier.load_from_checkpoint(\n",
" \"https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt\")"
" \"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt\")"
]
},
{
Expand Down
74 changes: 55 additions & 19 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,15 @@
from torchmetrics import Accuracy

import flash
from flash.audio import SpeechRecognition
from flash.core.adapter import Adapter
from flash.core.classification import ClassificationTask
from flash.core.data.process import DefaultPreprocess, Postprocess
from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image
from flash.image import ImageClassificationData, ImageClassifier
from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING

if _TABULAR_AVAILABLE:
from flash.tabular import TabularClassifier
else:
TabularClassifier = None

from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image
from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation
from flash.tabular import TabularClassifier
from flash.text import SummarizationTask, TextClassifier, TranslationTask
from tests.helpers.utils import _AUDIO_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING

# ======== Mock functions ========

Expand Down Expand Up @@ -251,23 +248,62 @@ def test_task_datapipeline_save(tmpdir):
@pytest.mark.parametrize(
["cls", "filename"],
[
# needs to be updated.
# pytest.param(
# ImageClassifier,
# "image_classification_model.pt",
# marks=pytest.mark.skipif(
# not _IMAGE_TESTING,
# reason="image packages aren't installed",
# ),
# ),
pytest.param(
ImageClassifier,
"0.5.2/image_classification_model.pt",
marks=pytest.mark.skipif(
not _IMAGE_TESTING,
reason="image packages aren't installed",
),
),
pytest.param(
SemanticSegmentation,
"0.5.2/semantic_segmentation_model.pt",
marks=pytest.mark.skipif(
not _IMAGE_TESTING,
reason="image packages aren't installed",
),
),
pytest.param(
SpeechRecognition,
"0.5.2/speech_recognition_model.pt",
marks=pytest.mark.skipif(
not _AUDIO_TESTING,
reason="audio packages aren't installed",
),
),
pytest.param(
TabularClassifier,
"tabular_classification_model.pt",
"0.5.2/tabular_classification_model.pt",
marks=pytest.mark.skipif(
not _TABULAR_TESTING,
reason="tabular packages aren't installed",
),
),
pytest.param(
TextClassifier,
"0.5.2/text_classification_model.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
),
),
pytest.param(
SummarizationTask,
"0.5.2/summarization_model_xsum.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
),
),
pytest.param(
TranslationTask,
"0.5.2/translation_model_en_ro.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
),
),
],
)
def test_model_download(tmpdir, cls, filename):
Expand Down

0 comments on commit f2b08da

Please sign in to comment.