-
Notifications
You must be signed in to change notification settings - Fork 32k
[Proposal] Breaking change zero-shot-object-detection for improved consistency.
#20280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,7 @@ | ||
| from typing import Dict, List, Union | ||
| from typing import Any, Dict, List, Union | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ..tokenization_utils_base import BatchEncoding | ||
| from ..utils import ( | ||
| add_end_docstrings, | ||
| is_tf_available, | ||
| is_torch_available, | ||
| is_vision_available, | ||
| logging, | ||
| requires_backends, | ||
| ) | ||
| from .base import PIPELINE_INIT_ARGS, Pipeline | ||
| from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends | ||
| from .base import PIPELINE_INIT_ARGS, ChunkPipeline | ||
|
|
||
|
|
||
| if is_vision_available(): | ||
|
|
@@ -22,13 +12,15 @@ | |
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| from transformers.modeling_outputs import BaseModelOutput | ||
|
|
||
| from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| @add_end_docstrings(PIPELINE_INIT_ARGS) | ||
| class ZeroShotObjectDetectionPipeline(Pipeline): | ||
| class ZeroShotObjectDetectionPipeline(ChunkPipeline): | ||
| """ | ||
| Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of | ||
| objects when you provide an image and a set of `candidate_labels`. | ||
|
|
@@ -43,13 +35,13 @@ class ZeroShotObjectDetectionPipeline(Pipeline): | |
| ... "http://images.cocodataset.org/val2017/000000039769.jpg", | ||
| ... candidate_labels=["cat", "couch"], | ||
| ... ) | ||
| [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] | ||
| [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}] | ||
|
|
||
| >>> detector( | ||
| ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", | ||
| ... candidate_labels=["head", "bird"], | ||
| ... ) | ||
| [[{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]] | ||
| [{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}] | ||
| ``` | ||
|
|
||
| [Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial) | ||
|
|
@@ -72,24 +64,45 @@ def __init__(self, **kwargs): | |
|
|
||
| def __call__( | ||
| self, | ||
| images: Union[str, List[str], "Image.Image", List["Image.Image"]], | ||
| text_queries: Union[str, List[str], List[List[str]]] = None, | ||
| image: Union[str, "Image.Image", List[Dict[str, Any]]], | ||
| candidate_labels: Union[str, List[str]] = None, | ||
| **kwargs | ||
| ): | ||
| """ | ||
| Detect objects (bounding boxes & classes) in the image(s) passed as inputs. | ||
|
|
||
| Args: | ||
| images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): | ||
| image (`str`, `PIL.Image` or `List[Dict[str, Any]]`): | ||
| The pipeline handles three types of images: | ||
|
|
||
| - A string containing an http url pointing to an image | ||
| - A string containing a local path to an image | ||
| - An image loaded in PIL directly | ||
|
|
||
| text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with. | ||
| If given multiple images, `text_queries` should be provided as a list of lists, where each nested list | ||
| contains the text queries for the corresponding image. | ||
| You can use this parameter to send directly a list of images, or a dataset or a generator like so: | ||
|
|
||
| ```python | ||
| >>> from transformers import pipeline | ||
|
|
||
| >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") | ||
| >>> detector( | ||
| ... [ | ||
| ... { | ||
| ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg", | ||
| ... "candidate_labels": ["cat", "couch"], | ||
| ... }, | ||
| ... { | ||
| ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg", | ||
| ... "candidate_labels": ["cat", "couch"], | ||
| ... }, | ||
| ... ] | ||
| ... ) | ||
| [[{'score': 0.286811888217926, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.2537279725074768, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.12082888185977936, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}], [{'score': 0.286811888217926, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.2537279725074768, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.12082888185977936, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] | ||
| ``` | ||
|
|
||
|
|
||
| candidate_labels (`str` or `List[str]` or `List[List[str]]`): | ||
| What the model should recognize in the image. | ||
|
|
||
| threshold (`float`, *optional*, defaults to 0.1): | ||
| The probability necessary to make a prediction. | ||
|
|
@@ -108,28 +121,13 @@ def __call__( | |
| - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a | ||
| dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. | ||
| """ | ||
| if "candidate_labels" in kwargs: | ||
| text_queries = kwargs.pop("candidate_labels") | ||
| if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)): | ||
| if isinstance(images, (str, Image.Image)): | ||
| inputs = {"images": images, "text_queries": text_queries} | ||
| elif isinstance(images, List): | ||
| assert len(images) == 1, "Input text_queries and images must have correspondance" | ||
| inputs = {"images": images[0], "text_queries": text_queries} | ||
| else: | ||
| raise TypeError(f"Innapropriate type of images: {type(images)}") | ||
|
|
||
| elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)): | ||
| if isinstance(images, (Image.Image, str)): | ||
| images = [images] | ||
| assert len(images) == len(text_queries), "Input text_queries and images must have correspondance" | ||
| inputs = {"images": images, "text_queries": text_queries} | ||
| if "text_queries" in kwargs: | ||
| candidate_labels = kwargs.pop("text_queries") | ||
|
|
||
| if isinstance(image, (str, Image.Image)): | ||
| inputs = {"image": image, "candidate_labels": candidate_labels} | ||
| else: | ||
| """ | ||
| Supports the following format | ||
| - {"images": images, "text_queries": text_queries} | ||
| """ | ||
| inputs = images | ||
| inputs = image | ||
| results = super().__call__(inputs, **kwargs) | ||
| return results | ||
|
|
||
|
|
@@ -142,49 +140,54 @@ def _sanitize_parameters(self, **kwargs): | |
| return {}, {}, postprocess_params | ||
|
|
||
| def preprocess(self, inputs): | ||
| if not isinstance(inputs["images"], List): | ||
| inputs["images"] = [inputs["images"]] | ||
| images = [load_image(img) for img in inputs["images"]] | ||
| text_queries = inputs["text_queries"] | ||
| if isinstance(text_queries, str) or isinstance(text_queries[0], str): | ||
| text_queries = [text_queries] | ||
|
|
||
| target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images] | ||
| target_sizes = torch.cat(target_sizes) | ||
| inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt") | ||
| return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs} | ||
| image = load_image(inputs["image"]) | ||
| candidate_labels = inputs["candidate_labels"] | ||
| if isinstance(candidate_labels, str): | ||
| candidate_labels = candidate_labels.split(",") | ||
|
|
||
| target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32) | ||
| for i, candidate_label in enumerate(candidate_labels): | ||
| text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework) | ||
| image_features = self.feature_extractor(image, return_tensors=self.framework) | ||
| yield { | ||
| "is_last": i == len(candidate_labels) - 1, | ||
| "target_size": target_size, | ||
| "candidate_label": candidate_label, | ||
| **text_inputs, | ||
| **image_features, | ||
| } | ||
|
|
||
| def _forward(self, model_inputs): | ||
| target_sizes = model_inputs.pop("target_sizes") | ||
| text_queries = model_inputs.pop("text_queries") | ||
| target_size = model_inputs.pop("target_size") | ||
| candidate_label = model_inputs.pop("candidate_label") | ||
| is_last = model_inputs.pop("is_last") | ||
|
|
||
| outputs = self.model(**model_inputs) | ||
|
|
||
| model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs}) | ||
| model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs} | ||
| return model_outputs | ||
|
|
||
| def postprocess(self, model_outputs, threshold=0.1, top_k=None): | ||
| texts = model_outputs["text_queries"] | ||
|
|
||
| outputs = self.feature_extractor.post_process( | ||
| outputs=model_outputs, target_sizes=model_outputs["target_sizes"] | ||
| ) | ||
|
|
||
| results = [] | ||
| for i in range(len(outputs)): | ||
| keep = outputs[i]["scores"] >= threshold | ||
| labels = outputs[i]["labels"][keep].tolist() | ||
| scores = outputs[i]["scores"][keep].tolist() | ||
| boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]] | ||
|
|
||
| result = [ | ||
| {"score": score, "label": texts[i][label], "box": box} | ||
| for score, label, box in zip(scores, labels, boxes) | ||
| ] | ||
|
|
||
| result = sorted(result, key=lambda x: x["score"], reverse=True) | ||
| if top_k: | ||
| result = result[:top_k] | ||
| results.append(result) | ||
| for model_output in model_outputs: | ||
| label = model_output["candidate_label"] | ||
| model_output = BaseModelOutput(model_output) | ||
| outputs = self.feature_extractor.post_process( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Narsil This is where I assume the non-max suppression would be done. We can check once though
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something weird. self.feature_extractor,post_process method itself doesn't have nms but self.feature_extractor,post_process_image_guided_detection method does.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to switch though, this will need rework since I guess we need to stack all the outputs together before calliung this. Is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old code only had the probability threshold but no NMS, could you add it back @Narsil?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends, supporting NMS in all post-processing methods would be really handy and I'll add it to all relevant models' feature extractors anyway. Or as you suggested, I can create a util function and edit the post-processing methods.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whichever your prefer. I like indedendant functions usually because there is less state to them (no
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I will go for the common util function for NMS + improved post-processing methods with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will merge this and we can fix NMS in another PR, ok ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! |
||
| outputs=model_output, target_sizes=model_output["target_size"] | ||
| )[0] | ||
| keep = outputs["scores"] >= threshold | ||
|
|
||
| for index in keep.nonzero(): | ||
| score = outputs["scores"][index].item() | ||
| box = self._get_bounding_box(outputs["boxes"][index][0]) | ||
|
|
||
| result = {"score": score, "label": label, "box": box} | ||
| results.append(result) | ||
|
|
||
| results = sorted(results, key=lambda x: x["score"], reverse=True) | ||
| if top_k: | ||
| results = results[:top_k] | ||
|
|
||
| return results | ||
|
|
||
|
|
@@ -208,94 +211,3 @@ def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]: | |
| "ymax": ymax, | ||
| } | ||
| return bbox | ||
|
|
||
| # Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet! | ||
| def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs): | ||
| """ | ||
| Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and | ||
| `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: | ||
| the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to | ||
| CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the | ||
| doctsring of the above two methods for more information. | ||
|
|
||
| Args: | ||
| text (`str`, `List[str]`, `List[List[str]]`): | ||
| The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings | ||
| (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set | ||
| `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). | ||
| images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, | ||
| `List[torch.Tensor]`): | ||
| The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | ||
| tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a | ||
| number of channels, H and W are image height and width. | ||
| return_tensors (`str` or [`~utils.TensorType`], *optional*): | ||
| If set, will return tensors of a particular framework. Acceptable values are: | ||
| - `'tf'`: Return TensorFlow `tf.constant` objects. | ||
| - `'pt'`: Return PyTorch `torch.Tensor` objects. | ||
| - `'np'`: Return NumPy `np.ndarray` objects. | ||
| - `'jax'`: Return JAX `jnp.ndarray` objects. | ||
| Returns: | ||
| [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: | ||
| - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. | ||
| - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when | ||
| `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not | ||
| `None`). | ||
| - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | ||
| """ | ||
|
|
||
| if text is None and images is None: | ||
| raise ValueError("You have to specify at least one text or image. Both cannot be none.") | ||
|
|
||
| if text is not None: | ||
| if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): | ||
| encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] | ||
|
|
||
| elif isinstance(text, List) and isinstance(text[0], List): | ||
| encodings = [] | ||
|
|
||
| # Maximum number of queries across batch | ||
| max_num_queries = max([len(t) for t in text]) | ||
|
|
||
| # Pad all batch samples to max number of text queries | ||
| for t in text: | ||
| if len(t) != max_num_queries: | ||
| t = t + [" "] * (max_num_queries - len(t)) | ||
|
|
||
| encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) | ||
| encodings.append(encoding) | ||
| else: | ||
| raise TypeError("Input text should be a string, a list of strings or a nested list of strings") | ||
|
|
||
| if return_tensors == "np": | ||
| input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) | ||
| attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) | ||
|
|
||
| elif return_tensors == "pt" and is_torch_available(): | ||
| import torch | ||
|
|
||
| input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) | ||
| attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) | ||
|
|
||
| elif return_tensors == "tf" and is_tf_available(): | ||
| import tensorflow as tf | ||
|
|
||
| input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) | ||
| attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) | ||
|
|
||
| else: | ||
| raise ValueError("Target return tensor type could not be returned") | ||
|
|
||
| encoding = BatchEncoding() | ||
| encoding["input_ids"] = input_ids | ||
| encoding["attention_mask"] = attention_mask | ||
|
|
||
| if images is not None: | ||
| image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) | ||
|
|
||
| if text is not None and images is not None: | ||
| encoding["pixel_values"] = image_features.pixel_values | ||
| return encoding | ||
| elif text is not None: | ||
| return encoding | ||
| else: | ||
| return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) | ||
Uh oh!
There was an error while loading. Please reload this page.