Skip to content

Commit

Permalink
Evaluate regression with metric operators
Browse files Browse the repository at this point in the history
  • Loading branch information
manushreegangwar committed Dec 6, 2024
1 parent 48c52c4 commit ed789cb
Showing 1 changed file with 51 additions and 7 deletions.
58 changes: 51 additions & 7 deletions fiftyone/utils/eval/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def evaluate_regressions(
pred_field,
gt_field="ground_truth",
eval_key=None,
eval_metrics=None,
missing=None,
method=None,
progress=None,
Expand Down Expand Up @@ -69,6 +70,8 @@ def evaluate_regressions(
gt_field ("ground_truth"): the name of the field containing the
ground truth :class:`fiftyone.core.labels.Regression` instances
eval_key (None): a string key to use to refer to this evaluation
eval_metrics (None): a list of tuples of ``fiftyone.operators.Operator``
for metric computation and the corresponding kwargs for the operator.
missing (None): a missing value. Any None-valued regressions are
given this value for results purposes
method (None): a string specifying the evaluation method to use. The
Expand Down Expand Up @@ -97,7 +100,11 @@ def evaluate_regressions(
eval_method.register_samples(samples, eval_key)

results = eval_method.evaluate_samples(
samples, eval_key=eval_key, missing=missing, progress=progress
samples,
eval_metrics=eval_metrics,
eval_key=eval_key,
missing=missing,
progress=progress,
)
eval_method.save_run_results(samples, eval_key, results)

Expand Down Expand Up @@ -155,7 +162,12 @@ def register_samples(self, samples, eval_key):
dataset.add_sample_field(eval_key, fof.FloatField)

def evaluate_samples(
self, samples, eval_key=None, missing=None, progress=None
self,
samples,
eval_metrics=None,
eval_key=None,
missing=None,
progress=None,
):
"""Evaluates the regression predictions in the given samples with
respect to the specified ground truth values.
Expand Down Expand Up @@ -254,10 +266,14 @@ class SimpleEvaluation(RegressionEvaluation):
"""

def evaluate_samples(
self, samples, eval_key=None, missing=None, progress=None
self,
samples,
eval_metrics=None,
eval_key=None,
missing=None,
progress=None,
):
metric = self.config._metric

if metric == "squared_error":
error_fcn = lambda yp, yt: (yp - yt) ** 2
elif metric == "absolute_error":
Expand Down Expand Up @@ -293,12 +309,27 @@ def evaluate_samples(
_confs = confs
_ids = ids

# Metric operators.
agg_metric_ops = []
if eval_metrics:
for metric_op, metric_kwargs in eval_metrics:
if not metric_op.is_aggregate: # check if per-sample/frame
try:
metric_op(
samples, ytrue, ypred, eval_key, **metric_kwargs
)
except Exception as e:
print(e)
else:
agg_metric_ops.append((metric_op, metric_kwargs))

results = RegressionResults(
samples,
self.config,
eval_key,
_ytrue,
_ypred,
agg_metric_ops=agg_metric_ops,
confs=_confs,
ids=_ids,
missing=missing,
Expand All @@ -312,10 +343,8 @@ def compute_error(yp, yt):
if missing is not None:
if yp is None:
yp = missing

if yt is None:
yt = missing

try:
return error_fcn(yp, yt)
except:
Expand Down Expand Up @@ -371,6 +400,7 @@ def __init__(
eval_key,
ytrue,
ypred,
agg_metric_ops=None,
confs=None,
ids=None,
missing=None,
Expand All @@ -387,6 +417,7 @@ def __init__(
self.confs = confs
self.ids = ids
self.missing = missing
self.agg_metrics_ops = agg_metric_ops

def metrics(self, weights=None):
"""Computes various popular regression metrics for the results.
Expand Down Expand Up @@ -431,7 +462,7 @@ def metrics(self, weights=None):
max_error = 0.0
support = 0

return {
results = {
"mean_squared_error": mse,
"root_mean_squared_error": rmse,
"mean_absolute_error": mae,
Expand All @@ -442,6 +473,19 @@ def metrics(self, weights=None):
"support": support,
}

if self.agg_metrics_ops:
for metric_op, metric_kwargs in self.agg_metrics_ops:
metric_val = metric_op(self.samples, **metric_kwargs)
sample_eval_key = metric_kwargs.get("sample_eval_key", None)
metric_name = (
f"{metric_op.name}_{sample_eval_key}"
if sample_eval_key
else metric_op.name
)
results[metric_name] = metric_val.result

return results

def print_metrics(self, weights=None, digits=2):
"""Prints the regression metrics computed via :meth:`metrics`.
Expand Down

0 comments on commit ed789cb

Please sign in to comment.