From 38a5ca785d02bae74e09b4102635d6711f096c46 Mon Sep 17 00:00:00 2001 From: getzze Date: Tue, 23 Jul 2024 17:48:27 +0100 Subject: [PATCH 1/3] Add object keypoint similarity method (#1003) * Add object keypoint similarity method * fix max_tracking * correct off-by-one error * correct off-by-one error --- sleap/config/pipeline_form.yaml | 44 ++++++++++-- sleap/gui/learning/runners.py | 8 +++ sleap/nn/tracker/components.py | 94 ++++++++++++++++++++++++- sleap/nn/tracking.py | 103 +++++++++++++++++++++++----- tests/fixtures/datasets.py | 7 ++ tests/nn/test_inference.py | 10 +-- tests/nn/test_tracker_components.py | 66 ++++++++++++++++-- 7 files changed, 300 insertions(+), 32 deletions(-) diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c730fa9c4..d130b9cb9 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -52,7 +52,7 @@ training: This pipeline uses two models: a "centroid" model to locate and crop around each animal in the frame, and a "centered-instance confidence map" model for predicted node locations - for each individual animal predicted by the centroid model.' + for each individual animal predicted by the centroid model.' - label: Max Instances name: max_instances type: optional_int @@ -217,7 +217,7 @@ training: - name: controller_port label: Controller Port type: int - default: 9000 + default: 9000 range: 1024,65535 - name: publish_port @@ -388,7 +388,7 @@ inference: tracking-only: - name: batch_size - label: Batch Size + label: Batch Size type: int default: 4 range: 1,512 @@ -439,7 +439,7 @@ inference: label: Similarity Method type: list default: instance - options: instance,centroid,iou + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -478,6 +478,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks @@ -521,8 +537,8 @@ inference: - name: tracking.similarity label: Similarity Method type: list - default: iou - options: instance,centroid,iou + default: instance + options: "instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -557,6 +573,22 @@ inference: label: Nodes to use for Tracking type: string default: 0,1,2 + - type: text + text: 'Object keypoint similarity options:
+ Only used if this similarity method is selected.' + - name: tracking.oks_errors + label: Keypoints errors in pixels + help: 'Standard error in pixels of the distance for each keypoint. + If the list is empty, defaults to 1. If singleton list, each keypoint has + the same error. Otherwise, the length should be the same as the number of + keypoints in the skeleton.' + type: string + default: + - name: tracking.oks_score_weighting + label: Use prediction score for weighting + help: 'Use prediction scores to weight the similarity of each keypoint' + type: bool + default: false - type: text text: 'Post-tracker data cleaning:' - name: tracking.post_connect_single_breaks diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 7569607a0..d0bb1f3ba 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -260,12 +260,20 @@ def make_predict_cli_call( "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", + "tracking.oks_score_weighting", ) for key in bool_items_as_ints: if key in self.inference_params: self.inference_params[key] = int(self.inference_params[key]) + remove_spaces_items = ("tracking.similarity",) + + for key in remove_spaces_items: + if key in self.inference_params: + value = self.inference_params[key] + self.inference_params[key] = value.replace(" ", "_") + for key, val in self.inference_params.items(): if not key.startswith(("_", "outputs.", "model.", "data.")): cli_args.extend((f"--{key}", str(val))) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 10b2953b7..b2f35b21f 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -14,7 +14,8 @@ """ import operator from collections import defaultdict -from typing import List, Tuple, Optional, TypeVar, Callable +import logging +from typing import List, Tuple, Union, Optional, TypeVar, Callable import attr import numpy as np @@ -23,6 +24,8 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +logger = logging.getLogger(__name__) + InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) @@ -40,6 +43,95 @@ def instance_similarity( return similarity +def factory_object_keypoint_similarity( + keypoint_errors: Optional[Union[List, int, float]] = None, + score_weighting: bool = False, + normalization_keypoints: str = "all", +) -> Callable: + """Factory for similarity function based on object keypoints. + + Args: + keypoint_errors: The standard error of the distance between the predicted + keypoint and the true value, in pixels. + If None or empty list, defaults to 1. + If a scalar or singleton list, every keypoint has the same error. + If a list, defines the error for each keypoint, the length should be equal + to the number of keypoints in the skeleton. + score_weighting: If True, use `score` of `PredictedPoint` to weigh + `keypoint_errors`. If False, do not add a weight to `keypoint_errors`. + normalization_keypoints: Determine how to normalize similarity score. One of + ["all", "ref", "union"]. If "all", similarity score is normalized by number + of reference points. If "ref", similarity score is normalized by number of + visible reference points. If "union", similarity score is normalized by + number of points both visible in query and reference instance. + Default is "all". + + Returns: + Callable that returns object keypoint similarity between two `Instance`s. + + """ + keypoint_errors = 1 if keypoint_errors is None else keypoint_errors + with np.errstate(divide="ignore"): + kp_precision = 1 / (2 * np.array(keypoint_errors) ** 2) + + def object_keypoint_similarity( + ref_instance: InstanceType, query_instance: InstanceType + ) -> float: + nonlocal kp_precision + # Keypoints + ref_points = ref_instance.points_array + query_points = query_instance.points_array + # Keypoint scores + if score_weighting: + ref_scores = getattr(ref_instance, "scores", np.ones(len(ref_points))) + query_scores = getattr(query_instance, "scores", np.ones(len(query_points))) + else: + ref_scores = 1 + query_scores = 1 + # Number of keypoint for normalization + if normalization_keypoints in ("ref", "union"): + ref_visible = ~(np.isnan(ref_points).any(axis=1)) + if normalization_keypoints == "ref": + max_n_keypoints = np.sum(ref_visible) + elif normalization_keypoints == "union": + query_visible = ~(np.isnan(query_points).any(axis=1)) + max_n_keypoints = np.sum(np.logical_and(ref_visible, query_visible)) + else: # if normalization_keypoints == "all": + max_n_keypoints = len(ref_points) + if max_n_keypoints == 0: + return 0 + + # Make sure the sizes of kp_precision and n_points match + if kp_precision.size > 1 and 2 * kp_precision.size != ref_points.size: + # Correct kp_precision size to fit number of points + n_points = ref_points.size // 2 + mess = ( + "keypoint_errors array should have the same size as the number of " + f"keypoints in the instance: {kp_precision.size} != {n_points}" + ) + + if kp_precision.size > n_points: + kp_precision = kp_precision[:n_points] + mess += "\nTruncating keypoint_errors array." + + else: # elif kp_precision.size < n_points: + pad = n_points - kp_precision.size + kp_precision = np.pad(kp_precision, (0, pad), "edge") + mess += "\nPadding keypoint_errors array by repeating the last value." + logger.warning(mess) + + # Compute distances + dists = np.sum((query_points - ref_points) ** 2, axis=1) * kp_precision + + similarity = ( + np.nansum(ref_scores * query_scores * np.exp(-dists)) / max_n_keypoints + ) + + return similarity + + return object_keypoint_similarity + + def centroid_distance( ref_instance: InstanceType, query_instance: InstanceType, cache: dict = dict() ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9865b7db5..2b02839de 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -10,6 +10,7 @@ from sleap import Track, LabeledFrame, Skeleton from sleap.nn.tracker.components import ( + factory_object_keypoint_similarity, instance_similarity, centroid_distance, instance_iou, @@ -391,6 +392,7 @@ def get_ref_instances( def get_candidates( self, track_matching_queue_dict: Dict[Track, Deque[MatchedFrameInstance]], + max_tracking: bool, t: int, img: np.ndarray, *args, @@ -404,7 +406,7 @@ def get_candidates( tracks = [] for track, matched_items in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for matched_item in matched_items: ref_t, ref_img = ( @@ -466,6 +468,7 @@ class SimpleMaxTracksCandidateMaker(SimpleCandidateMaker): def get_candidates( self, track_matching_queue_dict: Dict, + max_tracking: bool, *args, **kwargs, ) -> List[InstanceType]: @@ -473,7 +476,7 @@ def get_candidates( candidate_instances = [] tracks = [] for track, matched_instances in track_matching_queue_dict.items(): - if len(tracks) <= self.max_tracks: + if not max_tracking or len(tracks) < self.max_tracks: tracks.append(track) for ref_instance in matched_instances: if ref_instance.instance_t.n_visible_points >= self.min_points: @@ -492,6 +495,7 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, + object_keypoint=instance_similarity, ) match_policies = dict( @@ -598,8 +602,15 @@ def _init_matching_queue(self): """Factory for instantiating default matching queue with specified size.""" return deque(maxlen=self.track_window) + @property + def has_max_tracking(self) -> bool: + return isinstance( + self.candidate_maker, + (SimpleMaxTracksCandidateMaker, FlowMaxTracksCandidateMaker), + ) + def reset_candidates(self): - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict: self.track_matching_queue_dict[track] = deque(maxlen=self.track_window) else: @@ -610,14 +621,15 @@ def unique_tracks_in_queue(self) -> List[Track]: """Returns the unique tracks in the matching queue.""" unique_tracks = set() - for match_item in self.track_matching_queue: - for instance in match_item.instances_t: - unique_tracks.add(instance.track) - - if self.max_tracking: + if self.has_max_tracking: for track in self.track_matching_queue_dict.keys(): unique_tracks.add(track) + else: + for match_item in self.track_matching_queue: + for instance in match_item.instances_t: + unique_tracks.add(instance.track) + return list(unique_tracks) @property @@ -646,7 +658,7 @@ def track( # Infer timestep if not provided. if t is None: - if self.max_tracking: + if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: # Default to last timestep + 1 if available. @@ -684,10 +696,10 @@ def track( self.pre_cull_function(untracked_instances) # Build a pool of matchable candidate instances. - if self.max_tracking: + if self.has_max_tracking: candidate_instances = self.candidate_maker.get_candidates( track_matching_queue_dict=self.track_matching_queue_dict, - max_tracks=self.max_tracks, + max_tracking=self.max_tracking, t=t, img=img, ) @@ -721,13 +733,16 @@ def track( ) # Add the tracked instances to the dictionary of matched instances. - if self.max_tracking: + if self.has_max_tracking: for tracked_instance in tracked_instances: if tracked_instance.track in self.track_matching_queue_dict: self.track_matching_queue_dict[tracked_instance.track].append( MatchedFrameInstance(t, tracked_instance, img) ) - elif len(self.track_matching_queue_dict) < self.max_tracks: + elif ( + not self.max_tracking + or len(self.track_matching_queue_dict) < self.max_tracks + ): self.track_matching_queue_dict[tracked_instance.track] = deque( maxlen=self.track_window ) @@ -773,7 +788,8 @@ def spawn_for_untracked_instances( # Skip if we've reached the maximum number of tracks. if ( - self.max_tracking + self.has_max_tracking + and self.max_tracking and len(self.track_matching_queue_dict) >= self.max_tracks ): break @@ -838,8 +854,17 @@ def make_tracker_by_name( # Max tracking options max_tracks: Optional[int] = None, max_tracking: bool = False, + # Object keypoint similarity options + oks_errors: Optional[list] = None, + oks_score_weighting: bool = False, + oks_normalization: str = "all", **kwargs, ) -> BaseTracker: + # Parse max_tracking arguments, only True if max_tracks is not None and > 0 + max_tracking = max_tracking if max_tracks else False + if max_tracking and tracker in ("simple", "flow"): + # Force a candidate maker of 'maxtracks' type + tracker += "maxtracks" if tracker.lower() == "none": candidate_maker = None @@ -858,7 +883,14 @@ def make_tracker_by_name( raise ValueError(f"{match} is not a valid tracker matching function.") candidate_maker = tracker_policies[tracker](min_points=min_match_points) - similarity_function = similarity_policies[similarity] + if similarity == "object_keypoint": + similarity_function = factory_object_keypoint_similarity( + keypoint_errors=oks_errors, + score_weighting=oks_score_weighting, + normalization_keypoints=oks_normalization, + ) + else: + similarity_function = similarity_policies[similarity] matching_function = match_policies[match] if tracker == "flow": @@ -931,7 +963,10 @@ def get_by_name_factory_options(cls): option = dict(name="max_tracking", default=False) option["type"] = bool - option["help"] = "If true then the tracker will cap the max number of tracks." + option["help"] = ( + "If true then the tracker will cap the max number of tracks. " + "Falls back to false if `max_tracks` is not defined or 0." + ) options.append(option) option = dict(name="max_tracks", default=None) @@ -1054,6 +1089,42 @@ def int_list_func(s): ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." options.append(option) + def float_list_func(s): + return [float(x.strip()) for x in s.split(",")] if s else None + + option = dict(name="oks_errors", default="1") + option["type"] = float_list_func + option["help"] = ( + "For Object Keypoint similarity: the standard error of the distance " + "between the predicted keypoint and the true value, in pixels.\n" + "If None or empty list, defaults to 1. If a scalar or singleton list, " + "every keypoint has the same error. If a list, defines the error for each " + "keypoint, the length should be equal to the number of keypoints in the " + "skeleton." + ) + options.append(option) + + option = dict(name="oks_score_weighting", default="0") + option["type"] = int + option["help"] = ( + "For Object Keypoint similarity: if 0 (default), only the distance between the reference " + "and query keypoint is used to compute the similarity. If 1, each distance is weighted " + "by the prediction scores of the reference and query keypoint." + ) + options.append(option) + + option = dict(name="oks_normalization", default="all") + option["type"] = str + option["options"] = ["all", "ref", "union"] + option["help"] = ( + "For Object Keypoint similarity: Determine how to normalize similarity score. " + "If 'all', similarity score is normalized by number of reference points. " + "If 'ref', similarity score is normalized by number of visible reference points. " + "If 'union', similarity score is normalized by number of points both visible " + "in query and reference instance." + ) + options.append(option) + return options @classmethod diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 801fcc092..ec5dfbc29 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -41,6 +41,13 @@ def centered_pair_predictions(): return Labels.load_file(TEST_JSON_PREDICTIONS) +@pytest.fixture +def centered_pair_predictions_sorted(centered_pair_predictions): + labels: Labels = centered_pair_predictions + labels.labeled_frames.sort(key=lambda lf: lf.frame_idx) + return labels + + @pytest.fixture def min_labels(): return Labels.load_file(TEST_JSON_MIN_LABELS) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index f99f136ab..98f5fbcec 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1373,7 +1373,7 @@ def test_retracking( # Create sleap-track command cmd = ( f"{slp_path} --tracking.tracker {tracker_method} --video.index 0 --frames 1-3 " - "--cpu" + "--tracking.similarity object_keypoint --cpu" ) if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" @@ -1393,6 +1393,8 @@ def test_retracking( parser = _make_cli_parser() args, _ = parser.parse_known_args(args=args) tracker = _make_tracker_from_cli(args) + # Additional check for similarity method + assert tracker.similarity_function.__name__ == "object_keypoint_similarity" output_path = f"{slp_path}.{tracker.get_name()}.slp" # Assert tracked predictions file exists @@ -1747,9 +1749,9 @@ def test_sleap_track_invalid_input( sleap_track(args=args) -def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): +def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): """Test flow tracker instances are pruned.""" - labels: Labels = centered_pair_predictions + labels: Labels = centered_pair_predictions_sorted track_window = 5 # Setup tracker @@ -1759,7 +1761,7 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): tracker.candidate_maker = cast(FlowCandidateMaker, tracker.candidate_maker) # Run tracking - frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) + frames = labels.labeled_frames # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index f861241ee..5786945fb 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -9,23 +9,79 @@ FrameMatches, greedy_matching, ) +from sleap.io.dataset import Labels from sleap.instance import PredictedInstance from sleap.skeleton import Skeleton +def tracker_by_name(frames=None, **kwargs): + t = Tracker.make_tracker_by_name(**kwargs) + print(kwargs) + print(t.candidate_maker) + if frames is None: + t.track([]) + t.final_pass([]) + return + + for lf in frames: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + t.track(**track_args) + t.final_pass(frames) + + @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) @pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("count", [0, 2]) -def test_tracker_by_name(tracker, similarity, match, count): - t = Tracker.make_tracker_by_name( - "flow", "instance", "greedy", clean_instance_count=2 +def test_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + similarity, + match, + count, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity=similarity, + match=match, + max_tracks=count, + ) + + +@pytest.mark.parametrize( + "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] +) +@pytest.mark.parametrize("oks_score_weighting", ["True", "False"]) +@pytest.mark.parametrize("oks_normalization", ["all", "ref", "union"]) +def test_oks_tracker_by_name( + centered_pair_predictions_sorted, + tracker, + oks_score_weighting, + oks_normalization, +): + # This is slow, so limit to 5 time points + frames = centered_pair_predictions_sorted[:5] + + tracker_by_name( + frames=frames, + tracker=tracker, + similarity="object_keypoint", + matching="greedy", + oks_score_weighting=oks_score_weighting, + oks_normalization=oks_normalization, + max_tracks=2, ) - t.track([]) - t.final_pass([]) def test_cull_instances(centered_pair_predictions): From 1581506ce888647dfc2cff07569026137e60a45f Mon Sep 17 00:00:00 2001 From: gqcpm <63070177+gqcpm@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:51:35 -0700 Subject: [PATCH 2/3] Generate suggestions using max point displacement threshold (#1862) * create function max_point_displacement, _max_point_displacement_video. Add to yaml file. Create test for new function . . . will need to edit * remove unnecessary for loop, calculate proper displacement, adjusted tests accordingly * Increase range for displacement threshold * Fix frames not found bug * Return the latter frame index * Lint --------- Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/config/suggestions.yaml | 9 +++++- sleap/gui/suggestions.py | 52 +++++++++++++++++++++++++++++++++++ tests/gui/test_suggestions.py | 14 ++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml index 8cf89728a..1440530fc 100644 --- a/sleap/config/suggestions.yaml +++ b/sleap/config/suggestions.yaml @@ -3,7 +3,7 @@ main: label: Method type: stacked default: " " - options: " ,image features,sample,prediction score,velocity,frame chunk" + options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" " ": sample: @@ -175,6 +175,13 @@ main: type: double default: 0.1 range: 0.1,1.0 + + "max point displacement": + - name: displacement_threshold + label: Maximum Displacement Threshold + type: int + default: 10 + range: 0,999 - name: target label: Target diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 48b916437..b85d6ac32 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] prediction_score=cls.prediction_score, velocity=cls.velocity, frame_chunk=cls.frame_chunk, + max_point_displacement=cls.max_point_displacement, ) method = str.replace(params["method"], " ", "_") @@ -213,6 +214,7 @@ def _prediction_score_video( ): lfs = labels.find(video) frames = len(lfs) + # initiate an array filled with -1 to store frame index (starting from 0). idxs = np.full((frames), -1, dtype="int") @@ -291,6 +293,56 @@ def _velocity_video( return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod + def max_point_displacement( + cls, + labels: "Labels", + videos: List[Video], + displacement_threshold: float, + **kwargs, + ): + """Finds frames with maximum point displacement above a threshold.""" + + proposed_suggestions = [] + for video in videos: + proposed_suggestions.extend( + cls._max_point_displacement_video(video, labels, displacement_threshold) + ) + + suggestions = VideoFrameSuggestions.filter_unique_suggestions( + labels, videos, proposed_suggestions + ) + + return suggestions + + @classmethod + def _max_point_displacement_video( + cls, video: Video, labels: "Labels", displacement_threshold: float + ): + # Get numpy of shape (frames, tracks, nodes, x, y) + labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False) + + # Return empty list if not enough frames + n_frames, n_tracks, n_nodes, _ = labels_numpy.shape + + if n_frames < 2: + return [] + + # Calculate displacements + diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y) + euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes) + mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks) + + # Find frames where mean displacement is above threshold + threshold_mask = np.any( + mean_euc_norm > displacement_threshold, axis=-1 + ) # (frames - 1,) + frame_idxs = list( + np.argwhere(threshold_mask).flatten() + 1 + ) # [0, len(frames - 1)] + + return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod def frame_chunk( cls, diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py index bbad73179..196ff2d35 100644 --- a/tests/gui/test_suggestions.py +++ b/tests/gui/test_suggestions.py @@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions): assert suggestions[1].frame_idx == 45 +def test_max_point_displacement_suggestions(centered_pair_predictions): + suggestions = VideoFrameSuggestions.suggest( + labels=centered_pair_predictions, + params=dict( + videos=centered_pair_predictions.videos, + method="max_point_displacement", + displacement_threshold=6, + ), + ) + assert len(suggestions) == 19 + assert suggestions[0].frame_idx == 28 + assert suggestions[1].frame_idx == 82 + + def test_frame_increment(centered_pair_predictions: Labels): # Testing videos that have less frames than desired Samples per Video (stride) # Expected result is there should be n suggestions where n is equal to the frames From 28c34e22e0cf78e1774476d0ac76c7ea0b4814fe Mon Sep 17 00:00:00 2001 From: Andrew Park Date: Thu, 25 Jul 2024 20:15:04 -0700 Subject: [PATCH 3/3] Added Three Different Cases for Adding a New Instance (#1859) * implemented paste with offset * right click and then default will paste the new instance at the location of the cursor * modified the logics for creating new instance * refined the logic * fixed the logic for right click * refined logics for adding new instance at a specific location * Remove print statements * Comment code * Ensure that we choose a non nan reference node * Move OOB nodes to closest in-bounds position --------- Co-authored-by: roomrys <38435167+roomrys@users.noreply.github.com> --- sleap/gui/commands.py | 61 ++++++++++++++++++++++++++++++++++++-- sleap/gui/widgets/video.py | 5 +++- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 1a64a071c..8df85fc8e 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -2913,6 +2913,8 @@ def create_new_instance( copy_instance=copy_instance, new_instance=new_instance, mark_complete=mark_complete, + init_method=init_method, + location=location, ) if has_missing_nodes: @@ -2984,6 +2986,8 @@ def set_visible_nodes( copy_instance: Optional[Union[Instance, PredictedInstance]], new_instance: Instance, mark_complete: bool, + init_method: str, + location: Optional[QtCore.QPoint] = None, ) -> bool: """Sets visible nodes for new instance. @@ -3010,6 +3014,25 @@ def set_visible_nodes( scale_width = new_size_width / old_size_width scale_height = new_size_height / old_size_height + # Default the offset is 0 + offset_x = 0 + offset_y = 0 + + # Using the menu or the hotkey + if init_method == "best": + offset_x = 10 + offset_y = 10 + + # Using right click and context menu + if location is not None: + reference_node = next( + (node for node in copy_instance if not node.isnan()), None + ) + reference_x = reference_node.x + reference_y = reference_node.y + offset_x = location.x() - (reference_x * scale_width) + offset_y = location.y() - (reference_y * scale_height) + # Go through each node in skeleton. for node in context.state["skeleton"].node_names: # If we're copying from a skeleton that has this node. @@ -3018,13 +3041,45 @@ def set_visible_nodes( # We don't want to copy a PredictedPoint or score attribute. x_old = copy_instance[node].x y_old = copy_instance[node].y - x_new = x_old * scale_width - y_new = y_old * scale_height + # Copy the instance without scale or offset if predicted + if isinstance(copy_instance, PredictedInstance): + x_new = x_old + y_new = y_old + else: + x_new = x_old * scale_width + y_new = y_old * scale_height + + # Apply offset if in bounds + x_new_offset = x_new + offset_x + y_new_offset = y_new + offset_y + + # Default visibility is same as copied instance. + visible = copy_instance[node].visible + + # If the node is offset to outside the frame, mark as not visible. + if x_new_offset < 0: + x_new = 0 + visible = False + elif x_new_offset > new_size_width: + x_new = new_size_width + visible = False + else: + x_new = x_new_offset + if y_new_offset < 0: + y_new = 0 + visible = False + elif y_new_offset > new_size_height: + y_new = new_size_height + visible = False + else: + y_new = y_new_offset + + # Update the new instance with the new x, y, and visibility. new_instance[node] = Point( x=x_new, y=y_new, - visible=copy_instance[node].visible, + visible=visible, complete=mark_complete, ) else: diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 502ea388e..745908048 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -367,7 +367,10 @@ def show_contextual_menu(self, where: QtCore.QPoint): menu.addAction("Add Instance:").setEnabled(False) - menu.addAction("Default", lambda: self.context.newInstance(init_method="best")) + menu.addAction( + "Default", + lambda: self.context.newInstance(init_method="best", location=scene_pos), + ) menu.addAction( "Average",