Skip to content

Commit

Permalink
adding support for custom evaluation metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 21, 2024
1 parent 0bd6b39 commit 4fa72a0
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 37 deletions.
13 changes: 12 additions & 1 deletion fiftyone/utils/eval/activitynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class ActivityNetEvaluationConfig(DetectionEvaluationConfig):
that mAP and PR curves can be generated
iou_threshs (None): a list of IoU thresholds to use when computing mAP
and PR curves. Only applicable when ``compute_mAP`` is True
custom_metrics (None): an optional list of custom metrics to compute
or dict mapping metric names to kwargs dicts
"""

def __init__(
Expand All @@ -50,10 +52,16 @@ def __init__(
classwise=None,
compute_mAP=False,
iou_threshs=None,
custom_metrics=None,
**kwargs,
):
super().__init__(
pred_field, gt_field, iou=iou, classwise=classwise, **kwargs
pred_field,
gt_field,
iou=iou,
classwise=classwise,
custom_metrics=custom_metrics,
**kwargs,
)

if compute_mAP and iou_threshs is None:
Expand Down Expand Up @@ -323,6 +331,7 @@ class ActivityNetDetectionResults(DetectionResults):
``num_iou_threshs x num_classes x num_recall``
missing (None): a missing label string. Any unmatched segments are
given this label for evaluation purposes
custom_metrics (None): an optional dict of custom metrics
backend (None): a :class:`ActivityNetEvaluation` backend
"""

Expand All @@ -339,6 +348,7 @@ def __init__(
classes,
thresholds=None,
missing=None,
custom_metrics=None,
backend=None,
):
super().__init__(
Expand All @@ -348,6 +358,7 @@ def __init__(
matches,
classes=classes,
missing=missing,
custom_metrics=custom_metrics,
backend=backend,
)

Expand Down
112 changes: 110 additions & 2 deletions fiftyone/utils/eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,124 @@
|
"""
import itertools
import logging

import numpy as np
import sklearn.metrics as skm

import fiftyone.core.evaluation as foe
import fiftyone.core.plots as fop
import fiftyone.core.utils as fou

foo = fou.lazy_import("fiftyone.operators")


logger = logging.getLogger(__name__)


class BaseEvaluationMethodConfig(foe.EvaluationMethodConfig):
"""Base class for configuring evaluation methods.
Args:
**kwargs: any leftover keyword arguments after subclasses have done
their parsing
"""

pass


class BaseEvaluationMethod(foe.EvaluationMethod):
"""Base class for evaluation methods.
Args:
config: an :class:`BaseEvaluationMethodConfig`
"""

def _get_custom_metrics(self):
if not self.config.custom_metrics:
return {}

if isinstance(self.config.custom_metrics, list):
return {m: None for m in self.config.custom_metrics}

return self.config.custom_metrics

def compute_custom_metrics(self, samples, eval_key, results):
results.custom_metrics = {}

for metric, kwargs in self._get_custom_metrics().items():
try:
operator = foo.get_operator(metric)
value = operator.compute(
samples, eval_key, results, **kwargs or {}
)
if value is not None:
results.custom_metrics[operator.config.label] = value
except Exception as e:
logger.warning(
"Failed to compute metric '%s': Reason: %s",
operator.uri,
e,
)

def get_custom_metric_fields(self, samples, eval_key):
fields = []

for metric in self._get_custom_metrics().keys():
try:
operator = foo.get_operator(metric)
fields.extend(operator.get_fields(samples, eval_key))
except Exception as e:
logger.warning(
"Failed to get fields for metric '%s': Reason: %s",
operator.uri,
e,
)

return fields

def rename_custom_metrics(self, samples, eval_key, new_eval_key):
for metric in self._get_custom_metrics().keys():
try:
operator = foo.get_operator(metric)
operator.rename(samples, eval_key, new_eval_key)
except Exception as e:
logger.warning(
"Failed to rename fields for metric '%s': Reason: %s",
operator.uri,
e,
)

def cleanup_custom_metrics(self, samples, eval_key):
for metric in self._get_custom_metrics().keys():
try:
operator = foo.get_operator(metric)
operator.cleanup(samples, eval_key)
except Exception as e:
logger.warning(
"Failed to cleanup metric '%s': Reason: %s",
operator.uri,
e,
)


class BaseEvaluationResults(foe.EvaluationResults):
"""Base class for evaluation results.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`BaseEvaluationMethodConfig` used
eval_key: the evaluation key
backend (None): an :class:`EvaluationMethod` backend
"""

pass


class BaseClassificationResults(BaseEvaluationResults):
"""Base class for evaluation results that expose classification metrics
like P/R/F1 and confusion matrices.
Args:
samples: the :class:`fiftyone.core.collections.SampleCollection` used
config: the :class:`fiftyone.core.evaluation.EvaluationMethodConfig`
Expand All @@ -32,8 +139,7 @@ class BaseEvaluationResults(foe.EvaluationResults):
observed ground truth/predicted labels are used
missing (None): a missing label string. Any None-valued labels are
given this label for evaluation purposes
samples (None): the :class:`fiftyone.core.collections.SampleCollection`
for which the results were computed
custom_metrics (None): an optional dict of custom metrics
backend (None): a :class:`fiftyone.core.evaluation.EvaluationMethod`
backend
"""
Expand All @@ -51,6 +157,7 @@ def __init__(
ypred_ids=None,
classes=None,
missing=None,
custom_metrics=None,
backend=None,
):
super().__init__(samples, config, eval_key, backend=backend)
Expand All @@ -72,6 +179,7 @@ def __init__(
)
self.classes = np.asarray(classes)
self.missing = missing
self.custom_metrics = custom_metrics

def report(self, classes=None):
"""Generates a classification report for the results via
Expand Down
Loading

0 comments on commit 4fa72a0

Please sign in to comment.