diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index caeaba605..3f78a2924 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -258,6 +258,7 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", + "tracking.pre_cull_merge_instances", "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", diff --git a/sleap/instance.py b/sleap/instance.py index 67e96f330..299fdba17 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -18,12 +18,15 @@ """ import math +import operator +from functools import reduce +from itertools import chain, combinations import numpy as np import cattr from copy import copy -from typing import Dict, List, Optional, Union, Tuple, ForwardRef +from typing import Dict, List, Optional, Union, Sequence, Tuple, ForwardRef from numpy.lib.recfunctions import structured_to_unstructured @@ -1177,6 +1180,73 @@ def from_numpy( ) +def all_disjoint(x: Sequence[Sequence]) -> bool: + return all((set(p0).isdisjoint(set(p1))) for p0, p1 in combinations(x, 2)) + + +def create_merged_instances( + instances: List[PredictedInstance], + penalty: float = 0.2, +) -> List[PredictedInstance]: + """Create merged instances from the list of PredictedInstance. + + Only instances with non-overlapping visible nodes are merged. + + Args: + instances: a list of original PredictedInstances to try to merge. + penalty: a float between 0 and 1. All scores of the merged instance + are multplied by (1 - penalty). + + Returns: + a list of PredictedInstance that were merged. + """ + # Ensure same skeleton + skeletons = {inst.skeleton for inst in instances} + if len(skeletons) != 1: + return [] + skeleton = list(skeletons)[0] + + # Ensure same track + tracks = {inst.track for inst in instances} + if len(tracks) != 1: + return [] + track = list(tracks)[0] + + # Ensure non-intersecting visible nodes + merged_instances = [] + instance_subsets = chain( + *(combinations(instances, n) for n in range(2, len(instances) + 1)) + ) + for subset in instance_subsets: + if not all_disjoint([s.nodes for s in subset]): + continue + + nodes_points = [] + for instance in subset: + nodes_points.extend(list(instance.nodes_points)) + predicted_points = {node: point for node, point in nodes_points} + + instance_score = reduce(lambda x, y: x * y, [s.score for s in subset]) + + # Penalize scores of merged instances + if 0 < penalty <= 1: + factor = 1 - penalty + instance_score *= factor + for point in predicted_points.values(): + point.score *= factor + + merged_instance = PredictedInstance( + points=predicted_points, + skeleton=skeleton, + score=instance_score, + track=track, + ) + + merged_instances.append(merged_instance) + + return merged_instances + + def make_instance_cattr() -> cattr.Converter: """Create a cattr converter for Lists of Instances/PredictedInstances. diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index b2f35b21f..58b0b6029 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,6 +12,7 @@ """ + import operator from collections import defaultdict import logging @@ -23,6 +24,7 @@ from sleap import PredictedInstance, Instance, Track from sleap.nn import utils +from sleap.instance import create_merged_instances logger = logging.getLogger(__name__) @@ -249,7 +251,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]: # keep looping while some indexes still remain in the indexes list while len(idxs) > 0: - # we want to add the best box which is the last box in sorted list picked_box_idx = idxs[-1] @@ -351,6 +352,8 @@ def cull_frame_instances( instances_list: List[InstanceType], instance_count: int, iou_threshold: Optional[float] = None, + merge_instances: bool = False, + merging_penalty: float = 0.2, ) -> List["LabeledFrame"]: """ Removes instances (for single frame) over instance per frame threshold. @@ -361,6 +364,9 @@ def cull_frame_instances( iou_threshold: Intersection over Union (IOU) threshold to use when removing overlapping instances over target count; if None, then only use score to determine which instances to remove. + merge_instances: If True, allow merging instances with no overlapping + merging_penalty: a float between 0 and 1. All scores of the merged + instance are multplied by (1 - penalty). Returns: Updated list of frames, also modifies frames in place. @@ -368,6 +374,13 @@ def cull_frame_instances( if not instances_list: return + # Merge instances + if merge_instances: + merged_instances = create_merged_instances( + instances_list, penalty=merging_penalty + ) + instances_list.extend(merged_instances) + if len(instances_list) > instance_count: # List of instances which we'll pare down keep_instances = instances_list @@ -387,9 +400,10 @@ def cull_frame_instances( if len(keep_instances) > instance_count: # Sort by ascending score, get target number of instances # from the end of list (i.e., with highest score) - extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[ - :-instance_count - ] + extra_instances = sorted( + keep_instances, + key=operator.attrgetter("score"), + )[:-instance_count] # Remove the extra instances for inst in extra_instances: @@ -523,7 +537,6 @@ def from_candidate_instances( candidate_tracks = [] if candidate_instances: - # Group candidate instances by track. candidate_instances_by_track = defaultdict(list) for instance in candidate_instances: @@ -536,7 +549,6 @@ def from_candidate_instances( matching_similarities = np.full(dims, np.nan) for i, untracked_instance in enumerate(untracked_instances): - for j, candidate_track in enumerate(candidate_tracks): # Compute similarity between untracked instance and all track # candidates. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index d10dca420..1891bde9c 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -933,6 +933,8 @@ def make_tracker_by_name( target_instance_count: int = 0, pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, + pre_cull_merge_instances: bool = False, + pre_cull_merging_penalty: float = 0.2, # Post-tracking options to connect broken tracks post_connect_single_breaks: bool = False, # TODO: deprecate these post-tracking cleaning options @@ -999,13 +1001,15 @@ def make_tracker_by_name( ) pre_cull_function = None - if target_instance_count and pre_cull_to_target: + if (target_instance_count and pre_cull_to_target) or pre_cull_merge_instances: def pre_cull_function(inst_list): cull_frame_instances( inst_list, instance_count=target_instance_count, iou_threshold=pre_cull_iou_threshold, + merge_instances=pre_cull_merge_instances, + merging_penalty=pre_cull_merging_penalty, ) tracker_obj = cls( @@ -1084,6 +1088,22 @@ def get_by_name_factory_options(cls): ) options.append(option) + option = dict(name="pre_cull_merge_instances", default=False) + option["type"] = bool + option["help"] = ( + "If True, allow merging instances with non-overlapping visible nodes " + "to create new instances *before* tracking." + ) + options.append(option) + + option = dict(name="pre_cull_merging_penalty", default=0.2) + option["type"] = float + option["help"] = ( + "A float between 0 and 1. All scores of the merged instances " + "are multplied by (1 - penalty)." + ) + options.append(option) + option = dict(name="post_connect_single_breaks", default=0) option["type"] = int option["help"] = (