Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Feature/fiftyone 2 #365

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a2caa9e
add fiftyone module availability
ehofesmann May 29, 2021
9e7b38e
add fiftyone datasource
ehofesmann May 29, 2021
2d55c9e
add video classification data source
ehofesmann May 29, 2021
35ef87e
add fiftyone classification serializer
ehofesmann May 29, 2021
2c94a5c
optimizations, rework fo serializer
ehofesmann Jun 2, 2021
c6587ea
support classification, detection, segmentation
ehofesmann Jun 3, 2021
3289ffb
values list, load segmentation dataset in load sample
ehofesmann Jun 3, 2021
3748dbf
FiftyOneLabels test
ehofesmann Jun 3, 2021
7ca3d92
serializer and detection tests
ehofesmann Jun 3, 2021
95690ac
fiftyone classification tests
ehofesmann Jun 3, 2021
5883aae
segmentation and video tests
ehofesmann Jun 3, 2021
ab34980
add detections serializiation test
ehofesmann Jun 3, 2021
036a28b
cleanup
brimoor Jun 3, 2021
1489954
cleanup
brimoor Jun 3, 2021
d4616b1
fix test
ehofesmann Jun 3, 2021
7131cf6
Merge branch 'feature/fiftyone' of github.com:voxel51/lightning-flash…
ehofesmann Jun 3, 2021
51d8046
inherit fiftyonedatasource
ehofesmann Jun 3, 2021
68e1ce3
tweaks
brimoor Jun 3, 2021
6ac7ced
Merge branch 'feature/fiftyone' of https://github.com/voxel51/lightni…
brimoor Jun 3, 2021
78e033a
fix class index
ehofesmann Jun 3, 2021
00cb79d
Merge branch 'feature/fiftyone' of github.com:voxel51/lightning-flash…
ehofesmann Jun 3, 2021
073ce6d
adding helper functions for common operations
brimoor Jun 3, 2021
e57ca24
updating interface
brimoor Jun 3, 2021
4a64b11
always use a Label class
brimoor Jun 4, 2021
42474d8
exposing base class params
brimoor Jun 4, 2021
31229fd
merge
ehofesmann Jun 4, 2021
5b1b09c
indent
ehofesmann Jun 4, 2021
64163b3
revert segmentation optimization
ehofesmann Jun 4, 2021
3763262
revert to mutli
ehofesmann Jun 4, 2021
30f0ec3
linting
brimoor Jun 4, 2021
9b91ea1
adding support for label thresholding
brimoor Jun 4, 2021
09858af
linting
ehofesmann Jun 4, 2021
62d280f
merge
ehofesmann Jun 4, 2021
0561388
Merge branch 'master' into feature/fiftyone
ehofesmann Jun 4, 2021
0ce6ede
update changelog
ehofesmann Jun 4, 2021
5278af2
resolve some issues, clean API
tchaton Jun 4, 2021
f204bb2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.1] - YYYY-MM-DD

### Added

- Added integration with FiftyOne ([#360](https://github.com/PyTorchLightning/lightning-flash/pull/360))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
Expand Down
132 changes: 130 additions & 2 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.data.data_source import LabelsState
from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.data.process import Serializer
from flash.core.model import Task
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

if _FIFTYONE_AVAILABLE:
import fiftyone as fo
from fiftyone.core.labels import Classification, Classifications
else:
Classification, Classifications = None, None


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -80,6 +87,8 @@ class Logits(ClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
return sample.tolist()


Expand All @@ -88,6 +97,8 @@ class Probabilities(ClassificationSerializer):
list."""

def serialize(self, sample: Any) -> Any:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
return torch.sigmoid(sample).tolist()
return torch.softmax(sample, -1).tolist()
Expand All @@ -109,6 +120,8 @@ def __init__(self, multi_label: bool = False, threshold: float = 0.5):
self.threshold = threshold

def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
Expand Down Expand Up @@ -140,6 +153,8 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
sample = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
sample = torch.tensor(sample)
labels = None

if self._labels is not None:
Expand All @@ -158,3 +173,116 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
else:
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
return classes


class FiftyOneLabels(ClassificationSerializer):
"""A :class:`.Serializer` which converts the model outputs to FiftyOne classification format.

Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this
threshold will be replaced with None
store_logits: Boolean determining whether to store logits in the FiftyOne labels
"""

def __init__(
self,
labels: Optional[List[str]] = None,
multi_label: bool = False,
threshold: Optional[float] = None,
store_logits: bool = False,
):
if not _FIFTYONE_AVAILABLE:
raise ModuleNotFoundError("Please, run `pip install fiftyone`.")

if multi_label and threshold is None:
threshold = 0.5

super().__init__(multi_label=multi_label)
self._labels = labels
self.threshold = threshold
self.store_logits = store_logits

if labels is not None:
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[Classification, Classifications]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)
metadata = sample[DefaultDataKeys.METADATA]
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

logits = None
if self.store_logits:
logits = pred.tolist()

if self.multi_label:
one_hot = (pred.sigmoid() > self.threshold).int().tolist()
classes = []
for index, value in enumerate(one_hot):
if value == 1:
classes.append(index)
probabilities = torch.sigmoid(pred).tolist()
else:
classes = torch.argmax(pred, -1).tolist()
probabilities = torch.softmax(pred, -1).tolist()

if labels is not None:
if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=labels[idx],
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=labels[classes],
confidence=confidence,
logits=logits,
)
else:
rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning)

if self.multi_label:
classifications = []
for idx in classes:
fo_cls = Classification(
label=str(idx),
confidence=probabilities[idx],
)
classifications.append(fo_cls)
fo_predictions = Classifications(
classifications=classifications,
logits=logits,
)
else:
confidence = max(probabilities)
if self.threshold is not None and confidence < self.threshold:
fo_predictions = None
else:
fo_predictions = Classification(
label=str(classes),
confidence=confidence,
logits=logits,
)

return fo.Sample(filepath=metadata.filepath, predictions=fo_predictions)
Loading