Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Adding support for loading datasets and visualizing model predictions via FiftyOne #360

Merged
merged 64 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a2caa9e
add fiftyone module availability
ehofesmann May 29, 2021
9e7b38e
add fiftyone datasource
ehofesmann May 29, 2021
2d55c9e
add video classification data source
ehofesmann May 29, 2021
35ef87e
add fiftyone classification serializer
ehofesmann May 29, 2021
2c94a5c
optimizations, rework fo serializer
ehofesmann Jun 2, 2021
c6587ea
support classification, detection, segmentation
ehofesmann Jun 3, 2021
3289ffb
values list, load segmentation dataset in load sample
ehofesmann Jun 3, 2021
3748dbf
FiftyOneLabels test
ehofesmann Jun 3, 2021
7ca3d92
serializer and detection tests
ehofesmann Jun 3, 2021
95690ac
fiftyone classification tests
ehofesmann Jun 3, 2021
5883aae
segmentation and video tests
ehofesmann Jun 3, 2021
ab34980
add detections serializiation test
ehofesmann Jun 3, 2021
036a28b
cleanup
brimoor Jun 3, 2021
1489954
cleanup
brimoor Jun 3, 2021
d4616b1
fix test
ehofesmann Jun 3, 2021
7131cf6
Merge branch 'feature/fiftyone' of github.com:voxel51/lightning-flash…
ehofesmann Jun 3, 2021
51d8046
inherit fiftyonedatasource
ehofesmann Jun 3, 2021
68e1ce3
tweaks
brimoor Jun 3, 2021
6ac7ced
Merge branch 'feature/fiftyone' of https://github.com/voxel51/lightni…
brimoor Jun 3, 2021
78e033a
fix class index
ehofesmann Jun 3, 2021
00cb79d
Merge branch 'feature/fiftyone' of github.com:voxel51/lightning-flash…
ehofesmann Jun 3, 2021
073ce6d
adding helper functions for common operations
brimoor Jun 3, 2021
e57ca24
updating interface
brimoor Jun 3, 2021
4a64b11
always use a Label class
brimoor Jun 4, 2021
42474d8
exposing base class params
brimoor Jun 4, 2021
31229fd
merge
ehofesmann Jun 4, 2021
5b1b09c
indent
ehofesmann Jun 4, 2021
64163b3
revert segmentation optimization
ehofesmann Jun 4, 2021
3763262
revert to mutli
ehofesmann Jun 4, 2021
30f0ec3
linting
brimoor Jun 4, 2021
9b91ea1
adding support for label thresholding
brimoor Jun 4, 2021
09858af
linting
ehofesmann Jun 4, 2021
62d280f
merge
ehofesmann Jun 4, 2021
0561388
Merge branch 'master' into feature/fiftyone
ehofesmann Jun 4, 2021
0ce6ede
update changelog
ehofesmann Jun 4, 2021
5278af2
resolve some issues, clean API
tchaton Jun 4, 2021
f204bb2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2021
d282d42
normalize detection labels
ehofesmann Jun 7, 2021
aa2de97
normalize detection labels
ehofesmann Jun 7, 2021
d56e5d3
fiftyone serializer return filepaths, paths data sources store filepaths
ehofesmann Jun 9, 2021
55b11c6
formatting
ehofesmann Jun 9, 2021
8e54021
formatting
ehofesmann Jun 9, 2021
1e1022b
remove fiftyone from dir, rename fiftyone_visualize()
ehofesmann Jun 9, 2021
fd94802
merge master
ehofesmann Jun 11, 2021
7c44cb5
update metadata to contain dictionaries
ehofesmann Jun 11, 2021
080d81a
add fiftyone docs
ehofesmann Jun 11, 2021
f34ab3b
update fiftyone examples
ehofesmann Jun 12, 2021
ba21d61
rename from_fiftyone_datasets to from_fiftyone
ehofesmann Jun 14, 2021
9a67c85
fiftyone_visualize to visualize, update docs
ehofesmann Jun 15, 2021
bfd4051
merge master
ehofesmann Jun 15, 2021
49d63bf
resolve comments
ehofesmann Jun 15, 2021
427d280
resolve test issues
ehofesmann Jun 15, 2021
8fe7d76
formatting
ehofesmann Jun 15, 2021
a908b9d
formatting
ehofesmann Jun 15, 2021
895e0bf
yapf formatting
ehofesmann Jun 15, 2021
45dded6
update for current FO version
ehofesmann Jun 15, 2021
60c3947
resolve metadata batch issue
ehofesmann Jun 15, 2021
1525514
Merge remote-tracking branch 'upstream/master' into feature/fiftyone
ehofesmann Jun 15, 2021
5120209
use current FO release, update test requirements
ehofesmann Jun 15, 2021
bec6461
syntax
ehofesmann Jun 15, 2021
6eef60e
update test
ehofesmann Jun 15, 2021
9c9b321
update tests
ehofesmann Jun 15, 2021
6d6aad3
yapf formatting
ehofesmann Jun 15, 2021
691fc7f
one more test...
ehofesmann Jun 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added integration with FiftyOne ([#360](https://github.com/PyTorchLightning/lightning-flash/pull/360))
- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _load_py_module(fname, pkg="flash"):
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
"fiftyone": ("https://voxel51.com/docs/fiftyone/", None),
}

# -- Options for HTML output -------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ Lightning Flash
code/text
code/video

.. toctree::
:maxdepth: 1
:caption: Integrations

integrations/fiftyone


.. toctree::
:maxdepth: 1
:caption: Contributing a Task
Expand Down
148 changes: 148 additions & 0 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
########
FiftyOne
########

We have collaborated with the team at
`Voxel51 <https://voxel51.com>`_ to integrate their tool,
`FiftyOne <https://fiftyone.ai>`_, into Lightning Flash.

FiftyOne is an open-source tool for building high-quality
datasets and computer vision models. The FiftyOne API and App enable you to
visualize datasets and interpret models faster and more effectively.

This integration allows you to view predictions generated by your tasks in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`, as well as easily incorporate
:ref:`FiftyOne Datasets <fiftyone:fiftyone-basics>` into your tasks. All image and video tasks
are supported!

.. raw:: html

<div style="margin-top: 20px; margin-bottom: 20px">
<video controls muted poster="https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/fiftyone_poster.png" style="width: 100%;">
<source src="https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/fiftyone_long_sizzle.mp4" type="video/mp4">
</video>
</div>

************
Installation
************

In order to utilize this integration with FiftyOne, you will need to
:ref:`install the tool<fiftyone:installing-fiftyone>`:

.. code:: shell

pip install fiftyone


*****************************
Visualizing Flash predictions
*****************************

This section shows you how to augment your existing Lightning Flash workflows
with a couple of lines of code that let you visualize predictions in FiftyOne.
You can visualize predictions for classification, object detection, and
semantic segmentation tasks. Doing so is as easy as updating your model to use
one of the following serializers:

* :class:`FiftyOneLabels(return_filepath=True)<flash.core.classification.FiftyOneLabels>`
* :class:`FiftyOneSegmentationLabels(return_filepath=True)<flash.image.segmentation.serialization.FiftyOneSegmentationLabels>`
* :class:`FiftyOneDetectionLabels(return_filepath=True)<flash.image.detection.serialization.FiftyOneDetectionLabels>`

The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`. This function accepts a list of
dictionaries containing :ref:`FiftyOne Label<fiftyone:using-labels>` objects
and filepaths which is the exact output of the FiftyOne serializers when the flag
``return_filepath=True`` is specified.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification.py
:language: python
:lines: 14-


***********************
Using FiftyOne datasets
***********************

The above workflow is great for visualizing model predictions. However, if you
store your data in a FiftyOne Dataset initially, then you can also visualize
ground truth annotations. This allows you to perform more complex analysis with
:ref:`views <fiftyone:using-views>` into your data and
:ref:`evaluation <fiftyone:evaluating-models>` of your model results.

The
:meth:`~flash.core.data.data_module.DataModule.from_fiftyone`
method allows you to load your FiftyOne Datasets directly into a
:class:`~flash.core.data.data_module.DataModule` to be used for training,
testing, or inference.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
:language: python
:lines: 14-


**********************
Visualizing embeddings
**********************

FiftyOne provides the methods for
:ref:`dimensionality reduction<fiftyone:brain-embeddings-visualization>` and
:ref:`interactive plotting<fiftyone:embeddings-plots>`. When combined with
:ref:`embedding tasks <image_embedder>` in Flash, you can accomplish
powerful workflows like clustering, similarity search, pre-annotation, and more
in only a few lines of code.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_embedding.py
:language: python
:lines: 14-

.. image:: https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/embeddings.png
:alt: embeddings_example
:align: center

------

*************
API reference
*************

.. _from_fiftyone:

DataModule.from_fiftyone
------------------------

.. automethod:: flash.core.data.data_module.DataModule.from_fiftyone
:noindex:

.. _fiftyone_labels:

FiftyOneLabels
--------------

.. autoclass:: flash.core.classification.FiftyOneLabels
:members:

.. _fiftyone_segmentation_labels:

FiftyOneSegmentationLabels
--------------------------

.. autoclass:: flash.image.segmentation.serialization.FiftyOneSegmentationLabels
:members:

.. _fiftyone_detection_labels:

FiftyOneDetectionLabels
-----------------------

.. autoclass:: flash.image.detection.serialization.FiftyOneDetectionLabels
:members:


.. _fiftyone_visualize:

visualize
---------

.. autofunction:: flash.core.integrations.fiftyone.visualize
2 changes: 1 addition & 1 deletion docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Here's how it looks (from `video/classification.data.py <https://github.com/PyTo
.. literalinclude:: ../../../flash/video/classification/data.py
:language: python
:dedent: 4
:pyobject: VideoClassificationPathsDataSource.load_data
:pyobject: BaseVideoClassification.load_data

Preprocess
^^^^^^^^^^
Expand Down
144 changes: 142 additions & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.data.data_source import LabelsState
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -80,6 +87,8 @@ class Logits(ClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()


Expand All @@ -88,6 +97,8 @@ class Probabilities(ClassificationSerializer):
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()
Expand All @@ -109,6 +120,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand Down Expand Up @@ -140,6 +153,8 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand All @@ -158,3 +173,128 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
else:
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
return classes


class FiftyOneLabels(ClassificationSerializer):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne classification format.

Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this
threshold will be replaced with None
store_logits: Boolean determining whether to store logits in the FiftyOne labels
return_filepath: Boolean determining whether to return a dict
containing filepath and FiftyOne labels (True) or only a
list of FiftyOne labels (False)
"""

def __init__(
self,
labels: Optional[List[str]] = None,
multi_label: bool = False,
threshold: Optional[float] = None,
store_logits: bool = False,
return_filepath: bool = False,
):
if not _FIFTYONE_AVAILABLE:
raise ModuleNotFoundError("Please, run `pip install fiftyone`.")

if multi_label and threshold is None:
threshold = 0.5

super().__init__(multi_label=multi_label)
self._labels = labels
self.threshold = threshold
self.store_logits = store_logits
self.return_filepath = return_filepath

if labels is not None:
self.set_state(LabelsState(labels))

def serialize(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)

labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

logits = None
if self.store_logits:
logits = pred.tolist()

if self.multi_label:
one_hot = (pred.sigmoid() > self.threshold).int().tolist()
classes = []
for index, value in enumerate(one_hot):
if value == 1:
classes.append(index)
probabilities = torch.sigmoid(pred).tolist()
else:
classes = torch.argmax(pred, -1).tolist()
probabilities = torch.softmax(pred, -1).tolist()

if labels is not None:
if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=labels[idx],
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=labels[classes],
confidence=confidence,
logits=logits,
)
else:
rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning)

if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=str(idx),
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=str(classes),
confidence=confidence,
logits=logits,
)

if self.return_filepath:
filepath = sample[DefaultDataKeys.METADATA]["filepath"]
return {"filepath": filepath, "predictions": fo_predictions}
else:
return fo_predictions
2 changes: 1 addition & 1 deletion flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def forward(self, samples: Sequence[Any]) -> Any:
with self._collate_context:
samples, metadata = self._extract_metadata(samples)
samples = self.collate_fn(samples)
if metadata:
if metadata and isinstance(samples, dict):
samples[DefaultDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)

Expand Down
Loading