From c656bb8180e8d55483fad07c68fe0d71959de92d Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 2 Sep 2024 16:57:13 +0100 Subject: [PATCH] remove max_tracking argument, just use max_tracks --- docs/guides/cli.md | 10 ++++------ sleap/config/pipeline_form.yaml | 4 ---- sleap/gui/learning/runners.py | 6 ------ sleap/nn/inference.py | 4 ---- sleap/nn/tracking.py | 16 ++++++---------- tests/nn/test_inference.py | 3 --- tests/nn/test_tracker_components.py | 18 ++++++------------ tests/nn/test_tracking_integration.py | 7 ++----- 8 files changed, 18 insertions(+), 50 deletions(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 1bc666fc1..b8b8762df 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -124,7 +124,7 @@ usage: sleap-track [-h] [-m MODELS] [--frames FRAMES] [--only-labeled-frames] [- [--verbosity {none,rich,json}] [--video.dataset VIDEO.DATASET] [--video.input_format VIDEO.INPUT_FORMAT] [--video.index VIDEO.INDEX] [--cpu | --first-gpu | --last-gpu | --gpu GPU] [--max_edge_length_ratio MAX_EDGE_LENGTH_RATIO] [--dist_penalty_weight DIST_PENALTY_WEIGHT] [--batch_size BATCH_SIZE] [--open-in-gui] [--peak_threshold PEAK_THRESHOLD] - [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracking TRACKING.MAX_TRACKING] + [-n MAX_INSTANCES] [--tracking.tracker TRACKING.TRACKER] [--tracking.max_tracks TRACKING.MAX_TRACKS] [--tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT] [--tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET] [--tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD] [--tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS] @@ -184,10 +184,8 @@ optional arguments: Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. --tracking.tracker TRACKING.TRACKER Options: simple, flow, simplemaxtracks, flowmaxtracks, None (default: None) - --tracking.max_tracking TRACKING.MAX_TRACKING - If true then the tracker will cap the max number of tracks. (default: False) --tracking.max_tracks TRACKING.MAX_TRACKS - Maximum number of tracks to be tracked by the tracker. (default: None) + Maximum number of tracks to be tracked by the tracker. No limit if None or -1. (default: None) --tracking.target_instance_count TRACKING.TARGET_INSTANCE_COUNT Target number of instances to track per frame. (default: 0) --tracking.pre_cull_to_target TRACKING.PRE_CULL_TO_TARGET @@ -261,13 +259,13 @@ sleap-track -m "models/my_model" --tracking.tracker simple -o "output_prediction **5. Inference with max tracks limit:** ```none -sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracking 1 --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" +sleap-track -m "models/my_model" --tracking.tracker simple --tracking.max_tracks 4 -o "output_predictions.slp" "input_video.mp4" ``` **6. Re-tracking without pose inference:** ```none -sleap-track --tracking.tracker simple --tracking.max_tracking 1 --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" +sleap-track --tracking.tracker simple --tracking.max_tracks 4 -o "retracked.slp" "input_predictions.slp" ``` **7. Select GPU for pose inference:** diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index c34faea55..406a02ea8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -521,10 +521,6 @@ inference: text: 'Tracking:
This tracker assigns track identities by matching instances from prior frames to instances on subsequent frames.' - # - name: tracking.max_tracking - # label: Limit max number of tracks - # type: bool - # default: false - name: tracking.max_tracks label: Max number of tracks type: optional_int diff --git a/sleap/gui/learning/runners.py b/sleap/gui/learning/runners.py index 3f78a2924..fb2d799e0 100644 --- a/sleap/gui/learning/runners.py +++ b/sleap/gui/learning/runners.py @@ -244,11 +244,6 @@ def make_predict_cli_call( if self.inference_params["tracking.tracker"] in compat_trackers: tname = self.inference_params["tracking.tracker"][: -len("maxtracks")] self.inference_params["tracking.tracker"] = tname - self.inference_params["tracking.max_tracking"] = True - - # Setting max_tracks to a value means we want to use the max_tracking mode. - if self.inference_params.get("tracking.max_tracks") is not None: - self.inference_params["tracking.max_tracking"] = True # --tracking.kf_init_frame_count enables the kalman filter tracking # so if not set, then remove other (unused) args @@ -259,7 +254,6 @@ def make_predict_cli_call( bool_items_as_ints = ( "tracking.pre_cull_to_target", "tracking.pre_cull_merge_instances", - "tracking.max_tracking", "tracking.post_connect_single_breaks", "tracking.save_shifted_instances", "tracking.oks_score_weighting", diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 75d2a1ae2..dc7d19f69 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4914,14 +4914,10 @@ def unpack_sleap_model(model_path): ) predictor.verbosity = progress_reporting if tracker is not None: - use_max_tracker = ( - tracker_max_instances is not None and tracker_max_instances > 0 - ) predictor.tracker = Tracker.make_tracker_by_name( tracker=tracker, track_window=tracker_window, post_connect_single_breaks=True, - max_tracking=use_max_tracker, max_tracks=tracker_max_instances, # clean_instance_count=tracker_max_instances, ) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 1891bde9c..7ecf8fbcb 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -945,7 +945,6 @@ def make_tracker_by_name( kf_node_indices: Optional[list] = None, # Max tracking options max_tracks: Optional[int] = None, - max_tracking: bool = False, prefer_reassigning_track: bool = False, allow_reassigning_track: bool = False, # Object keypoint similarity options @@ -956,9 +955,8 @@ def make_tracker_by_name( report_rate: float = 2.0, **kwargs, ) -> BaseTracker: - # Parse max_tracking arguments, only True if max_tracks is not None and > 0 - max_tracking = max_tracking and max_tracks is not None and max_tracks > 0 - max_tracks = max_tracks if max_tracking else -1 + # Parse max_tracks, set to -1 if None + max_tracks = max_tracks if max_tracks is not None and max_tracks >= 0 else -1 if tracker.lower() == "none": candidate_maker = None @@ -1056,14 +1054,12 @@ def get_by_name_factory_options(cls): ] options.append(option) - option = dict(name="max_tracking", default=False) - option["type"] = bool - option["help"] = "If true then the tracker will cap the max number of tracks." - options.append(option) - option = dict(name="max_tracks", default=None) option["type"] = int - option["help"] = "Maximum number of tracks to be tracked by the tracker." + option["help"] = ( + "Maximum number of tracks to be tracked by the tracker. " + "No maximum if set to -1." + ) options.append(option) option = dict(name="target_instance_count", default=0) diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index 8e24145bf..e474071f3 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1377,7 +1377,6 @@ def test_retracking( if tracker_method == "flow": cmd += " --tracking.save_shifted_instances 1" elif tracker_method == "simplemaxtracks" or tracker_method == "flowmaxtracks": - cmd += " --tracking.max_tracking 1" cmd += " --tracking.max_tracks 2" if output_path == "not_default": output_path = Path(tmpdir, "tracked_slp.slp") @@ -1790,7 +1789,6 @@ def test_max_tracks_matching_queue( ): """Test flow max tracks instance generation.""" labels: Labels = centered_pair_predictions - max_tracking = True track_window = 5 # Setup flow max tracker @@ -1798,7 +1796,6 @@ def test_max_tracks_matching_queue( tracker=trackername, track_window=track_window, save_shifted_instances=True, - max_tracking=max_tracking, max_tracks=max_tracks, ) diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 787f7747d..7eb522791 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -245,7 +245,7 @@ def make_inst(x, y): return insts -def test_max_tracking_large_gap_single_track(): +def test_max_tracks_large_gap_single_track(): # Track 2 instances with gap > window size preds = make_insts( [ @@ -282,8 +282,7 @@ def test_max_tracking_large_gap_single_track(): tracker="simple", match="hungarian", track_window=2, - # max_tracks=2, - max_tracking=False, + max_tracks=-1, ) tracked = [] @@ -299,7 +298,6 @@ def test_max_tracking_large_gap_single_track(): match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -311,7 +309,7 @@ def test_max_tracking_large_gap_single_track(): assert len(all_tracks) == 2 -def test_max_tracking_small_gap_on_both_tracks(): +def test_max_tracks_small_gap_on_both_tracks(): # Test 2 instances with both tracks with gap > window size preds = make_insts( [ @@ -344,8 +342,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracker="simple", match="hungarian", track_window=2, - # max_tracks=2, - max_tracking=False, + max_tracks=-1, ) tracked = [] @@ -361,7 +358,6 @@ def test_max_tracking_small_gap_on_both_tracks(): match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] @@ -373,7 +369,7 @@ def test_max_tracking_small_gap_on_both_tracks(): assert len(all_tracks) == 2 -def test_max_tracking_extra_detections(): +def test_max_tracks_extra_detections(): # Test having more than 2 detected instances in a frame preds = make_insts( [ @@ -411,8 +407,7 @@ def test_max_tracking_extra_detections(): tracker="simple", match="hungarian", track_window=2, - # max_tracks=2, - max_tracking=False, + max_tracks=-1, ) tracked = [] @@ -428,7 +423,6 @@ def test_max_tracking_extra_detections(): match="hungarian", track_window=2, max_tracks=2, - max_tracking=True, ) tracked = [] diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 0509557f0..9335900d8 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -25,7 +25,7 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): def test_simple_max_tracks(tmpdir, centered_pair_predictions_slp_path): cli = ( "--tracking.tracker simple " - "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--tracking.max_tracks 2 " "--frames 200-300 " f"-o {tmpdir}/simplemaxtracks.slp " f"{centered_pair_predictions_slp_path}" @@ -107,13 +107,12 @@ def main(f, dir): ) def make_tracker( - tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 + tracker_name, matcher_name, sim_name, max_tracks, scale=0 ): tracker = trackers[tracker_name]( matching_function=matchers[matcher_name], similarity_function=similarities[sim_name], max_tracks=max_tracks, - max_tracking=max_tracking, ) if scale: tracker.candidate_maker.img_scale = scale @@ -142,7 +141,6 @@ def make_tracker_and_filename(*args, **kwargs): tracker_name=tracker_name, matcher_name=matcher_name, max_tracks=2, - max_tracking=True, sim_name=sim_name, scale=scale, ) @@ -152,7 +150,6 @@ def make_tracker_and_filename(*args, **kwargs): tracker_name=tracker_name, matcher_name=matcher_name, max_tracks=2, - max_tracking=True, sim_name=sim_name, scale=0, )