Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,9 @@ def save_pretrained(
if self.image_processor is not None:
self.image_processor.save_pretrained(save_directory, **kwargs)

if self.processor is not None:
self.processor.save_pretrained(save_directory, **kwargs)

if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)

Expand Down
115 changes: 69 additions & 46 deletions src/transformers/pipelines/zero_shot_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from .base import ChunkPipeline, build_pipeline_init_args
from ..utils.deprecation import deprecate_kwarg
from .base import Pipeline, build_pipeline_init_args


if is_vision_available():
Expand All @@ -12,15 +13,13 @@
if is_torch_available():
import torch

from transformers.modeling_outputs import BaseModelOutput

from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES

logger = logging.get_logger(__name__)


@add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
class ZeroShotObjectDetectionPipeline(ChunkPipeline):
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
class ZeroShotObjectDetectionPipeline(Pipeline):
"""
Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
objects when you provide an image and a set of `candidate_labels`.
Expand Down Expand Up @@ -53,6 +52,13 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
"""

_load_processor = True

# set to False because required sub-processors will be loaded with the `Processor` class
_load_tokenizer = False
_load_image_processor = False
_load_feature_extractor = False

def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -66,6 +72,9 @@ def __call__(
self,
image: Union[str, "Image.Image", List[Dict[str, Any]]],
candidate_labels: Union[str, List[str]] = None,
threshold: float = 0.1,
top_k: Optional[int] = None,
timeout: Optional[float] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -125,14 +134,11 @@ def __call__(
- **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
"""
if "text_queries" in kwargs:
candidate_labels = kwargs.pop("text_queries")

if isinstance(image, (str, Image.Image)):
inputs = {"image": image, "candidate_labels": candidate_labels}
else:
inputs = image
results = super().__call__(inputs, **kwargs)
results = super().__call__(inputs, timeout=timeout, threshold=threshold, top_k=top_k, **kwargs)
return results

def _sanitize_parameters(self, **kwargs):
Expand All @@ -146,57 +152,74 @@ def _sanitize_parameters(self, **kwargs):
postprocess_params["top_k"] = kwargs["top_k"]
return preprocess_params, {}, postprocess_params

@deprecate_kwarg("text_queries", new_name="candidate_labels", version="5.0.0")
def _preprocess_input_keys(self, image, candidate_labels):
"""
The method is used to convert input keys to pipeline specific keys, taking into consideration
backward compatibility for deprecated ones.
"""
return {"image": image, "candidate_labels": candidate_labels}

def preprocess(self, inputs, timeout=None):
# convert keys to unified format: image + candidate_labels
inputs = self._preprocess_input_keys(**inputs)

image = load_image(inputs["image"], timeout=timeout)
candidate_labels = inputs["candidate_labels"]
if isinstance(candidate_labels, str):
candidate_labels = candidate_labels.split(",")

target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32)
for i, candidate_label in enumerate(candidate_labels):
text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
image_features = self.image_processor(image, return_tensors=self.framework)
if self.framework == "pt":
image_features = image_features.to(self.torch_dtype)
yield {
"is_last": i == len(candidate_labels) - 1,
"target_size": target_size,
"candidate_label": candidate_label,
**text_inputs,
**image_features,
}

# preprocess the inputs
model_inputs = self.processor(images=image, text=candidate_labels, return_tensors=self.framework)

# save extra data for post processing
model_inputs["_target_size"] = [image.height, image.width]
model_inputs["_candidate_labels"] = candidate_labels

return model_inputs

def _forward(self, model_inputs):
target_size = model_inputs.pop("target_size")
candidate_label = model_inputs.pop("candidate_label")
is_last = model_inputs.pop("is_last")
# separate extra data and model inputs for forward
extras = {k: v for k, v in model_inputs.items() if k.startswith("_")}
inputs = {k: v for k, v in model_inputs.items() if not k.startswith("_")}

# forward
model_outputs = self.model(**inputs)

outputs = self.model(**model_inputs)
# pass extra data for post processing
outputs = {**model_outputs, **extras}

model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs}
return model_outputs
return outputs

def postprocess(self, model_outputs, threshold=0.1, top_k=None):
results = []
for model_output in model_outputs:
label = model_output["candidate_label"]
model_output = BaseModelOutput(model_output)
outputs = self.image_processor.post_process_object_detection(
outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
if hasattr(self.processor, "post_process_grounded_object_detection"):
# Grounding Dino and OmDet cases
outputs = self.processor.post_process_grounded_object_detection(
model_outputs,
model_outputs["_input_ids"],
box_threshold=threshold,
text_threshold=threshold,
target_sizes=[model_outputs["_target_size"]],
)[0]
else:
# OwlViT and OwlV2 cases
outputs = self.processor.post_process_object_detection(
outputs=model_outputs, threshold=threshold, target_sizes=[model_outputs["_target_size"]]
)[0]
labels = model_outputs["_candidate_labels"]
outputs["labels"] = [labels[label_id.item()] for label_id in outputs["labels"]]

for index in outputs["scores"].nonzero():
score = outputs["scores"][index].item()
box = self._get_bounding_box(outputs["boxes"][index][0])
scores = outputs["scores"].tolist()
boxes = [self._get_bounding_box(box) for box in outputs["boxes"]]
labels = outputs["labels"]

result = {"score": score, "label": label, "box": box}
results.append(result)
annotations = [
{"score": score, "label": label, "box": box} for score, label, box in zip(scores, labels, boxes)
]

results = sorted(results, key=lambda x: x["score"], reverse=True)
annotations = sorted(annotations, key=lambda x: x["score"], reverse=True)
if top_k:
results = results[:top_k]
annotations = annotations[:top_k]

return results
return annotations

def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
"""
Expand Down