Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/elise/add-csv-and-text-file-supp…
Browse files Browse the repository at this point in the history
…ort-to-sleap-track' into elise/add-csv-and-text-file-support-to-sleap-track
  • Loading branch information
emdavis02 committed Jul 30, 2024
2 parents 9860fe8 + 12ffe7e commit df65469
Show file tree
Hide file tree
Showing 12 changed files with 436 additions and 37 deletions.
44 changes: 38 additions & 6 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ training:
This pipeline uses two models: a "<u>centroid</u>" model to
locate and crop around each animal in the frame, and a
"<u>centered-instance confidence map</u>" model for predicted node locations
for each individual animal predicted by the centroid model.'
for each individual animal predicted by the centroid model.'
- label: Max Instances
name: max_instances
type: optional_int
Expand Down Expand Up @@ -217,7 +217,7 @@ training:
- name: controller_port
label: Controller Port
type: int
default: 9000
default: 9000
range: 1024,65535

- name: publish_port
Expand Down Expand Up @@ -388,7 +388,7 @@ inference:
tracking-only:

- name: batch_size
label: Batch Size
label: Batch Size
type: int
default: 4
range: 1,512
Expand Down Expand Up @@ -439,7 +439,7 @@ inference:
label: Similarity Method
type: list
default: instance
options: instance,centroid,iou
options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
Expand Down Expand Up @@ -478,6 +478,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
- type: text
text: '<b>Object keypoint similarity options</b>:<br />
Only used if this similarity method is selected.'
- name: tracking.oks_errors
label: Keypoints errors in pixels
help: 'Standard error in pixels of the distance for each keypoint.
If the list is empty, defaults to 1. If singleton list, each keypoint has
the same error. Otherwise, the length should be the same as the number of
keypoints in the skeleton.'
type: string
default:
- name: tracking.oks_score_weighting
label: Use prediction score for weighting
help: 'Use prediction scores to weight the similarity of each keypoint'
type: bool
default: false
- type: text
text: '<b>Post-tracker data cleaning</b>:'
- name: tracking.post_connect_single_breaks
Expand Down Expand Up @@ -521,8 +537,8 @@ inference:
- name: tracking.similarity
label: Similarity Method
type: list
default: iou
options: instance,centroid,iou
default: instance
options: "instance,centroid,iou,object keypoint"
- name: tracking.match
label: Matching Method
type: list
Expand Down Expand Up @@ -557,6 +573,22 @@ inference:
label: Nodes to use for Tracking
type: string
default: 0,1,2
- type: text
text: '<b>Object keypoint similarity options</b>:<br />
Only used if this similarity method is selected.'
- name: tracking.oks_errors
label: Keypoints errors in pixels
help: 'Standard error in pixels of the distance for each keypoint.
If the list is empty, defaults to 1. If singleton list, each keypoint has
the same error. Otherwise, the length should be the same as the number of
keypoints in the skeleton.'
type: string
default:
- name: tracking.oks_score_weighting
label: Use prediction score for weighting
help: 'Use prediction scores to weight the similarity of each keypoint'
type: bool
default: false
- type: text
text: '<b>Post-tracker data cleaning</b>:'
- name: tracking.post_connect_single_breaks
Expand Down
9 changes: 8 additions & 1 deletion sleap/config/suggestions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ main:
label: Method
type: stacked
default: " "
options: " ,image features,sample,prediction score,velocity,frame chunk"
options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement"
" ":

sample:
Expand Down Expand Up @@ -175,6 +175,13 @@ main:
type: double
default: 0.1
range: 0.1,1.0

"max point displacement":
- name: displacement_threshold
label: Maximum Displacement Threshold
type: int
default: 10
range: 0,999

- name: target
label: Target
Expand Down
61 changes: 58 additions & 3 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,6 +2913,8 @@ def create_new_instance(
copy_instance=copy_instance,
new_instance=new_instance,
mark_complete=mark_complete,
init_method=init_method,
location=location,
)

