Skip to content

Commit

Permalink
option to merge instances
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Aug 2, 2024
1 parent f0e6319 commit ab24b50
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 8 deletions.
1 change: 1 addition & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def make_predict_cli_call(

bool_items_as_ints = (
"tracking.pre_cull_to_target",
"tracking.pre_cull_merge_instances",
"tracking.max_tracking",
"tracking.post_connect_single_breaks",
"tracking.save_shifted_instances",
Expand Down
72 changes: 71 additions & 1 deletion sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
"""

import math
import operator
from functools import reduce
from itertools import chain, combinations

import numpy as np
import cattr

from copy import copy
from typing import Dict, List, Optional, Union, Tuple, ForwardRef
from typing import Dict, List, Optional, Union, Sequence, Tuple, ForwardRef

from numpy.lib.recfunctions import structured_to_unstructured

Expand Down Expand Up @@ -1177,6 +1180,73 @@ def from_numpy(
)


def all_disjoint(x: Sequence[Sequence]) -> bool:
return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2))


def create_merged_instances(
instances: List[PredictedInstance],
penalty: float = 0.2,
) -> List[PredictedInstance]:
"""Create merged instances from the list of PredictedInstance.
Only instances with non-overlapping visible nodes are merged.
Args:
instances: a list of original PredictedInstances to try to merge.
penalty: a float between 0 and 1. All scores of the merged instance
are multplied by (1 - penalty).
Returns:
a list of PredictedInstance that were merged.
"""
# Ensure same skeleton
skeletons = {inst.skeleton for inst in instances}
if len(skeletons) != 1:
return []
skeleton = list(skeletons)[0]

# Ensure same track
tracks = {inst.track for inst in instances}
if len(tracks) != 1:
return []
track = list(tracks)[0]

# Ensure non-intersecting visible nodes
merged_instances = []
instance_subsets = chain(
*(combinations(instances, n) for n in range(2, len(instances) + 1))
)
for subset in instance_subsets:
if not all_disjoint([s.nodes for s in subset]):
continue

nodes_points = []
for instance in subset:
nodes_points.extend(list(instance.nodes_points))
predicted_points = {node: point for node, point in nodes_points}

instance_score = reduce(lambda x, y: x * y, [s.score for s in subset])

# Penalize scores of merged instances
if 0 < penalty <= 1:
factor = 1 - penalty
instance_score *= factor
for point in predicted_points.values():
point.score *= factor

merged_instance = PredictedInstance(
points=predicted_points,
skeleton=skeleton,
score=instance_score,
track=track,
)

merged_instances.append(merged_instance)

return merged_instances


def make_instance_cattr() -> cattr.Converter:
"""Create a cattr converter for Lists of Instances/PredictedInstances.
Expand Down
24 changes: 18 additions & 6 deletions sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import operator
from collections import defaultdict
import logging
Expand All @@ -23,6 +24,7 @@

from sleap import PredictedInstance, Instance, Track
from sleap.nn import utils
from sleap.instance import create_merged_instances

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,7 +251,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]:

# keep looping while some indexes still remain in the indexes list
while len(idxs) > 0:

# we want to add the best box which is the last box in sorted list
picked_box_idx = idxs[-1]

Expand Down Expand Up @@ -351,6 +352,8 @@ def cull_frame_instances(
instances_list: List[InstanceType],
instance_count: int,
iou_threshold: Optional[float] = None,
merge_instances: bool = False,
merging_penalty: float = 0.2,
) -> List["LabeledFrame"]:
"""
Removes instances (for single frame) over instance per frame threshold.
Expand All @@ -361,13 +364,23 @@ def cull_frame_instances(
iou_threshold: Intersection over Union (IOU) threshold to use when
removing overlapping instances over target count; if None, then
only use score to determine which instances to remove.
merge_instances: If True, allow merging instances with no overlapping
merging_penalty: a float between 0 and 1. All scores of the merged
instance are multplied by (1 - penalty).
Returns:
Updated list of frames, also modifies frames in place.
"""
if not instances_list:
return

# Merge instances
if merge_instances:
merged_instances = create_merged_instances(
instances_list, penalty=merging_penalty
)
instances_list.extend(merged_instances)

if len(instances_list) > instance_count:
# List of instances which we'll pare down
keep_instances = instances_list
Expand All @@ -387,9 +400,10 @@ def cull_frame_instances(
if len(keep_instances) > instance_count:
# Sort by ascending score, get target number of instances
# from the end of list (i.e., with highest score)
extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[
:-instance_count
]
extra_instances = sorted(
keep_instances,
key=operator.attrgetter("score"),
)[:-instance_count]

# Remove the extra instances
for inst in extra_instances:
Expand Down Expand Up @@ -523,7 +537,6 @@ def from_candidate_instances(
candidate_tracks = []

if candidate_instances:

# Group candidate instances by track.
candidate_instances_by_track = defaultdict(list)
for instance in candidate_instances:
Expand All @@ -536,7 +549,6 @@ def from_candidate_instances(
matching_similarities = np.full(dims, np.nan)

for i, untracked_instance in enumerate(untracked_instances):

for j, candidate_track in enumerate(candidate_tracks):
# Compute similarity between untracked instance and all track
# candidates.
Expand Down
22 changes: 21 additions & 1 deletion sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ def make_tracker_by_name(
target_instance_count: int = 0,
pre_cull_to_target: bool = False,
pre_cull_iou_threshold: Optional[float] = None,
pre_cull_merge_instances: bool = False,
pre_cull_merging_penalty: float = 0.2,
# Post-tracking options to connect broken tracks
post_connect_single_breaks: bool = False,
# TODO: deprecate these post-tracking cleaning options
Expand Down Expand Up @@ -999,13 +1001,15 @@ def make_tracker_by_name(
)

pre_cull_function = None
if target_instance_count and pre_cull_to_target:
if (target_instance_count and pre_cull_to_target) or pre_cull_merge_instances:

def pre_cull_function(inst_list):
cull_frame_instances(
inst_list,
instance_count=target_instance_count,
iou_threshold=pre_cull_iou_threshold,
merge_instances=pre_cull_merge_instances,
merging_penalty=pre_cull_merging_penalty,
)

tracker_obj = cls(
Expand Down Expand Up @@ -1084,6 +1088,22 @@ def get_by_name_factory_options(cls):
)
options.append(option)

option = dict(name="pre_cull_merge_instances", default=False)
option["type"] = bool
option["help"] = (
"If True, allow merging instances with non-overlapping visible nodes "
"to create new instances *before* tracking."
)
options.append(option)

option = dict(name="pre_cull_merging_penalty", default=0.2)
option["type"] = float
option["help"] = (
"A float between 0 and 1. All scores of the merged instances "
"are multplied by (1 - penalty)."
)
options.append(option)

option = dict(name="post_connect_single_breaks", default=0)
option["type"] = int
option["help"] = (
Expand Down

0 comments on commit ab24b50

Please sign in to comment.