diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml index 8cf89728a..1440530fc 100644 --- a/sleap/config/suggestions.yaml +++ b/sleap/config/suggestions.yaml @@ -3,7 +3,7 @@ main: label: Method type: stacked default: " " - options: " ,image features,sample,prediction score,velocity,frame chunk" + options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" " ": sample: @@ -175,6 +175,13 @@ main: type: double default: 0.1 range: 0.1,1.0 + + "max point displacement": + - name: displacement_threshold + label: Maximum Displacement Threshold + type: int + default: 10 + range: 0,999 - name: target label: Target diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 48b916437..b85d6ac32 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] prediction_score=cls.prediction_score, velocity=cls.velocity, frame_chunk=cls.frame_chunk, + max_point_displacement=cls.max_point_displacement, ) method = str.replace(params["method"], " ", "_") @@ -213,6 +214,7 @@ def _prediction_score_video( ): lfs = labels.find(video) frames = len(lfs) + # initiate an array filled with -1 to store frame index (starting from 0). idxs = np.full((frames), -1, dtype="int") @@ -291,6 +293,56 @@ def _velocity_video( return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod + def max_point_displacement( + cls, + labels: "Labels", + videos: List[Video], + displacement_threshold: float, + **kwargs, + ): + """Finds frames with maximum point displacement above a threshold.""" + + proposed_suggestions = [] + for video in videos: + proposed_suggestions.extend( + cls._max_point_displacement_video(video, labels, displacement_threshold) + ) + + suggestions = VideoFrameSuggestions.filter_unique_suggestions( + labels, videos, proposed_suggestions + ) + + return suggestions + + @classmethod + def _max_point_displacement_video( + cls, video: Video, labels: "Labels", displacement_threshold: float + ): + # Get numpy of shape (frames, tracks, nodes, x, y) + labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False) + + # Return empty list if not enough frames + n_frames, n_tracks, n_nodes, _ = labels_numpy.shape + + if n_frames < 2: + return [] + + # Calculate displacements + diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y) + euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes) + mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks) + + # Find frames where mean displacement is above threshold + threshold_mask = np.any( + mean_euc_norm > displacement_threshold, axis=-1 + ) # (frames - 1,) + frame_idxs = list( + np.argwhere(threshold_mask).flatten() + 1 + ) # [0, len(frames - 1)] + + return cls.idx_list_to_frame_list(frame_idxs, video) + @classmethod def frame_chunk( cls, diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py index bbad73179..196ff2d35 100644 --- a/tests/gui/test_suggestions.py +++ b/tests/gui/test_suggestions.py @@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions): assert suggestions[1].frame_idx == 45 +def test_max_point_displacement_suggestions(centered_pair_predictions): + suggestions = VideoFrameSuggestions.suggest( + labels=centered_pair_predictions, + params=dict( + videos=centered_pair_predictions.videos, + method="max_point_displacement", + displacement_threshold=6, + ), + ) + assert len(suggestions) == 19 + assert suggestions[0].frame_idx == 28 + assert suggestions[1].frame_idx == 82 + + def test_frame_increment(centered_pair_predictions: Labels): # Testing videos that have less frames than desired Samples per Video (stride) # Expected result is there should be n suggestions where n is equal to the frames