Skip to content

Commit

Permalink
Generate suggestions using max point displacement threshold (#1862)
Browse files Browse the repository at this point in the history
* create function max_point_displacement, _max_point_displacement_video. Add to yaml file. Create test for new function . . . will need to edit

* remove unnecessary for loop, calculate proper displacement, adjusted tests accordingly

* Increase range for displacement threshold

* Fix frames not found bug

* Return the latter frame index

* Lint

---------

Co-authored-by: roomrys <[email protected]>
  • Loading branch information
gqcpm and roomrys authored Jul 24, 2024
1 parent 38a5ca7 commit 1581506
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sleap/config/suggestions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ main:
label: Method
type: stacked
default: " "
options: " ,image features,sample,prediction score,velocity,frame chunk"
options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement"
" ":

sample:
Expand Down Expand Up @@ -175,6 +175,13 @@ main:
type: double
default: 0.1
range: 0.1,1.0

"max point displacement":
- name: displacement_threshold
label: Maximum Displacement Threshold
type: int
default: 10
range: 0,999

- name: target
label: Target
Expand Down
52 changes: 52 additions & 0 deletions sleap/gui/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame]
prediction_score=cls.prediction_score,
velocity=cls.velocity,
frame_chunk=cls.frame_chunk,
max_point_displacement=cls.max_point_displacement,
)

method = str.replace(params["method"], " ", "_")
Expand Down Expand Up @@ -213,6 +214,7 @@ def _prediction_score_video(
):
lfs = labels.find(video)
frames = len(lfs)

# initiate an array filled with -1 to store frame index (starting from 0).
idxs = np.full((frames), -1, dtype="int")

Expand Down Expand Up @@ -291,6 +293,56 @@ def _velocity_video(

return cls.idx_list_to_frame_list(frame_idxs, video)

@classmethod
def max_point_displacement(
cls,
labels: "Labels",
videos: List[Video],
displacement_threshold: float,
**kwargs,
):
"""Finds frames with maximum point displacement above a threshold."""

proposed_suggestions = []
for video in videos:
proposed_suggestions.extend(
cls._max_point_displacement_video(video, labels, displacement_threshold)
)

suggestions = VideoFrameSuggestions.filter_unique_suggestions(
labels, videos, proposed_suggestions
)

return suggestions

@classmethod
def _max_point_displacement_video(
cls, video: Video, labels: "Labels", displacement_threshold: float
):
# Get numpy of shape (frames, tracks, nodes, x, y)
labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False)

# Return empty list if not enough frames
n_frames, n_tracks, n_nodes, _ = labels_numpy.shape

if n_frames < 2:
return []

# Calculate displacements
diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y)
euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes)
mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks)

# Find frames where mean displacement is above threshold
threshold_mask = np.any(
mean_euc_norm > displacement_threshold, axis=-1
) # (frames - 1,)
frame_idxs = list(
np.argwhere(threshold_mask).flatten() + 1
) # [0, len(frames - 1)]

return cls.idx_list_to_frame_list(frame_idxs, video)

@classmethod
def frame_chunk(
cls,
Expand Down
14 changes: 14 additions & 0 deletions tests/gui/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def test_velocity_suggestions(centered_pair_predictions):
assert suggestions[1].frame_idx == 45


def test_max_point_displacement_suggestions(centered_pair_predictions):
suggestions = VideoFrameSuggestions.suggest(
labels=centered_pair_predictions,
params=dict(
videos=centered_pair_predictions.videos,
method="max_point_displacement",
displacement_threshold=6,
),
)
assert len(suggestions) == 19
assert suggestions[0].frame_idx == 28
assert suggestions[1].frame_idx == 82


def test_frame_increment(centered_pair_predictions: Labels):
# Testing videos that have less frames than desired Samples per Video (stride)
# Expected result is there should be n suggestions where n is equal to the frames
Expand Down

0 comments on commit 1581506

Please sign in to comment.