From b61af69eefc0bc0e6b7a710cd5243c46855afa8b Mon Sep 17 00:00:00 2001 From: getzze Date: Fri, 25 Oct 2024 16:29:13 +0100 Subject: [PATCH] undo coderabbit suggestion for cached_property --- sleap/nn/inference.py | 92 +++++++++++++++++++++---------------------- sleap/nn/tracking.py | 9 +++-- 2 files changed, 49 insertions(+), 52 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14d14e7d8..0923b6979 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -21,69 +21,67 @@ function which provides a simplified interface for creating `Predictor`s. """ -import attr import argparse +import atexit +import json import logging -import warnings import os -import sys -import tempfile import platform import shutil -import atexit import subprocess -import rich.progress -import pandas as pd -from rich.pretty import pprint +import sys +import tempfile +import warnings +from abc import ABC, abstractmethod from collections import deque -import json -from time import time from datetime import datetime from pathlib import Path -import tensorflow_hub as hub -from abc import ABC, abstractmethod -from typing import Text, Optional, List, Dict, Union, Iterator, Tuple -from threading import Thread from queue import Queue +from threading import Thread +from time import time +from typing import Dict, Iterator, List, Optional, Text, Tuple, Union if sys.version_info >= (3, 8): from functools import cached_property else: # cached_property is defined only for python >=3.8 - from functools import lru_cache - - def cached_property(func): - return property(lru_cache()(func)) + cached_property = property -import tensorflow as tf +import attr import numpy as np +import pandas as pd +import rich.progress +import tensorflow as tf +import tensorflow_hub as hub +from rich.pretty import pprint +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2, +) import sleap - -from sleap.nn.config import TrainingJobConfig, DataConfig -from sleap.nn.data.resizing import SizeMatcher -from sleap.nn.model import Model -from sleap.nn.tracking import Tracker -from sleap.nn.paf_grouping import PAFScorer +from sleap.instance import LabeledFrame, PredictedInstance +from sleap.io.dataset import Labels +from sleap.nn.config import DataConfig, TrainingJobConfig from sleap.nn.data.pipelines import ( - Provider, - Pipeline, + Batcher, + InstanceCentroidFinder, + KerasModelPredictor, LabelsReader, - VideoReader, Normalizer, - Resizer, + Pipeline, Prefetcher, - InstanceCentroidFinder, - KerasModelPredictor, + Provider, + Resizer, + VideoReader, ) +from sleap.nn.data.resizing import SizeMatcher +from sleap.nn.model import Model +from sleap.nn.paf_grouping import PAFScorer +from sleap.nn.tracking import Tracker from sleap.nn.utils import reset_input_layer -from sleap.io.dataset import Labels -from sleap.util import frame_list, make_scoped_dictionary, RateColumn -from sleap.instance import PredictedInstance, LabeledFrame +from sleap.util import RateColumn, frame_list, make_scoped_dictionary -from tensorflow.python.framework.convert_to_constants import ( - convert_variables_to_constants_v2, -) +logger = logging.getLogger(__name__) MOVENET_MODELS = { "lightning": { @@ -135,8 +133,6 @@ def cached_property(func): ], ) -logger = logging.getLogger(__name__) - def get_keras_model_path(path: Text) -> str: """Utility method for finding the path to a saved Keras model. @@ -169,7 +165,8 @@ class Predictor(ABC): def report_period(self) -> float: """Time between progress reports in seconds.""" if self.report_rate <= 0: - raise ValueError("report_rate must be positive") + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 return 1.0 / self.report_rate @classmethod @@ -360,7 +357,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: ensure_rgb=(not self.is_grayscale), ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -617,7 +614,7 @@ def export_model( ) + (keras_model_shape[3],) tracing_batch = np.zeros((1,) + sample_shape, dtype="uint8") - outputs = self.inference_model.predict(tracing_batch) + _ = self.inference_model.predict(tracing_batch) self.inference_model.export_model( save_path, signatures, save_traces, model_name, tensors, unrag_outputs @@ -2570,7 +2567,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4422,13 +4419,13 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: if self.centroid_model is None: anchor_part = self.confmap_config.data.instance_cropping.center_on_part - pipeline += sleap.nn.data.pipelines.InstanceCentroidFinder( + pipeline += InstanceCentroidFinder( center_on_anchor_part=anchor_part is not None, anchor_part_names=anchor_part, skeletons=self.confmap_config.data.labels.skeletons, ) - pipeline += sleap.nn.data.pipelines.Batcher( + pipeline += Batcher( batch_size=self.batch_size, drop_remainder=False, unrag=False ) @@ -4650,7 +4647,7 @@ def __init__(self, model_name="lightning"): ) def call(self, ex): - if type(ex) == dict: + if isinstance(ex, dict): img = ex["image"] else: @@ -5496,7 +5493,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor: max_instances=args.max_instances, ) - if type(predictor) == BottomUpPredictor: + if isinstance(predictor, BottomUpPredictor): predictor.inference_model.bottomup_layer.paf_scorer.max_edge_length_ratio = ( args.max_edge_length_ratio ) @@ -5608,7 +5605,6 @@ def main(args: Optional[list] = None): # Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run) if args.models is not None: - # Run inference on all files inputed for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)): # Setup models. diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index b65bb2f90..55170ea36 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -3,6 +3,7 @@ import abc import functools import json +import logging import sys from collections import deque from time import time @@ -38,10 +39,9 @@ from functools import cached_property else: # cached_property is defined only for python >=3.8 - from functools import lru_cache + cached_property = property - def cached_property(func): - return property(lru_cache()(func)) +logger = logging.getLogger(__name__) @attr.s(eq=False, slots=True, auto_attribs=True) @@ -539,7 +539,8 @@ def is_valid(self): def report_period(self) -> float: """Time between progress reports in seconds.""" if self.report_rate <= 0: - raise ValueError("report_rate must be positive") + logger.warning("report_rate must be positive, fallback to 1") + return 1.0 return 1.0 / self.report_rate def run_step(self, lf: LabeledFrame) -> LabeledFrame: