Skip to content

Commit

Permalink
Add video path and frame indices to metrics (#1396)
Browse files Browse the repository at this point in the history
* Add `Instance`s and `PredictedInstance`s to metrics

* Add tests

* Add frame/video info to metrics, wip: test writing

* Fix metrics save test
  • Loading branch information
roomrys authored Jul 27, 2023
1 parent b2ad203 commit e94b516
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 19 deletions.
50 changes: 37 additions & 13 deletions sleap/nn/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
from typing import Any, Dict, List, Optional, Text, Tuple, Union
import logging
import sleap

from sleap import Labels, LabeledFrame, Instance, PredictedInstance
from sleap.nn.config import (
TrainingJobConfig,
Expand Down Expand Up @@ -484,28 +484,45 @@ def compute_generalized_voc_metrics(

def compute_dists(
positive_pairs: List[Tuple[Instance, PredictedInstance, Any]]
) -> np.ndarray:
) -> Dict[str, Union[np.ndarray, List[int], List[str]]]:
"""Compute Euclidean distances between matched pairs of instances.
Args:
positive_pairs: A list of tuples of the form `(instance_gt, instance_pr, _)`
containing the matched pair of instances.
Returns:
An array of pairwise distances of shape `(n_positive_pairs, n_nodes)`.
A dictionary with the following keys:
dists: An array of pairwise distances of shape `(n_positive_pairs, n_nodes)`
frame_idxs: A list of frame indices corresponding to the `dists`
video_paths: A list of video paths corresponding to the `dists`
"""
dists = []
frame_idxs = []
video_paths = []
for instance_gt, instance_pr, _ in positive_pairs:
points_gt = instance_gt.points_array
points_pr = instance_pr.points_array

dists.append(np.linalg.norm(points_pr - points_gt, axis=-1))
frame_idxs.append(instance_gt.frame.frame_idx)
video_paths.append(instance_gt.frame.video.backend.filename)

dists = np.array(dists)

return dists
# Bundle everything into a dictionary
dists_dict = {
"dists": dists,
"frame_idxs": frame_idxs,
"video_paths": video_paths,
}

return dists_dict


def compute_dist_metrics(dists: np.ndarray) -> Dict[Text, np.ndarray]:
def compute_dist_metrics(
dists_dict: Dict[str, Union[np.ndarray, List[Instance]]]
) -> Dict[Text, np.ndarray]:
"""Compute the Euclidean distance error at different percentiles.
Args:
Expand All @@ -514,7 +531,10 @@ def compute_dist_metrics(dists: np.ndarray) -> Dict[Text, np.ndarray]:
Returns:
A dictionary of distance metrics.
"""
dists = dists_dict["dists"]
results = {
"dist.frame_idxs": dists_dict["frame_idxs"],
"dist.video_paths": dists_dict["video_paths"],
"dist.dists": dists,
"dist.avg": np.nanmean(dists),
"dist.p50": np.nan,
Expand Down Expand Up @@ -636,11 +656,11 @@ def evaluate(
threshold=match_threshold,
user_labels_only=user_labels_only,
)
dists = compute_dists(positive_pairs)
dists_dict = compute_dists(positive_pairs)

metrics.update(compute_visibility_conf(positive_pairs))
metrics.update(compute_dist_metrics(dists))
metrics.update(compute_pck_metrics(dists))
metrics.update(compute_dist_metrics(dists_dict))
metrics.update(compute_pck_metrics(dists_dict["dists"]))

pair_oks = np.array([oks for _, _, oks in positive_pairs])
pair_pck = metrics["pck.pcks"].mean(axis=-1).mean(axis=-1)
Expand All @@ -662,7 +682,7 @@ def evaluate(

def evaluate_model(
cfg: TrainingJobConfig,
labels_reader: LabelsReader,
labels_gt: Union[LabelsReader, Labels],
model: Model,
save: bool = True,
split_name: Text = "test",
Expand All @@ -671,8 +691,8 @@ def evaluate_model(
Args:
cfg: The `TrainingJobConfig` associated with the model.
labels_reader: A `LabelsReader` pipeline generator that reads the ground truth
data to evaluate.
labels_gt: A `LabelsReader` pipeline generator that reads the ground truth
data to evaluate or a `Labels` object to be used as ground truth.
model: The `sleap.nn.model.Model` instance to evaluate.
save: If True, save the predictions and metrics to the model folder.
split_name: String name to append to the saved filenames.
Expand Down Expand Up @@ -721,11 +741,13 @@ def evaluate_model(
raise ValueError("Unrecognized model type:", head_config)

# Predict.
labels_pr = predictor.predict(labels_reader, make_labels=True)
labels_pr: Labels = predictor.predict(labels_gt, make_labels=True)

# Compute metrics.
try:
metrics = evaluate(labels_reader.labels, labels_pr)
if isinstance(labels_gt, LabelsReader):
labels_gt = labels_gt.labels
metrics = evaluate(labels_gt, labels_pr)
except:
logger.warning("Failed to compute metrics.")
metrics = None
Expand Down Expand Up @@ -776,6 +798,8 @@ def load_metrics(model_path: str, split: str = "val") -> Dict[str, Any]:
- `"dist.p95"`: Distance for 95th percentile
- `"dist.p99"`: Distance for 99th percentile
- `"dist.dists"`: All distances
- `"dist.frame_idxs"`: Frame indices corresponding to `"dist.dists"`
- `"dist.video_paths"`: Video paths corresponding to `"dist.dists"`
- `"pck.mPCK"`: Mean Percentage of Correct Keypoints (PCK)
- `"oks.mOKS"`: Mean Object Keypoint Similarity (OKS)
- `"oks_voc.mAP"`: VOC with OKS scores - mean Average Precision (mAP)
Expand Down
6 changes: 3 additions & 3 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,22 +962,22 @@ def evaluate(self):
logger.info("Saving evaluation metrics to model folder...")
sleap.nn.evals.evaluate_model(
cfg=self.config,
labels_reader=self.data_readers.training_labels_reader,
labels_gt=self.data_readers.training_labels_reader,
model=self.model,
save=True,
split_name="train",
)
sleap.nn.evals.evaluate_model(
cfg=self.config,
labels_reader=self.data_readers.validation_labels_reader,
labels_gt=self.data_readers.validation_labels_reader,
model=self.model,
save=True,
split_name="val",
)
if self.data_readers.test_labels_reader is not None:
sleap.nn.evals.evaluate_model(
cfg=self.config,
labels_reader=self.data_readers.test_labels_reader,
labels_gt=self.data_readers.test_labels_reader,
model=self.model,
save=True,
split_name="test",
Expand Down
10 changes: 8 additions & 2 deletions tests/fixtures/instances.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import pytest

from sleap.instance import Instance, Point, PredictedInstance
from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance


@pytest.fixture
def instances(skeleton):
def instances(skeleton, centered_pair_vid):

# Generate some instances
NUM_INSTANCES = 500

video = centered_pair_vid
instances = []
for i in range(NUM_INSTANCES):

instance = Instance(skeleton=skeleton)
instance["head"] = Point(i * 1, i * 2)
instance["left-wing"] = Point(10 + i * 1, 10 + i * 2)
Expand All @@ -19,6 +21,10 @@ def instances(skeleton):
# Lets make an NaN entry to test skip_nan as well
instance["thorax"]

# Add a LabeledFrame
labeled_frame = LabeledFrame(video=video, frame_idx=i, instances=[instance])
instance.frame = labeled_frame

instances.append(instance)

return instances
Expand Down
99 changes: 98 additions & 1 deletion tests/nn/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from pathlib import Path
import numpy as np
import tensorflow as tf

from typing import List, Tuple

import sleap
from sleap.nn.evals import load_metrics, compute_oks

from sleap import Instance, PredictedInstance
from sleap.instance import Point
from sleap.nn.config.training_job import TrainingJobConfig
from sleap.nn.data.providers import LabelsReader
from sleap.nn.evals import (
compute_dists,
compute_dist_metrics,
compute_oks,
load_metrics,
evaluate_model,
)
from sleap.nn.model import Model


sleap.use_cpu_only()
Expand Down Expand Up @@ -48,6 +65,86 @@ def test_compute_oks():
np.testing.assert_allclose(oks, 1)


def test_compute_dists(instances, predicted_instances):
# Make some changes to the instances
error_start = 10
error_end = 20
expected_dists = []
for offset, zipped_insts in enumerate(
zip(
instances[error_start:error_end], predicted_instances[error_start:error_end]
)
):

inst, pred_inst = zipped_insts
for node_name in inst.skeleton.node_names:
pred_point = pred_inst[node_name]
if pred_point != np.NaN:
inst[node_name] = Point(
pred_point.x + offset, pred_point.y + offset + 1
)

error = ((offset ** 2) + (offset + 1) ** 2) ** (1 / 2)
expected_dists.append(error)

best_match_oks = np.NaN
positive_pairs: List[Tuple[Instance, PredictedInstance]] = [
(inst, pred_inst, best_match_oks)
for inst, pred_inst in zip(instances, predicted_instances)
]

dists_dict = compute_dists(positive_pairs=positive_pairs)
dists = dists_dict["dists"]

# Replace nan to 0
dists_no_nan = np.nan_to_num(dists, nan=0)
np.testing.assert_allclose(dists_no_nan[0:10], 0)

# Replace nan to negative (which we never see in a norm)
dists_no_nan = np.nan_to_num(dists, nan=-1)

# Check distances are as expected
for idx, error in enumerate(expected_dists):
idx += error_start
dists_idx = dists_no_nan[idx]
dists_idx = dists_idx[dists_idx >= 0]
np.testing.assert_allclose(dists_idx, error)

# Check instances are as expected
dists_metric = compute_dist_metrics(dists_dict)
for idx, zipped_metrics in enumerate(
zip(dists_metric["dist.frame_idxs"], dists_metric["dist.video_paths"])
):
frame_idx, video_path = zipped_metrics
assert frame_idx == instances[idx].frame.frame_idx
assert video_path == instances[idx].frame.video.backend.filename


def test_evaluate_model(min_labels_slp, min_bottomup_model_path):

labels_reader = LabelsReader(labels=min_labels_slp, user_instances_only=True)
model_dir: str = min_bottomup_model_path
cfg = TrainingJobConfig.load_json(str(Path(model_dir, "training_config.json")))
model = Model.from_config(
config=cfg.model,
skeleton=labels_reader.labels.skeletons[0],
tracks=labels_reader.labels.tracks,
update_config=True,
)
model.keras_model = tf.keras.models.load_model(
Path(model_dir) / "best_model.h5", compile=False
)

labels_pr, metrics = evaluate_model(
cfg=cfg,
labels_gt=labels_reader,
model=model,
save=True,
split_name="test",
)
assert metrics is not None # If metrics is None, then the metrics were not saved


def test_load_metrics(min_centered_instance_model_path):
model_path = min_centered_instance_model_path

Expand Down

0 comments on commit e94b516

Please sign in to comment.