diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index f07c228cc..a1a083ba5 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -368,6 +368,122 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: def _initialize_inference_model(self): pass + def _process_batch(self, ex: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Run prediction model on batch. + + This method handles running inference on a batch and postprocessing. + + Args: + ex: a dictionary holding the input for inference. + + Returns: + The input dictionary updated with the predictions. + """ + # Skip inference if model is not loaded + if self.inference_model is None: + return ex + + # Run inference on current batch. + preds = self.inference_model.predict_on_batch(ex, numpy=True) + + # Add model outputs to the input data example. + ex.update(preds) + + # Convert to numpy arrays if not already. + if isinstance(ex["video_ind"], tf.Tensor): + ex["video_ind"] = ex["video_ind"].numpy().flatten() + if isinstance(ex["frame_ind"], tf.Tensor): + ex["frame_ind"] = ex["frame_ind"].numpy().flatten() + + # Adjust for potential SizeMatcher scaling. + offset_x = ex.get("offset_x", 0) + offset_y = ex.get("offset_y", 0) + ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) + ex["instance_peaks"] /= np.expand_dims( + np.expand_dims(ex["scale"], axis=1), axis=1 + ) + + return ex + + def _run_batch_json( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + max_length: int = 30, + ) -> Iterator[Dict[str, np.ndarray]]: + n_processed = 0 + n_recent = deque(maxlen=max_length) + elapsed_recent = deque(maxlen=max_length) + last_report = time() + t0_all = time() + t0_batch = time() + for ex in examples: + # Process batch of examples. + ex = self._process_batch(ex) + + # Track timing and progress. + elapsed_batch = time() - t0_batch + t0_batch = time() + n_batch = len(ex["frame_ind"]) + n_processed += n_batch + elapsed_all = time() - t0_all + + # Compute recent rate. + n_recent.append(n_batch) + elapsed_recent.append(elapsed_batch) + rate = sum(n_recent) / sum(elapsed_recent) + eta = (n_total - n_processed) / rate + + # Report. + if time() > last_report + self.report_period: + print( + json.dumps( + { + "n_processed": n_processed, + "n_total": n_total, + "elapsed": elapsed_all, + "rate": rate, + "eta": eta, + } + ), + flush=True, + ) + last_report = time() + + # Return results. + yield ex + + def _run_batch_rich( + self, + examples: List[Dict[str, np.ndarray]], + n_total: int, + ) -> Iterator[Dict[str, np.ndarray]]: + with rich.progress.Progress( + "{task.description}", + rich.progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + "ETA:", + rich.progress.TimeRemainingColumn(), + RateColumn(), + auto_refresh=False, + refresh_per_second=self.report_rate, + speed_estimate_period=5, + ) as progress: + task = progress.add_task("Predicting...", total=n_total) + last_report = time() + for ex in examples: + ex = self._process_batch(ex) + + progress.update(task, advance=len(ex["frame_ind"])) + + # Handle refreshing manually to support notebooks. + if time() > last_report + self.report_period: + progress.refresh() + last_report = time() + + # Return results. + yield ex + def _predict_generator( self, data_provider: Provider ) -> Iterator[Dict[str, np.ndarray]]: @@ -389,106 +505,22 @@ def _predict_generator( if self.inference_model is None: self._initialize_inference_model() - def process_batch(ex): - # Run inference on current batch. - preds = self.inference_model.predict_on_batch(ex, numpy=True) - - # Add model outputs to the input data example. - ex.update(preds) - - # Convert to numpy arrays if not already. - if isinstance(ex["video_ind"], tf.Tensor): - ex["video_ind"] = ex["video_ind"].numpy().flatten() - if isinstance(ex["frame_ind"], tf.Tensor): - ex["frame_ind"] = ex["frame_ind"].numpy().flatten() - - # Adjust for potential SizeMatcher scaling. - offset_x = ex.get("offset_x", 0) - offset_y = ex.get("offset_y", 0) - ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) - ex["instance_peaks"] /= np.expand_dims( - np.expand_dims(ex["scale"], axis=1), axis=1 - ) - - return ex - # Compile loop examples before starting time to improve ETA + n_total = len(data_provider) examples = self.pipeline.make_dataset() # Loop over data batches with optional progress reporting. if self.verbosity == "rich": - with rich.progress.Progress( - "{task.description}", - rich.progress.BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - "ETA:", - rich.progress.TimeRemainingColumn(), - RateColumn(), - auto_refresh=False, - refresh_per_second=self.report_rate, - speed_estimate_period=5, - ) as progress: - task = progress.add_task("Predicting...", total=len(data_provider)) - last_report = time() - for ex in examples: - ex = process_batch(ex) - progress.update(task, advance=len(ex["frame_ind"])) - - # Handle refreshing manually to support notebooks. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - progress.refresh() - - # Return results. - yield ex + for ex in self._run_batch_rich(examples, n_total=n_total): + yield ex elif self.verbosity == "json": - n_processed = 0 - n_total = len(data_provider) - n_recent = deque(maxlen=30) - elapsed_recent = deque(maxlen=30) - last_report = time() - t0_all = time() - t0_batch = time() - for ex in examples: - # Process batch of examples. - ex = process_batch(ex) - - # Track timing and progress. - elapsed_batch = time() - t0_batch - t0_batch = time() - n_batch = len(ex["frame_ind"]) - n_processed += n_batch - elapsed_all = time() - t0_all - - # Compute recent rate. - n_recent.append(n_batch) - elapsed_recent.append(elapsed_batch) - rate = sum(n_recent) / sum(elapsed_recent) - eta = (n_total - n_processed) / rate - - # Report. - elapsed_since_last_report = time() - last_report - if elapsed_since_last_report > self.report_period: - print( - json.dumps( - { - "n_processed": n_processed, - "n_total": n_total, - "elapsed": elapsed_all, - "rate": rate, - "eta": eta, - } - ), - flush=True, - ) - last_report = time() - - # Return results. + for ex in self._run_batch_json(examples, n_total=n_total): yield ex + else: for ex in examples: - yield process_batch(ex) + yield self._process_batch(ex) def predict( self, data: Union[Provider, sleap.Labels, sleap.Video], make_labels: bool = True