diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml
index c730fa9c4..d130b9cb9 100644
--- a/sleap/config/pipeline_form.yaml
+++ b/sleap/config/pipeline_form.yaml
@@ -52,7 +52,7 @@ training:
This pipeline uses two models: a "centroid" model to
locate and crop around each animal in the frame, and a
"centered-instance confidence map" model for predicted node locations
- for each individual animal predicted by the centroid model.'
+ for each individual animal predicted by the centroid model.'
- label: Max Instances
name: max_instances
type: optional_int
@@ -217,7 +217,7 @@ training:
- name: controller_port
label: Controller Port
type: int
- default: 9000
+ default: 9000
range: 1024,65535
- name: publish_port
@@ -388,7 +388,7 @@ inference:
tracking-only:
- name: batch_size
- label: Batch Size
+ label: Batch Size
type: int
default: 4
range: 1,512
@@ -439,7 +439,7 @@ inference:
label: Similarity Method
type: list
default: instance
- options: instance,centroid,iou
+ options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
@@ -478,6 +478,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
+ - type: text
+ text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.'
+ - name: tracking.oks_errors
+ label: Keypoints errors in pixels
+ help: 'Standard error in pixels of the distance for each keypoint.
+ If the list is empty, defaults to 1. If singleton list, each keypoint has
+ the same error. Otherwise, the length should be the same as the number of
+ keypoints in the skeleton.'
+ type: string
+ default:
+ - name: tracking.oks_score_weighting
+ label: Use prediction score for weighting
+ help: 'Use prediction scores to weight the similarity of each keypoint'
+ type: bool
+ default: false
- type: text
text: 'Post-tracker data cleaning:'
- name: tracking.post_connect_single_breaks
@@ -521,8 +537,8 @@ inference:
- name: tracking.similarity
label: Similarity Method
type: list
- default: iou
- options: instance,centroid,iou
+ default: instance
+ options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
@@ -557,6 +573,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
+ - type: text
+ text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.'
+ - name: tracking.oks_errors
+ label: Keypoints errors in pixels
+ help: 'Standard error in pixels of the distance for each keypoint.
+ If the list is empty, defaults to 1. If singleton list, each keypoint has
+ the same error. Otherwise, the length should be the same as the number of
+ keypoints in the skeleton.'
+ type: string
+ default:
+ - name: tracking.oks_score_weighting
+ label: Use prediction score for weighting
+ help: 'Use prediction scores to weight the similarity of each keypoint'
+ type: bool
+ default: false
- type: text
text: 'Post-tracker data cleaning:'
- name: tracking.post_connect_single_breaks
diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml
index 8cf89728a..1440530fc 100644
--- a/sleap/config/suggestions.yaml
+++ b/sleap/config/suggestions.yaml
@@ -3,7 +3,7 @@ main:
label: Method
type: stacked
default: " "
- options: " ,image features,sample,prediction score,velocity,frame chunk"
+ options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement"
" ":
sample:
@@ -175,6 +175,13 @@ main:
type: double
default: 0.1
range: 0.1,1.0
+
+ "max point displacement":
+ - name: displacement_threshold
+ label: Maximum Displacement Threshold
+ type: int
+ default: 10
+ range: 0,999
- name: target
label: Target
diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py
index 1a64a071c..8df85fc8e 100644
--- a/sleap/gui/commands.py
+++ b/sleap/gui/commands.py
@@ -2913,6 +2913,8 @@ def create_new_instance(
copy_instance=copy_instance,
new_instance=new_instance,
mark_complete=mark_complete,
+ init_method=init_method,
+ location=location,
)
if has_missing_nodes:
@@ -2984,6 +2986,8 @@ def set_visible_nodes(
copy_instance: Optional[Union[Instance, PredictedInstance]],
new_instance: Instance,
mark_complete: bool,
+ init_method: str,
+ location: Optional[QtCore.QPoint] = None,
) -> bool:
"""Sets visible nodes for new instance.
@@ -3010,6 +3014,25 @@ def set_visible_nodes(
scale_width = new_size_width / old_size_width
scale_height = new_size_height / old_size_height
+ # Default the offset is 0
+ offset_x = 0
+ offset_y = 0
+
+ # Using the menu or the hotkey
+ if init_method == "best":
+ offset_x = 10
+ offset_y = 10
+
+ # Using right click and context menu
+ if location is not None:
+ reference_node = next(
+ (node for node in copy_instance if not node.isnan()), None
+ )
+ reference_x = reference_node.x
+ reference_y = reference_node.y
+ offset_x = location.x() - (reference_x * scale_width)
+ offset_y = location.y() - (reference_y * scale_height)
+
# Go through each node in skeleton.
for node in context.state["skeleton"].node_names:
# If we're copying from a skeleton that has this node.
@@ -3018,13 +3041,45 @@ def set_visible_nodes(
# We don't want to copy a PredictedPoint or score attribute.
x_old = copy_instance[node].x
y_old = copy_instance[node].y
- x_new = x_old * scale_width
- y_new = y_old * scale_height
+ # Copy the instance without scale or offset if predicted
+ if isinstance(copy_instance, PredictedInstance):
+ x_new = x_old
+ y_new = y_old
+ else:
+ x_new = x_old * scale_width
+ y_new = y_old * scale_height
+
+ # Apply offset if in bounds
+ x_new_offset = x_new + offset_x
+ y_new_offset = y_new + offset_y
+
+ # Default visibility is same as copied instance.
+ visible = copy_instance[node].visible
+
+ # If the node is offset to outside the frame, mark as not visible.
+ if x_new_offset < 0:
+ x_new = 0
+ visible = False
+ elif x_new_offset > new_size_width:
+ x_new = new_size_width
+ visible = False
+ else:
+ x_new = x_new_offset
+ if y_new_offset < 0:
+ y_new = 0
+ visible = False
+ elif y_new_offset > new_size_height:
+ y_new = new_size_height
+ visible = False
+ else:
+ y_new = y_new_offset
+
+ # Update the new instance with the new x, y, and visibility.
new_instance[node] = Point(
x=x_new,
y=y_new,
- visible=copy_instance[node].visible,
+ visible=visible,
complete=mark_complete,
)
else:
diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py
index 7569607a0..d0bb1f3ba 100644
--- a/sleap/gui/learning/runners.py
+++ b/sleap/gui/learning/runners.py
@@ -260,12 +260,20 @@ def make_predict_cli_call(
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
+ "tracking.oks_score_weighting",
)
for key in bool_items_as_ints:
if key in self.inference_params:
self.inference_params[key] = int(self.inference_params[key])
+ remove_spaces_items = ("tracking.similarity",)
+
+ for key in remove_spaces_items:
+ if key in self.inference_params:
+ value = self.inference_params[key]
+ self.inference_params[key] = value.replace(" ", "_")
+
for key, val in self.inference_params.items():
if not key.startswith(("_", "outputs.", "model.", "data.")):
cli_args.extend((f"--{key}", str(val)))
diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py
index 48b916437..b85d6ac32 100644
--- a/sleap/gui/suggestions.py
+++ b/sleap/gui/suggestions.py
@@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame]
prediction_score=cls.prediction_score,
velocity=cls.velocity,
frame_chunk=cls.frame_chunk,
+ max_point_displacement=cls.max_point_displacement,
)
method = str.replace(params["method"], " ", "_")
@@ -213,6 +214,7 @@ def _prediction_score_video(
):
lfs = labels.find(video)
frames = len(lfs)
+
# initiate an array filled with -1 to store frame index (starting from 0).
idxs = np.full((frames), -1, dtype="int")
@@ -291,6 +293,56 @@ def _velocity_video(
return cls.idx_list_to_frame_list(frame_idxs, video)
+ @classmethod
+ def max_point_displacement(
+ cls,
+ labels: "Labels",
+ videos: List[Video],
+ displacement_threshold: float,
+ **kwargs,
+ ):
+ """Finds frames with maximum point displacement above a threshold."""
+
+ proposed_suggestions = []
+ for video in videos:
+ proposed_suggestions.extend(
+ cls._max_point_displacement_video(video, labels, displacement_threshold)
+ )
+
+ suggestions = VideoFrameSuggestions.filter_unique_suggestions(
+ labels, videos, proposed_suggestions
+ )
+
+ return suggestions
+
+ @classmethod
+ def _max_point_displacement_video(
+ cls, video: Video, labels: "Labels", displacement_threshold: float
+ ):
+ # Get numpy of shape (frames, tracks, nodes, x, y)
+ labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False)
+
+ # Return empty list if not enough frames
+ n_frames, n_tracks, n_nodes, _ = labels_numpy.shape
+
+ if n_frames < 2:
+ return []
+
+ # Calculate displacements
+ diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y)
+ euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes)
+ mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks)
+
+ # Find frames where mean displacement is above threshold
+ threshold_mask = np.any(
+ mean_euc_norm > displacement_threshold, axis=-1
+ ) # (frames - 1,)
+ frame_idxs = list(
+ np.argwhere(threshold_mask).flatten() + 1
+ ) # [0, len(frames - 1)]
+
+ return cls.idx_list_to_frame_list(frame_idxs, video)
+
@classmethod
def frame_chunk(
cls,
diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py
index 502ea388e..745908048 100644
--- a/sleap/gui/widgets/video.py
+++ b/sleap/gui/widgets/video.py
@@ -367,7 +367,10 @@ def show_contextual_menu(self, where: QtCore.QPoint):
menu.addAction("Add Instance:").setEnabled(False)
- menu.addAction("Default", lambda: self.context.newInstance(init_method="best"))
+ menu.addAction(
+ "Default",
+ lambda: self.context.newInstance(init_method="best", location=scene_pos),
+ )
menu.addAction(
"Average",
diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py
index 10b2953b7..b2f35b21f 100644
--- a/sleap/nn/tracker/components.py
+++ b/sleap/nn/tracker/components.py
@@ -14,7 +14,8 @@
"""
import operator
from collections import defaultdict
-from typing import List, Tuple, Optional, TypeVar, Callable
+import logging
+from typing import List, Tuple, Union, Optional, TypeVar, Callable
import attr
import numpy as np
@@ -23,6 +24,8 @@
from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils
+logger = logging.getLogger(__name__)
+
InstanceType = TypeVar("InstanceType", Instance, PredictedInstance)
@@ -40,6 +43,95 @@ def instance_similarity(
return similarity
+def factory_object_keypoint_similarity(
+ keypoint_errors: Optional[Union[List, int, float]] = None,
+ score_weighting: bool = False,
+ normalization_keypoints: str = "all",
+) -> Callable:
+ """Factory for similarity function based on object keypoints.
+
+ Args:
+ keypoint_errors: The standard error of the distance between the predicted
+ keypoint and the true value, in pixels.
+ If None or empty list, defaults to 1.
+ If a scalar or singleton list, every keypoint has the same error.
+ If a list, defines the error for each keypoint, the length should be equal
+ to the number of keypoints in the skeleton.
+ score_weighting: If True, use `score` of `PredictedPoint` to weigh
+ `keypoint_errors`. If False, do not add a weight to `keypoint_errors`.
+ normalization_keypoints: Determine how to normalize similarity score. One of
+ ["all", "ref", "union"]. If "all", similarity score is normalized by number
+ of reference points. If "ref", similarity score is normalized by number of
+ visible reference points. If "union", similarity score is normalized by
+ number of points both visible in query and reference instance.
+ Default is "all".
+
+ Returns:
+ Callable that returns object keypoint similarity between two `Instance`s.
+
+ """
+ keypoint_errors = 1 if keypoint_errors is None else keypoint_errors
+ with np.errstate(divide="ignore"):
+ kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2)
+
+ def object_keypoint_similarity(
+ ref_instance: InstanceType, query_instance: InstanceType
+ ) -> float:
+ nonlocal kp_precision
+ # Keypoints
+ ref_points = ref_instance.points_array
+ query_points = query_instance.points_array
+ # Keypoint scores
+ if score_weighting:
+ ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points)))
+ query_scores = getattr(query_instance, "scores", np.ones(len(query_points)))
+ else:
+ ref_scores = 1
+ query_scores = 1
+ # Number of keypoint for normalization
+ if normalization_keypoints in ("ref", "union"):
+ ref_visible = ~(np.isnan(ref_points).any(axis=1))
+ if normalization_keypoints == "ref":
+ max_n_keypoints = np.sum(ref_visible)
+ elif normalization_keypoints == "union":
+ query_visible = ~(np.isnan(query_points).any(axis=1))
+ max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible))
+ else: # if normalization_keypoints == "all":
+ max_n_keypoints = len(ref_points)
+ if max_n_keypoints == 0:
+ return 0
+
+ # Make sure the sizes of kp_precision and n_points match
+ if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size:
+ # Correct kp_precision size to fit number of points
+ n_points = ref_points.size // 2
+ mess = (
+ "keypoint_errors array should have the same size as the number of "
+ f"keypoints in the instance: {kp_precision.size} != {n_points}"
+ )
+
+ if kp_precision.size > n_points:
+ kp_precision = kp_precision[:n_points]
+ mess += "\nTruncating keypoint_errors array."
+
+ else: # elif kp_precision.size < n_points:
+ pad = n_points - kp_precision.size
+ kp_precision = np.pad(kp_precision, (0, pad), "edge")
+ mess += "\nPadding keypoint_errors array by repeating the last value."
+ logger.warning(mess)
+
+ # Compute distances
+ dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision
+
+ similarity = (
+ np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints
+ )
+
+ return similarity
+
+ return object_keypoint_similarity
+
+
def centroid_distance(
ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict()
) -> float:
diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py
index 9865b7db5..2b02839de 100644
--- a/sleap/nn/tracking.py
+++ b/sleap/nn/tracking.py
@@ -10,6 +10,7 @@
from sleap import Track, LabeledFrame, Skeleton
from sleap.nn.tracker.components import (
+ factory_object_keypoint_similarity,
instance_similarity,
centroid_distance,
instance_iou,
@@ -391,6 +392,7 @@ def get_ref_instances(
def get_candidates(
self,
track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]],
+ max_tracking: bool,
t: int,
img: np.ndarray,
*args,
@@ -404,7 +406,7 @@ def get_candidates(
tracks = []
for track, matched_items in track_matching_queue_dict.items():
- if len(tracks) <= self.max_tracks:
+ if not max_tracking or len(tracks) < self.max_tracks:
tracks.append(track)
for matched_item in matched_items:
ref_t, ref_img = (
@@ -466,6 +468,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker):
def get_candidates(
self,
track_matching_queue_dict: Dict,
+ max_tracking: bool,
*args,
**kwargs,
) -> List[InstanceType]:
@@ -473,7 +476,7 @@ def get_candidates(
candidate_instances = []
tracks = []
for track, matched_instances in track_matching_queue_dict.items():
- if len(tracks) <= self.max_tracks:
+ if not max_tracking or len(tracks) < self.max_tracks:
tracks.append(track)
for ref_instance in matched_instances:
if ref_instance.instance_t.n_visible_points >= self.min_points:
@@ -492,6 +495,7 @@ def get_candidates(
instance=instance_similarity,
centroid=centroid_distance,
iou=instance_iou,
+ object_keypoint=instance_similarity,
)
match_policies = dict(
@@ -598,8 +602,15 @@ def _init_matching_queue(self):
"""Factory for instantiating default matching queue with specified size."""
return deque(maxlen=self.track_window)
+ @property
+ def has_max_tracking(self) -> bool:
+ return isinstance(
+ self.candidate_maker,
+ (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker),
+ )
+
def reset_candidates(self):
- if self.max_tracking:
+ if self.has_max_tracking:
for track in self.track_matching_queue_dict:
self.track_matching_queue_dict[track] = deque(maxlen=self.track_window)
else:
@@ -610,14 +621,15 @@ def unique_tracks_in_queue(self) -> List[Track]:
"""Returns the unique tracks in the matching queue."""
unique_tracks = set()
- for match_item in self.track_matching_queue:
- for instance in match_item.instances_t:
- unique_tracks.add(instance.track)
-
- if self.max_tracking:
+ if self.has_max_tracking:
for track in self.track_matching_queue_dict.keys():
unique_tracks.add(track)
+ else:
+ for match_item in self.track_matching_queue:
+ for instance in match_item.instances_t:
+ unique_tracks.add(instance.track)
+
return list(unique_tracks)
@property
@@ -646,7 +658,7 @@ def track(
# Infer timestep if not provided.
if t is None:
- if self.max_tracking:
+ if self.has_max_tracking:
if len(self.track_matching_queue_dict) > 0:
# Default to last timestep + 1 if available.
@@ -684,10 +696,10 @@ def track(
self.pre_cull_function(untracked_instances)
# Build a pool of matchable candidate instances.
- if self.max_tracking:
+ if self.has_max_tracking:
candidate_instances = self.candidate_maker.get_candidates(
track_matching_queue_dict=self.track_matching_queue_dict,
- max_tracks=self.max_tracks,
+ max_tracking=self.max_tracking,
t=t,
img=img,
)
@@ -721,13 +733,16 @@ def track(
)
# Add the tracked instances to the dictionary of matched instances.
- if self.max_tracking:
+ if self.has_max_tracking:
for tracked_instance in tracked_instances:
if tracked_instance.track in self.track_matching_queue_dict:
self.track_matching_queue_dict[tracked_instance.track].append(
MatchedFrameInstance(t, tracked_instance, img)
)
- elif len(self.track_matching_queue_dict) < self.max_tracks:
+ elif (
+ not self.max_tracking
+ or len(self.track_matching_queue_dict) < self.max_tracks
+ ):
self.track_matching_queue_dict[tracked_instance.track] = deque(
maxlen=self.track_window
)
@@ -773,7 +788,8 @@ def spawn_for_untracked_instances(
# Skip if we've reached the maximum number of tracks.
if (
- self.max_tracking
+ self.has_max_tracking
+ and self.max_tracking
and len(self.track_matching_queue_dict) >= self.max_tracks
):
break
@@ -838,8 +854,17 @@ def make_tracker_by_name(
# Max tracking options
max_tracks: Optional[int] = None,
max_tracking: bool = False,
+ # Object keypoint similarity options
+ oks_errors: Optional[list] = None,
+ oks_score_weighting: bool = False,
+ oks_normalization: str = "all",
**kwargs,
) -> BaseTracker:
+ # Parse max_tracking arguments, only True if max_tracks is not None and > 0
+ max_tracking = max_tracking if max_tracks else False
+ if max_tracking and tracker in ("simple", "flow"):
+ # Force a candidate maker of 'maxtracks' type
+ tracker += "maxtracks"
if tracker.lower() == "none":
candidate_maker = None
@@ -858,7 +883,14 @@ def make_tracker_by_name(
raise ValueError(f"{match} is not a valid tracker matching function.")
candidate_maker = tracker_policies[tracker](min_points=min_match_points)
- similarity_function = similarity_policies[similarity]
+ if similarity == "object_keypoint":
+ similarity_function = factory_object_keypoint_similarity(
+ keypoint_errors=oks_errors,
+ score_weighting=oks_score_weighting,
+ normalization_keypoints=oks_normalization,
+ )
+ else:
+ similarity_function = similarity_policies[similarity]
matching_function = match_policies[match]
if tracker == "flow":
@@ -931,7 +963,10 @@ def get_by_name_factory_options(cls):
option = dict(name="max_tracking", default=False)
option["type"] = bool
- option["help"] = "If true then the tracker will cap the max number of tracks."
+ option["help"] = (
+ "If true then the tracker will cap the max number of tracks. "
+ "Falls back to false if `max_tracks` is not defined or 0."
+ )
options.append(option)
option = dict(name="max_tracks", default=None)
@@ -1054,6 +1089,42 @@ def int_list_func(s):
] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used."
options.append(option)
+ def float_list_func(s):
+ return [float(x.strip()) for x in s.split(",")] if s else None
+
+ option = dict(name="oks_errors", default="1")
+ option["type"] = float_list_func
+ option["help"] = (
+ "For Object Keypoint similarity: the standard error of the distance "
+ "between the predicted keypoint and the true value, in pixels.\n"
+ "If None or empty list, defaults to 1. If a scalar or singleton list, "
+ "every keypoint has the same error. If a list, defines the error for each "
+ "keypoint, the length should be equal to the number of keypoints in the "
+ "skeleton."
+ )
+ options.append(option)
+
+ option = dict(name="oks_score_weighting", default="0")
+ option["type"] = int
+ option["help"] = (
+ "For Object Keypoint similarity: if 0 (default), only the distance between the reference "
+ "and query keypoint is used to compute the similarity. If 1, each distance is weighted "
+ "by the prediction scores of the reference and query keypoint."
+ )
+ options.append(option)
+
+ option = dict(name="oks_normalization", default="all")
+ option["type"] = str
+ option["options"] = ["all", "ref", "union"]
+ option["help"] = (
+ "For Object Keypoint similarity: Determine how to normalize similarity score. "
+ "If 'all', similarity score is normalized by number of reference points. "
+ "If 'ref', similarity score is normalized by number of visible reference points. "
+ "If 'union', similarity score is normalized by number of points both visible "
+ "in query and reference instance."
+ )
+ options.append(option)
+
return options
@classmethod
diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py
index 801fcc092..ec5dfbc29 100644
--- a/tests/fixtures/datasets.py
+++ b/tests/fixtures/datasets.py
@@ -41,6 +41,13 @@ def centered_pair_predictions():
return Labels.load_file(TEST_JSON_PREDICTIONS)
+@pytest.fixture
+def centered_pair_predictions_sorted(centered_pair_predictions):
+ labels: Labels = centered_pair_predictions
+ labels.labeled_frames.sort(key=lambda lf: lf.frame_idx)
+ return labels
+
+
@pytest.fixture
def min_labels():
return Labels.load_file(TEST_JSON_MIN_LABELS)
diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py
index bbad73179..196ff2d35 100644
--- a/tests/gui/test_suggestions.py
+++ b/tests/gui/test_suggestions.py
@@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions):
assert suggestions[1].frame_idx == 45
+def test_max_point_displacement_suggestions(centered_pair_predictions):
+ suggestions = VideoFrameSuggestions.suggest(
+ labels=centered_pair_predictions,
+ params=dict(
+ videos=centered_pair_predictions.videos,
+ method="max_point_displacement",
+ displacement_threshold=6,
+ ),
+ )
+ assert len(suggestions) == 19
+ assert suggestions[0].frame_idx == 28
+ assert suggestions[1].frame_idx == 82
+
+
def test_frame_increment(centered_pair_predictions: Labels):
# Testing videos that have less frames than desired Samples per Video (stride)
# Expected result is there should be n suggestions where n is equal to the frames
diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py
index d13180591..fd615ea81 100644
--- a/tests/nn/test_inference.py
+++ b/tests/nn/test_inference.py
@@ -1375,7 +1375,7 @@ def test_retracking(
# Create sleap-track command
cmd = (
f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 "
- "--cpu"
+ "--tracking.similarity object_keypoint --cpu"
)
if tracker_method == "flow":
cmd += " --tracking.save_shifted_instances 1"
@@ -1395,6 +1395,8 @@ def test_retracking(
parser = _make_cli_parser()
args, _ = parser.parse_known_args(args=args)
tracker = _make_tracker_from_cli(args)
+ # Additional check for similarity method
+ assert tracker.similarity_function.__name__ == "object_keypoint_similarity"
output_path = f"{slp_path}.{tracker.get_name()}.slp"
# Assert tracked predictions file exists
@@ -1909,9 +1911,9 @@ def test_sleap_track_text_file_input(
assert Path(expected_output_file).exists()
-def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
+def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir):
"""Test flow tracker instances are pruned."""
- labels: Labels = centered_pair_predictions
+ labels: Labels = centered_pair_predictions_sorted
track_window = 5
# Setup tracker
@@ -1921,7 +1923,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker)
# Run tracking
- frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx)
+ frames = labels.labeled_frames
# Run tracking on subset of frames using psuedo-implementation of
# sleap.nn.tracking.run_tracker
diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py
index f861241ee..5786945fb 100644
--- a/tests/nn/test_tracker_components.py
+++ b/tests/nn/test_tracker_components.py
@@ -9,23 +9,79 @@
FrameMatches,
greedy_matching,
)
+from sleap.io.dataset import Labels
from sleap.instance import PredictedInstance
from sleap.skeleton import Skeleton
+def tracker_by_name(frames=None, **kwargs):
+ t = Tracker.make_tracker_by_name(**kwargs)
+ print(kwargs)
+ print(t.candidate_maker)
+ if frames is None:
+ t.track([])
+ t.final_pass([])
+ return
+
+ for lf in frames:
+ # Clear the tracks
+ for inst in lf.instances:
+ inst.track = None
+
+ track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx])
+ t.track(**track_args)
+ t.final_pass(frames)
+
+
@pytest.mark.parametrize(
"tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
)
@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"])
@pytest.mark.parametrize("match", ["greedy", "hungarian"])
@pytest.mark.parametrize("count", [0, 2])
-def test_tracker_by_name(tracker, similarity, match, count):
- t = Tracker.make_tracker_by_name(
- "flow", "instance", "greedy", clean_instance_count=2
+def test_tracker_by_name(
+ centered_pair_predictions_sorted,
+ tracker,
+ similarity,
+ match,
+ count,
+):
+ # This is slow, so limit to 5 time points
+ frames = centered_pair_predictions_sorted[:5]
+
+ tracker_by_name(
+ frames=frames,
+ tracker=tracker,
+ similarity=similarity,
+ match=match,
+ max_tracks=count,
+ )
+
+
+@pytest.mark.parametrize(
+ "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"]
+)
+@pytest.mark.parametrize("oks_score_weighting", ["True", "False"])
+@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"])
+def test_oks_tracker_by_name(
+ centered_pair_predictions_sorted,
+ tracker,
+ oks_score_weighting,
+ oks_normalization,
+):
+ # This is slow, so limit to 5 time points
+ frames = centered_pair_predictions_sorted[:5]
+
+ tracker_by_name(
+ frames=frames,
+ tracker=tracker,
+ similarity="object_keypoint",
+ matching="greedy",
+ oks_score_weighting=oks_score_weighting,
+ oks_normalization=oks_normalization,
+ max_tracks=2,
)
- t.track([])
- t.final_pass([])
def test_cull_instances(centered_pair_predictions):