Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Adding support for loading datasets and visualizing model predictions…
Browse files Browse the repository at this point in the history
… via FiftyOne (#360)

* add fiftyone module availability

* add fiftyone datasource

* add video classification data source

* add fiftyone classification serializer

* optimizations, rework fo serializer

* support classification, detection, segmentation

* values list, load segmentation dataset in load sample

* FiftyOneLabels test

* serializer and detection tests

* fiftyone classification tests

* segmentation and video tests

* add detections serializiation test

* cleanup

* fix test

* inherit fiftyonedatasource

* tweaks

* fix class index

* adding helper functions for common operations

* updating interface

* always use a Label class

* exposing base class params

* indent

* revert segmentation optimization

* revert to mutli

* linting

* adding support for label thresholding

* linting

* update changelog

* resolve some issues, clean API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* normalize detection labels

* normalize detection labels

* fiftyone serializer return filepaths, paths data sources store filepaths

* formatting

* formatting

* remove fiftyone from dir, rename fiftyone_visualize()

* update metadata to contain dictionaries

* add fiftyone docs

* update fiftyone examples

* rename from_fiftyone_datasets to from_fiftyone

* fiftyone_visualize to visualize, update docs

* resolve comments

* resolve test issues

* formatting

* formatting

* yapf formatting

* update for current FO version

* resolve metadata batch issue

* use current FO release, update test requirements

* syntax

* update test

* update tests

* yapf formatting

* one more test...

Co-authored-by: brimoor <[email protected]>
Co-authored-by: tchaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 15, 2021
1 parent 01fef90 commit 25d6633
Show file tree
Hide file tree
Showing 43 changed files with 1,901 additions and 101 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added integration with FiftyOne ([#360](https://github.com/PyTorchLightning/lightning-flash/pull/360))
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _load_py_module(fname, pkg="flash"):
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
"fiftyone": ("https://voxel51.com/docs/fiftyone/", None),
}

# -- Options for HTML output -------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ Lightning Flash
code/text
code/video

.. toctree::
:maxdepth: 1
:caption: Integrations

integrations/fiftyone


.. toctree::
:maxdepth: 1
:caption: Contributing a Task
Expand Down
148 changes: 148 additions & 0 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
########
FiftyOne
########

We have collaborated with the team at
`Voxel51 <https://voxel51.com>`_ to integrate their tool,
`FiftyOne <https://fiftyone.ai>`_, into Lightning Flash.

FiftyOne is an open-source tool for building high-quality
datasets and computer vision models. The FiftyOne API and App enable you to
visualize datasets and interpret models faster and more effectively.

This integration allows you to view predictions generated by your tasks in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`, as well as easily incorporate
:ref:`FiftyOne Datasets <fiftyone:fiftyone-basics>` into your tasks. All image and video tasks
are supported!

.. raw:: html

<div style="margin-top: 20px; margin-bottom: 20px">
<video controls muted poster="https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/fiftyone_poster.png" style="width: 100%;">
<source src="https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/fiftyone_long_sizzle.mp4" type="video/mp4">
</video>
</div>

************
Installation
************

In order to utilize this integration with FiftyOne, you will need to
:ref:`install the tool<fiftyone:installing-fiftyone>`:

.. code:: shell
pip install fiftyone
*****************************
Visualizing Flash predictions
*****************************

This section shows you how to augment your existing Lightning Flash workflows
with a couple of lines of code that let you visualize predictions in FiftyOne.
You can visualize predictions for classification, object detection, and
semantic segmentation tasks. Doing so is as easy as updating your model to use
one of the following serializers:

* :class:`FiftyOneLabels(return_filepath=True)<flash.core.classification.FiftyOneLabels>`
* :class:`FiftyOneSegmentationLabels(return_filepath=True)<flash.image.segmentation.serialization.FiftyOneSegmentationLabels>`
* :class:`FiftyOneDetectionLabels(return_filepath=True)<flash.image.detection.serialization.FiftyOneDetectionLabels>`

The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`. This function accepts a list of
dictionaries containing :ref:`FiftyOne Label<fiftyone:using-labels>` objects
and filepaths which is the exact output of the FiftyOne serializers when the flag
``return_filepath=True`` is specified.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification.py
:language: python
:lines: 14-


***********************
Using FiftyOne datasets
***********************

The above workflow is great for visualizing model predictions. However, if you
store your data in a FiftyOne Dataset initially, then you can also visualize
ground truth annotations. This allows you to perform more complex analysis with
:ref:`views <fiftyone:using-views>` into your data and
:ref:`evaluation <fiftyone:evaluating-models>` of your model results.

The
:meth:`~flash.core.data.data_module.DataModule.from_fiftyone`
method allows you to load your FiftyOne Datasets directly into a
:class:`~flash.core.data.data_module.DataModule` to be used for training,
testing, or inference.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
:language: python
:lines: 14-


**********************
Visualizing embeddings
**********************

FiftyOne provides the methods for
:ref:`dimensionality reduction<fiftyone:brain-embeddings-visualization>` and
:ref:`interactive plotting<fiftyone:embeddings-plots>`. When combined with
:ref:`embedding tasks <image_embedder>` in Flash, you can accomplish
powerful workflows like clustering, similarity search, pre-annotation, and more
in only a few lines of code.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_embedding.py
:language: python
:lines: 14-

.. image:: https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/embeddings.png
:alt: embeddings_example
:align: center

------

*************
API reference
*************

.. _from_fiftyone:

DataModule.from_fiftyone
------------------------

.. automethod:: flash.core.data.data_module.DataModule.from_fiftyone
:noindex:

.. _fiftyone_labels:

FiftyOneLabels
--------------

.. autoclass:: flash.core.classification.FiftyOneLabels
:members:

.. _fiftyone_segmentation_labels:

FiftyOneSegmentationLabels
--------------------------

.. autoclass:: flash.image.segmentation.serialization.FiftyOneSegmentationLabels
:members:

.. _fiftyone_detection_labels:

FiftyOneDetectionLabels
-----------------------

.. autoclass:: flash.image.detection.serialization.FiftyOneDetectionLabels
:members:


.. _fiftyone_visualize:

visualize
---------

.. autofunction:: flash.core.integrations.fiftyone.visualize
2 changes: 1 addition & 1 deletion docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Here's how it looks (from `video/classification.data.py <https://github.com/PyTo
.. literalinclude:: ../../../flash/video/classification/data.py
:language: python
:dedent: 4
:pyobject: VideoClassificationPathsDataSource.load_data
:pyobject: BaseVideoClassification.load_data

Preprocess
^^^^^^^^^^
Expand Down
144 changes: 142 additions & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.data.data_source import LabelsState
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -80,6 +87,8 @@ class Logits(ClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()


Expand All @@ -88,6 +97,8 @@ class Probabilities(ClassificationSerializer):
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()
Expand All @@ -109,6 +120,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand Down Expand Up @@ -140,6 +153,8 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand All @@ -158,3 +173,128 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
else:
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
return classes


class FiftyOneLabels(ClassificationSerializer):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne classification format.
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this
threshold will be replaced with None
store_logits: Boolean determining whether to store logits in the FiftyOne labels
return_filepath: Boolean determining whether to return a dict
containing filepath and FiftyOne labels (True) or only a
list of FiftyOne labels (False)
"""

def __init__(
self,
labels: Optional[List[str]] = None,
multi_label: bool = False,
threshold: Optional[float] = None,
store_logits: bool = False,
return_filepath: bool = False,
):
if not _FIFTYONE_AVAILABLE:
raise ModuleNotFoundError("Please, run `pip install fiftyone`.")

if multi_label and threshold is None:
threshold = 0.5

super().__init__(multi_label=multi_label)
self._labels = labels
self.threshold = threshold
self.store_logits = store_logits
self.return_filepath = return_filepath

if labels is not None:
self.set_state(LabelsState(labels))

def serialize(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)

labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

logits = None
if self.store_logits:
logits = pred.tolist()

if self.multi_label:
one_hot = (pred.sigmoid() > self.threshold).int().tolist()
classes = []
for index, value in enumerate(one_hot):
if value == 1:
classes.append(index)
probabilities = torch.sigmoid(pred).tolist()
else:
classes = torch.argmax(pred, -1).tolist()
probabilities = torch.softmax(pred, -1).tolist()

if labels is not None:
if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=labels[idx],
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=labels[classes],
confidence=confidence,
logits=logits,
)
else:
rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning)

if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=str(idx),
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=str(classes),
confidence=confidence,
logits=logits,
)

if self.return_filepath:
filepath = sample[DefaultDataKeys.METADATA]["filepath"]
return {"filepath": filepath, "predictions": fo_predictions}
else:
return fo_predictions
2 changes: 1 addition & 1 deletion flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def forward(self, samples: Sequence[Any]) -> Any:
with self._collate_context:
samples, metadata = self._extract_metadata(samples)
samples = self.collate_fn(samples)
if metadata:
if metadata and isinstance(samples, dict):
samples[DefaultDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)

Expand Down
Loading

0 comments on commit 25d6633

Please sign in to comment.