Skip to content

Commit

Permalink
coderabbit suggestion for division by zero error, cache_property and …
Browse files Browse the repository at this point in the history
…attr default
  • Loading branch information
getzze committed Oct 25, 2024
1 parent 5ef39f8 commit 347578d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
10 changes: 8 additions & 2 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@

if sys.version_info >= (3, 8):
from functools import cached_property
else: # cached_property is define only for python >=3.8
cached_property = property

else: # cached_property is defined only for python >=3.8
from functools import lru_cache

def cached_property(func):
return property(lru_cache()(func))

import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -164,6 +168,8 @@ class Predictor(ABC):
@cached_property
def report_period(self) -> float:
"""Time between progress reports in seconds."""
if self.report_rate <= 0:
raise ValueError("report_rate must be positive")
return 1.0 / self.report_rate

@classmethod
Expand Down
31 changes: 15 additions & 16 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tracking tools for linking grouped instances over time."""

import abc
import functools
import json
import sys
from collections import deque
Expand Down Expand Up @@ -35,8 +36,12 @@

if sys.version_info >= (3, 8):
from functools import cached_property
else: # cached_property is define only for python >=3.8
cached_property = property

else: # cached_property is defined only for python >=3.8
from functools import lru_cache

def cached_property(func):
return property(lru_cache()(func))


@attr.s(eq=False, slots=True, auto_attribs=True)
Expand Down Expand Up @@ -519,8 +524,12 @@ def get_candidates(
class BaseTracker(abc.ABC):
"""Abstract base class for tracker."""

verbosity: str
report_rate: float
verbosity: str = attr.ib(
validator=attr.validators.in_(["none", "rich", "json"]),
default="none",
kw_only=True,
)
report_rate: float = attr.ib(default=2.0, kw_only=True)

@property
def is_valid(self):
Expand All @@ -529,6 +538,8 @@ def is_valid(self):
@cached_property
def report_period(self) -> float:
"""Time between progress reports in seconds."""
if self.report_rate <= 0:
raise ValueError("report_rate must be positive")
return 1.0 / self.report_rate

def run_step(self, lf: LabeledFrame) -> LabeledFrame:
Expand Down Expand Up @@ -751,12 +762,6 @@ class Tracker(BaseTracker):

last_matches: Optional[FrameMatches] = None

verbosity: str = attr.ib(
validator=attr.validators.in_(["none", "rich", "json"]),
default="none",
)
report_rate: float = 2.0

@property
def is_valid(self):
return self.similarity_function is not None
Expand Down Expand Up @@ -1493,12 +1498,6 @@ class KalmanTracker(BaseTracker):
last_t: int = 0
last_init_t: int = 0

verbosity: str = attr.ib(
validator=attr.validators.in_(["none", "rich", "json"]),
default="none",
)
report_rate: float = 2.0

@property
def is_valid(self):
"""Do we have everything we need to run tracking?"""
Expand Down

0 comments on commit 347578d

Please sign in to comment.