From ed789cb062685ca67480ab20346263e221f6a057 Mon Sep 17 00:00:00 2001 From: Manushree Gangwar Date: Wed, 4 Dec 2024 14:01:56 -0700 Subject: [PATCH] Evaluate regression with metric operators --- fiftyone/utils/eval/regression.py | 58 +++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/fiftyone/utils/eval/regression.py b/fiftyone/utils/eval/regression.py index 09663ab6af5..8e33090057c 100644 --- a/fiftyone/utils/eval/regression.py +++ b/fiftyone/utils/eval/regression.py @@ -34,6 +34,7 @@ def evaluate_regressions( pred_field, gt_field="ground_truth", eval_key=None, + eval_metrics=None, missing=None, method=None, progress=None, @@ -69,6 +70,8 @@ def evaluate_regressions( gt_field ("ground_truth"): the name of the field containing the ground truth :class:`fiftyone.core.labels.Regression` instances eval_key (None): a string key to use to refer to this evaluation + eval_metrics (None): a list of tuples of ``fiftyone.operators.Operator`` + for metric computation and the corresponding kwargs for the operator. missing (None): a missing value. Any None-valued regressions are given this value for results purposes method (None): a string specifying the evaluation method to use. The @@ -97,7 +100,11 @@ def evaluate_regressions( eval_method.register_samples(samples, eval_key) results = eval_method.evaluate_samples( - samples, eval_key=eval_key, missing=missing, progress=progress + samples, + eval_metrics=eval_metrics, + eval_key=eval_key, + missing=missing, + progress=progress, ) eval_method.save_run_results(samples, eval_key, results) @@ -155,7 +162,12 @@ def register_samples(self, samples, eval_key): dataset.add_sample_field(eval_key, fof.FloatField) def evaluate_samples( - self, samples, eval_key=None, missing=None, progress=None + self, + samples, + eval_metrics=None, + eval_key=None, + missing=None, + progress=None, ): """Evaluates the regression predictions in the given samples with respect to the specified ground truth values. @@ -254,10 +266,14 @@ class SimpleEvaluation(RegressionEvaluation): """ def evaluate_samples( - self, samples, eval_key=None, missing=None, progress=None + self, + samples, + eval_metrics=None, + eval_key=None, + missing=None, + progress=None, ): metric = self.config._metric - if metric == "squared_error": error_fcn = lambda yp, yt: (yp - yt) ** 2 elif metric == "absolute_error": @@ -293,12 +309,27 @@ def evaluate_samples( _confs = confs _ids = ids + # Metric operators. + agg_metric_ops = [] + if eval_metrics: + for metric_op, metric_kwargs in eval_metrics: + if not metric_op.is_aggregate: # check if per-sample/frame + try: + metric_op( + samples, ytrue, ypred, eval_key, **metric_kwargs + ) + except Exception as e: + print(e) + else: + agg_metric_ops.append((metric_op, metric_kwargs)) + results = RegressionResults( samples, self.config, eval_key, _ytrue, _ypred, + agg_metric_ops=agg_metric_ops, confs=_confs, ids=_ids, missing=missing, @@ -312,10 +343,8 @@ def compute_error(yp, yt): if missing is not None: if yp is None: yp = missing - if yt is None: yt = missing - try: return error_fcn(yp, yt) except: @@ -371,6 +400,7 @@ def __init__( eval_key, ytrue, ypred, + agg_metric_ops=None, confs=None, ids=None, missing=None, @@ -387,6 +417,7 @@ def __init__( self.confs = confs self.ids = ids self.missing = missing + self.agg_metrics_ops = agg_metric_ops def metrics(self, weights=None): """Computes various popular regression metrics for the results. @@ -431,7 +462,7 @@ def metrics(self, weights=None): max_error = 0.0 support = 0 - return { + results = { "mean_squared_error": mse, "root_mean_squared_error": rmse, "mean_absolute_error": mae, @@ -442,6 +473,19 @@ def metrics(self, weights=None): "support": support, } + if self.agg_metrics_ops: + for metric_op, metric_kwargs in self.agg_metrics_ops: + metric_val = metric_op(self.samples, **metric_kwargs) + sample_eval_key = metric_kwargs.get("sample_eval_key", None) + metric_name = ( + f"{metric_op.name}_{sample_eval_key}" + if sample_eval_key + else metric_op.name + ) + results[metric_name] = metric_val.result + + return results + def print_metrics(self, weights=None, digits=2): """Prints the regression metrics computed via :meth:`metrics`.