Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Bump pretrained weights to 0.6.0 (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 7, 2021
1 parent eec8485 commit 4a242e6
Show file tree
Hide file tree
Showing 13 changed files with 21 additions and 21 deletions.
6 changes: 3 additions & 3 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ You can pass in a sample of data (image file path, a string of text, etc) to the
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)
# 3. Predict whether the image contains an ant or a bee
Expand All @@ -46,7 +46,7 @@ Predict on a csv file
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt"
)
# 3. Generate predictions from a csv file! Who would survive?
Expand Down Expand Up @@ -74,7 +74,7 @@ reference below).
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)
# 3. Attach the Output
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Here's an example of inference:
from flash.text import TextClassifier

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)
model.output = FiftyOneLabels(return_filepath=True) # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)
model.output = FiftyOneLabels(return_filepath=False) # output FiftyOne format
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
from flash.image import ImageClassifier

model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flash.image.segmentation.output import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/semantic_segmentation_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/semantic_segmentation_model.pt"
)
model.output = SegmentationLabels(visualize=False)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
from flash.audio import SpeechRecognition

model = SpeechRecognition.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/speech_recognition_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/speech_recognition_model.pt"
)
model.serve()
2 changes: 1 addition & 1 deletion flash_examples/serve/summarization/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
from flash.text import SummarizationTask

model = SummarizationTask.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/summarization_model_xsum.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/summarization_model_xsum.pt"
)
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt"
"https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt"
)
model.output = Labels(["Did not survive", "Survived"])
model.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/text_classification_model.pt")
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt")
model.serve()
2 changes: 1 addition & 1 deletion flash_examples/serve/translation/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.
from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.5.2/translation_model_en_ro.pt")
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/translation_model_en_ro.pt")
model.serve()
2 changes: 1 addition & 1 deletion flash_notebooks/tabular_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@
"outputs": [],
"source": [
"model = TabularClassifier.load_from_checkpoint(\n",
" \"https://flash-weights.s3.amazonaws.com/0.5.2/tabular_classification_model.pt\")"
" \"https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt\")"
]
},
{
Expand Down
14 changes: 7 additions & 7 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,55 +250,55 @@ def test_task_datapipeline_save(tmpdir):
[
pytest.param(
ImageClassifier,
"0.5.2/image_classification_model.pt",
"0.6.0/image_classification_model.pt",
marks=pytest.mark.skipif(
not _IMAGE_TESTING,
reason="image packages aren't installed",
),
),
pytest.param(
SemanticSegmentation,
"0.5.2/semantic_segmentation_model.pt",
"0.6.0/semantic_segmentation_model.pt",
marks=pytest.mark.skipif(
not _IMAGE_TESTING,
reason="image packages aren't installed",
),
),
pytest.param(
SpeechRecognition,
"0.5.2/speech_recognition_model.pt",
"0.6.0/speech_recognition_model.pt",
marks=pytest.mark.skipif(
not _AUDIO_TESTING,
reason="audio packages aren't installed",
),
),
pytest.param(
TabularClassifier,
"0.5.2/tabular_classification_model.pt",
"0.6.0/tabular_classification_model.pt",
marks=pytest.mark.skipif(
not _TABULAR_TESTING,
reason="tabular packages aren't installed",
),
),
pytest.param(
TextClassifier,
"0.5.2/text_classification_model.pt",
"0.6.0/text_classification_model.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
),
),
pytest.param(
SummarizationTask,
"0.5.2/summarization_model_xsum.pt",
"0.6.0/summarization_model_xsum.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
),
),
pytest.param(
TranslationTask,
"0.5.2/translation_model_en_ro.pt",
"0.6.0/translation_model_en_ro.pt",
marks=pytest.mark.skipif(
not _TEXT_TESTING,
reason="text packages aren't installed",
Expand Down

0 comments on commit 4a242e6

Please sign in to comment.