Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference results data structures #46

Open
talmo opened this issue May 3, 2024 · 0 comments
Open

Inference results data structures #46

talmo opened this issue May 3, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@talmo
Copy link
Contributor

talmo commented May 3, 2024

As we support many model types, they all have different types of outputs produced during inference, including intermediate ones (e.g., centroids in the first stage of top-down models), as well as ones that are optional (e.g., confidence maps since they're expensive to transfer off the GPU).

Historically, organizing these tensors were hard to do, particularly in graph mode in TensorFlow, since autograph barely supported data containers other than tuples (e.g., dictionaries, dataclasses, or arbitrary python classes), in part because it's harder to traverse them and do shape and datatype tracing since Python is not strongly typed.

It would be great to take advantage of the fact that we are no longer constrained by these limitations and organize our output a bit better.

Right now, we're using dictionaries, which are helpful but a bit brittle. For example, here's the output of our stage-2 (centered instance) model:

# Build outputs.
outputs = {"pred_instance_peaks": peak_points, "pred_peak_values": peak_vals}
if self.return_confmaps:
outputs["pred_confmaps"] = cms.detach()
inputs.update(outputs)
return inputs

Compare that to using a dataclass like this from ultralytics:

class Results(SimpleClass):
    """
    A class for storing and manipulating inference results.

    Attributes:
        orig_img (numpy.ndarray): Original image as a numpy array.
        orig_shape (tuple): Original image shape in (height, width) format.
        boxes (Boxes, optional): Object containing detection bounding boxes.
        masks (Masks, optional): Object containing detection masks.
        probs (Probs, optional): Object containing class probabilities for classification tasks.
        keypoints (Keypoints, optional): Object containing detected keypoints for each object.
        speed (dict): Dictionary of preprocess, inference, and postprocess speeds (ms/image).
        names (dict): Dictionary of class names.
        path (str): Path to the image file.

    Methods:
        update(boxes=None, masks=None, probs=None, obb=None): Updates object attributes with new detection results.
        cpu(): Returns a copy of the Results object with all tensors on CPU memory.
        numpy(): Returns a copy of the Results object with all tensors as numpy arrays.
        cuda(): Returns a copy of the Results object with all tensors on GPU memory.
        to(*args, **kwargs): Returns a copy of the Results object with tensors on a specified device and dtype.
        new(): Returns a new Results object with the same image, path, and names.
        plot(...): Plots detection results on an input image, returning an annotated image.
        show(): Show annotated results to screen.
        save(filename): Save annotated results to file.
        verbose(): Returns a log string for each task, detailing detections and classifications.
        save_txt(txt_file, save_conf=False): Saves detection results to a text file.
        save_crop(save_dir, file_name=Path("im.jpg")): Saves cropped detection images.
        tojson(normalize=False): Converts detection results to JSON format.
    """

(source)

These get produced as outputs from their high level APIs (~= Predictors in our repo):

    def predict(
        self,
        source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
        stream: bool = False,
        predictor=None,
        **kwargs,
    ) -> list:
        """
        Performs predictions on the given image source using the YOLO model.

        This method facilitates the prediction process, allowing various configurations through keyword arguments.
        It supports predictions with custom predictors or the default predictor method. The method handles different
        types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
        through 'prompts'.

        The method sets up a new predictor if not already present and updates its arguments with each call.
        It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
        is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
        for confidence threshold and saving behavior.

        Args:
            source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
                Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
            stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
            predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
                If None, the method uses a default predictor. Defaults to None.
            **kwargs (any): Additional keyword arguments for configuring the prediction process. These arguments allow
                for further customization of the prediction behavior.

        Returns:
            (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.

        Raises:
            AttributeError: If the predictor is not properly set up.
        """

(source)

Which are kicked up from their Predictors (~= InferenceModels in core SLEAP ~= LightningModules in our sleap_nn.inference):

    def postprocess(self, preds, img, orig_imgs):
        """Return detection results for a given input image or list of images."""
        preds = ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            classes=self.args.classes,
            nc=len(self.model.names),
        )

        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for i, pred in enumerate(preds):
            orig_img = orig_imgs[i]
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
            pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
            pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
            img_path = self.batch[0][i]
            results.append(
                Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
            )
        return results

(source)

For us, it would be handy to be able to store arbitrary tensors that should go together, like indexing information which should help with tasks like re-batching in top-down models.

It would also enable handy utilities like a to() method that moves all its containing tensors to the same device, for example:

    def to(self, map_location):
        """Move instance to different device or change dtype. (See `torch.to` for more info).

        Args:
            map_location: Either the device or dtype for the instance to be moved.

        Returns:
            self: reference to the instance moved to correct device/dtype.
        """
        if map_location is not None and map_location != "":
            self._gt_track_id = self._gt_track_id.to(map_location)
            self._pred_track_id = self._pred_track_id.to(map_location)
            self._bbox = self._bbox.to(map_location)
            self._crop = self._crop.to(map_location)
            self._features = self._features.to(map_location)
            self.device = map_location

        return self

(source)

@talmo talmo added the enhancement New feature or request label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant