From 2789b613a41ead7276ddfc51f117790d62bdf95d Mon Sep 17 00:00:00 2001 From: grquach <101067674+grquach@users.noreply.github.com> Date: Wed, 10 Jul 2024 11:18:23 -0700 Subject: [PATCH] create function max_point_displacement, _max_point_displacement_video. Add to yaml file. Create test for new function . . . will need to edit --- sleap/config/suggestions.yaml | 8 +++++- sleap/gui/suggestions.py | 52 +++++++++++++++++++++++++++++++++++ tests/gui/test_suggestions.py | 13 +++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/sleap/config/suggestions.yaml b/sleap/config/suggestions.yaml index 8cf89728a..611f25972 100644 --- a/sleap/config/suggestions.yaml +++ b/sleap/config/suggestions.yaml @@ -3,7 +3,7 @@ main: label: Method type: stacked default: " " - options: " ,image features,sample,prediction score,velocity,frame chunk" + options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" " ": sample: @@ -175,6 +175,12 @@ main: type: double default: 0.1 range: 0.1,1.0 + + "max point displacement": + - name: per_video + label: Threshold + type: int + default: 10 - name: target label: Target diff --git a/sleap/gui/suggestions.py b/sleap/gui/suggestions.py index 48b916437..54f59843f 100644 --- a/sleap/gui/suggestions.py +++ b/sleap/gui/suggestions.py @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] prediction_score=cls.prediction_score, velocity=cls.velocity, frame_chunk=cls.frame_chunk, + max_point_displacement = cls.max_point_displacement, ) method = str.replace(params["method"], " ", "_") @@ -291,6 +292,57 @@ def _velocity_video( return cls.idx_list_to_frame_list(frame_idxs, video) + def max_point_displacement( + cls, + labels: "Labels", + videos: List[Video], + displacement_threshold: float, + **kwargs, + ): + """Finds frames with maximum point displacement above a threshold.""" + + proposed_suggestions = [] + for video in videos: + proposed_suggestions.extend( + cls._max_point_displacement_video(video, labels, displacement_threshold) + ) + + suggestions = VideoFrameSuggestions.filter_unique_suggestions( + labels, videos, proposed_suggestions + ) + + return suggestions + + @classmethod + def _max_point_displacement_video( + cls, video: Video, labels: "Labels", displacement_threshold: float + ): + lfs = labels.find(video) + frames = len(lfs) + + if frames < 2: + return [] + + displacements = [] + for i in range(1, frames): + prev_lf = lfs[i - 1] + curr_lf = lfs[i] + prev_points = np.array([inst.points_array for inst in prev_lf.instances_to_show]) + curr_points = np.array([inst.points_array for inst in curr_lf.instances_to_show]) + + if prev_points.shape != curr_points.shape: + continue + + displacement = np.linalg.norm(curr_points - prev_points, axis=2).sum() + displacements.append((displacement, curr_lf.frame_idx)) + + frame_idxs = [ + frame_idx for displacement, frame_idx in displacements if displacement > displacement_threshold + ] + + return cls.idx_list_to_frame_list(frame_idxs, video) + + @classmethod def frame_chunk( cls, diff --git a/tests/gui/test_suggestions.py b/tests/gui/test_suggestions.py index bbad73179..e6a62d8fc 100644 --- a/tests/gui/test_suggestions.py +++ b/tests/gui/test_suggestions.py @@ -23,6 +23,19 @@ def test_velocity_suggestions(centered_pair_predictions): assert suggestions[0].frame_idx == 21 assert suggestions[1].frame_idx == 45 +# something like this? +def test_max_point_displacement_suggestions(centered_pair_predictions): + suggestions = VideoFrameSuggestions.suggest( + labels=centered_pair_predictions, + params=dict( + videos=centered_pair_predictions.videos, + method="max_point_displacement", + displacement_threshold = 3 + ), + ) + assert len(suggestions) == 45 + assert suggestions[0].frame_idx == 21 + assert suggestions[1].frame_idx == 45 def test_frame_increment(centered_pair_predictions: Labels): # Testing videos that have less frames than desired Samples per Video (stride)