-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate suggestions using max point displacement threshold #1862
Changes from all commits
2789b61
bee834d
a17e5c8
8cc046c
3baf219
61fe572
a883246
e5ce91f
d8a6335
7e2ef29
47a8d06
9dddd6b
078c153
b0c47af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,7 +3,7 @@ main: | |||||||||||||||||||||||||||||
label: Method | ||||||||||||||||||||||||||||||
type: stacked | ||||||||||||||||||||||||||||||
default: " " | ||||||||||||||||||||||||||||||
options: " ,image features,sample,prediction score,velocity,frame chunk" | ||||||||||||||||||||||||||||||
options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement" | ||||||||||||||||||||||||||||||
" ": | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
sample: | ||||||||||||||||||||||||||||||
|
@@ -175,6 +175,13 @@ main: | |||||||||||||||||||||||||||||
type: double | ||||||||||||||||||||||||||||||
default: 0.1 | ||||||||||||||||||||||||||||||
range: 0.1,1.0 | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
"max point displacement": | ||||||||||||||||||||||||||||||
- name: displacement_threshold | ||||||||||||||||||||||||||||||
label: Maximum Displacement Threshold | ||||||||||||||||||||||||||||||
type: int | ||||||||||||||||||||||||||||||
default: 10 | ||||||||||||||||||||||||||||||
range: 0,999 | ||||||||||||||||||||||||||||||
Comment on lines
+178
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove trailing spaces. Trailing spaces are unnecessary and should be removed for clean code. -
+ Committable suggestion
Suggested change
Toolsyamllint
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
- name: target | ||||||||||||||||||||||||||||||
label: Target | ||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame] | |||||||||||||
prediction_score=cls.prediction_score, | ||||||||||||||
velocity=cls.velocity, | ||||||||||||||
frame_chunk=cls.frame_chunk, | ||||||||||||||
max_point_displacement=cls.max_point_displacement, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
method = str.replace(params["method"], " ", "_") | ||||||||||||||
|
@@ -213,6 +214,7 @@ def _prediction_score_video( | |||||||||||||
): | ||||||||||||||
lfs = labels.find(video) | ||||||||||||||
frames = len(lfs) | ||||||||||||||
|
||||||||||||||
# initiate an array filled with -1 to store frame index (starting from 0). | ||||||||||||||
idxs = np.full((frames), -1, dtype="int") | ||||||||||||||
|
||||||||||||||
|
@@ -291,6 +293,56 @@ def _velocity_video( | |||||||||||||
|
||||||||||||||
return cls.idx_list_to_frame_list(frame_idxs, video) | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
def max_point_displacement( | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to access this as a sleap/sleap/gui/suggestions.py Line 64 in 2789b61
, make sure to wrap it with:
Suggested change
|
||||||||||||||
cls, | ||||||||||||||
labels: "Labels", | ||||||||||||||
videos: List[Video], | ||||||||||||||
displacement_threshold: float, | ||||||||||||||
**kwargs, | ||||||||||||||
): | ||||||||||||||
"""Finds frames with maximum point displacement above a threshold.""" | ||||||||||||||
|
||||||||||||||
proposed_suggestions = [] | ||||||||||||||
for video in videos: | ||||||||||||||
proposed_suggestions.extend( | ||||||||||||||
cls._max_point_displacement_video(video, labels, displacement_threshold) | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
suggestions = VideoFrameSuggestions.filter_unique_suggestions( | ||||||||||||||
labels, videos, proposed_suggestions | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
return suggestions | ||||||||||||||
|
||||||||||||||
@classmethod | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix undefined name The - cls, video: Video, labels: "Labels", displacement_threshold: float
+ cls, video: Video, labels: "Labels", displacement_threshold: float
+ ):
+ from sleap.io.dataset import Labels Committable suggestion
Suggested change
|
||||||||||||||
def _max_point_displacement_video( | ||||||||||||||
cls, video: Video, labels: "Labels", displacement_threshold: float | ||||||||||||||
): | ||||||||||||||
# Get numpy of shape (frames, tracks, nodes, x, y) | ||||||||||||||
labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False) | ||||||||||||||
|
||||||||||||||
# Return empty list if not enough frames | ||||||||||||||
n_frames, n_tracks, n_nodes, _ = labels_numpy.shape | ||||||||||||||
|
||||||||||||||
if n_frames < 2: | ||||||||||||||
return [] | ||||||||||||||
|
||||||||||||||
# Calculate displacements | ||||||||||||||
diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y) | ||||||||||||||
euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes) | ||||||||||||||
mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks) | ||||||||||||||
|
||||||||||||||
# Find frames where mean displacement is above threshold | ||||||||||||||
threshold_mask = np.any( | ||||||||||||||
mean_euc_norm > displacement_threshold, axis=-1 | ||||||||||||||
) # (frames - 1,) | ||||||||||||||
frame_idxs = list( | ||||||||||||||
np.argwhere(threshold_mask).flatten() + 1 | ||||||||||||||
) # [0, len(frames - 1)] | ||||||||||||||
|
||||||||||||||
return cls.idx_list_to_frame_list(frame_idxs, video) | ||||||||||||||
|
||||||||||||||
Comment on lines
+318
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle The suggested approach handles
# Get numpy of shape (frames, tracks, nodes, x, y)
labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False)
# Return empty list if not enough frames
n_frames, n_tracks, n_nodes, _ = labels_numpy.shape
if n_frames < 2:
return []
# Calculate displacements
diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y)
euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes)
mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks)
# Find frames where mean displacement is above threshold
threshold_mask = np.any(
mean_euc_norm > displacement_threshold, axis=-1
) # (frames - 1,)
frame_idxs = list(
np.argwhere(threshold_mask).flatten() + 1
) # [0, len(frames - 1)]
return cls.idx_list_to_frame_list(frame_idxs, video)
ToolsRuff
|
||||||||||||||
@classmethod | ||||||||||||||
def frame_chunk( | ||||||||||||||
cls, | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove trailing spaces.
Trailing spaces are unnecessary and should be removed for clean code.
Committable suggestion
Tools
yamllint