Skip to content

Commit

Permalink
add tracking progress reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Oct 25, 2024
1 parent 0a8d5d2 commit 2b26078
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 139 deletions.
27 changes: 12 additions & 15 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
from threading import Thread
from queue import Queue

if sys.version_info >= (3, 8):
from functools import cached_property
else: # cached_property is define only for python >=3.8
cached_property = property

import tensorflow as tf
import numpy as np

Expand All @@ -54,7 +59,7 @@
from sleap.nn.config import TrainingJobConfig, DataConfig
from sleap.nn.data.resizing import SizeMatcher
from sleap.nn.model import Model
from sleap.nn.tracking import Tracker, run_tracker
from sleap.nn.tracking import Tracker
from sleap.nn.paf_grouping import PAFScorer
from sleap.nn.data.pipelines import (
Provider,
Expand All @@ -69,7 +74,7 @@
)
from sleap.nn.utils import reset_input_layer
from sleap.io.dataset import Labels
from sleap.util import frame_list, make_scoped_dictionary
from sleap.util import frame_list, make_scoped_dictionary, RateColumn
from sleap.instance import PredictedInstance, LabeledFrame

from tensorflow.python.framework.convert_to_constants import (
Expand Down Expand Up @@ -144,17 +149,6 @@ def get_keras_model_path(path: Text) -> str:
return os.path.join(path, "best_model.h5")


class RateColumn(rich.progress.ProgressColumn):
"""Renders the progress rate."""

def render(self, task: "Task") -> rich.progress.Text:
"""Show progress rate."""
speed = task.speed
if speed is None:
return rich.progress.Text("?", style="progress.data.speed")
return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")


@attr.s(auto_attribs=True)
class Predictor(ABC):
"""Base interface class for predictors."""
Expand All @@ -167,7 +161,7 @@ class Predictor(ABC):
report_rate: float = attr.ib(default=2.0, kw_only=True)
model_paths: List[str] = attr.ib(factory=list, kw_only=True)

@property
@cached_property
def report_period(self) -> float:
"""Time between progress reports in seconds."""
return 1.0 / self.report_rate
Expand Down Expand Up @@ -5487,7 +5481,10 @@ def _make_tracker_from_cli(args: argparse.Namespace) -> Optional[Tracker]:
"""
policy_args = make_scoped_dictionary(vars(args), exclude_nones=True)
if "tracking" in policy_args:
tracker = Tracker.make_tracker_by_name(**policy_args["tracking"])
tracker = Tracker.make_tracker_by_name(
progress_reporting=args.verbosity,
**policy_args["tracking"],
)
return tracker
return None

Expand Down
Loading

0 comments on commit 2b26078

Please sign in to comment.