diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 042958cbb0c6..5d75cc29880c 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1035,6 +1035,9 @@ def save_pretrained( if self.image_processor is not None: self.image_processor.save_pretrained(save_directory, **kwargs) + if self.processor is not None: + self.processor.save_pretrained(save_directory, **kwargs) + if self.modelcard is not None: self.modelcard.save_pretrained(save_directory) diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 9ad575202266..4c6b2dbc0862 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends -from .base import ChunkPipeline, build_pipeline_init_args +from ..utils.deprecation import deprecate_kwarg +from .base import Pipeline, build_pipeline_init_args if is_vision_available(): @@ -12,15 +13,13 @@ if is_torch_available(): import torch - from transformers.modeling_outputs import BaseModelOutput - from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES logger = logging.get_logger(__name__) -@add_end_docstrings(build_pipeline_init_args(has_image_processor=True)) -class ZeroShotObjectDetectionPipeline(ChunkPipeline): +@add_end_docstrings(build_pipeline_init_args(has_processor=True)) +class ZeroShotObjectDetectionPipeline(Pipeline): """ Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of objects when you provide an image and a set of `candidate_labels`. @@ -53,6 +52,13 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline): [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection). """ + _load_processor = True + + # set to False because required sub-processors will be loaded with the `Processor` class + _load_tokenizer = False + _load_image_processor = False + _load_feature_extractor = False + def __init__(self, **kwargs): super().__init__(**kwargs) @@ -66,6 +72,9 @@ def __call__( self, image: Union[str, "Image.Image", List[Dict[str, Any]]], candidate_labels: Union[str, List[str]] = None, + threshold: float = 0.1, + top_k: Optional[int] = None, + timeout: Optional[float] = None, **kwargs, ): """ @@ -125,14 +134,11 @@ def __call__( - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. """ - if "text_queries" in kwargs: - candidate_labels = kwargs.pop("text_queries") - if isinstance(image, (str, Image.Image)): inputs = {"image": image, "candidate_labels": candidate_labels} else: inputs = image - results = super().__call__(inputs, **kwargs) + results = super().__call__(inputs, timeout=timeout, threshold=threshold, top_k=top_k, **kwargs) return results def _sanitize_parameters(self, **kwargs): @@ -146,57 +152,74 @@ def _sanitize_parameters(self, **kwargs): postprocess_params["top_k"] = kwargs["top_k"] return preprocess_params, {}, postprocess_params + @deprecate_kwarg("text_queries", new_name="candidate_labels", version="5.0.0") + def _preprocess_input_keys(self, image, candidate_labels): + """ + The method is used to convert input keys to pipeline specific keys, taking into consideration + backward compatibility for deprecated ones. + """ + return {"image": image, "candidate_labels": candidate_labels} + def preprocess(self, inputs, timeout=None): + # convert keys to unified format: image + candidate_labels + inputs = self._preprocess_input_keys(**inputs) + image = load_image(inputs["image"], timeout=timeout) candidate_labels = inputs["candidate_labels"] - if isinstance(candidate_labels, str): - candidate_labels = candidate_labels.split(",") - - target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32) - for i, candidate_label in enumerate(candidate_labels): - text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework) - image_features = self.image_processor(image, return_tensors=self.framework) - if self.framework == "pt": - image_features = image_features.to(self.torch_dtype) - yield { - "is_last": i == len(candidate_labels) - 1, - "target_size": target_size, - "candidate_label": candidate_label, - **text_inputs, - **image_features, - } + + # preprocess the inputs + model_inputs = self.processor(images=image, text=candidate_labels, return_tensors=self.framework) + + # save extra data for post processing + model_inputs["_target_size"] = [image.height, image.width] + model_inputs["_candidate_labels"] = candidate_labels + + return model_inputs def _forward(self, model_inputs): - target_size = model_inputs.pop("target_size") - candidate_label = model_inputs.pop("candidate_label") - is_last = model_inputs.pop("is_last") + # separate extra data and model inputs for forward + extras = {k: v for k, v in model_inputs.items() if k.startswith("_")} + inputs = {k: v for k, v in model_inputs.items() if not k.startswith("_")} + + # forward + model_outputs = self.model(**inputs) - outputs = self.model(**model_inputs) + # pass extra data for post processing + outputs = {**model_outputs, **extras} - model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs} - return model_outputs + return outputs def postprocess(self, model_outputs, threshold=0.1, top_k=None): - results = [] - for model_output in model_outputs: - label = model_output["candidate_label"] - model_output = BaseModelOutput(model_output) - outputs = self.image_processor.post_process_object_detection( - outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"] + if hasattr(self.processor, "post_process_grounded_object_detection"): + # Grounding Dino and OmDet cases + outputs = self.processor.post_process_grounded_object_detection( + model_outputs, + model_outputs["_input_ids"], + box_threshold=threshold, + text_threshold=threshold, + target_sizes=[model_outputs["_target_size"]], )[0] + else: + # OwlViT and OwlV2 cases + outputs = self.processor.post_process_object_detection( + outputs=model_outputs, threshold=threshold, target_sizes=[model_outputs["_target_size"]] + )[0] + labels = model_outputs["_candidate_labels"] + outputs["labels"] = [labels[label_id.item()] for label_id in outputs["labels"]] - for index in outputs["scores"].nonzero(): - score = outputs["scores"][index].item() - box = self._get_bounding_box(outputs["boxes"][index][0]) + scores = outputs["scores"].tolist() + boxes = [self._get_bounding_box(box) for box in outputs["boxes"]] + labels = outputs["labels"] - result = {"score": score, "label": label, "box": box} - results.append(result) + annotations = [ + {"score": score, "label": label, "box": box} for score, label, box in zip(scores, labels, boxes) + ] - results = sorted(results, key=lambda x: x["score"], reverse=True) + annotations = sorted(annotations, key=lambda x: x["score"], reverse=True) if top_k: - results = results[:top_k] + annotations = annotations[:top_k] - return results + return annotations def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]: """