Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Minor FiftyOne integration tweaks (#418)
Browse files Browse the repository at this point in the history
* add fiftyone module availability

* add fiftyone datasource

* add video classification data source

* add fiftyone classification serializer

* optimizations, rework fo serializer

* support classification, detection, segmentation

* values list, load segmentation dataset in load sample

* FiftyOneLabels test

* serializer and detection tests

* fiftyone classification tests

* segmentation and video tests

* add detections serializiation test

* cleanup

* fix test

* inherit fiftyonedatasource

* tweaks

* fix class index

* adding helper functions for common operations

* updating interface

* always use a Label class

* exposing base class params

* indent

* revert segmentation optimization

* revert to mutli

* linting

* adding support for label thresholding

* linting

* update changelog

* resolve some issues, clean API

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* normalize detection labels

* normalize detection labels

* fiftyone serializer return filepaths, paths data sources store filepaths

* formatting

* formatting

* remove fiftyone from dir, rename fiftyone_visualize()

* update metadata to contain dictionaries

* add fiftyone docs

* update fiftyone examples

* rename from_fiftyone_datasets to from_fiftyone

* fiftyone_visualize to visualize, update docs

* resolve comments

* resolve test issues

* formatting

* formatting

* yapf formatting

* update for current FO version

* resolve metadata batch issue

* use current FO release, update test requirements

* syntax

* update test

* update tests

* yapf formatting

* one more test...

* update readme

* update readme

* update readme

* improving visualize() method docs

* making docstrings sphinx-friendly

* docs tweaks

* fiftyone integration docs update

* simpler

* using FiftyOne page title

* update embedder example

* isort

* yapf formatting

* fix docstring

Co-authored-by: Ethan Harris <[email protected]>

* update docstrings

Co-authored-by: Ethan Harris <[email protected]>

Co-authored-by: brimoor <[email protected]>
Co-authored-by: tchaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
5 people authored Jun 16, 2021
1 parent 146e05d commit 113efab
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 72 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,41 @@ classifier = LinearClassifier(128, 10)
When you reach the limits of the flexibility provided by Flash, then seamlessly transition to PyTorch Lightning which
gives you the most flexibility because it is simply organized PyTorch.

## Visualization

Predictions from image and video tasks can be visualized through an [integration with FiftyOne](https://lightning-flash.readthedocs.io/en/latest/integrations/fiftyone.html), allowing you to better understand and analyze how your model is performing.

```python
from flash.core.data.utils import download_data
from flash.core.integrations.fiftyone import visualize
from flash.image import ObjectDetector
from flash.image.detection.serialization import FiftyOneDetectionLabels

# 1. Download the data
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data(
"https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
"data/",
)

# 2. Load the model from a checkpoint and use the FiftyOne serializer
model = ObjectDetector.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/object_detection_model.pt"
)
model.serializer = FiftyOneDetectionLabels()

# 3. Detect the object on the images
filepaths = [
"data/coco128/images/train2017/000000000025.jpg",
"data/coco128/images/train2017/000000000520.jpg",
"data/coco128/images/train2017/000000000532.jpg",
]
predictions = model.predict(filepaths)

# 4. Visualize predictions in FiftyOne App
session = visualize(predictions, filepaths=filepaths)
```

## Contribute!
The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!

Expand Down
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ Lightning Flash

integrations/fiftyone


.. toctree::
:maxdepth: 1
:caption: Contributing a Task
Expand Down
31 changes: 23 additions & 8 deletions docs/source/integrations/fiftyone.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ are supported!
Installation
************

In order to utilize this integration with FiftyOne, you will need to
:ref:`install the tool<fiftyone:installing-fiftyone>`:
In order to utilize this integration, you will need to
:ref:`install FiftyOne <fiftyone:installing-fiftyone>`:

.. code:: shell
pip install fiftyone
*****************************
Visualizing Flash predictions
*****************************
Expand All @@ -52,14 +51,30 @@ one of the following serializers:
The :func:`~flash.core.integrations.fiftyone.visualize` function then lets you visualize
your predictions in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`. This function accepts a list of
dictionaries containing :ref:`FiftyOne Label<fiftyone:using-labels>` objects
and filepaths which is the exact output of the FiftyOne serializers when the flag
``return_filepath=True`` is specified.
dictionaries containing :ref:`FiftyOne Label <fiftyone:using-labels>` objects
and filepaths, which is exactly the output of the FiftyOne serializers when the
``return_filepath=True`` option is specified.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification.py
:language: python
:lines: 14-

The :func:`~flash.core.integrations.fiftyone.visualize` function can be used in
all of the following environments:

- **Local Python shell**: The App will launch in a new tab in your default
web browser
- **Remote Python shell**: Pass the ``remote=True`` option and then follow
the instructions printed to your remote shell to open the App in your
browser on your local machine
- **Jupyter notebook**: The App will launch in the output of your current
cell
- **Google Colab**: The App will launch in the output of your current cell
- **Python script**: Pass the ``wait=True`` option to block execution of your
script until the App is closed

See :ref:`this page <fiftyone:environments>` for more information about
using the FiftyOne App in different environments.

***********************
Using FiftyOne datasets
Expand All @@ -73,15 +88,14 @@ ground truth annotations. This allows you to perform more complex analysis with

The
:meth:`~flash.core.data.data_module.DataModule.from_fiftyone`
method allows you to load your FiftyOne Datasets directly into a
method allows you to load your FiftyOne datasets directly into a
:class:`~flash.core.data.data_module.DataModule` to be used for training,
testing, or inference.

.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
:language: python
:lines: 14-


**********************
Visualizing embeddings
**********************
Expand Down Expand Up @@ -130,6 +144,7 @@ FiftyOneSegmentationLabels

.. autoclass:: flash.image.segmentation.serialization.FiftyOneSegmentationLabels
:members:
:noindex:

.. _fiftyone_detection_labels:

Expand Down
67 changes: 44 additions & 23 deletions flash/core/integrations/fiftyone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,51 +20,72 @@


def visualize(
labels: Union[List[Label], List[Dict[str, Label]]],
predictions: Union[List[Label], List[Dict[str, Label]]],
filepaths: Optional[List[str]] = None,
wait: Optional[bool] = True,
label_field: Optional[str] = "predictions",
wait: Optional[bool] = False,
**kwargs
) -> Optional[Session]:
"""Use the result of a FiftyOne serializer to visualize predictions in the
FiftyOne App.
"""Visualizes predictions from a model with a FiftyOne Serializer in the
:ref:`FiftyOne App <fiftyone:fiftyone-app>`.
This method can be used in all of the following environments:
- **Local Python shell**: The App will launch in a new tab in your
default web browser.
- **Remote Python shell**: Pass the ``remote=True`` option to this method
and then follow the instructions printed to your remote shell to open
the App in your browser on your local machine.
- **Jupyter notebook**: The App will launch in the output of your current
cell.
- **Google Colab**: The App will launch in the output of your current
cell.
- **Python script**: Pass the ``wait=True`` option to block execution of
your script until the App is closed.
See :ref:`this page <fiftyone:environments>` for more information about
using the FiftyOne App in different environments.
Args:
labels: Either a list of FiftyOne labels that will be applied to the
corresponding filepaths provided with through `filepath` or
`datamodule`. Or a list of dictionaries containing image/video
filepaths and corresponding FiftyOne labels.
predictions: Can be either a list of FiftyOne labels that will be
matched with the corresponding ``filepaths``, or a list of
dictionaries with "filepath" and "predictions" keys that contains
the filepaths and predictions.
filepaths: A list of filepaths to images or videos corresponding to the
provided `labels`.
wait: A boolean determining whether to launch the FiftyOne session and
wait until the session is closed or whether to return immediately.
label_field: The string of the label field in the FiftyOne dataset
containing predictions
provided ``predictions``.
label_field: The name of the label field in which to store the
predictions in the FiftyOne dataset.
wait: Whether to block execution until the FiftyOne App is closed.
**kwargs: Optional keyword arguments for
:meth:`fiftyone:fiftyone.core.session.launch_app`.
Returns:
a :class:`fiftyone:fiftyone.core.session.Session`
"""
if not _FIFTYONE_AVAILABLE:
raise ModuleNotFoundError("Please, `pip install fiftyone`.")
if flash._IS_TESTING:
return None

# Flatten list if batches were used
if all(isinstance(fl, list) for fl in labels):
labels = list(chain.from_iterable(labels))
if all(isinstance(fl, list) for fl in predictions):
predictions = list(chain.from_iterable(predictions))

if all(isinstance(fl, dict) for fl in labels):
filepaths = [lab["filepath"] for lab in labels]
labels = [lab["predictions"] for lab in labels]
if all(isinstance(fl, dict) for fl in predictions):
filepaths = [lab["filepath"] for lab in predictions]
labels = [lab["predictions"] for lab in predictions]
else:
labels = predictions

if filepaths is None:
raise ValueError("The `filepaths` argument is required if filepaths are not provided in `labels`.")

dataset = fo.Dataset()
if filepaths:
dataset.add_labeled_images(
list(zip(filepaths, labels)),
LabeledImageTupleSampleParser(),
label_field=label_field,
)
dataset.add_samples([fo.Sample(filepath=f, **{label_field: l}) for f, l in zip(filepaths, labels)])

session = fo.launch_app(dataset, **kwargs)
if wait:
session.wait()

return session
34 changes: 18 additions & 16 deletions flash/image/segmentation/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@


class SegmentationLabels(Serializer):
"""A :class:`.Serializer` which converts the model outputs to the label of
the argmax classification per pixel in the image for semantic segmentation
tasks.
def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
"""A :class:`.Serializer` which converts the model outputs to the label of the argmax classification
per pixel in the image for semantic segmentation tasks.
Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualize: Wether to visualize the image labels.
"""

Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualize: Wether to visualize the image labels.
"""
def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
super().__init__()
self.labels_map = labels_map
self.visualize = visualize
Expand Down Expand Up @@ -89,22 +90,23 @@ def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor:


class FiftyOneSegmentationLabels(SegmentationLabels):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne
segmentation format.
Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualize: whether to visualize the image labels.
return_filepath: Boolean determining whether to return a dict
containing filepath and FiftyOne labels (True) or only a list of
FiftyOne labels (False).
"""

def __init__(
self,
labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None,
visualize: bool = False,
return_filepath: bool = False,
):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne segmentation format.
Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualize: Wether to visualize the image labels.
return_filepath: Boolean determining whether to return a dict
containing filepath and FiftyOne labels (True) or only a
list of FiftyOne labels (False)
"""
if not _FIFTYONE_AVAILABLE:
raise ModuleNotFoundError("Please, run `pip install fiftyone`.")

Expand Down
7 changes: 3 additions & 4 deletions flash_examples/integrations/fiftyone/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=True)
model.serializer = FiftyOneLabels(return_filepath=True) # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # flatten batches

# 5. Visualize predictions in FiftyOne
# Note: this blocks until the FiftyOne App is closed
# 5 Visualize predictions in FiftyOne App
# Optional: pass `wait=True` to block execution until App is closed
session = visualize(predictions)
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
dataset_type=fo.types.ImageClassificationDirectoryTree,
)

# 3 Load FiftyOne datasets
# 3 Load data into Flash
datamodule = ImageClassificationData.from_fiftyone(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
)

# 4 Fine tune a model
# 4 Fine tune model
model = ImageClassifier(
backbone="resnet18",
num_classes=datamodule.num_classes,
Expand All @@ -66,28 +66,22 @@

# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=False)
model.serializer = FiftyOneLabels(return_filepath=False) # output FiftyOne format
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # flatten batches

# 6 Add predictions to dataset
test_dataset.set_values("predictions", predictions)

# 7 Visualize labels in the App
session = fo.launch_app(test_dataset)

# 8 Evaluate your model
results = test_dataset.evaluate_classifications(
"predictions",
gt_field="ground_truth",
eval_key="eval",
)
# 7 Evaluate your model
results = test_dataset.evaluate_classifications("predictions", gt_field="ground_truth", eval_key="eval")
results.print_report()
plot = results.plot_confusion_matrix()
plot.show()

# Only when running this in a script
# Block until the FiftyOne App is closed
# 8 Visualize results in the App
session = fo.launch_app(test_dataset)

# Optional: block execution until App is closed
session.wait()
7 changes: 2 additions & 5 deletions flash_examples/integrations/fiftyone/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,17 @@
)

# 3 Load model
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)
embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128)

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)

session = fo.launch_app(dataset)

plot = results.visualize(labels="ground_truth.label")
plot.show()

# Only when running this in a script
# Block until the FiftyOne App is closed
# Optional: block execution until App is closed
session.wait()

0 comments on commit 113efab

Please sign in to comment.