Skip to content

Commit

Permalink
refactor tracking progress
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 11, 2024
1 parent 88699b1 commit 43b54e9
Showing 1 changed file with 80 additions and 72 deletions.
152 changes: 80 additions & 72 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from collections import deque
from time import time
from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Deque, Dict, Iterable, Iterator, List, Optional, Tuple

import attr
import cv2
Expand Down Expand Up @@ -546,6 +546,82 @@ def run_step(self, lf: LabeledFrame) -> LabeledFrame:
instances=self.track(**track_args),
)

def _run_tracker_json(
self,
frames: List[LabeledFrame],
max_length: int = 30,
) -> Iterator[LabeledFrame]:
n_total = len(frames)
n_processed = 0
n_batch = 0
n_recent = deque(maxlen=max_length)
elapsed_recent = deque(maxlen=max_length)
last_report = time()
t0_all = time()
t0_batch = time()

for lf in frames:
new_lf = self.run_step(lf)

# Track timing and progress
elapsed_all = time() - t0_all
n_processed += 1
n_batch += 1

# Report
if time() > last_report + self.report_period:
elapsed_batch = time() - t0_batch
t0_batch = time()

# Compute recent rate
n_recent.append(n_batch)
n_batch = 0
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()

yield new_lf

def _run_tracker_rich(self, frames: List[LabeledFrame]) -> Iterator[LabeledFrame]:
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Tracking...", total=len(frames))
last_report = time()
for lf in frames:
new_lf = self.run_step(lf)

progress.update(task, advance=1)

# Handle refreshing manually to support notebooks.
if time() > last_report + self.report_period:
progress.refresh()
last_report = time()

yield new_lf

def run_tracker(
self,
frames: List[LabeledFrame],
Expand All @@ -567,84 +643,16 @@ def run_tracker(
return frames

verbosity = verbosity or self.verbosity
new_lfs = []

# Run tracking on every frame
if verbosity == "rich":
with rich.progress.Progress(
"{task.description}",
rich.progress.BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
"ETA:",
rich.progress.TimeRemainingColumn(),
RateColumn(),
auto_refresh=False,
refresh_per_second=self.report_rate,
speed_estimate_period=5,
) as progress:
task = progress.add_task("Tracking...", total=len(frames))
last_report = time()
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)

progress.update(task, advance=1)

# Handle refreshing manually to support notebooks.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
progress.refresh()
new_lfs = list(self._run_tracker_rich(frames))

elif verbosity == "json":
n_total = len(frames)
n_processed = 0
n_batch = 0
elapsed_all = 0
n_recent = deque(maxlen=30)
elapsed_recent = deque(maxlen=30)
last_report = time()
t0_all = time()
t0_batch = time()
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)

# Track timing and progress.
elapsed_all = time() - t0_all
n_processed += 1
n_batch += 1

# Report.
elapsed_since_last_report = time() - last_report
if elapsed_since_last_report > self.report_period:
elapsed_batch = time() - t0_batch
t0_batch = time()

# Compute recent rate.
n_recent.append(n_batch)
n_batch = 0
elapsed_recent.append(elapsed_batch)
rate = sum(n_recent) / sum(elapsed_recent)
eta = (n_total - n_processed) / rate

print(
json.dumps(
{
"n_processed": n_processed,
"n_total": n_total,
"elapsed": elapsed_all,
"rate": rate,
"eta": eta,
}
),
flush=True,
)
last_report = time()
new_lfs = list(self._run_tracker_json(frames))

else:
for lf in frames:
new_lf = self.run_step(lf)
new_lfs.append(new_lf)
new_lfs = list(self.run_step(lf) for lf in frames)

# Run final_pass
if final_pass:
Expand Down

0 comments on commit 43b54e9

Please sign in to comment.