if has_missing_nodes:
Expand Down Expand Up @@ -2984,6 +2986,8 @@ def set_visible_nodes(
copy_instance: Optional[Union[Instance, PredictedInstance]],
new_instance: Instance,
mark_complete: bool,
init_method: str,
location: Optional[QtCore.QPoint] = None,
) -> bool:
"""Sets visible nodes for new instance.
Expand All @@ -3010,6 +3014,25 @@ def set_visible_nodes(
scale_width = new_size_width / old_size_width
scale_height = new_size_height / old_size_height

# Default the offset is 0
offset_x = 0
offset_y = 0

# Using the menu or the hotkey
if init_method == "best":
offset_x = 10
offset_y = 10

# Using right click and context menu
if location is not None:
reference_node = next(
(node for node in copy_instance if not node.isnan()), None
)
reference_x = reference_node.x
reference_y = reference_node.y
offset_x = location.x() - (reference_x * scale_width)
offset_y = location.y() - (reference_y * scale_height)

# Go through each node in skeleton.
for node in context.state["skeleton"].node_names:
# If we're copying from a skeleton that has this node.
Expand All @@ -3018,13 +3041,45 @@ def set_visible_nodes(
# We don't want to copy a PredictedPoint or score attribute.
x_old = copy_instance[node].x
y_old = copy_instance[node].y
x_new = x_old * scale_width
y_new = y_old * scale_height

# Copy the instance without scale or offset if predicted
if isinstance(copy_instance, PredictedInstance):
x_new = x_old
y_new = y_old
else:
x_new = x_old * scale_width
y_new = y_old * scale_height

# Apply offset if in bounds
x_new_offset = x_new + offset_x
y_new_offset = y_new + offset_y

# Default visibility is same as copied instance.
visible = copy_instance[node].visible

# If the node is offset to outside the frame, mark as not visible.
if x_new_offset < 0:
x_new = 0
visible = False
elif x_new_offset > new_size_width:
x_new = new_size_width
visible = False
else:
x_new = x_new_offset
if y_new_offset < 0:
y_new = 0
visible = False
elif y_new_offset > new_size_height:
y_new = new_size_height
visible = False
else:
y_new = y_new_offset

# Update the new instance with the new x, y, and visibility.
new_instance[node] = Point(
x=x_new,
y=y_new,
visible=copy_instance[node].visible,
visible=visible,
complete=mark_complete,
)
else:
Expand Down
8 changes: 8 additions & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,20 @@ def make_predict_cli_call(
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
"tracking.oks_score_weighting",
)

for key in bool_items_as_ints:
if key in self.inference_params:
self.inference_params[key] = int(self.inference_params[key])

remove_spaces_items = ("tracking.similarity",)

for key in remove_spaces_items:
if key in self.inference_params:
value = self.inference_params[key]
self.inference_params[key] = value.replace(" ", "_")

for key, val in self.inference_params.items():
if not key.startswith(("_", "outputs.", "model.", "data.")):
cli_args.extend((f"--{key}", str(val)))
Expand Down
52 changes: 52 additions & 0 deletions sleap/gui/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame]
prediction_score=cls.prediction_score,
velocity=cls.velocity,
frame_chunk=cls.frame_chunk,
max_point_displacement=cls.max_point_displacement,
)

method = str.replace(params["method"], " ", "_")
Expand Down Expand Up @@ -213,6 +214,7 @@ def _prediction_score_video(
):
lfs = labels.find(video)
frames = len(lfs)

# initiate an array filled with -1 to store frame index (starting from 0).
idxs = np.full((frames), -1, dtype="int")

Expand Down Expand Up @@ -291,6 +293,56 @@ def _velocity_video(

return cls.idx_list_to_frame_list(frame_idxs, video)

@classmethod
def max_point_displacement(
cls,
labels: "Labels",
videos: List[Video],
displacement_threshold: float,
**kwargs,
):
"""Finds frames with maximum point displacement above a threshold."""

proposed_suggestions = []
for video in videos:
proposed_suggestions.extend(
cls._max_point_displacement_video(video, labels, displacement_threshold)
)

suggestions = VideoFrameSuggestions.filter_unique_suggestions(
labels, videos, proposed_suggestions
)

return suggestions

@classmethod
def _max_point_displacement_video(
cls, video: Video, labels: "Labels", displacement_threshold: float
):
# Get numpy of shape (frames, tracks, nodes, x, y)
labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False)

# Return empty list if not enough frames
n_frames, n_tracks, n_nodes, _ = labels_numpy.shape

if n_frames < 2:
return []

# Calculate displacements
diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y)
euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes)
mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks)

# Find frames where mean displacement is above threshold
threshold_mask = np.any(
mean_euc_norm > displacement_threshold, axis=-1
) # (frames - 1,)
frame_idxs = list(
np.argwhere(threshold_mask).flatten() + 1
) # [0, len(frames - 1)]

return cls.idx_list_to_frame_list(frame_idxs, video)

@classmethod
def frame_chunk(
cls,
Expand Down
5 changes: 4 additions & 1 deletion sleap/gui/widgets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def show_contextual_menu(self, where: QtCore.QPoint):

menu.addAction("Add Instance:").setEnabled(False)

menu.addAction("Default", lambda: self.context.newInstance(init_method="best"))
menu.addAction(
"Default",
lambda: self.context.newInstance(init_method="best", location=scene_pos),
)

menu.addAction(
"Average",
Expand Down
Loading

0 comments on commit df65469

Please sign in to comment.