Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Remove model.predict (#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 6, 2021
1 parent d046de4 commit 5dd695f
Show file tree
Hide file tree
Showing 63 changed files with 420 additions and 407 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011))


- Changed `from_files` and `from_folders` from `ObjectDetectionData`, `InstanceSegmentationData`, `KeypointDetectionData` to support only the `predicting` stage ([#1018](https://github.com/PyTorchLightning/lightning-flash/pull/1018))

### Deprecated
Expand Down Expand Up @@ -75,8 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948))

- (Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010))
- Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010))

- Removed `Task.predict`, use `Trainer.predict` instead ([#1030](https://github.com/PyTorchLightning/lightning-flash/pull/1030))

## [0.5.2] - 2021-11-05

Expand Down
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,6 @@ model.serve()

or make predictions from raw data directly.

```py
predictions = model.predict(["data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png"])
```

or make predictions with 2 GPUs.

```py
trainer = Trainer(accelerator='ddp', gpus=2)
dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB")
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,3 @@ _________
~flash.core.trainer.from_argparse_args
~flash.core.utilities.apply_func.get_callable_name
~flash.core.utilities.apply_func.get_callable_dict
~flash.core.model.predict_context
18 changes: 12 additions & 6 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ Once you've finetuned, use the model to predict:
# Output predictions as labels, automatically inferred from the training data in part 2.
model.output = LabelsOutput()

predictions = model.predict(
[
predict_datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
]
)
predictions = trainer.predict(model, datamodule=predict_datamodule)
print(predictions)

We get the following output:
Expand All @@ -76,19 +77,24 @@ We get the following output:
.. testcode:: finetune
:hide:

assert all([prediction in ["ants", "bees"] for prediction in predictions])
assert all(
[all([prediction in ["ants", "bees"] for prediction in prediction_batch]) for prediction_batch in predictions]
)

.. code-block::
['bees', 'ants']
[['bees', 'ants']]
Or you can use the saved model for prediction anywhere you want!

.. code-block:: python
from flash.image import ImageClassifier
from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
predictions = model.predict("path/to/your/own/image.png")
trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)
46 changes: 14 additions & 32 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@ Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

Predict on a single sample of data
==================================

You can pass in a sample of data (image file path, a string of text, etc) to the :func:`~flash.core.model.Task.predict` method.

First create a :class:`~flash.core.data.data_module.DataModule` with some predict data, then pass it to the :meth:`Trainer.predict <flash.core.trainer.Trainer.predict>` method.

.. code-block:: python
from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassifier
from flash.image import ImageClassifier, ImageClassificationData
# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
Expand All @@ -28,30 +24,13 @@ You can pass in a sample of data (image file path, a string of text, etc) to the
)
# 3. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
Predict on a csv file
=====================

.. code-block:: python
from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt"
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [["bees"]]
Serializing predictions
Expand All @@ -61,7 +40,6 @@ To change the output format of predictions you can attach an :class:`~flash.core
:class:`~flash.core.model.Task`. For example, you can choose to output probabilities (for more options see the API
reference below).


.. code-block:: python
from flash.core.classification import ProbabilitiesOutput
Expand All @@ -81,6 +59,10 @@ reference below).
model.output = ProbabilitiesOutput()
# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [[0.5926494598388672, 0.40735048055648804]]
# out: [[[0.5926494598388672, 0.40735048055648804]]]
20 changes: 14 additions & 6 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Inference
Inference is the process of generating predictions from trained models. To use a task for inference:

1. Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works.
2. Pass in the data to :func:`flash.core.model.Task.predict`.
2. Load your data into a :class:`~flash.core.data.data_module.DataModule` and pass it to :func:`Trainer.predict <flash.core.trainer.Trainer.predict>`.

|
Expand All @@ -88,19 +88,22 @@ Here's an example of inference:
.. testcode::

# import our libraries
from flash.text import TextClassifier
from flash import Trainer
from flash.text import TextClassifier, TextClassificationData

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict(
[
trainer = Trainer()
datamodule = TextClassificationData.from_lists(
predict_data=[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"This guy has done a great job with this movie!",
]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

We get the following output:
Expand All @@ -113,11 +116,16 @@ We get the following output:
.. testcode::
:hide:

assert all([prediction in ["positive", "negative"] for prediction in predictions])
assert all(
[
all([prediction in ["positive", "negative"] for prediction in prediction_batch])
for prediction_batch in predictions
]
)

.. code-block::
["negative", "negative", "positive"]
[["negative", "negative", "positive"]]
-------

Expand Down
2 changes: 1 addition & 1 deletion docs/source/template/tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ These tests are very similar to ``test_train``, but here they are for completene
We also include tests for prediction named ``test_predict_*`` for each of our data sources.
In our case, we have ``test_predict_numpy`` and ``test_predict_sklearn``.
These tests should use the ``input`` argument to :meth:`~flash.core.model.Task.predict` to select the required :class:`~flash.core.data.Input`.
These tests should load the data with a :class:`~flash.core.data.data_module.DataModule` and generate predictions with :func:`Trainer.predict <flash.core.trainer.Trainer.predict>`.
Here's ``test_predict_sklearn`` as an example:

.. literalinclude:: ../../../tests/template/classification/test_model.py
Expand Down
5 changes: 3 additions & 2 deletions flash/core/integrations/labelstudio/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def __init__(self, datamodule: DataModule):
def show_predictions(self, predictions):
"""Converts predictions to Label Studio results."""
results = []
for pred in predictions:
results.append(self._construct_result(pred))
for prediction_batch in predictions:
for pred in prediction_batch:
results.append(self._construct_result(pred))
return results

def show_tasks(self, predictions, export_json=None):
Expand Down
81 changes: 2 additions & 79 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import functools
import inspect
import pickle
import warnings
from abc import ABCMeta
from copy import deepcopy
from importlib import import_module
Expand All @@ -41,7 +40,6 @@
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_base import InputBase as NewInputBase
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
Expand Down Expand Up @@ -262,27 +260,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
print("Benchmark Successful!")


def predict_context(func: Callable) -> Callable:
"""This decorator is used as context manager to put model in eval mode before running predict and reset to
train after."""

@functools.wraps(func)
def wrapper(self, *args, **kwargs) -> Any:
grad_enabled = torch.is_grad_enabled()
is_training = self.training
self.eval()
torch.set_grad_enabled(False)

result = func(self, *args, **kwargs)

if is_training:
self.train()
torch.set_grad_enabled(grad_enabled)
return result

return wrapper


class CheckDependenciesMeta(ABCMeta):
def __new__(mcs, *args, **kwargs):
result = ABCMeta.__new__(mcs, *args, **kwargs)
Expand Down Expand Up @@ -472,62 +449,8 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
prog_bar=True,
)

@predict_context
def predict(
self,
x: Any,
data_source: Optional[str] = None,
input: Optional[str] = None,
deserializer: Optional[Deserializer] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""Predict function for raw data or processed data.
Args:
x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.
input: A string that indicates the format of the data source to use which will override
the current data source format used
deserializer: A single :class:`~flash.core.data.process.Deserializer` to deserialize the input
data_pipeline: Use this to override the current data pipeline
Returns:
The post-processed model predictions
"""
if data_source is not None:
warnings.warn(
"The `data_source` argument has been deprecated since 0.6.0 and will be removed in 0.7.0. Use `input` "
"instead.",
FutureWarning,
)
input = data_source
running_stage = RunningStage.PREDICTING

data_pipeline = self.build_data_pipeline(None, deserializer, data_pipeline)

# <hack> Temporary fix to support new `Input` object
input = data_pipeline._input_transform_pipeline.input_of_name(input or "default")

if (inspect.isclass(input) and issubclass(input, NewInputBase)) or (
isinstance(input, functools.partial) and issubclass(input.func, NewInputBase)
):
dataset = input(running_stage, x, data_pipeline_state=self._data_pipeline_state)
else:
dataset = input.generate_dataset(x, running_stage)
# </hack>

dataloader = self.process_predict_dataset(dataset)
x = list(dataloader.dataset)
x = data_pipeline.worker_input_transform_processor(running_stage, collate_fn=dataloader.collate_fn)(x)
# todo (tchaton): Remove this when sync with Lightning master.
if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3:
x = self.transfer_batch_to_device(x, self.device, 0)
else:
x = self.transfer_batch_to_device(x, self.device)
x = data_pipeline.device_input_transform_processor(running_stage)(x)
x = x[0] if isinstance(x, list) else x
predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict`
predictions = data_pipeline.output_transform_processor(running_stage)(predictions)
return predictions
def predict(self, *args, **kwargs):
raise AttributeError("`flash.Task.predict` has been removed. Use `flash.Trainer.predict` instead.")

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
if isinstance(batch, tuple):
Expand Down
Loading

0 comments on commit 5dd695f

Please sign in to comment.