Skip to content

Commit

Permalink
Merge branch 'release/v1.2.0' of https://github.com/voxel51/fiftyone
Browse files Browse the repository at this point in the history
…into develop
  • Loading branch information
voxel51-bot committed Dec 18, 2024
2 parents 6f49543 + ae494a5 commit 4d83471
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,14 @@ export default function Evaluation(props: EvaluationProps) {
].join(" <br>") + "<extra></extra>",
},
]}
onClick={({ points }) => {
const firstPoint = points[0];
loadView("matrix", {
x: firstPoint.x,
y: firstPoint.y,
key: compareKey,
});
}}
layout={{
yaxis: {
autorange: "reversed",
Expand Down Expand Up @@ -1674,7 +1682,7 @@ function useActiveFilter(evaluation, compareEvaluation) {
const evalKey = evaluation?.info?.key;
const compareKey = compareEvaluation?.info?.key;
const [stages] = useRecoilState(view);
if (stages?.length === 1) {
if (stages?.length >= 1) {
const stage = stages[0];
const { _cls, kwargs } = stage;
if (_cls.endsWith("FilterLabels")) {
Expand Down
140 changes: 117 additions & 23 deletions fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,6 @@ def on_load(self, ctx):
ctx.panel.set_data("permissions", permissions)
self.load_pending_evaluations(ctx)

def is_binary_classification(self, info):
return (
info.config.type == "classification"
and info.config.method == "binary"
)

def get_avg_confidence(self, per_class_metrics):
count = 0
total = 0
Expand All @@ -115,7 +109,10 @@ def get_avg_confidence(self, per_class_metrics):

def get_tp_fp_fn(self, info, results):
# Binary classification
if self.is_binary_classification(info):
if (
info.config.type == "classification"
and info.config.method == "binary"
):
neg_label, pos_label = results.classes
tp_count = np.count_nonzero(
(results.ytrue == pos_label) & (results.ypred == pos_label)
Expand Down Expand Up @@ -341,6 +338,8 @@ def load_evaluation(self, ctx):
"per_class_metrics": per_class_metrics,
"mask_targets": mask_targets,
}
ctx.panel.set_state("missing", results.missing)

if ENABLE_CACHING:
# Cache the evaluation data
try:
Expand Down Expand Up @@ -435,32 +434,127 @@ def load_view(self, ctx):
return

view_state = ctx.panel.get_state("view") or {}
view_options = ctx.params.get("options", {})

eval_key = view_state.get("key")
eval_key = view_options.get("key", eval_key)
eval_view = ctx.dataset.load_evaluation_view(eval_key)
info = ctx.dataset.get_evaluation_info(eval_key)
pred_field = info.config.pred_field
gt_field = info.config.gt_field
view_options = ctx.params.get("options", {})

eval_key2 = view_state.get("compareKey", None)
pred_field2 = None
gt_field2 = None
if eval_key2:
info2 = ctx.dataset.get_evaluation_info(eval_key2)
pred_field2 = info2.config.pred_field
if info2.config.gt_field != gt_field:
gt_field2 = info2.config.gt_field

x = view_options.get("x", None)
y = view_options.get("y", None)
field = view_options.get("field", None)
computed_eval_key = view_options.get("key", eval_key)
missing = ctx.panel.get_state("missing", "(none)")

view = None
if view_type == "class":
view = ctx.dataset.filter_labels(pred_field, F("label") == x)
elif view_type == "matrix":
view = ctx.dataset.filter_labels(
gt_field, F("label") == y
).filter_labels(pred_field, F("label") == x)
elif view_type == "field":
if self.is_binary_classification(info):
uppercase_field = field.upper()
view = ctx.dataset.match(
{computed_eval_key: {"$eq": uppercase_field}}
if info.config.type == "classification":
if view_type == "class":
# All GT/predictions of class `x`
expr = F(f"{gt_field}.label") == x
expr |= F(f"{pred_field}.label") == x
if gt_field2 is not None:
expr |= F(f"{gt_field2}.label") == x
if pred_field2 is not None:
expr |= F(f"{pred_field2}.label") == x
view = eval_view.match(expr)
elif view_type == "matrix":
# Specific confusion matrix cell (including FP/FN)
expr = F(f"{gt_field}.label") == y
expr &= F(f"{pred_field}.label") == x
view = eval_view.match(expr)
elif view_type == "field":
if info.config.method == "binary":
# All TP/FP/FN
expr = F(f"{eval_key}") == field.upper()
view = eval_view.match(expr)
else:
# Correct/incorrect
expr = F(f"{eval_key}") == field
view = eval_view.match(expr)
elif info.config.type == "detection":
_, gt_root = ctx.dataset._get_label_field_path(gt_field)
_, pred_root = ctx.dataset._get_label_field_path(pred_field)
if gt_field2 is not None:
_, gt_root2 = ctx.dataset._get_label_field_path(gt_field2)
if pred_field2 is not None:
_, pred_root2 = ctx.dataset._get_label_field_path(pred_field2)

if view_type == "class":
# All GT/predictions of class `x`
view = eval_view.filter_labels(
gt_field, F("label") == x, only_matches=False
)
else:
view = ctx.dataset.filter_labels(
pred_field, F(computed_eval_key) == field
expr = F(gt_root).length() > 0
view = view.filter_labels(
pred_field, F("label") == x, only_matches=False
)
expr |= F(pred_root).length() > 0
if gt_field2 is not None:
view = view.filter_labels(
gt_field2, F("label") == x, only_matches=False
)
expr |= F(gt_root2).length() > 0
if pred_field2 is not None:
view = view.filter_labels(
pred_field2, F("label") == x, only_matches=False
)
expr |= F(pred_root2).length() > 0
view = view.match(expr)
elif view_type == "matrix":
if y == missing:
# False positives of class `x`
expr = (F("label") == x) & (F(eval_key) == "fp")
view = eval_view.filter_labels(
pred_field, expr, only_matches=True
)
elif x == missing:
# False negatives of class `y`
expr = (F("label") == y) & (F(eval_key) == "fn")
view = eval_view.filter_labels(
gt_field, expr, only_matches=True
)
else:
# All class `y` GT and class `x` predictions in same sample
view = eval_view.filter_labels(
gt_field, F("label") == y, only_matches=False
)
expr = F(gt_root).length() > 0
view = view.filter_labels(
pred_field, F("label") == x, only_matches=False
)
expr &= F(pred_root).length() > 0
view = view.match(expr)
elif view_type == "field":
if field == "tp":
# All true positives
view = eval_view.filter_labels(
gt_field, F(eval_key) == field, only_matches=False
)
view = view.filter_labels(
pred_field, F(eval_key) == field, only_matches=True
)
elif field == "fn":
# All false negatives
view = eval_view.filter_labels(
gt_field, F(eval_key) == field, only_matches=True
)
else:
# All false positives
view = eval_view.filter_labels(
pred_field, F(eval_key) == field, only_matches=True
)

if view is not None:
ctx.ops.set_view(view)

Expand Down

0 comments on commit 4d83471

Please sign in to comment.