Skip to content

Commit

Permalink
undo coderabbit suggestion for cached_property
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Oct 25, 2024
1 parent 347578d commit b61af69
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 52 deletions.
92 changes: 44 additions & 48 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,69 +21,67 @@
function which provides a simplified interface for creating `Predictor`s.
"""

import attr
import argparse
import atexit
import json
import logging
import warnings
import os
import sys
import tempfile
import platform
import shutil
import atexit
import subprocess
import rich.progress
import pandas as pd
from rich.pretty import pprint
import sys
import tempfile
import warnings
from abc import ABC, abstractmethod
from collections import deque
import json
from time import time
from datetime import datetime
from pathlib import Path
import tensorflow_hub as hub
from abc import ABC, abstractmethod
from typing import Text, Optional, List, Dict, Union, Iterator, Tuple
from threading import Thread
from queue import Queue
from threading import Thread
from time import time
from typing import Dict, Iterator, List, Optional, Text, Tuple, Union

if sys.version_info >= (3, 8):
from functools import cached_property

else: # cached_property is defined only for python >=3.8
from functools import lru_cache

def cached_property(func):
return property(lru_cache()(func))
cached_property = property

import tensorflow as tf
import attr
import numpy as np
import pandas as pd
import rich.progress
import tensorflow as tf
import tensorflow_hub as hub
from rich.pretty import pprint
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)

import sleap

from sleap.nn.config import TrainingJobConfig, DataConfig
from sleap.nn.data.resizing import SizeMatcher
from sleap.nn.model import Model
from sleap.nn.tracking import Tracker
from sleap.nn.paf_grouping import PAFScorer
from sleap.instance import LabeledFrame, PredictedInstance
from sleap.io.dataset import Labels
from sleap.nn.config import DataConfig, TrainingJobConfig
from sleap.nn.data.pipelines import (
Provider,
Pipeline,
Batcher,
InstanceCentroidFinder,
KerasModelPredictor,
LabelsReader,
VideoReader,
Normalizer,
Resizer,
Pipeline,
Prefetcher,
InstanceCentroidFinder,
KerasModelPredictor,
Provider,
Resizer,
VideoReader,
)
from sleap.nn.data.resizing import SizeMatcher
from sleap.nn.model import Model
from sleap.nn.paf_grouping import PAFScorer
from sleap.nn.tracking import Tracker
from sleap.nn.utils import reset_input_layer
from sleap.io.dataset import Labels
from sleap.util import frame_list, make_scoped_dictionary, RateColumn
from sleap.instance import PredictedInstance, LabeledFrame
from sleap.util import RateColumn, frame_list, make_scoped_dictionary

from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2,
)
logger = logging.getLogger(__name__)

MOVENET_MODELS = {
"lightning": {
Expand Down Expand Up @@ -135,8 +133,6 @@ def cached_property(func):
],
)

logger = logging.getLogger(__name__)


def get_keras_model_path(path: Text) -> str:
"""Utility method for finding the path to a saved Keras model.
Expand Down Expand Up @@ -169,7 +165,8 @@ class Predictor(ABC):
def report_period(self) -> float:
"""Time between progress reports in seconds."""
if self.report_rate <= 0:
raise ValueError("report_rate must be positive")
logger.warning("report_rate must be positive, fallback to 1")
return 1.0
return 1.0 / self.report_rate

@classmethod
Expand Down Expand Up @@ -360,7 +357,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
ensure_rgb=(not self.is_grayscale),
)

pipeline += sleap.nn.data.pipelines.Batcher(
pipeline += Batcher(
batch_size=self.batch_size, drop_remainder=False, unrag=False
)

Expand Down Expand Up @@ -617,7 +614,7 @@ def export_model(
) + (keras_model_shape[3],)

tracing_batch = np.zeros((1,) + sample_shape, dtype="uint8")
outputs = self.inference_model.predict(tracing_batch)
_ = self.inference_model.predict(tracing_batch)

self.inference_model.export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
Expand Down Expand Up @@ -2570,7 +2567,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
skeletons=self.confmap_config.data.labels.skeletons,
)

pipeline += sleap.nn.data.pipelines.Batcher(
pipeline += Batcher(
batch_size=self.batch_size, drop_remainder=False, unrag=False
)

Expand Down Expand Up @@ -4422,13 +4419,13 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:

if self.centroid_model is None:
anchor_part = self.confmap_config.data.instance_cropping.center_on_part
pipeline += sleap.nn.data.pipelines.InstanceCentroidFinder(
pipeline += InstanceCentroidFinder(
center_on_anchor_part=anchor_part is not None,
anchor_part_names=anchor_part,
skeletons=self.confmap_config.data.labels.skeletons,
)

pipeline += sleap.nn.data.pipelines.Batcher(
pipeline += Batcher(
batch_size=self.batch_size, drop_remainder=False, unrag=False
)

Expand Down Expand Up @@ -4650,7 +4647,7 @@ def __init__(self, model_name="lightning"):
)

def call(self, ex):
if type(ex) == dict:
if isinstance(ex, dict):
img = ex["image"]

else:
Expand Down Expand Up @@ -5496,7 +5493,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
max_instances=args.max_instances,
)

if type(predictor) == BottomUpPredictor:
if isinstance(predictor, BottomUpPredictor):
predictor.inference_model.bottomup_layer.paf_scorer.max_edge_length_ratio = (
args.max_edge_length_ratio
)
Expand Down Expand Up @@ -5608,7 +5605,6 @@ def main(args: Optional[list] = None):

# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:

# Run inference on all files inputed
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
# Setup models.
Expand Down
9 changes: 5 additions & 4 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import functools
import json
import logging
import sys
from collections import deque
from time import time
Expand Down Expand Up @@ -38,10 +39,9 @@
from functools import cached_property

else: # cached_property is defined only for python >=3.8
from functools import lru_cache
cached_property = property

def cached_property(func):
return property(lru_cache()(func))
logger = logging.getLogger(__name__)


@attr.s(eq=False, slots=True, auto_attribs=True)
Expand Down Expand Up @@ -539,7 +539,8 @@ def is_valid(self):
def report_period(self) -> float:
"""Time between progress reports in seconds."""
if self.report_rate <= 0:
raise ValueError("report_rate must be positive")
logger.warning("report_rate must be positive, fallback to 1")
return 1.0
return 1.0 / self.report_rate

def run_step(self, lf: LabeledFrame) -> LabeledFrame:
Expand Down

0 comments on commit b61af69

Please sign in to comment.