diff --git a/CHANGELOG.md b/CHANGELOG.md index 7807247782..73d4f07a3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Output` suffix to `Preds`, `FiftyOneDetectionLabels`, `SegmentationLabels`, `FiftyOneDetectionLabels`, `DetectionLabels`, `Classes`, `FiftyOneLabels`, `Labels`, `Logits`, `Probabilities` ([#1011](https://github.com/PyTorchLightning/lightning-flash/pull/1011)) - - Changed `from_files` and `from_folders` from `ObjectDetectionData`, `InstanceSegmentationData`, `KeypointDetectionData` to support only the `predicting` stage ([#1018](https://github.com/PyTorchLightning/lightning-flash/pull/1018)) ### Deprecated @@ -75,8 +74,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948)) -- (Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010)) +- Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010)) +- Removed `Task.predict`, use `Trainer.predict` instead ([#1030](https://github.com/PyTorchLightning/lightning-flash/pull/1030)) ## [0.5.2] - 2021-11-05 diff --git a/README.md b/README.md index 4e8294cbe7..4b959193f5 100644 --- a/README.md +++ b/README.md @@ -128,12 +128,6 @@ model.serve() or make predictions from raw data directly. -```py -predictions = model.predict(["data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png"]) -``` - -or make predictions with 2 GPUs. - ```py trainer = Trainer(accelerator='ddp', gpus=2) dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB") diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 9adc4f9b40..37e08087ea 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -133,4 +133,3 @@ _________ ~flash.core.trainer.from_argparse_args ~flash.core.utilities.apply_func.get_callable_name ~flash.core.utilities.apply_func.get_callable_dict - ~flash.core.model.predict_context diff --git a/docs/source/common/finetuning_example.rst b/docs/source/common/finetuning_example.rst index 60f0d8af7a..4bcd38315b 100644 --- a/docs/source/common/finetuning_example.rst +++ b/docs/source/common/finetuning_example.rst @@ -58,12 +58,13 @@ Once you've finetuned, use the model to predict: # Output predictions as labels, automatically inferred from the training data in part 2. model.output = LabelsOutput() - predictions = model.predict( - [ + predict_datamodule = ImageClassificationData.from_files( + predict_files=[ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg", ] ) + predictions = trainer.predict(model, datamodule=predict_datamodule) print(predictions) We get the following output: @@ -76,19 +77,24 @@ We get the following output: .. testcode:: finetune :hide: - assert all([prediction in ["ants", "bees"] for prediction in predictions]) + assert all( + [all([prediction in ["ants", "bees"] for prediction in prediction_batch]) for prediction_batch in predictions] + ) .. code-block:: - ['bees', 'ants'] + [['bees', 'ants']] Or you can use the saved model for prediction anywhere you want! .. code-block:: python - from flash.image import ImageClassifier + from flash import Trainer + from flash.image import ImageClassifier, ImageClassificationData # load finetuned checkpoint model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") - predictions = model.predict("path/to/your/own/image.png") + trainer = Trainer() + datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"]) + predictions = trainer.predict(model, datamodule=datamodule) diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index 328913215e..e38d12bddb 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -7,17 +7,13 @@ Predictions (inference) You can use Flash to get predictions on pretrained or finetuned models. -Predict on a single sample of data -================================== - -You can pass in a sample of data (image file path, a string of text, etc) to the :func:`~flash.core.model.Task.predict` method. - +First create a :class:`~flash.core.data.data_module.DataModule` with some predict data, then pass it to the :meth:`Trainer.predict ` method. .. code-block:: python + from flash import Trainer from flash.core.data.utils import download_data - from flash.image import ImageClassifier - + from flash.image import ImageClassifier, ImageClassificationData # 1. Download the data set download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") @@ -28,30 +24,13 @@ You can pass in a sample of data (image file path, a string of text, etc) to the ) # 3. Predict whether the image contains an ant or a bee - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") - print(predictions) - - - -Predict on a csv file -===================== - -.. code-block:: python - - from flash.core.data.utils import download_data - from flash.tabular import TabularClassifier - - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") - - # 2. Load the model from a checkpoint - model = TabularClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt" + trainer = Trainer() + datamodule = ImageClassificationData.from_files( + predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"] ) - - # 3. Generate predictions from a csv file! Who would survive? - predictions = model.predict("data/titanic/titanic.csv") + predictions = trainer.predict(model, datamodule=datamodule) print(predictions) + # out: [["bees"]] Serializing predictions @@ -61,7 +40,6 @@ To change the output format of predictions you can attach an :class:`~flash.core :class:`~flash.core.model.Task`. For example, you can choose to output probabilities (for more options see the API reference below). - .. code-block:: python from flash.core.classification import ProbabilitiesOutput @@ -81,6 +59,10 @@ reference below). model.output = ProbabilitiesOutput() # 4. Predict whether the image contains an ant or a bee - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + trainer = Trainer() + datamodule = ImageClassificationData.from_files( + predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"] + ) + predictions = trainer.predict(model, datamodule=datamodule) print(predictions) - # out: [[0.5926494598388672, 0.40735048055648804]] + # out: [[[0.5926494598388672, 0.40735048055648804]]] diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 227c6dcaca..a729684630 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -79,7 +79,7 @@ Inference Inference is the process of generating predictions from trained models. To use a task for inference: 1. Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works. -2. Pass in the data to :func:`flash.core.model.Task.predict`. +2. Load your data into a :class:`~flash.core.data.data_module.DataModule` and pass it to :func:`Trainer.predict `. | @@ -88,19 +88,22 @@ Here's an example of inference: .. testcode:: # import our libraries - from flash.text import TextClassifier + from flash import Trainer + from flash.text import TextClassifier, TextClassificationData # 1. Init the finetuned task from URL model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/text_classification_model.pt") # 2. Perform inference from list of sequences - predictions = model.predict( - [ + trainer = Trainer() + datamodule = TextClassificationData.from_lists( + predict_data=[ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "This guy has done a great job with this movie!", ] ) + predictions = trainer.predict(model, datamodule=datamodule) print(predictions) We get the following output: @@ -113,11 +116,16 @@ We get the following output: .. testcode:: :hide: - assert all([prediction in ["positive", "negative"] for prediction in predictions]) + assert all( + [ + all([prediction in ["positive", "negative"] for prediction in prediction_batch]) + for prediction_batch in predictions + ] + ) .. code-block:: - ["negative", "negative", "positive"] + [["negative", "negative", "positive"]] ------- diff --git a/docs/source/template/tests.rst b/docs/source/template/tests.rst index b06397f99f..a8f5d9184f 100644 --- a/docs/source/template/tests.rst +++ b/docs/source/template/tests.rst @@ -76,7 +76,7 @@ These tests are very similar to ``test_train``, but here they are for completene We also include tests for prediction named ``test_predict_*`` for each of our data sources. In our case, we have ``test_predict_numpy`` and ``test_predict_sklearn``. -These tests should use the ``input`` argument to :meth:`~flash.core.model.Task.predict` to select the required :class:`~flash.core.data.Input`. +These tests should load the data with a :class:`~flash.core.data.data_module.DataModule` and generate predictions with :func:`Trainer.predict `. Here's ``test_predict_sklearn`` as an example: .. literalinclude:: ../../../tests/template/classification/test_model.py diff --git a/flash/core/integrations/labelstudio/visualizer.py b/flash/core/integrations/labelstudio/visualizer.py index d585377acc..d902d43145 100644 --- a/flash/core/integrations/labelstudio/visualizer.py +++ b/flash/core/integrations/labelstudio/visualizer.py @@ -22,8 +22,9 @@ def __init__(self, datamodule: DataModule): def show_predictions(self, predictions): """Converts predictions to Label Studio results.""" results = [] - for pred in predictions: - results.append(self._construct_result(pred)) + for prediction_batch in predictions: + for pred in prediction_batch: + results.append(self._construct_result(pred)) return results def show_tasks(self, predictions, export_json=None): diff --git a/flash/core/model.py b/flash/core/model.py index f847f4efcc..f9335ea03b 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -14,7 +14,6 @@ import functools import inspect import pickle -import warnings from abc import ABCMeta from copy import deepcopy from importlib import import_module @@ -41,7 +40,6 @@ from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.io.input import Input -from flash.core.data.io.input_base import InputBase as NewInputBase from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform @@ -262,27 +260,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul print("Benchmark Successful!") -def predict_context(func: Callable) -> Callable: - """This decorator is used as context manager to put model in eval mode before running predict and reset to - train after.""" - - @functools.wraps(func) - def wrapper(self, *args, **kwargs) -> Any: - grad_enabled = torch.is_grad_enabled() - is_training = self.training - self.eval() - torch.set_grad_enabled(False) - - result = func(self, *args, **kwargs) - - if is_training: - self.train() - torch.set_grad_enabled(grad_enabled) - return result - - return wrapper - - class CheckDependenciesMeta(ABCMeta): def __new__(mcs, *args, **kwargs): result = ABCMeta.__new__(mcs, *args, **kwargs) @@ -472,62 +449,8 @@ def test_step(self, batch: Any, batch_idx: int) -> None: prog_bar=True, ) - @predict_context - def predict( - self, - x: Any, - data_source: Optional[str] = None, - input: Optional[str] = None, - deserializer: Optional[Deserializer] = None, - data_pipeline: Optional[DataPipeline] = None, - ) -> Any: - """Predict function for raw data or processed data. - - Args: - x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. - input: A string that indicates the format of the data source to use which will override - the current data source format used - deserializer: A single :class:`~flash.core.data.process.Deserializer` to deserialize the input - data_pipeline: Use this to override the current data pipeline - - Returns: - The post-processed model predictions - """ - if data_source is not None: - warnings.warn( - "The `data_source` argument has been deprecated since 0.6.0 and will be removed in 0.7.0. Use `input` " - "instead.", - FutureWarning, - ) - input = data_source - running_stage = RunningStage.PREDICTING - - data_pipeline = self.build_data_pipeline(None, deserializer, data_pipeline) - - # Temporary fix to support new `Input` object - input = data_pipeline._input_transform_pipeline.input_of_name(input or "default") - - if (inspect.isclass(input) and issubclass(input, NewInputBase)) or ( - isinstance(input, functools.partial) and issubclass(input.func, NewInputBase) - ): - dataset = input(running_stage, x, data_pipeline_state=self._data_pipeline_state) - else: - dataset = input.generate_dataset(x, running_stage) - # - - dataloader = self.process_predict_dataset(dataset) - x = list(dataloader.dataset) - x = data_pipeline.worker_input_transform_processor(running_stage, collate_fn=dataloader.collate_fn)(x) - # todo (tchaton): Remove this when sync with Lightning master. - if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: - x = self.transfer_batch_to_device(x, self.device, 0) - else: - x = self.transfer_batch_to_device(x, self.device) - x = data_pipeline.device_input_transform_processor(running_stage)(x) - x = x[0] if isinstance(x, list) else x - predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict` - predictions = data_pipeline.output_transform_processor(running_stage)(predictions) - return predictions + def predict(self, *args, **kwargs): + raise AttributeError("`flash.Task.predict` has been removed. Use `flash.Trainer.predict` instead.") def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: if isinstance(batch, tuple): diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index c15a489235..650b51ccb8 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -25,7 +25,7 @@ from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.image.classification.data import ImageClassificationFilesInput, ImageClassificationFolderInput -from flash.image.data import ImageInput +from flash.image.data import ImageFilesInput if _TORCHVISION_AVAILABLE: import torchvision @@ -48,7 +48,6 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence if DataKeys.TARGET in samples.keys(): targets = samples[DataKeys.TARGET] - targets = [{"target_boxes": target["boxes"]} for target in targets] for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): target["target_boxes"] *= scale @@ -62,15 +61,14 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence return samples -class FastFaceInput(ImageInput): +class FastFaceInput(ImageFilesInput): """Logic for loading from FDDBDataset.""" def load_data(self, dataset: Dataset) -> List[Dict[str, Any]]: return [ { DataKeys.INPUT: filepath, - "boxes": targets["target_boxes"], - "labels": [1] * targets["target_boxes"].shape[0], + DataKeys.TARGET: targets, } for filepath, targets in zip(dataset.ids, dataset.targets) ] @@ -112,10 +110,7 @@ def default_transforms(self) -> Dict[str, Callable]: ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys( DataKeys.TARGET, - nn.Sequential( - ApplyToKeys("boxes", torch.as_tensor), - ApplyToKeys("labels", torch.as_tensor), - ), + ApplyToKeys("target_boxes", torch.as_tensor), ), ), "collate": fastface_collate_fn, @@ -174,3 +169,29 @@ def from_datasets( output_transform=cls.output_transform_cls(), **data_module_kwargs, ) + + @classmethod + def from_files( + cls, + predict_files: Optional[Sequence[str]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + **data_module_kwargs: Any, + ) -> "FaceDetectionData": + return cls( + predict_dataset=ImageClassificationFilesInput(RunningStage.PREDICTING, predict_files), + input_transform=cls.input_transform_cls(predict_transform=predict_transform), + **data_module_kwargs, + ) + + @classmethod + def from_folders( + cls, + predict_folder: Optional[str] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + **data_module_kwargs: Any, + ) -> "FaceDetectionData": + return cls( + predict_dataset=ImageClassificationFolderInput(RunningStage.PREDICTING, predict_folder), + input_transform=cls.input_transform_cls(predict_transform=predict_transform), + **data_module_kwargs, + ) diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 28cbe3a488..d7f1cb8f1e 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -78,7 +78,7 @@ def __init__( self.save_hyperparameters() if model in ff.list_pretrained_models(): - self.model = FaceDetector.get_model(model, pretrained, **kwargs) + model = FaceDetector.get_model(model, pretrained, **kwargs) else: ValueError(model + f" is not supported yet, please select one from {ff.list_pretrained_models()}") @@ -108,11 +108,11 @@ def get_model( model.register_buffer("mean", getattr(pl_model, "mean")) model.register_buffer("std", getattr(pl_model, "std")) - # copy pasting `_output_transform` function from `fastface.FaceDetector` to `torch.nn.Module` + # copy pasting `_postprocess` function from `fastface.FaceDetector` to `torch.nn.Module` # set output_transform function # this is called from FaceDetector lightning module form fastface itself # https://github.com/borhanMorphy/fastface/blob/master/fastface/module.py#L200 - setattr(model, "_output_transform", getattr(pl_model, "_output_transform")) + setattr(model, "_postprocess", getattr(pl_model, "_postprocess")) return model diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 92cccb8e86..ef36fc56d3 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from torch.utils.data import Dataset @@ -116,6 +116,29 @@ def from_folders( **data_module_kwargs, ) + @classmethod + def from_files( + cls, + predict_files: Optional[List[str]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + scans_folder_name: Optional[str] = "scans", + labels_folder_name: Optional[str] = "labels", + calibrations_folder_name: Optional[str] = "calibs", + data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, + **data_module_kwargs: Any, + ) -> "PointCloudObjectDetectorData": + ds_kw = dict( + scans_folder_name=scans_folder_name, + labels_folder_name=labels_folder_name, + calibrations_folder_name=calibrations_folder_name, + data_format=data_format, + ) + return cls( + predict_dataset=PointCloudObjectDetectorFoldersInput(RunningStage.PREDICTING, predict_files, **ds_kw), + input_transform=cls.input_transform_cls(predict_transform), + **data_module_kwargs, + ) + @classmethod def from_datasets( cls, diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index dbaddcf4c1..1843bd877f 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -152,14 +152,15 @@ def show_predictions(self, predictions): lut.add_label(id, id, color=color) viz.set_lut("label", lut) - for pred in predictions: - data = { - "points": pred[DataKeys.INPUT][:, :3], - "name": pred[DataKeys.METADATA], - } - bounding_box = pred[DataKeys.PREDS] - - viz.visualize([data], bounding_boxes=bounding_box) + for prediction_batch in predictions: + for pred in prediction_batch: + data = { + "points": pred[DataKeys.INPUT][:, :3], + "name": pred[DataKeys.METADATA], + } + bounding_box = pred[DataKeys.PREDS] + + viz.visualize([data], bounding_boxes=bounding_box) def launch_app(datamodule: DataModule) -> "App": diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 4459bf704c..c0206e5086 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from torch.utils.data import Dataset @@ -114,6 +114,19 @@ def from_folders( **data_module_kwargs, ) + @classmethod + def from_files( + cls, + predict_files: Optional[List[str]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + **data_module_kwargs: Any, + ) -> "PointCloudSegmentationData": + return cls( + predict_dataset=PointCloudSegmentationFoldersInput(RunningStage.PREDICTING, predict_files), + input_transform=cls.input_transform_cls(predict_transform), + **data_module_kwargs, + ) + @classmethod def from_datasets( cls, diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index e9fbfd4c97..f904321dc9 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -83,15 +83,16 @@ def show_predictions(self, predictions): color_map = dataset.color_map predictions_visualizations = [] - for pred in predictions: - predictions_visualizations.append( - { - "points": pred[DataKeys.INPUT], - "labels": pred[DataKeys.TARGET], - "predictions": torch.argmax(pred[DataKeys.PREDS], axis=-1) + 1, - "name": pred[DataKeys.METADATA]["name"], - } - ) + for prediction_batch in predictions: + for pred in prediction_batch: + predictions_visualizations.append( + { + "points": pred[DataKeys.INPUT], + "labels": pred[DataKeys.TARGET], + "predictions": torch.argmax(pred[DataKeys.PREDS], axis=-1) + 1, + "name": pred[DataKeys.METADATA]["name"], + } + ) viz = Visualizer() lut = LabelLUT() diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index d5bdda536a..3a5368b5e9 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -20,7 +20,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch from torch import Tensor @@ -542,7 +542,7 @@ def __init__( answer_column_name=answer_column_name, doc_stride=doc_stride, ), - "dict": QuestionAnsweringDictionaryInput( + "dicts": QuestionAnsweringDictionaryInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -560,7 +560,7 @@ def __init__( doc_stride=doc_stride, ), }, - default_input="dict", + default_input="dicts", ) self.set_state(QuestionAnsweringBackboneState(self.backbone)) @@ -910,3 +910,29 @@ def from_csv( sampler=sampler, **input_transform_kwargs, ) + + @classmethod + def from_dicts( + cls, + predict_data: Optional[Dict[str, Any]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + input_transform: Optional[InputTransform] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + **input_transform_kwargs: Any, + ) -> "DataModule": + return cls.from_input( + "dicts", + predict_data=predict_data, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + input_transform=input_transform, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **input_transform_kwargs, + ) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index c4f6d1958f..e539ac52b0 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -13,12 +13,14 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch from torch import Tensor +from torch.utils.data import Sampler import flash +from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input, InputFormat from flash.core.data.io.input_transform import InputTransform @@ -288,7 +290,7 @@ def __init__( padding=padding, **backbone_kwargs, ), - "sentences": Seq2SeqSentencesInput( + InputFormat.LISTS: Seq2SeqSentencesInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -296,7 +298,7 @@ def __init__( **backbone_kwargs, ), }, - default_input="sentences", + default_input=InputFormat.LISTS, deserializer=TextDeserializer(backbone, max_source_length), ) @@ -362,3 +364,29 @@ class Seq2SeqData(DataModule): input_transform_cls = Seq2SeqInputTransform output_transform_cls = Seq2SeqOutputTransform + + @classmethod + def from_lists( + cls, + predict_data: Optional[List[str]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + input_transform: Optional[InputTransform] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + **input_transform_kwargs: Any, + ) -> "DataModule": + return cls.from_input( + InputFormat.LISTS, + predict_data=predict_data, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + input_transform=input_transform, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **input_transform_kwargs, + ) diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index e879cd92f5..99e878ffc3 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -35,13 +35,14 @@ trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 1)) # 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c -predictions = model.predict( - [ +datamodule = AudioClassificationData.from_files( + predict_files=[ "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/face_detection.py b/flash_examples/face_detection.py index 9762cadb01..e4ed1406a5 100644 --- a/flash_examples/face_detection.py +++ b/flash_examples/face_detection.py @@ -30,17 +30,18 @@ model = FaceDetector(model="lffd_slim") # # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count(), fast_dev_run=True) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect faces in a few images! -predictions = model.predict( - [ +datamodule = FaceDetectionData.from_files( + predict_files=[ "data/2002/07/19/big/img_18.jpg", "data/2002/07/19/big/img_65.jpg", "data/2002/07/19/big/img_255.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # # 5. Save the model! diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py index a975bcbe4b..39a66fc54d 100644 --- a/flash_examples/graph_classification.py +++ b/flash_examples/graph_classification.py @@ -39,7 +39,8 @@ trainer.fit(model, datamodule=datamodule) # 4. Classify some graphs! -predictions = model.predict(dataset[:3]) +datamodule = GraphClassificationData.from_datasets(predict_dataset=dataset[:3]) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/graph_embedder.py b/flash_examples/graph_embedder.py index 7646b0f5c8..5d0af766fd 100644 --- a/flash_examples/graph_embedder.py +++ b/flash_examples/graph_embedder.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + +import flash from flash.core.utilities.imports import example_requires -from flash.graph import GraphEmbedder +from flash.graph import GraphClassificationData, GraphEmbedder example_requires("graph") @@ -25,5 +28,7 @@ model = GraphEmbedder.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/graph_classification_model.pt") # 3. Generate embeddings for the first 3 graphs -predictions = model.predict(dataset[:3]) +trainer = flash.Trainer(gpus=torch.cuda.device_count()) +datamodule = GraphClassificationData.from_datasets(predict_dataset=dataset[:3]) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index 3b9413a629..f28c956972 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -33,13 +33,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict what's on a few images! ants or bees? -predictions = model.predict( - [ +datamodule = ImageClassificationData.from_files( + predict_files=[ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index cb51698038..fc809f4bd5 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -47,13 +47,14 @@ def resolver(root, file_id): trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict the genre of a few movies! -predictions = model.predict( - [ +datamodule = ImageClassificationData.from_files( + predict_files=[ "data/movie_posters/predict/tt0085318.jpg", "data/movie_posters/predict/tt0089461.jpg", "data/movie_posters/predict/tt0097179.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index 72e81e2bde..5f272a834d 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -46,11 +46,13 @@ # 5. Download the downstream prediction dataset and generate embeddings download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") -embeddings = embedder.predict( - [ +datamodule = ImageClassificationData.from_files( + predict_files=[ "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg", "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg", ] ) +embeddings = trainer.predict(embedder, datamodule=datamodule) + # list of embeddings for images sent to the predict function print(embeddings) diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py index 76ee2e89e8..86595eaefa 100644 --- a/flash_examples/instance_segmentation.py +++ b/flash_examples/instance_segmentation.py @@ -42,13 +42,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict( - [ +datamodule = InstanceSegmentationData.from_files( + predict_files=[ str(data_dir / "images/yorkshire_terrier_9.jpg"), str(data_dir / "images/yorkshire_terrier_12.jpg"), str(data_dir / "images/yorkshire_terrier_13.jpg"), ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/integrations/baal/image_classification_active_learning.py b/flash_examples/integrations/baal/image_classification_active_learning.py index 65864dcc3c..c5e7c828a7 100644 --- a/flash_examples/integrations/baal/image_classification_active_learning.py +++ b/flash_examples/integrations/baal/image_classification_active_learning.py @@ -38,7 +38,6 @@ backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=ProbabilitiesOutput() ) - # 3.1 Create the trainer trainer = flash.Trainer(max_epochs=3) @@ -51,7 +50,10 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict what's on a few images! ants or bees? -predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") +datamodule = ImageClassificationData.from_files( + predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"] +) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/flash_examples/integrations/labelstudio/image_classification.py index 653efb9751..b36fe0bab4 100644 --- a/flash_examples/integrations/labelstudio/image_classification.py +++ b/flash_examples/integrations/labelstudio/image_classification.py @@ -32,12 +32,13 @@ model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") model.output = LabelsOutput() -predictions = model.predict( - [ +datamodule = ImageClassificationData.from_files( + predict_files=[ "data/test/1.jpg", "data/test/2.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) # 5. Visualize predictions app = launch_app(datamodule) diff --git a/flash_examples/integrations/labelstudio/text_classification.py b/flash_examples/integrations/labelstudio/text_classification.py index 930b75bc07..d89254de23 100644 --- a/flash_examples/integrations/labelstudio/text_classification.py +++ b/flash_examples/integrations/labelstudio/text_classification.py @@ -22,13 +22,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Classify a few sentences! How was the movie? -predictions = model.predict( - [ +datamodule = TextClassificationData.from_lists( + predict_data=[ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado.", ] ) +predictions = trainer.predict(model, datamodule=datamodule) # 5. Save the model! trainer.save_checkpoint("text_classification_model.pt") diff --git a/flash_examples/integrations/labelstudio/video_classification.py b/flash_examples/integrations/labelstudio/video_classification.py index c9e76c88ab..d50102cde8 100644 --- a/flash_examples/integrations/labelstudio/video_classification.py +++ b/flash_examples/integrations/labelstudio/video_classification.py @@ -29,8 +29,8 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 5. Make a prediction -predictions = model.predict(os.path.join(os.getcwd(), "data/test")) -print(predictions) +datamodule = VideoClassificationData.from_folders(predict_folder=os.path.join(os.getcwd(), "data/test")) +predictions = trainer.predict(model, datamodule=datamodule) # 6. Save the model! trainer.save_checkpoint("video_classification.pt") diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py index 012fda8444..10ecc6b305 100644 --- a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -60,7 +60,8 @@ trainer.fit(model, datamodule=datamodule) # 4. Generate predictions -predictions = model.predict(data) +datamodule = TabularForecastingData.from_data_frame(predict_data_frame=data, parameters=datamodule.parameters) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # Plot with PyTorch Forecasting! diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py index b1fa29cc02..cc7e4aa0d2 100644 --- a/flash_examples/keypoint_detection.py +++ b/flash_examples/keypoint_detection.py @@ -41,13 +41,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict( - [ +datamodule = KeypointDetectionData.from_files( + predict_files=[ str(data_dir / "biwi_sample/images/0.jpg"), str(data_dir / "biwi_sample/images/1.jpg"), str(data_dir / "biwi_sample/images/10.jpg"), ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 1a5dddbce9..e10b813253 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -34,13 +34,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict( - [ +datamodule = ObjectDetectionData.from_files( + predict_files=[ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py index ff29265355..9e9b6f6f82 100644 --- a/flash_examples/pointcloud_detection.py +++ b/flash_examples/pointcloud_detection.py @@ -36,12 +36,14 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict( - [ +datamodule = PointCloudObjectDetectorData.from_files( + predict_files=[ "data/KITTI_Tiny/Kitti/predict/scans/000000.bin", "data/KITTI_Tiny/Kitti/predict/scans/000001.bin", ] ) +predictions = trainer.predict(model, datamodule=datamodule) +print(predictions) # 5. Save the model! trainer.save_checkpoint("pointcloud_detection_model.pt") diff --git a/flash_examples/pointcloud_segmentation.py b/flash_examples/pointcloud_segmentation.py index 7d1a0eb538..69e1467c9d 100644 --- a/flash_examples/pointcloud_segmentation.py +++ b/flash_examples/pointcloud_segmentation.py @@ -36,12 +36,14 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict( - [ +datamodule = PointCloudSegmentationData.from_files( + predict_files=[ "data/SemanticKittiTiny/predict/000000.bin", "data/SemanticKittiTiny/predict/000001.bin", ] ) +predictions = trainer.predict(model, datamodule=datamodule) +print(predictions) # 5. Save the model! trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/question_answering.py b/flash_examples/question_answering.py index 8620e5aed1..76b317b797 100644 --- a/flash_examples/question_answering.py +++ b/flash_examples/question_answering.py @@ -31,8 +31,8 @@ trainer.finetune(model, datamodule=datamodule) # 4. Answer some Questions! -predictions = model.predict( - { +datamodule = QuestionAnsweringData.from_dicts( + predict_data={ "id": ["56ddde6b9a695914005b9629", "56ddde6b9a695914005b9628"], "context": [ """ @@ -59,6 +59,7 @@ "question": ["When were the Normans in Normandy?", "In what country is Normandy located?"], } ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/semantic_segmentation.py b/flash_examples/semantic_segmentation.py index a1f6b76f62..282ec7786b 100644 --- a/flash_examples/semantic_segmentation.py +++ b/flash_examples/semantic_segmentation.py @@ -45,13 +45,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Segment a few images! -predictions = model.predict( - [ +datamodule = SemanticSegmentationData.from_files( + predict_files=[ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", "data/CameraRGB/F63-1.png", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py index 4fa00feed0..d7cf4cb264 100644 --- a/flash_examples/speech_recognition.py +++ b/flash_examples/speech_recognition.py @@ -35,7 +35,8 @@ trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") # 4. Predict on audio files! -predictions = model.predict(["data/timit/example.wav"]) +datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"]) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/style_transfer.py b/flash_examples/style_transfer.py index 607f5ad0f6..047b044e71 100644 --- a/flash_examples/style_transfer.py +++ b/flash_examples/style_transfer.py @@ -32,13 +32,14 @@ trainer.fit(model, datamodule=datamodule) # 4. Apply style transfer to a few images! -predictions = model.predict( - [ +datamodule = StyleTransferData.from_files( + predict_files=[ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/summarization.py b/flash_examples/summarization.py index 5433805be3..6d685c314a 100644 --- a/flash_examples/summarization.py +++ b/flash_examples/summarization.py @@ -33,28 +33,31 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Summarize some text! -predictions = model.predict( - """ - Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local - people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. - They came to Brixton to see work which has started to revitalise the borough. - It was Charles' first visit to the area since 1996, when he was accompanied by the former - South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue - for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. - ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. - She asked me were they ripe and I said yes - they're from the Dominican Republic."" - Mr Chong is one of 170 local retailers who accept the Brixton Pound. - Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market - or in participating shops. - During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children - nearby on an estate off Coldharbour Lane. Mr West said: - ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" - He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" - Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. - The trust hopes to restore and refurbish the building, - where once Jimi Hendrix and The Clash played, as a new community and business centre." - """ +datamodule = SummarizationData.from_lists( + predict_data=[ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ + ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index ef80723afa..03ace27164 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -36,7 +36,11 @@ trainer.fit(model, datamodule=datamodule) # 4. Generate predictions from a CSV -predictions = model.predict("data/titanic/titanic.csv") +datamodule = TabularClassificationData.from_csv( + predict_file="data/titanic/titanic.csv", + parameters=datamodule.parameters, +) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 3dfc26a349..88949212f2 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -58,7 +58,8 @@ trainer.fit(model, datamodule=datamodule) # 4. Generate predictions -predictions = model.predict(data) +datamodule = TabularForecastingData.from_data_frame(predict_data_frame=data, parameters=datamodule.parameters) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py index 1461371e0c..15cbd1bf0d 100644 --- a/flash_examples/tabular_regression.py +++ b/flash_examples/tabular_regression.py @@ -46,7 +46,11 @@ trainer.fit(model, datamodule=datamodule) # 4. Generate predictions from a CSV -predictions = model.predict("data/SeoulBikeData.csv") +datamodule = TabularRegressionData.from_csv( + predict_file="data/SeoulBikeData.csv", + parameters=datamodule.parameters, +) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/template.py b/flash_examples/template.py index 0d8c7016ed..5cb6418862 100644 --- a/flash_examples/template.py +++ b/flash_examples/template.py @@ -32,13 +32,14 @@ trainer.fit(model, datamodule=datamodule) # 4. Classify a few examples -predictions = model.predict( - [ +datamodule = TemplateData.from_numpy( + predict_data=[ np.array([4.9, 3.0, 1.4, 0.2]), np.array([6.9, 3.2, 5.7, 2.3]), np.array([7.2, 3.0, 5.8, 1.6]), ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 9e8a0b6856..f6ebfd838f 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -36,13 +36,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Classify a few sentences! How was the movie? -predictions = model.predict( - [ +datamodule = TextClassificationData.from_lists( + predict_data=[ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado.", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification_multi_label.py b/flash_examples/text_classification_multi_label.py index 72f87b7c81..a66a8334ef 100644 --- a/flash_examples/text_classification_multi_label.py +++ b/flash_examples/text_classification_multi_label.py @@ -42,13 +42,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Generate predictions for a few comments! -predictions = model.predict( - [ +datamodule = TextClassificationData.from_lists( + predict_data=[ "No, he is an arrogant, self serving, immature idiot. Get it right.", "U SUCK HANNAH MONTANA", "Would you care to vote? Thx.", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/translation.py b/flash_examples/translation.py index 30f7c3053a..c9d3b529ab 100644 --- a/flash_examples/translation.py +++ b/flash_examples/translation.py @@ -36,13 +36,14 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Translate something! -predictions = model.predict( - [ +datamodule = TranslationData.from_lists( + predict_data=[ "BBC News went to meet one of the project's first graduates.", "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", "Of course, it's still early in the election cycle.", ] ) +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/video_classification.py b/flash_examples/video_classification.py index 2e36161b05..54682dcd24 100644 --- a/flash_examples/video_classification.py +++ b/flash_examples/video_classification.py @@ -37,7 +37,8 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Make a prediction -predictions = model.predict("data/kinetics/predict") +datamodule = VideoClassificationData.from_folders(predict_folder="data/kinetics/predict") +predictions = trainer.predict(model, datamodule=datamodule) print(predictions) # 5. Save the model! diff --git a/flash_examples/visualizations/pointcloud_detection.py b/flash_examples/visualizations/pointcloud_detection.py index 899e30a3aa..50b4e62909 100644 --- a/flash_examples/visualizations/pointcloud_detection.py +++ b/flash_examples/visualizations/pointcloud_detection.py @@ -36,7 +36,8 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict(["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"]) +datamodule = PointCloudObjectDetectorData.from_files(predict_files=["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"]) +predictions = trainer.predict(model, datamodule=datamodule) # 5. Save the model! trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py index c50ea7b958..c2486592b5 100644 --- a/flash_examples/visualizations/pointcloud_segmentation.py +++ b/flash_examples/visualizations/pointcloud_segmentation.py @@ -36,12 +36,13 @@ trainer.fit(model, datamodule) # 4. Predict what's within a few PointClouds? -predictions = model.predict( - [ +datamodule = PointCloudSegmentationData.from_files( + predict_files=[ "data/SemanticKittiTiny/predict/000000.bin", "data/SemanticKittiTiny/predict/000001.bin", ] ) +predictions = trainer.predict(model, datamodule=datamodule) # 5. Save the model! trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb index 16e180b0c9..c0f7d2a58b 100644 --- a/flash_notebooks/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -275,7 +275,7 @@ "id": "individual-recipe", "metadata": {}, "source": [ - "### 2a. Predict what's on a few images! ants or bees?" + "### 2. Predict what's on a few images! ants or bees?" ] }, { @@ -285,31 +285,14 @@ "metadata": {}, "outputs": [], "source": [ - "predictions = model.predict([\n", - " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", - " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", - " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", - "])\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "unique-humanitarian", - "metadata": {}, - "source": [ - "### 2b. Or generate predictions with a whole folder!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "empirical-wound", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = ImageClassificationData.from_folders(predict_folder=\"data/hymenoptera_data/predict/\")\n", - "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", + "datamodule = ImageClassificationData.from_files(\n", + " predict_files=[\n", + " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", + " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", + " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", + " ]\n", + ")\n", + "predictions = trainer.predict(model, datamodule=datamodule)\n", "print(predictions)" ] }, diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 43d6ca454f..99bfa7f61a 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -54,7 +54,7 @@ "\n", "import flash\n", "from flash.core.data.utils import download_data\n", - "from flash.tabular import TabularClassifier, TabularData" + "from flash.tabular import TabularClassifier, TabularClassificationData" ] }, { @@ -94,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "datamodule = TabularData.from_csv(\n", + "datamodule = TabularClassificationData.from_csv(\n", " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", " [\"Fare\"],\n", " target_fields=\"Survived\",\n", @@ -242,7 +242,11 @@ "metadata": {}, "outputs": [], "source": [ - "predictions = model.predict(\"data/titanic/titanic.csv\")" + "datamodule = TabularClassificationData.from_csv(\n", + " predict_file=\"data/titanic/titanic.csv\",\n", + " parameters=datamodule.parameters,\n", + ")\n", + "predictions = trainer.predict(model, datamodule=datamodule)" ] }, { diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb index 183695e8db..29fae78dcb 100644 --- a/flash_notebooks/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -266,7 +266,7 @@ "id": "worst-consumer", "metadata": {}, "source": [ - "### 2a. Classify a few sentences! How was the movie?" + "### 2. Classify a few sentences! How was the movie?" ] }, { @@ -276,37 +276,14 @@ "metadata": {}, "outputs": [], "source": [ - "predictions = model.predict([\n", - " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", - " \"The worst movie in the history of cinema.\",\n", - " \"I come from Bulgaria where it 's almost impossible to have a tornado.\",\n", - " \"Very, very afraid\",\n", - " \"This guy has done a great job with this movie!\",\n", - "])\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "id": "limited-culture", - "metadata": {}, - "source": [ - "### 2b. Or generate predictions from a sheet file!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "persistent-formula", - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = TextClassificationData.from_csv(\n", - " \"review\",\n", - " predict_file=\"data/imdb/predict.csv\",\n", - " backbone=\"prajjwal1/bert-tiny\",\n", + "datamodule = TextClassificationData.from_lists(\n", + " predict_data=[\n", + " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", + " \"The worst movie in the history of cinema.\",\n", + " \"I come from Bulgaria where it 's almost impossible to have a tornado.\",\n", + " ]\n", ")\n", - "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", + "predictions = trainer.predict(model, datamodule=datamodule)\n", "print(predictions)" ] }, diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index 6999a817a6..b007bd4067 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -215,7 +215,7 @@ def test_label_studio_predictions_visualization(): ) assert datamodule app = launch_app(datamodule) - predictions = [0, 1, 1, 0] + predictions = [[0, 1], [1, 0]] vis_predictions = app.show_predictions(predictions) assert len(vis_predictions) == 4 assert vis_predictions[0]["result"][0]["id"] != vis_predictions[3]["result"][0]["id"] diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a21615545a..1da4e8dabd 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -16,11 +16,8 @@ from copy import deepcopy from itertools import chain from numbers import Number -from pathlib import Path from typing import Any, Tuple -from unittest import mock -import numpy as np import pytest import pytorch_lightning as pl import torch @@ -35,11 +32,10 @@ from flash.audio import SpeechRecognition from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask -from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image +from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE from flash.graph import GraphClassifier, GraphEmbedder -from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation +from flash.image import ImageClassifier, SemanticSegmentation from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier, TranslationTask from tests.helpers.utils import _AUDIO_TESTING, _GRAPH_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING @@ -160,6 +156,13 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] +def test_task_predict_raises(): + with pytest.raises(AttributeError, match="`flash.Task.predict` has been removed."): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + task = ClassificationTask(model, loss_fn=F.nll_loss) + task.predict("args", kwarg="test") + + @pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) def test_nested_tasks(tmpdir, task): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) @@ -175,45 +178,6 @@ def test_nested_tasks(tmpdir, task): assert "test_nll_loss" in result[0] -def test_classificationtask_task_predict(): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) - task = ClassificationTask(model, input_transform=DefaultInputTransform()) - ds = DummyDataset() - expected = list(range(10)) - # single item - x0, _ = ds[0] - pred0 = task.predict(x0) - assert pred0[0] in expected - # list - x1, _ = ds[1] - pred1 = task.predict([x0, x1]) - assert all(c in expected for c in pred1) - assert pred0[0] == pred1[0] - - -@mock.patch("flash._IS_TESTING", True) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_classification_task_predict_folder_path(tmpdir): - train_dir = Path(tmpdir / "train") - train_dir.mkdir() - - def _rand_image(): - return Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype="uint8")) - - _rand_image().save(train_dir / "1.png") - _rand_image().save(train_dir / "2.png") - - datamodule = ImageClassificationData.from_folders(predict_folder=train_dir) - - task = ImageClassifier(num_classes=10) - predictions = task.predict( - str(train_dir), - input="folders", - data_pipeline=datamodule.data_pipeline, - ) - assert len(predictions) == 2 - - def test_classification_task_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 3a5d7f2c6f..470224a01b 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -73,9 +73,11 @@ def test_predict_dataset(tmpdir): """Tests that we can generate predictions from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) - data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) - out = model.predict(tudataset, input="datasets", data_pipeline=data_pipe) - assert isinstance(out[0], int) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) + predict_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + out = trainer.predict(model, predict_dl) + assert isinstance(out[0][0], int) @pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index c06da323e9..e68af709e8 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -58,6 +58,8 @@ def test_predict_dataset(tmpdir): model = GraphEmbedder( GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes).backbone ) - data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) - out = model.predict(tudataset, input="datasets", data_pipeline=data_pipe) - assert isinstance(out[0], torch.Tensor) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) + predict_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + out = trainer.predict(model, predict_dl) + assert isinstance(out[0][0], torch.Tensor) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 42ea267362..7165398af8 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -108,12 +108,10 @@ def test_multilabel(tmpdir): train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=5) trainer.finetune(model, train_dl, strategy=("freeze_unfreeze", 1)) - image, label = ds[0][DataKeys.INPUT], ds[0][DataKeys.TARGET] - predictions = model.predict([{DataKeys.INPUT: image}]) + predictions = trainer.predict(model, train_dl)[0] assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 - assert len(predictions[0]) == num_classes == len(label) - assert len(torch.unique(label)) <= 2 + assert len(predictions[0]) == num_classes @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 3343b64aa7..4a06daeb22 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -46,12 +46,12 @@ def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) - data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) - model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) + datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + model = ObjectDetector(head=head, backbone=backbone, num_classes=datamodule.num_classes) trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data, strategy="freeze") + trainer.finetune(model, datamodule=datamodule, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") @@ -59,8 +59,8 @@ def test_detection(tmpdir, head, backbone): Image.new("RGB", (512, 512)).save(test_image_one) Image.new("RGB", (512, 512)).save(test_image_two) - test_images = [str(test_image_one), str(test_image_two)] - model.predict(test_images) + datamodule = ObjectDetectionData.from_files(predict_files=[str(test_image_one), str(test_image_two)]) + trainer.predict(model, datamodule=datamodule) @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -70,12 +70,12 @@ def test_detection_fiftyone(tmpdir, head, backbone): train_dataset = _create_synth_fiftyone_dataset(tmpdir) - data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) - model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) + datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + model = ObjectDetector(head=head, backbone=backbone, num_classes=datamodule.num_classes) trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data, strategy="freeze") + trainer.finetune(model, datamodule, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") @@ -83,5 +83,5 @@ def test_detection_fiftyone(tmpdir, head, backbone): Image.new("RGB", (512, 512)).save(test_image_one) Image.new("RGB", (512, 512)).save(test_image_two) - test_images = [str(test_image_one), str(test_image_two)] - model.predict(test_images) + datamodule = ObjectDetectionData.from_files(predict_files=[str(test_image_one), str(test_image_two)]) + trainer.predict(model, datamodule=datamodule) diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segmentation/test_model.py index d432889639..1f7833682e 100644 --- a/tests/image/instance_segmentation/test_model.py +++ b/tests/image/instance_segmentation/test_model.py @@ -68,24 +68,26 @@ def test_instance_segmentation_inference(tmpdir): trainer = flash.Trainer(max_epochs=1, fast_dev_run=True) trainer.finetune(model, datamodule=datamodule, strategy="freeze") - predictions = model.predict( - [ + datamodule = InstanceSegmentationData.from_files( + predict_files=[ str(data_dir / "images/yorkshire_terrier_9.jpg"), str(data_dir / "images/yorkshire_terrier_12.jpg"), str(data_dir / "images/yorkshire_terrier_13.jpg"), ] ) - assert len(predictions) == 3 + predictions = trainer.predict(model, datamodule=datamodule) + assert len(predictions[0]) == 3 model_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(model_path) InstanceSegmentation.load_from_checkpoint(model_path) - predictions = model.predict( - [ + datamodule = InstanceSegmentationData.from_files( + predict_files=[ str(data_dir / "images/yorkshire_terrier_9.jpg"), str(data_dir / "images/yorkshire_terrier_12.jpg"), - str(data_dir / "images/yorkshire_terrier_15.jpg"), + str(data_dir / "images/yorkshire_terrier_13.jpg"), ] ) - assert len(predictions) == 3 + predictions = trainer.predict(model, datamodule=datamodule) + assert len(predictions[0]) == 3 diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 82e767bdaa..67cf0946e8 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -22,11 +22,10 @@ from flash import Trainer from flash.__main__ import main -from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import SemanticSegmentation -from flash.image.segmentation.data import SemanticSegmentationInputTransform +from flash.image.segmentation.data import SemanticSegmentationData, SemanticSegmentationInputTransform from tests.helpers.utils import _IMAGE_TESTING, _SERVE_TESTING # ======== Mock functions ======== @@ -106,22 +105,24 @@ def test_unfreeze(): def test_predict_tensor(): img = torch.rand(1, 3, 64, 64) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") - data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform()) - out = model.predict(img, input="tensors", data_pipeline=data_pipe) - assert isinstance(out[0], list) - assert len(out[0]) == 64 + datamodule = SemanticSegmentationData.from_tensors(predict_data=img) + trainer = Trainer() + out = trainer.predict(model, datamodule=datamodule) + assert isinstance(out[0][0], list) assert len(out[0][0]) == 64 + assert len(out[0][0][0]) == 64 @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") - data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform()) - out = model.predict(img, input="numpy", data_pipeline=data_pipe) - assert isinstance(out[0], list) - assert len(out[0]) == 64 + datamodule = SemanticSegmentationData.from_numpy(predict_data=img) + trainer = Trainer() + out = trainer.predict(model, datamodule=datamodule) + assert isinstance(out[0][0], list) assert len(out[0][0]) == 64 + assert len(out[0][0][0]) == 64 @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py index 56ebf4b078..510a5e47a8 100644 --- a/tests/pointcloud/detection/test_data.py +++ b/tests/pointcloud/detection/test_data.py @@ -34,7 +34,9 @@ def test_pointcloud_object_detection_data(tmpdir): download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir) - dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train")) + datamodule = PointCloudObjectDetectorData.from_folders( + train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"), + ) class MockModel(PointCloudObjectDetector): def training_step(self, batch, batch_idx: int): @@ -48,11 +50,12 @@ def training_step(self, batch, batch_idx: int): num_classes = 19 model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes) trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) - trainer.fit(model, dm) + trainer.fit(model, datamodule=datamodule) - predict_path = join(tmpdir, "KITTI_Micro", "Kitti", "predict") - model.eval() + datamodule = PointCloudObjectDetectorData.from_files( + predict_files=[join(tmpdir, "KITTI_Micro", "Kitti", "predict", "scans", "000000.bin")] + ) - predictions = model.predict([join(predict_path, "scans/000000.bin")]) + predictions = trainer.predict(model, datamodule=datamodule)[0] assert predictions[0][DataKeys.INPUT].shape[1] == 4 assert len(predictions[0][DataKeys.PREDS]) == 158 diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py index bce4693f2c..b17372813d 100644 --- a/tests/pointcloud/segmentation/test_data.py +++ b/tests/pointcloud/segmentation/test_data.py @@ -31,7 +31,10 @@ def test_pointcloud_segmentation_data(tmpdir): download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiMicro.zip", tmpdir) - dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train")) + datamodule = PointCloudSegmentationData.from_folders( + train_folder=join(tmpdir, "SemanticKittiMicro", "train"), + predict_folder=join(tmpdir, "SemanticKittiMicro", "predict"), + ) class MockModel(PointCloudSegmentation): def training_step(self, batch, batch_idx: int): @@ -48,9 +51,9 @@ def training_step(self, batch, batch_idx: int): num_classes = 19 model = MockModel(backbone="randlanet", num_classes=num_classes) trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) - trainer.fit(model, dm) + trainer.fit(model, datamodule=datamodule) - predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict")) + predictions = trainer.predict(model, datamodule=datamodule)[0] assert predictions[0][DataKeys.INPUT].shape == torch.Size([45056, 3]) assert predictions[0][DataKeys.PREDS].shape == torch.Size([45056, 19]) assert predictions[0][DataKeys.TARGET].shape == torch.Size([45056]) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index a3ab60d97c..b2a09133d7 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -18,11 +18,10 @@ import torch from flash import Trainer -from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.template import TemplateSKLearnClassifier -from flash.template.classification.data import TemplateInputTransform +from flash.template.classification.data import TemplateData if _SKLEARN_AVAILABLE: from sklearn import datasets @@ -105,9 +104,10 @@ def test_predict_numpy(): """Tests that we can generate predictions from a numpy array.""" row = np.random.rand(1, DummyDataset.num_features) model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) - data_pipe = DataPipeline(input_transform=TemplateInputTransform()) - out = model.predict(row, data_pipeline=data_pipe) - assert isinstance(out[0], int) + datamodule = TemplateData.from_numpy(predict_data=row) + trainer = Trainer() + out = trainer.predict(model, datamodule=datamodule) + assert isinstance(out[0][0], int) @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") @@ -115,9 +115,10 @@ def test_predict_sklearn(): """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" bunch = datasets.load_iris() model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) - data_pipe = DataPipeline(input_transform=TemplateInputTransform()) - out = model.predict(bunch, input="sklearn", data_pipeline=data_pipe) - assert isinstance(out[0], int) + datamodule = TemplateData.from_sklearn(predict_bunch=bunch) + trainer = Trainer() + out = trainer.predict(model, datamodule=datamodule) + assert isinstance(out[0][0], int) @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")