diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index a24b7201c844..83a4969a7f6d 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -1,17 +1,7 @@ -from typing import Dict, List, Union +from typing import Any, Dict, List, Union -import numpy as np - -from ..tokenization_utils_base import BatchEncoding -from ..utils import ( - add_end_docstrings, - is_tf_available, - is_torch_available, - is_vision_available, - logging, - requires_backends, -) -from .base import PIPELINE_INIT_ARGS, Pipeline +from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends +from .base import PIPELINE_INIT_ARGS, ChunkPipeline if is_vision_available(): @@ -22,13 +12,15 @@ if is_torch_available(): import torch + from transformers.modeling_outputs import BaseModelOutput + from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING logger = logging.get_logger(__name__) @add_end_docstrings(PIPELINE_INIT_ARGS) -class ZeroShotObjectDetectionPipeline(Pipeline): +class ZeroShotObjectDetectionPipeline(ChunkPipeline): """ Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of objects when you provide an image and a set of `candidate_labels`. @@ -43,13 +35,13 @@ class ZeroShotObjectDetectionPipeline(Pipeline): ... "http://images.cocodataset.org/val2017/000000039769.jpg", ... candidate_labels=["cat", "couch"], ... ) - [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] + [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}] >>> detector( ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", ... candidate_labels=["head", "bird"], ... ) - [[{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]] + [{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}] ``` [Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial) @@ -72,24 +64,45 @@ def __init__(self, **kwargs): def __call__( self, - images: Union[str, List[str], "Image.Image", List["Image.Image"]], - text_queries: Union[str, List[str], List[List[str]]] = None, + image: Union[str, "Image.Image", List[Dict[str, Any]]], + candidate_labels: Union[str, List[str]] = None, **kwargs ): """ Detect objects (bounding boxes & classes) in the image(s) passed as inputs. Args: - images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): + image (`str`, `PIL.Image` or `List[Dict[str, Any]]`): The pipeline handles three types of images: - A string containing an http url pointing to an image - A string containing a local path to an image - An image loaded in PIL directly - text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with. - If given multiple images, `text_queries` should be provided as a list of lists, where each nested list - contains the text queries for the corresponding image. + You can use this parameter to send directly a list of images, or a dataset or a generator like so: + + ```python + >>> from transformers import pipeline + + >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") + >>> detector( + ... [ + ... { + ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg", + ... "candidate_labels": ["cat", "couch"], + ... }, + ... { + ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg", + ... "candidate_labels": ["cat", "couch"], + ... }, + ... ] + ... ) + [[{'score': 0.286811888217926, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.2537279725074768, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.12082888185977936, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}], [{'score': 0.286811888217926, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.2537279725074768, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.12082888185977936, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] + ``` + + + candidate_labels (`str` or `List[str]` or `List[List[str]]`): + What the model should recognize in the image. threshold (`float`, *optional*, defaults to 0.1): The probability necessary to make a prediction. @@ -108,28 +121,13 @@ def __call__( - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. """ - if "candidate_labels" in kwargs: - text_queries = kwargs.pop("candidate_labels") - if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)): - if isinstance(images, (str, Image.Image)): - inputs = {"images": images, "text_queries": text_queries} - elif isinstance(images, List): - assert len(images) == 1, "Input text_queries and images must have correspondance" - inputs = {"images": images[0], "text_queries": text_queries} - else: - raise TypeError(f"Innapropriate type of images: {type(images)}") - - elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)): - if isinstance(images, (Image.Image, str)): - images = [images] - assert len(images) == len(text_queries), "Input text_queries and images must have correspondance" - inputs = {"images": images, "text_queries": text_queries} + if "text_queries" in kwargs: + candidate_labels = kwargs.pop("text_queries") + + if isinstance(image, (str, Image.Image)): + inputs = {"image": image, "candidate_labels": candidate_labels} else: - """ - Supports the following format - - {"images": images, "text_queries": text_queries} - """ - inputs = images + inputs = image results = super().__call__(inputs, **kwargs) return results @@ -142,49 +140,54 @@ def _sanitize_parameters(self, **kwargs): return {}, {}, postprocess_params def preprocess(self, inputs): - if not isinstance(inputs["images"], List): - inputs["images"] = [inputs["images"]] - images = [load_image(img) for img in inputs["images"]] - text_queries = inputs["text_queries"] - if isinstance(text_queries, str) or isinstance(text_queries[0], str): - text_queries = [text_queries] - - target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images] - target_sizes = torch.cat(target_sizes) - inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt") - return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs} + image = load_image(inputs["image"]) + candidate_labels = inputs["candidate_labels"] + if isinstance(candidate_labels, str): + candidate_labels = candidate_labels.split(",") + + target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32) + for i, candidate_label in enumerate(candidate_labels): + text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework) + image_features = self.feature_extractor(image, return_tensors=self.framework) + yield { + "is_last": i == len(candidate_labels) - 1, + "target_size": target_size, + "candidate_label": candidate_label, + **text_inputs, + **image_features, + } def _forward(self, model_inputs): - target_sizes = model_inputs.pop("target_sizes") - text_queries = model_inputs.pop("text_queries") + target_size = model_inputs.pop("target_size") + candidate_label = model_inputs.pop("candidate_label") + is_last = model_inputs.pop("is_last") + outputs = self.model(**model_inputs) - model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs}) + model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs} return model_outputs def postprocess(self, model_outputs, threshold=0.1, top_k=None): - texts = model_outputs["text_queries"] - - outputs = self.feature_extractor.post_process( - outputs=model_outputs, target_sizes=model_outputs["target_sizes"] - ) results = [] - for i in range(len(outputs)): - keep = outputs[i]["scores"] >= threshold - labels = outputs[i]["labels"][keep].tolist() - scores = outputs[i]["scores"][keep].tolist() - boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]] - - result = [ - {"score": score, "label": texts[i][label], "box": box} - for score, label, box in zip(scores, labels, boxes) - ] - - result = sorted(result, key=lambda x: x["score"], reverse=True) - if top_k: - result = result[:top_k] - results.append(result) + for model_output in model_outputs: + label = model_output["candidate_label"] + model_output = BaseModelOutput(model_output) + outputs = self.feature_extractor.post_process( + outputs=model_output, target_sizes=model_output["target_size"] + )[0] + keep = outputs["scores"] >= threshold + + for index in keep.nonzero(): + score = outputs["scores"][index].item() + box = self._get_bounding_box(outputs["boxes"][index][0]) + + result = {"score": score, "label": label, "box": box} + results.append(result) + + results = sorted(results, key=lambda x: x["score"], reverse=True) + if top_k: + results = results[:top_k] return results @@ -208,94 +211,3 @@ def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]: "ymax": ymax, } return bbox - - # Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet! - def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs): - """ - Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and - `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: - the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to - CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the - doctsring of the above two methods for more information. - - Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, - `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a - number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: - [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: - - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not - `None`). - - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - """ - - if text is None and images is None: - raise ValueError("You have to specify at least one text or image. Both cannot be none.") - - if text is not None: - if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): - encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] - - elif isinstance(text, List) and isinstance(text[0], List): - encodings = [] - - # Maximum number of queries across batch - max_num_queries = max([len(t) for t in text]) - - # Pad all batch samples to max number of text queries - for t in text: - if len(t) != max_num_queries: - t = t + [" "] * (max_num_queries - len(t)) - - encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) - encodings.append(encoding) - else: - raise TypeError("Input text should be a string, a list of strings or a nested list of strings") - - if return_tensors == "np": - input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) - - elif return_tensors == "pt" and is_torch_available(): - import torch - - input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) - attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) - - elif return_tensors == "tf" and is_tf_available(): - import tensorflow as tf - - input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) - attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) - - else: - raise ValueError("Target return tensor type could not be returned") - - encoding = BatchEncoding() - encoding["input_ids"] = input_ids - encoding["attention_mask"] = attention_mask - - if images is not None: - image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) - - if text is not None and images is not None: - encoding["pixel_values"] = image_features.pixel_values - return encoding - elif text is not None: - return encoding - else: - return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) diff --git a/tests/pipelines/test_pipelines_zero_shot_object_detection.py b/tests/pipelines/test_pipelines_zero_shot_object_detection.py index aef64e7db27b..06a611b53af2 100644 --- a/tests/pipelines/test_pipelines_zero_shot_object_detection.py +++ b/tests/pipelines/test_pipelines_zero_shot_object_detection.py @@ -43,28 +43,28 @@ def get_test_pipeline(self, model, tokenizer, feature_extractor): examples = [ { - "images": "./tests/fixtures/tests_samples/COCO/000000039769.png", - "text_queries": ["cat", "remote", "couch"], + "image": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "candidate_labels": ["cat", "remote", "couch"], } ] return object_detector, examples def run_pipeline_test(self, object_detector, examples): - batch_outputs = object_detector(examples, threshold=0.0) - - self.assertEqual(len(examples), len(batch_outputs)) - for outputs in batch_outputs: - for output_per_image in outputs: - self.assertGreater(len(output_per_image), 0) - for detected_object in output_per_image: - self.assertEqual( - detected_object, - { - "score": ANY(float), - "label": ANY(str), - "box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)}, - }, - ) + outputs = object_detector(examples[0], threshold=0.0) + + n = len(outputs) + self.assertGreater(n, 0) + self.assertEqual( + outputs, + [ + { + "score": ANY(float), + "label": ANY(str), + "box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)}, + } + for i in range(n) + ], + ) @require_tf @unittest.skip("Zero Shot Object Detection not implemented in TF") @@ -79,43 +79,32 @@ def test_small_model_pt(self): outputs = object_detector( "./tests/fixtures/tests_samples/COCO/000000039769.png", - text_queries=["cat", "remote", "couch"], + candidate_labels=["cat", "remote", "couch"], threshold=0.64, ) self.assertEqual( nested_simplify(outputs, decimals=4), [ - [ - {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, - {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, - {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, - {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, - ] + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.7218, "label": "remote", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.7184, "label": "couch", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6656, "label": "cat", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6614, "label": "couch", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + {"score": 0.6419, "label": "cat", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, ], ) outputs = object_detector( - ["./tests/fixtures/tests_samples/COCO/000000039769.png"], - text_queries=["cat", "remote", "couch"], - threshold=0.64, - ) - - self.assertEqual( - nested_simplify(outputs, decimals=4), [ - [ - {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, - {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, - {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, - {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, - ] + { + "image": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "candidate_labels": ["cat", "remote", "couch"], + } ], - ) - - outputs = object_detector( - "./tests/fixtures/tests_samples/COCO/000000039769.png", - text_queries=[["cat", "remote", "couch"]], threshold=0.64, ) @@ -124,67 +113,48 @@ def test_small_model_pt(self): [ [ {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.7218, "label": "remote", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.7184, "label": "couch", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6656, "label": "cat", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6614, "label": "couch", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + {"score": 0.6419, "label": "cat", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, ] ], ) - outputs = object_detector( - [ - "./tests/fixtures/tests_samples/COCO/000000039769.png", - "http://images.cocodataset.org/val2017/000000039769.jpg", - ], - text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]], - threshold=0.64, - ) - - self.assertEqual( - nested_simplify(outputs, decimals=4), - [ - [ - {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, - {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, - {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, - {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, - ], - [ - {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, - {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, - {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, - {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, - ], - ], - ) - @require_torch @slow def test_large_model_pt(self): object_detector = pipeline("zero-shot-object-detection") outputs = object_detector( - "http://images.cocodataset.org/val2017/000000039769.jpg", text_queries=["cat", "remote", "couch"] + "http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=["cat", "remote", "couch"] ) self.assertEqual( nested_simplify(outputs, decimals=4), [ - [ - {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, - {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, - {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, - {"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}}, - {"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}}, - ] + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, + {"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}}, + {"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}}, ], ) outputs = object_detector( [ - "http://images.cocodataset.org/val2017/000000039769.jpg", - "http://images.cocodataset.org/val2017/000000039769.jpg", + { + "image": "http://images.cocodataset.org/val2017/000000039769.jpg", + "candidate_labels": ["cat", "remote", "couch"], + }, + { + "image": "http://images.cocodataset.org/val2017/000000039769.jpg", + "candidate_labels": ["cat", "remote", "couch"], + }, ], - text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]], ) self.assertEqual( nested_simplify(outputs, decimals=4), @@ -219,17 +189,15 @@ def test_threshold(self): outputs = object_detector( "http://images.cocodataset.org/val2017/000000039769.jpg", - text_queries=["cat", "remote", "couch"], + candidate_labels=["cat", "remote", "couch"], threshold=threshold, ) self.assertEqual( nested_simplify(outputs, decimals=4), [ - [ - {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, - {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, - {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, - ] + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, ], ) @@ -241,15 +209,13 @@ def test_top_k(self): outputs = object_detector( "http://images.cocodataset.org/val2017/000000039769.jpg", - text_queries=["cat", "remote", "couch"], + candidate_labels=["cat", "remote", "couch"], top_k=top_k, ) self.assertEqual( nested_simplify(outputs, decimals=4), [ - [ - {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, - {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, - ] + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, ], )