diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b6aae344b..2c1dee615 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -5,7 +5,7 @@ import sys from collections import deque from time import time -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple import attr import cv2 @@ -546,6 +546,82 @@ def run_step(self, lf: LabeledFrame) -> LabeledFrame: instances=self.track(**track_args), ) + def _run_tracker_json( + self, + frames: List[LabeledFrame], + max_length: int = 30, + ) -> Iterator[LabeledFrame]: + n_total = len(frames) + n_processed = 0 + n_batch = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + + for lf in frames: + new_lf = self.run_step(lf) + + # Track timing and progress + elapsed_all = time() - t0_all + n_processed += 1 + n_batch += 1 + + # Report + if time() > last_report + self.report_period: + elapsed_batch = time() - t0_batch + t0_batch = time() + + # Compute recent rate + n_recent.append(n_batch) + n_batch = 0 + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta + } + ), + flush=True, + ) + last_report = time() + + yield new_lf + + def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Tracking...", total=len(frames)) + last_report = time() + for lf in frames: + new_lf = self.run_step(lf) + + progress.update(task, advance=1) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + yield new_lf + def run_tracker( self, frames: List[LabeledFrame], @@ -567,84 +643,16 @@ def run_tracker( return frames verbosity = verbosity or self.verbosity - new_lfs = [] # Run tracking on every frame if verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Tracking...", total=len(frames)) - last_report = time() - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) - - progress.update(task, advance=1) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() + new_lfs = list(self._run_tracker_rich(frames)) elif verbosity == "json": - n_total = len(frames) - n_processed = 0 - n_batch = 0 - elapsed_all = 0 - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) - - # Track timing and progress. - elapsed_all = time() - t0_all - n_processed += 1 - n_batch += 1 - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - elapsed_batch = time() - t0_batch - t0_batch = time() - - # Compute recent rate. - n_recent.append(n_batch) - n_batch = 0 - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() + new_lfs = list(self._run_tracker_json(frames)) else: - for lf in frames: - new_lf = self.run_step(lf) - new_lfs.append(new_lf) + new_lfs = list(self.run_step(lf) for lf in frames) # Run final_pass if final_pass: