diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 57f53ad45..4e969c16b 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -46,6 +46,11 @@ from threading import Thread from queue import Queue +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property + import tensorflow as tf import numpy as np @@ -54,7 +59,7 @@ from sleap.nn.config import TrainingJobConfig, DataConfig from sleap.nn.data.resizing import SizeMatcher from sleap.nn.model import Model -from sleap.nn.tracking import Tracker, run_tracker +from sleap.nn.tracking import Tracker from sleap.nn.paf_grouping import PAFScorer from sleap.nn.data.pipelines import ( Provider, @@ -69,7 +74,7 @@ ) from sleap.nn.utils import reset_input_layer from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary +from sleap.util import frame_list, make_scoped_dictionary, RateColumn from sleap.instance import PredictedInstance, LabeledFrame from tensorflow.python.framework.convert_to_constants import ( @@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str: return os.path.join(path, "best_model.h5") -class RateColumn(rich.progress.ProgressColumn): - """Renders the progress rate.""" - - def render(self, task: "Task") -> rich.progress.Text: - """Show progress rate.""" - speed = task.speed - if speed is None: - return rich.progress.Text("?", style="progress.data.speed") - return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") - - @attr.s(auto_attribs=True) class Predictor(ABC): """Base interface class for predictors.""" @@ -167,7 +161,7 @@ class Predictor(ABC): report_rate: float = attr.ib(default=2.0, kw_only=True) model_paths: List[str] = attr.ib(factory=list, kw_only=True) - @property + @cached_property def report_period(self) -> float: """Time between progress reports in seconds.""" return 1.0 / self.report_rate @@ -5487,7 +5481,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]: """ policy_args = make_scoped_dictionary(vars(args), exclude_nones=True) if "tracking" in policy_args: - tracker = Tracker.make_tracker_by_name(**policy_args["tracking"]) + tracker = Tracker.make_tracker_by_name( + progress_reporting=args.verbosity, + **policy_args["tracking"], + ) return tracker return None diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..6d7c423b5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1,12 +1,16 @@ """Tracking tools for linking grouped instances over time.""" -from collections import deque, defaultdict import abc +import json +import sys +from collections import deque +from time import time +from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple + import attr -import numpy as np import cv2 -import functools -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +import numpy as np +import rich.progress from sleap import Track, LabeledFrame, Skeleton @@ -26,8 +30,13 @@ Match, ) from sleap.nn.tracker.kalman import BareKalmanTracker - from sleap.nn.data.normalization import ensure_int +from sleap.util import RateColumn + +if sys.version_info >= (3, 8): + from functools import cached_property +else: # cached_property is define only for python >=3.8 + cached_property = property @attr.s(eq=False, slots=True, auto_attribs=True) @@ -66,7 +75,6 @@ def from_instance( shift_score: float = 0.0, with_skeleton: bool = False, ): - points_array = new_points_array if points_array is None: points_array = ref_instance.points_array @@ -511,10 +519,142 @@ def get_candidates( class BaseTracker(abc.ABC): """Abstract base class for tracker.""" + verbosity: str + report_rate: float + @property def is_valid(self): return False + @cached_property + def report_period(self) -> float: + """Time between progress reports in seconds.""" + return 1.0 / self.report_rate + + def run_step(self, lf: LabeledFrame) -> LabeledFrame: + # Clear the tracks + for inst in lf.instances: + inst.track = None + + track_args = dict(untracked_instances=lf.instances, t=lf.frame_idx) + if self.uses_image: + track_args["img"] = lf.video[lf.frame_idx] + else: + track_args["img"] = None + track_args["img_hw"] = lf.image.shape[-3:-1] + + return LabeledFrame( + frame_idx=lf.frame_idx, + video=lf.video, + instances=self.track(**track_args), + ) + + def run_tracker( + self, + frames: List[LabeledFrame], + *, + verbosity: Optional[str] = None, + final_pass: bool = True, + ) -> List[LabeledFrame]: + """Run the tracker on a set of labeled frames. + + Args: + frames: A list of labeled frames with instances. + + Returns: + The input frames with the new tracks assigned. If the frames already had tracks, + they will be cleared if the tracker has been re-initialized. + """ + # Return original frames if we aren't retracking + if not self.is_valid: + return frames + + verbosity = verbosity or self.verbosity + new_lfs = [] + + # Run tracking on every frame + if verbosity == "rich": + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + progress.refresh() + + elif verbosity == "json": + n_total = len(frames) + n_processed = 0 + n_batch = 0 + elapsed_all = 0 + n_recent = deque(maxlen=30) + elapsed_recent = deque(maxlen=30) + last_report = time() + t0_all = time() + t0_batch = time() + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + # Track timing and progress. + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report. + elapsed_since_last_report = time() - last_report + if elapsed_since_last_report > self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate. + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + else: + for lf in frames: + new_lf = self.run_step(lf) + new_lfs.append(new_lf) + + # Run final_pass + if final_pass: + self.final_pass(new_lfs) + + return new_lfs + @abc.abstractmethod def track( self, @@ -564,6 +704,12 @@ class Tracker(BaseTracker): use the max similarity (non-robust). For selecting a robust score, 0.95 is a good value. max_tracking: Max tracking is incorporated when this is set to true. + verbosity: Mode of inference progress reporting. If `"rich"` (the + default), an updating progress bar is displayed in the console or notebook. + If `"json"`, a JSON-serialized message is printed out which can be captured + for programmatic progress monitoring. If `"none"`, nothing is displayed + during tracking -- this is recommended when running on clusters or headless + machines where the output is captured to a log file. """ max_tracks: int = None @@ -596,6 +742,12 @@ class Tracker(BaseTracker): last_matches: Optional[FrameMatches] = None + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): return self.similarity_function is not None @@ -670,7 +822,6 @@ def track( if t is None: if self.has_max_tracking: if len(self.track_matching_queue_dict) > 0: - # Default to last timestep + 1 if available. # Here we find the track that has the most instances. track_with_max_instances = max( @@ -686,7 +837,6 @@ def track( t = 0 else: if len(self.track_matching_queue) > 0: - # Default to last timestep + 1 if available. t = self.track_matching_queue[-1].t + 1 @@ -701,7 +851,6 @@ def track( # Process untracked instances. if untracked_instances: - if self.pre_cull_function: self.pre_cull_function(untracked_instances) @@ -791,7 +940,6 @@ def spawn_for_untracked_instances( ) -> List[InstanceType]: results = [] for inst in unmatched_instances: - # Skip if this instance is too small to spawn a new track with. if inst.n_visible_points < self.min_new_track_points: continue @@ -868,6 +1016,8 @@ def make_tracker_by_name( oks_errors: Optional[list] = None, oks_score_weighting: bool = False, oks_normalization: str = "all", + progress_reporting: str = "rich", + report_rate: float = 2.0, **kwargs, ) -> BaseTracker: # Parse max_tracking arguments, only True if max_tracks is not None and > 0 @@ -942,6 +1092,8 @@ def pre_cull_function(inst_list): max_tracks=max_tracks, target_instance_count=target_instance_count, post_connect_single_breaks=post_connect_single_breaks, + verbosity=progress_reporting, + report_rate=report_rate, ) if target_instance_count and kf_init_frame_count: @@ -961,7 +1113,6 @@ def pre_cull_function(inst_list): @classmethod def get_by_name_factory_options(cls): - options = [] option = dict(name="tracker", default="None") @@ -1230,7 +1381,6 @@ def add_frame_instances( # "usuable" instances—i.e., instances with the nodes that we'll track # using Kalman filters. elif frame_match.has_only_first_choice_matches: - good_instances = [ inst for inst in instances if self.is_usable_instance(inst) ] @@ -1321,6 +1471,12 @@ class KalmanTracker(BaseTracker): last_t: int = 0 last_init_t: int = 0 + verbosity: str = attr.ib( + validator=attr.validators.in_(["none", "rich", "json"]), + default="none", + ) + report_rate: float = 2.0 + @property def is_valid(self): """Do we have everything we need to run tracking?""" @@ -1444,7 +1600,6 @@ def track( # Check whether we've been getting good results from the Kalman filters. # First, has it been a while since the filters were initialized? if self.init_done and (t - self.last_init_t) > self.re_init_cooldown: - # If it's been a while, then see if it's also been a while since # the filters successfully matched tracks to the instances. if self.kalman_tracker.last_frame_with_tracks < t - self.re_init_after: @@ -1501,47 +1656,6 @@ def run(self, frames: List[LabeledFrame]): connect_single_track_breaks(frames, self.instance_count) -def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[LabeledFrame]: - """Run a tracker on a set of labeled frames. - - Args: - frames: A list of labeled frames with instances. - tracker: An initialized Tracker. - - Returns: - The input frames with the new tracks assigned. If the frames already had tracks, - they will be cleared if the tracker has been re-initialized. - """ - # Return original frames if we aren't retracking - if not tracker.is_valid: - return frames - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if tracker.uses_image: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - track_args["img_hw"] = lf.image.shape[-3:-1] - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args), - ) - new_lfs.append(new_lf) - - return new_lfs - - def retrack(): import argparse import operator @@ -1579,8 +1693,7 @@ def retrack(): print(f"Done loading predictions in {time.time() - t0} seconds.") print("Starting tracker...") - frames = run_tracker(frames=frames, tracker=tracker) - tracker.final_pass(frames) + frames = tracker.run_tracker(frames=frames) new_labels = Labels(labeled_frames=frames) diff --git a/sleap/util.py b/sleap/util.py index bc3389b7d..b92606868 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -1,32 +1,51 @@ -"""A miscellaneous set of utility functions. +"""A miscellaneous set of utility functions. Try not to put things in here unless they really have no other place. """ +from __future__ import annotations + +import base64 import json import os import re import shutil from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Hashable, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse from urllib.request import url2pathname +try: + from importlib.resources import files # New in 3.9+ +except ImportError: + from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. + import attr import h5py as h5 import numpy as np import psutil import rapidjson +import rich.progress import yaml - -try: - from importlib.resources import files # New in 3.9+ -except ImportError: - from importlib_resources import files # TODO(LM): Upgrade to importlib.resources. +from PIL import Image import sleap.version as sleap_version +if TYPE_CHECKING: + from rich.progress import Task + + +class RateColumn(rich.progress.ProgressColumn): + """Renders the progress rate.""" + + def render(self, task: Task) -> rich.progress.Text: + """Show progress rate.""" + speed = task.speed + if speed is None: + return rich.progress.Text("?", style="progress.data.speed") + return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed") + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index 0c7ba2b0a..9d3b65b38 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -15,23 +15,21 @@ from sleap.skeleton import Skeleton -def tracker_by_name(frames=None, **kwargs): - t = Tracker.make_tracker_by_name(**kwargs) - print(kwargs) - print(t.candidate_maker) - if frames is None: - t.track([]) - t.final_pass([]) - return - - for lf in frames: - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) - t.track(**track_args, img_hw=(1, 1)) - t.final_pass(frames) +def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): + # Create tracker + t = Tracker.make_tracker_by_name(verbosity="none", **kwargs) + # Update img_scale + if img_scale: + if hasattr(t, "candidate_maker") and hasattr(t.candidate_maker, "img_scale"): + t.candidate_maker.img_scale = img_scale + else: + # Do not even run tracking as it can be slow + pytest.skip("img_scale is not defined for this tracker") + return + + # Run tracking + new_frames = t.run_tracker(frames or []) + assert len(new_frames) == len(frames) @pytest.mark.parametrize( @@ -42,22 +40,25 @@ def tracker_by_name(frames=None, **kwargs): ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], ) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) +@pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name( centered_pair_predictions_sorted, tracker, similarity, match, + img_scale, count, ): # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity=similarity, match=match, + img_scale=img_scale, max_tracks=count, ) @@ -76,7 +77,7 @@ def test_oks_tracker_by_name( # This is slow, so limit to 5 time points frames = centered_pair_predictions_sorted[:5] - tracker_by_name( + run_tracker_by_name( frames=frames, tracker=tracker, similarity="object_keypoint", diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..c7c25476d 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -19,7 +19,7 @@ def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): inference_cli(cli.split(" ")) labels = sleap.load_file(f"{tmpdir}/simpletracks.slp") - assert len(labels.tracks) == 27 + assert len(labels.tracks) == 8 def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): @@ -37,18 +37,19 @@ def test_simplemax_tracker(tmpdir, centered_pair_predictions_slp_path): # TODO: Refactor the below things into a real test suite. +# running an equivalent to `make_ground_truth` is done as a test in tests/nn/test_tracker_components.py def make_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{len(tracker.spawned_tracks)}\t{time.time()-t0}") Labels.save_file(new_labels, gt_filename) def compare_ground_truth(frames, tracker, gt_filename): t0 = time.time() - new_labels = run_tracker(frames, tracker) + new_labels = tracker.run_tracker(frames, verbosity="none") print(f"{gt_filename}\t{time.time() - t0}") does_match = check_tracks(new_labels, gt_filename) @@ -78,43 +79,6 @@ def check_tracks(labels, gt_filename, limit=None): return True -def run_tracker(frames, tracker): - sig = inspect.signature(tracker.track) - takes_img = "img" in sig.parameters - - # t0 = time.time() - - new_lfs = [] - - # Run tracking on every frame - for lf in frames: - - # Clear the tracks - for inst in lf.instances: - inst.track = None - - track_args = dict(untracked_instances=lf.instances) - if takes_img: - track_args["img"] = lf.video[lf.frame_idx] - else: - track_args["img"] = None - - new_lf = LabeledFrame( - frame_idx=lf.frame_idx, - video=lf.video, - instances=tracker.track(**track_args, img_hw=lf.image.shape[-3:-1]), - ) - new_lfs.append(new_lf) - - # if lf.frame_idx % 100 == 0: print(lf.frame_idx, time.time()-t0) - - # print(time.time() - t0) - - new_labels = Labels() - new_labels.extend(new_lfs) - return new_labels - - def main(f, dir): filename = "tests/data/json_format_v2/centered_pair_predictions.json" @@ -166,7 +130,10 @@ def make_tracker( return tracker def make_filename(tracker_name, matcher_name, sim_name, scale=0): - return f"{dir}{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5" + return os.path.join( + dir, + f"{tracker_name}_{int(scale * 100)}_{matcher_name}_{sim_name}.h5", + ) def make_tracker_and_filename(*args, **kwargs): tracker = make_tracker(*args, **kwargs) @@ -180,7 +147,6 @@ def make_tracker_and_filename(*args, **kwargs): for tracker_name in trackers.keys(): for matcher_name in matchers.keys(): for sim_name in similarities.keys(): - if tracker_name == "flow": # If this tracker supports scale, try multiple scales for scale in scales: