diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f4a050f76..c8e1f16c16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added integration with FiftyOne ([#360](https://github.com/PyTorchLightning/lightning-flash/pull/360)) - Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389)) - Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390)) - Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401)) diff --git a/docs/source/conf.py b/docs/source/conf.py index c5b2b6c64c..48f1966183 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -98,6 +98,7 @@ def _load_py_module(fname, pkg="flash"): "PIL": ("https://pillow.readthedocs.io/en/stable/", None), "pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None), "pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None), + "fiftyone": ("https://voxel51.com/docs/fiftyone/", None), } # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 3e81b3a461..29ac911fda 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -64,6 +64,13 @@ Lightning Flash code/text code/video +.. toctree:: + :maxdepth: 1 + :caption: Integrations + + integrations/fiftyone + + .. toctree:: :maxdepth: 1 :caption: Contributing a Task diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst new file mode 100644 index 0000000000..d12392e5ec --- /dev/null +++ b/docs/source/integrations/fiftyone.rst @@ -0,0 +1,148 @@ +######## +FiftyOne +######## + +We have collaborated with the team at +`Voxel51 `_ to integrate their tool, +`FiftyOne `_, into Lightning Flash. + +FiftyOne is an open-source tool for building high-quality +datasets and computer vision models. The FiftyOne API and App enable you to +visualize datasets and interpret models faster and more effectively. + +This integration allows you to view predictions generated by your tasks in the +:ref:`FiftyOne App `, as well as easily incorporate +:ref:`FiftyOne Datasets ` into your tasks. All image and video tasks +are supported! + +.. raw:: html + +
+ +
+ +************ +Installation +************ + +In order to utilize this integration with FiftyOne, you will need to +:ref:`install the tool`: + +.. code:: shell + + pip install fiftyone + + +***************************** +Visualizing Flash predictions +***************************** + +This section shows you how to augment your existing Lightning Flash workflows +with a couple of lines of code that let you visualize predictions in FiftyOne. +You can visualize predictions for classification, object detection, and +semantic segmentation tasks. Doing so is as easy as updating your model to use +one of the following serializers: + +* :class:`FiftyOneLabels(return_filepath=True)` +* :class:`FiftyOneSegmentationLabels(return_filepath=True)` +* :class:`FiftyOneDetectionLabels(return_filepath=True)` + +The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize +your predictions in the +:ref:`FiftyOne App `. This function accepts a list of +dictionaries containing :ref:`FiftyOne Label` objects +and filepaths which is the exact output of the FiftyOne serializers when the flag +``return_filepath=True`` is specified. + +.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification.py + :language: python + :lines: 14- + + +*********************** +Using FiftyOne datasets +*********************** + +The above workflow is great for visualizing model predictions. However, if you +store your data in a FiftyOne Dataset initially, then you can also visualize +ground truth annotations. This allows you to perform more complex analysis with +:ref:`views ` into your data and +:ref:`evaluation ` of your model results. + +The +:meth:`~flash.core.data.data_module.DataModule.from_fiftyone` +method allows you to load your FiftyOne Datasets directly into a +:class:`~flash.core.data.data_module.DataModule` to be used for training, +testing, or inference. + +.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py + :language: python + :lines: 14- + + +********************** +Visualizing embeddings +********************** + +FiftyOne provides the methods for +:ref:`dimensionality reduction` and +:ref:`interactive plotting`. When combined with +:ref:`embedding tasks ` in Flash, you can accomplish +powerful workflows like clustering, similarity search, pre-annotation, and more +in only a few lines of code. + +.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_embedding.py + :language: python + :lines: 14- + +.. image:: https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/embeddings.png + :alt: embeddings_example + :align: center + +------ + +************* +API reference +************* + +.. _from_fiftyone: + +DataModule.from_fiftyone +------------------------ + +.. automethod:: flash.core.data.data_module.DataModule.from_fiftyone + :noindex: + +.. _fiftyone_labels: + +FiftyOneLabels +-------------- + +.. autoclass:: flash.core.classification.FiftyOneLabels + :members: + +.. _fiftyone_segmentation_labels: + +FiftyOneSegmentationLabels +-------------------------- + +.. autoclass:: flash.image.segmentation.serialization.FiftyOneSegmentationLabels + :members: + +.. _fiftyone_detection_labels: + +FiftyOneDetectionLabels +----------------------- + +.. autoclass:: flash.image.detection.serialization.FiftyOneDetectionLabels + :members: + + +.. _fiftyone_visualize: + +visualize +--------- + +.. autofunction:: flash.core.integrations.fiftyone.visualize diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 92ee56ae28..8b8881ce24 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -85,7 +85,7 @@ Here's how it looks (from `video/classification.data.py torch.Tensor: @@ -80,6 +87,8 @@ class Logits(ClassificationSerializer): """A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list.""" def serialize(self, sample: Any) -> Any: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + sample = torch.tensor(sample) return sample.tolist() @@ -88,6 +97,8 @@ class Probabilities(ClassificationSerializer): list.""" def serialize(self, sample: Any) -> Any: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + sample = torch.tensor(sample) if self.multi_label: return torch.sigmoid(sample).tolist() return torch.softmax(sample, -1).tolist() @@ -109,6 +120,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5): self.threshold = threshold def serialize(self, sample: Any) -> Union[int, List[int]]: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + sample = torch.tensor(sample) if self.multi_label: one_hot = (sample.sigmoid() > self.threshold).int().tolist() result = [] @@ -140,6 +153,8 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False self.set_state(LabelsState(labels)) def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + sample = torch.tensor(sample) labels = None if self._labels is not None: @@ -158,3 +173,128 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: else: rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning) return classes + + +class FiftyOneLabels(ClassificationSerializer): + """A :class:`.Serializer` which converts the model outputs to FiftyOne classification format. + + Args: + labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not + provided, will attempt to get them from the :class:`.LabelsState`. + multi_label: If true, treats outputs as multi label logits. + threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this + threshold will be replaced with None + store_logits: Boolean determining whether to store logits in the FiftyOne labels + return_filepath: Boolean determining whether to return a dict + containing filepath and FiftyOne labels (True) or only a + list of FiftyOne labels (False) + """ + + def __init__( + self, + labels: Optional[List[str]] = None, + multi_label: bool = False, + threshold: Optional[float] = None, + store_logits: bool = False, + return_filepath: bool = False, + ): + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install fiftyone`.") + + if multi_label and threshold is None: + threshold = 0.5 + + super().__init__(multi_label=multi_label) + self._labels = labels + self.threshold = threshold + self.store_logits = store_logits + self.return_filepath = return_filepath + + if labels is not None: + self.set_state(LabelsState(labels)) + + def serialize( + self, + sample: Any, + ) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]: + pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + pred = torch.tensor(pred) + + labels = None + + if self._labels is not None: + labels = self._labels + else: + state = self.get_state(LabelsState) + if state is not None: + labels = state.labels + + logits = None + if self.store_logits: + logits = pred.tolist() + + if self.multi_label: + one_hot = (pred.sigmoid() > self.threshold).int().tolist() + classes = [] + for index, value in enumerate(one_hot): + if value == 1: + classes.append(index) + probabilities = torch.sigmoid(pred).tolist() + else: + classes = torch.argmax(pred, -1).tolist() + probabilities = torch.softmax(pred, -1).tolist() + + if labels is not None: + if self.multi_label: + classifications = [] + for idx in classes: + fo_cls = Classification( + label=labels[idx], + confidence=probabilities[idx], + ) + classifications.append(fo_cls) + fo_predictions = Classifications( + classifications=classifications, + logits=logits, + ) + else: + confidence = max(probabilities) + if self.threshold is not None and confidence < self.threshold: + fo_predictions = None + else: + fo_predictions = Classification( + label=labels[classes], + confidence=confidence, + logits=logits, + ) + else: + rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning) + + if self.multi_label: + classifications = [] + for idx in classes: + fo_cls = Classification( + label=str(idx), + confidence=probabilities[idx], + ) + classifications.append(fo_cls) + fo_predictions = Classifications( + classifications=classifications, + logits=logits, + ) + else: + confidence = max(probabilities) + if self.threshold is not None and confidence < self.threshold: + fo_predictions = None + else: + fo_predictions = Classification( + label=str(classes), + confidence=confidence, + logits=logits, + ) + + if self.return_filepath: + filepath = sample[DefaultDataKeys.METADATA]["filepath"] + return {"filepath": filepath, "predictions": fo_predictions} + else: + return fo_predictions diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index cf524d2cef..587207a0a0 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -230,7 +230,7 @@ def forward(self, samples: Sequence[Any]) -> Any: with self._collate_context: samples, metadata = self._extract_metadata(samples) samples = self.collate_fn(samples) - if metadata: + if metadata and isinstance(samples, dict): samples[DefaultDataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 7a21498608..33ed2b020d 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -31,6 +31,13 @@ from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE + +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + from fiftyone.core.collections import SampleCollection +else: + SampleCollection = None class DataModule(pl.LightningDataModule): @@ -336,7 +343,8 @@ def data_pipeline(self) -> DataPipeline: return DataPipeline(self.data_source, self.preprocess, self.postprocess) def available_data_sources(self) -> Sequence[str]: - """Get the list of available data source names for use with this :class:`~flash.core.data.data_module.DataModule`. + """Get the list of available data source names for use with this + :class:`~flash.core.data.data_module.DataModule`. Returns: The list of data source names. @@ -1060,3 +1068,88 @@ def from_datasets( sampler=sampler, **preprocess_kwargs, ) + + @classmethod + def from_fiftyone( + cls, + train_dataset: Optional[SampleCollection] = None, + val_dataset: Optional[SampleCollection] = None, + test_dataset: Optional[SampleCollection] = None, + predict_dataset: Optional[SampleCollection] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + """Creates a :class:`~flash.core.data.data_module.DataModule` object + from the given FiftyOne Datasets using the + :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.FIFTYONE` + from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + + Args: + train_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the train data. + val_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the validation data. + test_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the test data. + predict_dataset: The ``fiftyone.core.collections.SampleCollection`` containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + train_dataset = fo.Dataset.from_dir( + "/path/to/dataset", + dataset_type=fo.types.ImageClassificationDirectoryTree, + ) + data_module = DataModule.from_fiftyone( + train_data = train_dataset, + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, `pip install fiftyone`.") + + return cls.from_data_source( + DefaultDataSources.FIFTYONE, + train_dataset, + val_dataset, + test_dataset, + predict_dataset, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index 618d4b40db..507db14877 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -41,6 +41,13 @@ from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.core.data.properties import ProcessState, Properties from flash.core.data.utils import CurrentRunningStageFuncContext +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE + +if _FIFTYONE_AVAILABLE: + from fiftyone.core.collections import SampleCollection + from fiftyone.core.labels import Label +else: + Label, SampleCollection = None, None # Credit to the PyTorchVision Team: @@ -145,6 +152,7 @@ class DefaultDataSources(LightningEnum): CSV = "csv" JSON = "json" DATASET = "dataset" + FIFTYONE = "fiftyone" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: @@ -217,7 +225,8 @@ def load_data(self, return data def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - """Given an element from the output of a call to :meth:`~flash.core.data.data_source.DataSource.load_data`, this hook + """Given an element from the output of a call to + :meth:`~flash.core.data.data_source.DataSource.load_data`, this hook should load a single data sample. The keys and values in the ``sample`` argument will be same as the keys and values in the outputs of :meth:`~flash.core.data.data_source.DataSource.load_data`. @@ -278,8 +287,8 @@ def generate_dataset( data: Optional[DATA_TYPE], running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - """Generate a single dataset with the given input to :meth:`~flash.core.data.data_source.DataSource.load_data` for - the given ``running_stage``. + """Generate a single dataset with the given input to + :meth:`~flash.core.data.data_source.DataSource.load_data` for the given ``running_stage``. Args: data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the dataset. @@ -445,10 +454,12 @@ def predict_load_data(self, if not isinstance(data, list): data = [data] + data = [{DefaultDataKeys.INPUT: input} for input in data] + return list( filter( lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), - super().predict_load_data(data), + data, ) ) @@ -461,3 +472,69 @@ class TensorDataSource(SequenceDataSource[torch.Tensor]): class NumpyDataSource(SequenceDataSource[np.ndarray]): """The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to :meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``np.ndarray`` objects.""" + + +class FiftyOneDataSource(DataSource[SampleCollection]): + """The ``FiftyOneDataSource`` expects the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a ``fiftyone.core.collections.SampleCollection``.""" + + def __init__(self, label_field: str = "ground_truth"): + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install fiftyone`.") + super().__init__() + self.label_field = label_field + + @property + def label_cls(self): + return Label + + def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + self._validate(data) + + label_path = data._get_label_field_path(self.label_field, "label")[1] + + filepaths = data.values("filepath") + targets = data.values(label_path) + + classes = self._get_classes(data) + + if dataset is not None: + dataset.num_classes = len(classes) + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + + if targets and isinstance(targets[0], list): + + def to_idx(t): + return [class_to_idx[x] for x in t] + else: + + def to_idx(t): + return class_to_idx[t] + + return [{ + DefaultDataKeys.INPUT: f, + DefaultDataKeys.TARGET: to_idx(t), + } for f, t in zip(filepaths, targets)] + + def predict_load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] + + def _validate(self, data): + label_type = data._get_label_field_type(self.label_field) + if not issubclass(label_type, self.label_cls): + raise ValueError( + "Expected field '%s' to have type %s; found %s" % (self.label_field, self.label_cls, label_type) + ) + + def _get_classes(self, data): + classes = data.classes.get(self.label_field, None) + + if not classes: + classes = data.default_classes + + if not classes: + label_path = data._get_label_field_path(self.label_field, "label")[1] + classes = data.distinct(label_path) + + return classes diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 2c43b996e1..c461768a36 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -510,7 +510,7 @@ def _save_sample(self, sample: Any) -> None: class Serializer(Properties): - """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model ouptut into + """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model output into the desired output format when predicting.""" def __init__(self): diff --git a/flash/core/integrations/__init__.py b/flash/core/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/integrations/fiftyone/__init__.py b/flash/core/integrations/fiftyone/__init__.py new file mode 100644 index 0000000000..cc7b22cbb7 --- /dev/null +++ b/flash/core/integrations/fiftyone/__init__.py @@ -0,0 +1 @@ +from flash.core.integrations.fiftyone.utils import visualize diff --git a/flash/core/integrations/fiftyone/utils.py b/flash/core/integrations/fiftyone/utils.py new file mode 100644 index 0000000000..5498a801ed --- /dev/null +++ b/flash/core/integrations/fiftyone/utils.py @@ -0,0 +1,70 @@ +from itertools import chain +from typing import Dict, List, Optional, Union + +import flash +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE + +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + from fiftyone.core.labels import Label + from fiftyone.core.sample import Sample + from fiftyone.core.session import Session + from fiftyone.utils.data.parsers import LabeledImageTupleSampleParser +else: + fo = None + SampleCollection = None + Label = None + Sample = None + Session = None + + +def visualize( + labels: Union[List[Label], List[Dict[str, Label]]], + filepaths: Optional[List[str]] = None, + wait: Optional[bool] = True, + label_field: Optional[str] = "predictions", + **kwargs +) -> Optional[Session]: + """Use the result of a FiftyOne serializer to visualize predictions in the + FiftyOne App. + + Args: + labels: Either a list of FiftyOne labels that will be applied to the + corresponding filepaths provided with through `filepath` or + `datamodule`. Or a list of dictionaries containing image/video + filepaths and corresponding FiftyOne labels. + filepaths: A list of filepaths to images or videos corresponding to the + provided `labels`. + wait: A boolean determining whether to launch the FiftyOne session and + wait until the session is closed or whether to return immediately. + label_field: The string of the label field in the FiftyOne dataset + containing predictions + """ + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, `pip install fiftyone`.") + if flash._IS_TESTING: + return None + + # Flatten list if batches were used + if all(isinstance(fl, list) for fl in labels): + labels = list(chain.from_iterable(labels)) + + if all(isinstance(fl, dict) for fl in labels): + filepaths = [lab["filepath"] for lab in labels] + labels = [lab["predictions"] for lab in labels] + + if filepaths is None: + raise ValueError("The `filepaths` argument is required if filepaths are not provided in `labels`.") + + dataset = fo.Dataset() + if filepaths: + dataset.add_labeled_images( + list(zip(filepaths, labels)), + LabeledImageTupleSampleParser(), + label_field=label_field, + ) + session = fo.launch_app(dataset, **kwargs) + if wait: + session.wait() + return session diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 93797ba20a..75dc93e605 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -76,6 +76,7 @@ def _compare_version(package: str, op, version) -> bool: _MATPLOTLIB_AVAILABLE = _module_available("matplotlib") _TRANSFORMERS_AVAILABLE = _module_available("transformers") _PYSTICHE_AVAILABLE = _module_available("pystiche") +_FIFTYONE_AVAILABLE = _module_available("fiftyone") _FASTAPI_AVAILABLE = _module_available("fastapi") _PYDANTIC_AVAILABLE = _module_available("pydantic") _GRAPHVIZ_AVAILABLE = _module_available("graphviz") diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index eb9626817e..474037e176 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -27,7 +27,7 @@ from flash.core.data.process import Deserializer, Preprocess from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.classification.transforms import default_transforms, train_default_transforms -from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource +from flash.image.data import ImageFiftyOneDataSource, ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -72,6 +72,7 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), deserializer: Optional[Deserializer] = None, + **data_source_kwargs: Any, ): self.image_size = image_size @@ -81,6 +82,7 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ + DefaultDataSources.FIFTYONE: ImageFiftyOneDataSource(**data_source_kwargs), DefaultDataSources.FILES: ImagePathsDataSource(), DefaultDataSources.FOLDERS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index a802ca425a..75a2dd0a49 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -120,8 +120,10 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + batch[DefaultDataKeys.PREDS] = super().predict_step((batch[DefaultDataKeys.INPUT]), + batch_idx, + dataloader_idx=dataloader_idx) + return batch def forward(self, x) -> torch.Tensor: x = self.backbone(x) diff --git a/flash/image/data.py b/flash/image/data.py index 06cee5bf7f..69fd25e657 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -15,7 +15,13 @@ import torch -from flash.core.data.data_source import DefaultDataKeys, NumpyDataSource, PathsDataSource, TensorDataSource +from flash.core.data.data_source import ( + DefaultDataKeys, + FiftyOneDataSource, + NumpyDataSource, + PathsDataSource, + TensorDataSource, +) from flash.core.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: @@ -31,19 +37,46 @@ def __init__(self): super().__init__(extensions=IMG_EXTENSIONS) def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) + img_path = sample[DefaultDataKeys.INPUT] + img = default_loader(img_path) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": (h, w), + } return sample class ImageTensorDataSource(TensorDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = to_pil_image(sample[DefaultDataKeys.INPUT]) + img = to_pil_image(sample[DefaultDataKeys.INPUT]) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = {"size": (h, w)} return sample class ImageNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) + img = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = {"size": (h, w)} + return sample + + +class ImageFiftyOneDataSource(FiftyOneDataSource): + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + img_path = sample[DefaultDataKeys.INPUT] + img = default_loader(img_path) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": (h, w), + } return sample diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 66cb43d4f0..f7f3416bf9 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -16,15 +16,21 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Preprocess -from flash.core.utilities.imports import _COCO_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, FiftyOneDataSource +from flash.core.data.process import Preprocess, Serializer +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.data import ImagePathsDataSource from flash.image.detection.transforms import default_transforms if _COCO_AVAILABLE: from pycocotools.coco import COCO +if _FIFTYONE_AVAILABLE: + from fiftyone.core.collections import SampleCollection + from fiftyone.core.labels import Detections +else: + Detections, SampleCollection = None, None + if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import default_loader @@ -86,9 +92,99 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq return data def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = default_loader(sample[DefaultDataKeys.INPUT]) + filepath = sample[DefaultDataKeys.INPUT] + img = default_loader(filepath) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = { + "filepath": filepath, + "size": (h, w), + } + return sample + return sample + + +class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): + + def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): + super().__init__(label_field=label_field) + self.iscrowd = iscrowd + + @property + def label_cls(self): + return Detections + + def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + self._validate(data) + + data.compute_metadata() + + filepaths = data.values("filepath") + widths = data.values("metadata.width") + heights = data.values("metadata.height") + labels = data.values(self.label_field + ".detections.label") + bboxes = data.values(self.label_field + ".detections.bounding_box") + iscrowds = data.values(self.label_field + ".detections." + self.iscrowd) + + classes = self._get_classes(data) + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + if dataset is not None: + dataset.num_classes = len(classes) + + output_data = [] + img_id = 1 + for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( + filepaths, widths, heights, labels, bboxes, iscrowds + ): + output_boxes = [] + output_labs = [] + output_iscrowd = [] + output_areas = [] + for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): + output_box, output_area = self._reformat_bbox(box[0], box[1], box[2], box[3], w, h) + output_areas.append(output_area) + output_labs.append(class_to_idx[lab]) + output_boxes.append(output_box) + if iscrowd is None: + iscrowd = 0 + output_iscrowd.append(iscrowd) + output_data.append( + dict( + input=fp, + target=dict( + boxes=output_boxes, + labels=output_labs, + image_id=img_id, + area=output_areas, + iscrowd=output_iscrowd, + ) + ) + ) + img_id += 1 + + return output_data + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + filepath = sample[DefaultDataKeys.INPUT] + img = default_loader(filepath) + sample[DefaultDataKeys.INPUT] = img + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = { + "filepath": filepath, + "size": (h, w), + } return sample + def _reformat_bbox(self, xmin, ymin, box_w, box_h, img_w, img_h): + xmin *= img_w + ymin *= img_h + box_w *= img_w + box_h *= img_h + xmax = xmin + box_w + ymax = ymin + box_h + output_bbox = [xmin, ymin, xmax, ymax] + return output_bbox, box_w * box_h + class ObjectDetectionPreprocess(Preprocess): @@ -98,6 +194,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + **data_source_kwargs: Any, ): super().__init__( train_transform=train_transform, @@ -105,6 +202,7 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ + DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), DefaultDataSources.FILES: ImagePathsDataSource(), DefaultDataSources.FOLDERS: ImagePathsDataSource(), "coco": COCODataSource(), diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index f26d219224..ec192dda02 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -18,11 +18,13 @@ from torch.optim import Optimizer from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image.backbones import OBJ_DETECTION_BACKBONES from flash.image.detection.finetuning import ObjectDetectionFineTuning +from flash.image.detection.serialization import DetectionLabels if _IMAGE_AVAILABLE: import torchvision @@ -90,6 +92,7 @@ def __init__( metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, optimizer: Type[Optimizer] = torch.optim.AdamW, learning_rate: float = 1e-3, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, ): @@ -112,6 +115,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, optimizer=optimizer, + serializer=serializer or DetectionLabels(), ) @staticmethod @@ -197,7 +201,8 @@ def test_step(self, batch, batch_idx): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: images = batch[DefaultDataKeys.INPUT] - return self(images) + batch[DefaultDataKeys.PREDS] = self(images) + return batch def configure_finetune_callback(self): return [ObjectDetectionFineTuning(train_bn=True)] diff --git a/flash/image/detection/serialization.py b/flash/image/detection/serialization.py new file mode 100644 index 0000000000..4d6bbe6ea7 --- /dev/null +++ b/flash/image/detection/serialization.py @@ -0,0 +1,113 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union + +from pytorch_lightning.utilities import rank_zero_warn + +from flash.core.data.data_source import DefaultDataKeys, LabelsState +from flash.core.data.process import Serializer +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE + +if _FIFTYONE_AVAILABLE: + from fiftyone.core.labels import Detection, Detections +else: + Detection, Detections = None, None + + +class DetectionLabels(Serializer): + """A :class:`.Serializer` which extracts predictions from sample dict.""" + + def serialize(self, sample: Any) -> Dict[str, Any]: + sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + return sample + + +class FiftyOneDetectionLabels(Serializer): + """A :class:`.Serializer` which converts model outputs to FiftyOne detection format. + + Args: + labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not + provided, will attempt to get them from the :class:`.LabelsState`. + threshold: a score threshold to apply to candidate detections. + return_filepath: Boolean determining whether to return a dict + containing filepath and FiftyOne labels (True) or only a + list of FiftyOne labels (False) + """ + + def __init__( + self, + labels: Optional[List[str]] = None, + threshold: Optional[float] = None, + return_filepath: bool = False, + ): + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install fiftyone`.") + + super().__init__() + self._labels = labels + self.threshold = threshold + self.return_filepath = return_filepath + + if labels is not None: + self.set_state(LabelsState(labels)) + + def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: + if DefaultDataKeys.METADATA not in sample: + raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels serializer.") + + labels = None + if self._labels is not None: + labels = self._labels + else: + state = self.get_state(LabelsState) + if state is not None: + labels = state.labels + else: + rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning) + + height, width = sample[DefaultDataKeys.METADATA]["size"] + + detections = [] + + for det in sample[DefaultDataKeys.PREDS]: + confidence = det["scores"].tolist() + + if self.threshold is not None and confidence < self.threshold: + continue + + xmin, ymin, xmax, ymax = [c.tolist() for c in det["boxes"]] + box = [ + xmin / width, + ymin / height, + (xmax - xmin) / width, + (ymax - ymin) / height, + ] + + label = det["labels"].tolist() + if labels is not None: + label = labels[label] + else: + label = str(int(label)) + + detections.append(Detection( + label=label, + bounding_box=box, + confidence=confidence, + )) + fo_predictions = Detections(detections=detections) + if self.return_filepath: + filepath = sample[DefaultDataKeys.METADATA]["filepath"] + return {"filepath": filepath, "predictions": fo_predictions} + else: + return fo_predictions diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 924e5d0cb9..1b10137ce7 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -30,16 +30,24 @@ from flash.core.data.data_source import ( DefaultDataKeys, DefaultDataSources, + FiftyOneDataSource, ImageLabelsMap, NumpyDataSource, PathsDataSource, TensorDataSource, ) from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE from flash.image.segmentation.serialization import SegmentationLabels from flash.image.segmentation.transforms import default_transforms, train_default_transforms +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + from fiftyone.core.collections import SampleCollection + from fiftyone.core.labels import Segmentation +else: + fo, Segmentation, SampleCollection = None, None, None + if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: @@ -64,7 +72,7 @@ class SemanticSegmentationNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = img.shape + sample[DefaultDataKeys.METADATA] = {"size": img.shape} return sample @@ -73,7 +81,7 @@ class SemanticSegmentationTensorDataSource(TensorDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = sample[DefaultDataKeys.INPUT].float() sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = img.shape + sample[DefaultDataKeys.METADATA] = {"size": img.shape} return sample @@ -139,18 +147,72 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW - return { - DefaultDataKeys.INPUT: img.float(), - DefaultDataKeys.TARGET: img_labels.float(), - DefaultDataKeys.METADATA: img.shape, + sample[DefaultDataKeys.INPUT] = img.float() + sample[DefaultDataKeys.TARGET] = img_labels.float() + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": img.shape, } + return sample def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - img = torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float() - return { - DefaultDataKeys.INPUT: img, - DefaultDataKeys.METADATA: img.shape, + img_path = sample[DefaultDataKeys.INPUT] + img = torchvision.io.read_image(img_path).float() + + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": img.shape, + } + return sample + + +class SemanticSegmentationFiftyOneDataSource(FiftyOneDataSource): + + def __init__(self, label_field: str = "ground_truth"): + if not _IMAGE_AVAILABLE: + raise ModuleNotFoundError("Please, pip install -e '.[image]'") + + super().__init__(label_field=label_field) + self._fo_dataset_name = None + + @property + def label_cls(self): + return Segmentation + + def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + self._validate(data) + + self._fo_dataset_name = data.name + return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] + + def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Tensor, torch.Size]]: + _fo_dataset = fo.load_dataset(self._fo_dataset_name) + + img_path = sample[DefaultDataKeys.INPUT] + fo_sample = _fo_dataset[img_path] + + img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW + + sample[DefaultDataKeys.INPUT] = img.float() + sample[DefaultDataKeys.TARGET] = img_labels.float() + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": img.shape, } + return sample + + def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + img_path = sample[DefaultDataKeys.INPUT] + img = torchvision.io.read_image(img_path).float() + + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = { + "filepath": img_path, + "size": img.shape, + } + return sample class SemanticSegmentationDeserializer(Deserializer): @@ -180,6 +242,7 @@ def __init__( deserializer: Optional['Deserializer'] = None, num_classes: int = None, labels_map: Dict[int, Tuple[int, int, int]] = None, + **data_source_kwargs: Any, ) -> None: """Preprocess pipeline for semantic segmentation tasks. @@ -189,6 +252,7 @@ def __init__( test_transform: Dictionary with the set of transforms to apply during testing. predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. + **data_source_kwargs: Additional arguments passed on to the data source constructors. """ if not _IMAGE_AVAILABLE: raise ModuleNotFoundError("Please, pip install 'lightning-flash[image]'") @@ -203,6 +267,7 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ + DefaultDataSources.FIFTYONE: SemanticSegmentationFiftyOneDataSource(**data_source_kwargs), DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(), @@ -301,7 +366,8 @@ def from_data_source( **preprocess_kwargs ) - dm.train_dataset.num_classes = num_classes + if dm.train_dataset is not None: + dm.train_dataset.num_classes = num_classes return dm @classmethod diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 1a810afa7f..8d89c4d4a8 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -33,7 +33,7 @@ class SemanticSegmentationPostprocess(Postprocess): def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA][-2:], interpolation='bilinear') + resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation='bilinear') sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) return super().per_sample_transform(sample) diff --git a/flash/image/segmentation/serialization.py b/flash/image/segmentation/serialization.py index 90918b0a97..574be138d9 100644 --- a/flash/image/segmentation/serialization.py +++ b/flash/image/segmentation/serialization.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import torch import flash from flash.core.data.data_source import DefaultDataKeys, ImageLabelsMap from flash.core.data.process import Serializer -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE + +if _FIFTYONE_AVAILABLE: + from fiftyone.core.labels import Segmentation +else: + Segmentation = None if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -40,7 +46,7 @@ def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, Args: labels_map: A dictionary that map the labels ids to pixel intensities. - visualise: Wether to visualise the image labels. + visualize: Wether to visualize the image labels. """ super().__init__() self.labels_map = labels_map @@ -80,3 +86,37 @@ def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: plt.imshow(labels_vis) plt.show() return labels.tolist() + + +class FiftyOneSegmentationLabels(SegmentationLabels): + + def __init__( + self, + labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, + visualize: bool = False, + return_filepath: bool = False, + ): + """A :class:`.Serializer` which converts the model outputs to FiftyOne segmentation format. + + Args: + labels_map: A dictionary that map the labels ids to pixel intensities. + visualize: Wether to visualize the image labels. + return_filepath: Boolean determining whether to return a dict + containing filepath and FiftyOne labels (True) or only a + list of FiftyOne labels (False) + """ + if not _FIFTYONE_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install fiftyone`.") + + super().__init__(labels_map=labels_map, visualize=visualize) + + self.return_filepath = return_filepath + + def serialize(self, sample: Dict[str, torch.Tensor]) -> Union[Segmentation, Dict[str, Any]]: + labels = super().serialize(sample) + fo_predictions = Segmentation(mask=np.array(labels)) + if self.return_filepath: + filepath = sample[DefaultDataKeys.METADATA]["filepath"] + return {"filepath": filepath, "predictions": fo_predictions} + else: + return fo_predictions diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 11c40bdc6c..c75d575784 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -20,10 +20,22 @@ from torch.utils.data import RandomSampler, Sampler from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, PathsDataSource +from flash.core.data.data_source import ( + DefaultDataKeys, + DefaultDataSources, + FiftyOneDataSource, + LabelsState, + PathsDataSource, +) from flash.core.data.process import Preprocess from flash.core.data.transforms import merge_transforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE + +if _FIFTYONE_AVAILABLE: + from fiftyone.core.collections import SampleCollection + from fiftyone.core.labels import Classification +else: + Classification, SampleCollection = None, None if _KORNIA_AVAILABLE: import kornia.augmentation as K @@ -32,6 +44,7 @@ from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler from pytorchvideo.data.encoded_video import EncodedVideo from pytorchvideo.data.encoded_video_dataset import EncodedVideoDataset, labeled_encoded_video_dataset + from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths from pytorchvideo.transforms import ( ApplyTransformToKey, RandomShortSideScale, @@ -45,7 +58,7 @@ _PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]] -class VideoClassificationPathsDataSource(PathsDataSource): +class BaseVideoClassification(object): def __init__( self, @@ -54,26 +67,25 @@ def __init__( decode_audio: bool = True, decoder: str = "pyav", ): - super().__init__(extensions=("mp4", "avi")) self.clip_sampler = clip_sampler self.video_sampler = video_sampler self.decode_audio = decode_audio self.decoder = decoder def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset': - ds: EncodedVideoDataset = labeled_encoded_video_dataset( - pathlib.Path(data), - self.clip_sampler, - video_sampler=self.video_sampler, - decode_audio=self.decode_audio, - decoder=self.decoder, - ) + ds = self._make_encoded_video_dataset(data) if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels} self.set_state(LabelsState(label_to_class_mapping)) dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos])) return ds + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + video_path = sample[DefaultDataKeys.INPUT] + sample.update(self._encoded_video_to_dict(EncodedVideo.from_path(video_path))) + sample[DefaultDataKeys.METADATA] = {"filepath": video_path} + return sample + def _encoded_video_to_dict(self, video) -> Dict[str, Any]: ( clip_start, @@ -107,8 +119,87 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]: } if audio_samples is not None else {}), } - def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - return self._encoded_video_to_dict(EncodedVideo.from_path(sample[DefaultDataKeys.INPUT])) + def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset': + raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()") + + +class VideoClassificationPathsDataSource(BaseVideoClassification, PathsDataSource): + + def __init__( + self, + clip_sampler: 'ClipSampler', + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = True, + decoder: str = "pyav", + ): + super().__init__( + clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + PathsDataSource.__init__( + self, + extensions=("mp4", "avi"), + ) + + def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset': + ds: EncodedVideoDataset = labeled_encoded_video_dataset( + pathlib.Path(data), + self.clip_sampler, + video_sampler=self.video_sampler, + decode_audio=self.decode_audio, + decoder=self.decoder, + ) + return ds + + +class VideoClassificationFiftyOneDataSource( + BaseVideoClassification, + FiftyOneDataSource, +): + + def __init__( + self, + clip_sampler: 'ClipSampler', + video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, + decode_audio: bool = True, + decoder: str = "pyav", + label_field: str = "ground_truth", + ): + super().__init__( + clip_sampler=clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ) + FiftyOneDataSource.__init__( + self, + label_field=label_field, + ) + + @property + def label_cls(self): + return Classification + + def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDataset': + classes = self._get_classes(data) + label_to_class_mapping = dict(enumerate(classes)) + class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()} + + filepaths = data.values("filepath") + labels = data.values(self.label_field + ".label") + targets = [class_to_label_mapping[lab] for lab in labels] + labeled_video_paths = LabeledVideoPaths(list(zip(filepaths, targets))) + + ds: EncodedVideoDataset = EncodedVideoDataset( + labeled_video_paths, + self.clip_sampler, + video_sampler=self.video_sampler, + decode_audio=self.decode_audio, + decoder=self.decoder, + ) + return ds class VideoClassificationPreprocess(Preprocess): @@ -125,6 +216,7 @@ def __init__( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", + **data_source_kwargs: Any, ): self.clip_sampler = clip_sampler self.clip_duration = clip_duration @@ -164,6 +256,13 @@ def __init__( decode_audio=decode_audio, decoder=decoder, ), + DefaultDataSources.FIFTYONE: VideoClassificationFiftyOneDataSource( + clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + **data_source_kwargs, + ), }, default_data_source=DefaultDataSources.FILES, ) diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 194dae99bc..4601bcaf5a 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -27,6 +27,7 @@ import flash from flash.core.classification import ClassificationTask, Labels +from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE @@ -153,7 +154,8 @@ def forward(self, x: Any) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: predictions = self(batch["video"]) - return predictions + batch[DefaultDataKeys.PREDS] = predictions + return batch def configure_finetune_callback(self) -> List[Callback]: return [VideoClassifierFinetuning()] diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py new file mode 100644 index 0000000000..2b9e13717b --- /dev/null +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import chain + +import flash +from flash.core.classification import FiftyOneLabels, Labels, Probabilities +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.core.integrations.fiftyone import visualize +from flash.image import ImageClassificationData, ImageClassifier + +# 1 Download data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip") + +# 2 Load data +datamodule = ImageClassificationData.from_folders( + train_folder="data/hymenoptera_data/train/", + val_folder="data/hymenoptera_data/val/", + test_folder="data/hymenoptera_data/test/", + predict_folder="data/hymenoptera_data/predict/", +) + +# 3 Fine tune a model +model = ImageClassifier( + backbone="resnet18", + num_classes=datamodule.num_classes, + serializer=Labels(), +) +trainer = flash.Trainer( + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, +) +trainer.finetune( + model, + datamodule=datamodule, + strategy=FreezeUnfreeze(unfreeze_epoch=1), +) +trainer.save_checkpoint("image_classification_model.pt") + +# 4 Predict from checkpoint +model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model.serializer = FiftyOneLabels(return_filepath=True) +predictions = trainer.predict(model, datamodule=datamodule) + +predictions = list(chain.from_iterable(predictions)) # flatten batches + +# 5. Visualize predictions in FiftyOne +# Note: this blocks until the FiftyOne App is closed +session = visualize(predictions) diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py new file mode 100644 index 0000000000..d7bc4cb72a --- /dev/null +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -0,0 +1,93 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import chain + +import fiftyone as fo + +import flash +from flash.core.classification import FiftyOneLabels, Labels, Probabilities +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.core.integrations.fiftyone import visualize +from flash.image import ImageClassificationData, ImageClassifier + +# 1 Download data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip") + +# 2 Load data into FiftyOne +train_dataset = fo.Dataset.from_dir( + dataset_dir="data/hymenoptera_data/train/", + dataset_type=fo.types.ImageClassificationDirectoryTree, +) +val_dataset = fo.Dataset.from_dir( + dataset_dir="data/hymenoptera_data/val/", + dataset_type=fo.types.ImageClassificationDirectoryTree, +) +test_dataset = fo.Dataset.from_dir( + dataset_dir="data/hymenoptera_data/test/", + dataset_type=fo.types.ImageClassificationDirectoryTree, +) + +# 3 Load FiftyOne datasets +datamodule = ImageClassificationData.from_fiftyone( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, +) + +# 4 Fine tune a model +model = ImageClassifier( + backbone="resnet18", + num_classes=datamodule.num_classes, + serializer=Labels(), +) +trainer = flash.Trainer( + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, +) +trainer.finetune( + model, + datamodule=datamodule, + strategy=FreezeUnfreeze(unfreeze_epoch=1), +) +trainer.save_checkpoint("image_classification_model.pt") + +# 5 Predict from checkpoint on data with ground truth +model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") +model.serializer = FiftyOneLabels(return_filepath=False) +datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset) +predictions = trainer.predict(model, datamodule=datamodule) + +predictions = list(chain.from_iterable(predictions)) # flatten batches + +# 6 Add predictions to dataset +test_dataset.set_values("predictions", predictions) + +# 7 Visualize labels in the App +session = fo.launch_app(test_dataset) + +# 8 Evaluate your model +results = test_dataset.evaluate_classifications( + "predictions", + gt_field="ground_truth", + eval_key="eval", +) +results.print_report() +plot = results.plot_confusion_matrix() +plot.show() + +# Only when running this in a script +# Block until the FiftyOne App is closed +session.wait() diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/flash_examples/integrations/fiftyone/image_embedding.py new file mode 100644 index 0000000000..71b9ef68e3 --- /dev/null +++ b/flash_examples/integrations/fiftyone/image_embedding.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import fiftyone as fo +import fiftyone.brain as fob +import numpy as np +import torch + +from flash.core.data.utils import download_data +from flash.image import ImageEmbedder + +# 1 Download data +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip") + +# 2 Load data into FiftyOne +dataset = fo.Dataset.from_dir( + "data/hymenoptera_data/test/", + fo.types.ImageClassificationDirectoryTree, +) + +# 3 Load model +embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) + +# 4 Generate embeddings +filepaths = dataset.values("filepath") +embeddings = np.stack(embedder.predict(filepaths)) + +# 5 Visualize in FiftyOne App +results = fob.compute_visualization(dataset, embeddings=embeddings) + +session = fo.launch_app(dataset) + +plot = results.visualize(labels="ground_truth.label") +plot.show() + +# Only when running this in a script +# Block until the FiftyOne App is closed +session.wait() diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 128176ebde..42dad56450 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -5,3 +5,4 @@ Pillow>=7.2 kornia>=0.5.1,<0.5.4 matplotlib pycocotools>=2.0.2 ; python_version >= "3.7" +fiftyone diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt index da7209cd44..85bc82a5df 100644 --- a/requirements/datatype_video.txt +++ b/requirements/datatype_video.txt @@ -2,3 +2,4 @@ torchvision Pillow>=7.2 kornia>=0.5.1,<0.5.4 pytorchvideo==0.1.0 +fiftyone diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index fd4de14d7e..9281c36ab4 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch -from flash.core.classification import Classes, Labels, Logits, Probabilities +from flash.core.classification import Classes, FiftyOneLabels, Labels, Logits, Probabilities +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE def test_classification_serializers(): @@ -37,3 +40,31 @@ def test_classification_serializers_multi_label(): ) assert Classes(multi_label=True).serialize(example_output) == [1, 2] assert Labels(labels, multi_label=True).serialize(example_output) == ['class_2', 'class_3'] + + +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") +def test_classification_serializers_fiftyone(): + + logits = torch.tensor([-0.1, 0.2, 0.3]) + example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes + labels = ['class_1', 'class_2', 'class_3'] + + predictions = FiftyOneLabels(return_filepath=True).serialize(example_output) + assert predictions["predictions"].label == '2' + assert predictions["filepath"] == "something" + predictions = FiftyOneLabels(labels, return_filepath=True).serialize(example_output) + assert predictions["predictions"].label == 'class_3' + assert predictions["filepath"] == "something" + + predictions = FiftyOneLabels(store_logits=True).serialize(example_output) + assert torch.allclose(torch.tensor(predictions.logits), logits) + assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1]) + assert predictions.label == '2' + predictions = FiftyOneLabels(labels, store_logits=True).serialize(example_output) + assert predictions.label == 'class_3' + + predictions = FiftyOneLabels(store_logits=True, multi_label=True).serialize(example_output) + assert torch.allclose(torch.tensor(predictions.logits), logits) + assert [c.label for c in predictions.classifications] == ['1', '2'] + predictions = FiftyOneLabels(labels, multi_label=True).serialize(example_output) + assert [c.label for c in predictions.classifications] == ['class_2', 'class_3'] diff --git a/tests/core/test_integrations.py b/tests/core/test_integrations.py new file mode 100644 index 0000000000..1dea0a9a81 --- /dev/null +++ b/tests/core/test_integrations.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from pathlib import Path +from typing import List, Optional, Tuple +from unittest import mock + +import pytest + +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from tests.examples.utils import run_test + +root = Path(__file__).parent.parent.parent + + +@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) +@pytest.mark.parametrize( + "folder, file", [ + pytest.param( + "fiftyone", + "image_classification.py", + marks=pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone library isn't installed") + ), + ] +) +def test_integrations(tmpdir, folder, file): + run_test(str(root / "flash_examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py new file mode 100644 index 0000000000..1dea0a9a81 --- /dev/null +++ b/tests/examples/test_integrations.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from pathlib import Path +from typing import List, Optional, Tuple +from unittest import mock + +import pytest + +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from tests.examples.utils import run_test + +root = Path(__file__).parent.parent.parent + + +@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) +@pytest.mark.parametrize( + "folder, file", [ + pytest.param( + "fiftyone", + "image_classification.py", + marks=pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone library isn't installed") + ), + ] +) +def test_integrations(tmpdir, folder, file): + run_test(str(root / "flash_examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 3a9d104a9b..9b440e6f6a 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -29,50 +29,13 @@ _TORCHVISION_GREATER_EQUAL_0_9, _VIDEO_AVAILABLE, ) +from tests.examples.utils import run_test _IMAGE_AVAILABLE = _IMAGE_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_9 root = Path(__file__).parent.parent.parent -def call_script( - filepath: str, - args: Optional[List[str]] = None, - timeout: Optional[int] = 60 * 5, -) -> Tuple[int, str, str]: - with open(filepath, 'r') as original: - data = original.read() - - with open(filepath, 'w') as modified: - modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data) - - if args is None: - args = [] - args = [str(a) for a in args] - command = [sys.executable, "-m", "coverage", "run", filepath] + args - print(" ".join(command)) - p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - try: - stdout, stderr = p.communicate(timeout=timeout) - except subprocess.TimeoutExpired: - p.kill() - stdout, stderr = p.communicate() - stdout = stdout.decode("utf-8") - stderr = stderr.decode("utf-8") - - with open(filepath, 'w') as modified: - modified.write(data) - - return p.returncode, stdout, stderr - - -def run_test(filepath): - code, stdout, stderr = call_script(filepath) - print(f"{filepath} STDOUT: {stdout}") - print(f"{filepath} STDERR: {stderr}") - assert not code - - @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( "folder, file", diff --git a/tests/examples/utils.py b/tests/examples/utils.py new file mode 100644 index 0000000000..9cf5a4e765 --- /dev/null +++ b/tests/examples/utils.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from pathlib import Path +from typing import List, Optional, Tuple +from unittest import mock + + +def call_script( + filepath: str, + args: Optional[List[str]] = None, + timeout: Optional[int] = 60 * 5, +) -> Tuple[int, str, str]: + with open(filepath, 'r') as original: + data = original.read() + + with open(filepath, 'w') as modified: + modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data) + + if args is None: + args = [] + args = [str(a) for a in args] + command = [sys.executable, "-m", "coverage", "run", filepath] + args + print(" ".join(command)) + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + try: + stdout, stderr = p.communicate(timeout=timeout) + except subprocess.TimeoutExpired: + p.kill() + stdout, stderr = p.communicate() + stdout = stdout.decode("utf-8") + stderr = stderr.decode("utf-8") + + with open(filepath, 'w') as modified: + modified.write(data) + + return p.returncode, stdout, stderr + + +def run_test(filepath): + code, stdout, stderr = call_script(filepath) + print(f"{filepath} STDOUT: {stdout}") + print(f"{filepath} STDERR: {stderr}") + assert not code diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index ae7159a34f..d16a9c0d0a 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -21,7 +21,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ImageClassificationData if _IMAGE_AVAILABLE: @@ -29,6 +29,9 @@ import torchvision from PIL import Image +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + def _dummy_image_loader(_): return torch.rand(3, 196, 196) @@ -381,3 +384,60 @@ def test_from_data(data, from_function): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [2, 5] + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") +def test_from_fiftyone(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + dataset = fo.Dataset.from_dir(str(tmpdir), dataset_type=fo.types.ImageDirectory) + s1 = dataset[train_images[0]] + s2 = dataset[train_images[1]] + s1["test"] = fo.Classification(label="1") + s2["test"] = fo.Classification(label="2") + s1.save() + s2.save() + + img_data = ImageClassificationData.from_fiftyone( + train_dataset=dataset, + test_dataset=dataset, + val_dataset=dataset, + label_field="test", + batch_size=2, + num_workers=0, + ) + assert img_data.train_dataloader() is not None + assert img_data.val_dataloader() is not None + assert img_data.test_dataloader() is not None + + # check train data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [0, 1] + + # check val data + data = next(iter(img_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [0, 1] + + # check test data + data = next(iter(img_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [0, 1] diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py index 002f445f8e..711bcc329f 100644 --- a/tests/image/classification/test_data_model_integration.py +++ b/tests/image/classification/test_data_model_integration.py @@ -18,12 +18,15 @@ import torch from flash import Trainer -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier if _IMAGE_AVAILABLE: from PIL import Image +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + def _dummy_image_loader(_): return torch.rand(3, 224, 224) @@ -56,3 +59,39 @@ def test_classification(tmpdir): model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, datamodule=data, strategy="freeze") + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") +def test_classification_fiftyone(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + train_dataset = fo.Dataset.from_dir(str(tmpdir), dataset_type=fo.types.ImageDirectory) + s1 = train_dataset[train_images[0]] + s2 = train_dataset[train_images[1]] + s1["test"] = fo.Classification(label="1") + s2["test"] = fo.Classification(label="2") + s1.save() + s2.save() + + data = ImageClassificationData.from_fiftyone( + train_dataset=train_dataset, + label_field="test", + batch_size=2, + num_workers=0, + image_size=(64, 64), + ) + + model = ImageClassifier(num_classes=2, backbone="resnet18") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, datamodule=data, strategy="freeze") diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 7e5fc5ad9a..b87ba8dec5 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -5,12 +5,15 @@ import pytest from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.detection.data import ObjectDetectionData if _IMAGE_AVAILABLE: from PIL import Image +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + def _create_dummy_coco_json(dummy_json_path): @@ -77,6 +80,46 @@ def _create_synth_coco_dataset(tmpdir): return train_folder, coco_ann_path +def _create_synth_fiftyone_dataset(tmpdir): + img_dir = Path(tmpdir / "fo_imgs") + img_dir.mkdir() + + Image.new('RGB', (1920, 1080)).save(img_dir / "sample_one.png") + Image.new('RGB', (1920, 1080)).save(img_dir / "sample_two.png") + + dataset = fo.Dataset.from_dir( + img_dir, + dataset_type=fo.types.ImageDirectory, + ) + + sample1 = dataset[str(img_dir / "sample_one.png")] + sample2 = dataset[str(img_dir / "sample_two.png")] + + d1 = fo.Detection( + label="person", + bounding_box=[0.3, 0.4, 0.2, 0.2], + ) + d2 = fo.Detection( + label="person", + bounding_box=[0.05, 0.10, 0.28, 0.15], + ) + d3 = fo.Detection( + label="person", + bounding_box=[0.23, 0.14, 0.09, 0.18], + ) + d1["iscrowd"] = 1 + d2["iscrowd"] = 0 + d3["iscrowd"] = 0 + + sample1["ground_truth"] = fo.Detections(detections=[d1]) + sample2["ground_truth"] = fo.Detections(detections=[d2, d3]) + + sample1.save() + sample2.save() + + return dataset + + @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") def test_image_detector_data_from_coco(tmpdir): @@ -121,3 +164,47 @@ def test_image_detector_data_from_coco(tmpdir): assert imgs[0].shape == (3, 1080, 1920) assert len(labels) == 1 assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed") +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") +def test_image_detector_data_from_fiftyone(tmpdir): + + train_dataset = _create_synth_fiftyone_dataset(tmpdir) + + datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + + data = next(iter(datamodule.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + + assert datamodule.val_dataloader() is None + assert datamodule.test_dataloader() is None + + datamodule = ObjectDetectionData.from_fiftyone( + train_dataset=train_dataset, + val_dataset=train_dataset, + test_dataset=train_dataset, + batch_size=1, + num_workers=0, + ) + + data = next(iter(datamodule.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + + data = next(iter(datamodule.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + + assert len(imgs) == 1 + assert imgs[0].shape == (3, 1080, 1920) + assert len(labels) == 1 + assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 428a053b75..a20a4c06d3 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -16,7 +16,7 @@ import pytest import flash -from flash.core.utilities.imports import _COCO_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData @@ -28,6 +28,9 @@ if _COCO_AVAILABLE: from tests.image.detection.test_data import _create_synth_coco_dataset +if _FIFTYONE_AVAILABLE: + from tests.image.detection.test_data import _create_synth_fiftyone_dataset + @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="pycocotools is not installed for testing") @pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") @@ -51,3 +54,27 @@ def test_detection(tmpdir, model, backbone): test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed for testing") +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") +@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) +def test_detection_fiftyone(tmpdir, model, backbone): + + train_dataset = _create_synth_fiftyone_dataset(tmpdir) + + data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + + trainer = flash.Trainer(fast_dev_run=True) + + trainer.finetune(model, data) + + test_image_one = os.fspath(tmpdir / "test_one.png") + test_image_two = os.fspath(tmpdir / "test_two.png") + + Image.new('RGB', (512, 512)).save(test_image_one) + Image.new('RGB', (512, 512)).save(test_image_two) + + test_images = [str(test_image_one), str(test_image_two)] + model.predict(test_images) diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_serialization.py new file mode 100644 index 0000000000..d4f7384786 --- /dev/null +++ b/tests/image/detection/test_serialization.py @@ -0,0 +1,57 @@ +import pytest +import torch + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.image.detection.serialization import FiftyOneDetectionLabels + + +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") +class TestFiftyOneDetectionLabels: + + def test_smoke(self): + serial = FiftyOneDetectionLabels() + assert serial is not None + + def test_serialize_fiftyone(self): + labels = ['class_1', 'class_2', 'class_3'] + serial = FiftyOneDetectionLabels() + filepath_serial = FiftyOneDetectionLabels(return_filepath=True) + threshold_serial = FiftyOneDetectionLabels(threshold=0.9) + labels_serial = FiftyOneDetectionLabels(labels=labels) + + sample = { + DefaultDataKeys.PREDS: [ + { + "boxes": [torch.tensor(20), torch.tensor(30), + torch.tensor(40), torch.tensor(50)], + "labels": torch.tensor(0), + "scores": torch.tensor(0.5), + }, + ], + DefaultDataKeys.METADATA: { + "filepath": "something", + "size": (100, 100), + }, + } + + detections = serial.serialize(sample) + assert len(detections.detections) == 1 + assert detections.detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + assert detections.detections[0].confidence == 0.5 + assert detections.detections[0].label == "0" + + detections = filepath_serial.serialize(sample) + assert len(detections["predictions"].detections) == 1 + assert detections["predictions"].detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + assert detections["predictions"].detections[0].confidence == 0.5 + assert detections["filepath"] == "something" + + detections = threshold_serial.serialize(sample) + assert len(detections.detections) == 0 + + detections = labels_serial.serialize(sample) + assert len(detections.detections) == 1 + assert detections.detections[0].bounding_box == [0.2, 0.3, 0.2, 0.2] + assert detections.detections[0].confidence == 0.5 + assert detections.detections[0].label == "class_1" diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index a45f0a947a..089871fedb 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -9,12 +9,15 @@ from flash import Trainer from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess if _IMAGE_AVAILABLE: from PIL import Image +if _FIFTYONE_AVAILABLE: + import fiftyone as fo + def build_checkboard(n, m, k=8): x = np.zeros((n, m)) @@ -248,6 +251,74 @@ def test_from_files_warning(self, tmpdir): num_classes=num_classes ) + @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") + def test_from_fiftyone(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + + for img_file in images: + _rand_image(img_size).save(img_file) + + targets = [np.array(_rand_labels(img_size, num_classes)) for _ in range(3)] + + dataset = fo.Dataset.from_dir( + str(tmp_dir), + dataset_type=fo.types.ImageDirectory, + ) + + for idx, sample in enumerate(dataset): + sample["ground_truth"] = fo.Segmentation(mask=targets[idx][:, :, 0]) + sample.save() + + # instantiate the data module + + dm = SemanticSegmentationData.from_fiftyone( + train_dataset=dataset, + val_dataset=dataset, + test_dataset=dataset, + predict_dataset=dataset, + batch_size=2, + num_workers=0, + num_classes=num_classes, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check predict data + data = next(iter(dm.predict_dataloader())) + imgs = data[DefaultDataKeys.INPUT] + assert imgs.shape == (2, 3, 196, 196) + def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) diff --git a/tests/image/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py index bb6599fd0e..a865223189 100644 --- a/tests/image/segmentation/test_serialization.py +++ b/tests/image/segmentation/test_serialization.py @@ -2,7 +2,8 @@ import torch from flash.core.data.data_source import DefaultDataKeys -from flash.image.segmentation.serialization import SegmentationLabels +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.image.segmentation.serialization import FiftyOneSegmentationLabels, SegmentationLabels class TestSemanticSegmentationLabels: @@ -35,6 +36,31 @@ def test_serialize(self): assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 + @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") + def test_serialize_fiftyone(self): + serial = FiftyOneSegmentationLabels() + filepath_serial = FiftyOneSegmentationLabels(return_filepath=True) + + preds = torch.zeros(5, 2, 3) + preds[1, 1, 2] = 1 # add peak in class 2 + preds[3, 0, 1] = 1 # add peak in class 4 + + sample = { + DefaultDataKeys.PREDS: preds, + DefaultDataKeys.METADATA: { + "filepath": "something" + }, + } + + segmentation = serial.serialize(sample) + assert segmentation.mask[1, 2] == 1 + assert segmentation.mask[0, 1] == 3 + + segmentation = filepath_serial.serialize(sample) + assert segmentation["predictions"].mask[1, 2] == 1 + assert segmentation["predictions"].mask[0, 1] == 3 + assert segmentation["filepath"] == "something" + # TODO: implement me def test_create_random_labels(self): pass diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index a9830cea26..fcb1bc68bd 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -14,13 +14,17 @@ import contextlib import os import tempfile +from pathlib import Path import pytest import torch from torch.utils.data import SequentialSampler import flash -from flash.core.utilities.imports import _VIDEO_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE + +if _FIFTYONE_AVAILABLE: + import fiftyone as fo if _VIDEO_AVAILABLE: import kornia.augmentation as K @@ -45,7 +49,7 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int): # https://github.com/facebookresearch/pytorchvideo/blob/4feccb607d7a16933d485495f91d067f177dd8db/tests/utils.py#L33 @contextlib.contextmanager -def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None): +def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None, directory=None): """ Creates a temporary lossless, mp4 video with synthetic content. Uses a context which deletes the video after exit. @@ -54,7 +58,7 @@ def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=No video_codec = "libx264rgb" options = {"crf": "0"} data = create_dummy_video_frames(num_frames, height, width) - with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".mp4") as f: + with tempfile.NamedTemporaryFile(prefix=prefix, suffix=".mp4", dir=directory) as f: f.close() io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) yield f.name, thwc_to_cthw(data).to(torch.float32) @@ -94,8 +98,33 @@ def mock_encoded_video_dataset_file(): yield f.name, label_videos, video_duration +@contextlib.contextmanager +def mock_encoded_video_dataset_folder(tmpdir): + """ + Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2. + Returns a directory that to this mock encoded video dataset and the video duration in seconds. + """ + num_frames = 10 + fps = 5 + + tmp_dir = Path(tmpdir) + os.makedirs(str(tmp_dir / "c1")) + os.makedirs(str(tmp_dir / "c2")) + + with temp_encoded_video(num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c1")) as ( + video_file_name_1, + data_1, + ): + with temp_encoded_video(num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c2")) as ( + video_file_name_2, + data_2, + ): + video_duration = num_frames / fps + yield str(tmp_dir), video_duration + + @pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") -def test_image_classifier_finetune(tmpdir): +def test_video_classifier_finetune(tmpdir): with mock_encoded_video_dataset_file() as ( mock_csv, @@ -160,6 +189,76 @@ def test_image_classifier_finetune(tmpdir): trainer.finetune(model, datamodule=datamodule) +@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") +def test_video_classifier_finetune_fiftyone(tmpdir): + + with mock_encoded_video_dataset_folder(tmpdir) as ( + dir_name, + total_duration, + ): + + half_duration = total_duration / 2 - 1e-9 + + train_dataset = fo.Dataset.from_dir( + dir_name, + dataset_type=fo.types.VideoClassificationDirectoryTree, + ) + datamodule = VideoClassificationData.from_fiftyone( + train_dataset=train_dataset, + clip_sampler="uniform", + clip_duration=half_duration, + video_sampler=SequentialSampler, + decode_audio=False, + ) + + for sample in datamodule.train_dataset.data: + expected_t_shape = 5 + assert sample["video"].shape[1] == expected_t_shape + + assert len(VideoClassifier.available_backbones()) > 5 + + train_transform = { + "post_tensor_transform": Compose([ + ApplyTransformToKey( + key="video", + transform=Compose([ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ]), + ), + ]), + "per_batch_transform_on_device": Compose([ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + data_format="BCTHW", + same_on_frame=False + ) + ), + ]), + } + + datamodule = VideoClassificationData.from_fiftyone( + train_dataset=train_dataset, + clip_sampler="uniform", + clip_duration=half_duration, + video_sampler=SequentialSampler, + decode_audio=False, + train_transform=train_transform + ) + + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False) + + trainer = flash.Trainer(fast_dev_run=True) + + trainer.finetune(model, datamodule=datamodule) + + @pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_jit(tmpdir): sample_input = torch.rand(1, 3, 32, 256, 256)