From b900e6f6e04bcf670bd867e1ddae0974105d01a7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 16 Nov 2022 14:33:46 +0100 Subject: [PATCH 1/2] Adding `zero-shot-object-detection` pipeline doctest. --- .../pipelines/zero_shot_object_detection.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 8c18bd502e6f..91f4284ebee3 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -33,6 +33,34 @@ class ZeroShotObjectDetectionPipeline(Pipeline): Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of objects when you provide an image and a set of `candidate_labels`. + Example: + + ```python + >>> from transformers import pipeline + + >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") + >>> answers = detector( + ... "http://images.cocodataset.org/val2017/000000039769.jpg", + ... candidate_labels=["cat", "couch"], + ... ) + >>> from transformers.testing_utils import ( + ... nested_simplify, + ... ) # Actual scores might vary slightly depending on PyTorch version or Tensorflow + + >>> nested_simplify(answers) + [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] + + >>> nested_simplify( + ... detector( + ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", + ... candidate_labels=["head", "bird"], + ... ) + ... ) + [[{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]] + ``` + + [Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial) + This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier: `"zero-shot-object-detection"`. @@ -87,6 +115,8 @@ def __call__( - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. """ + if "candidate_labels" in kwargs: + text_queries = kwargs.pop("candidate_labels") if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)): if isinstance(images, (str, Image.Image)): inputs = {"images": images, "text_queries": text_queries} From 06f184571c34f69b0828c1284c778d6e47efd207 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 16 Nov 2022 16:51:43 +0100 Subject: [PATCH 2/2] Remove nested_simplify. --- .../pipelines/zero_shot_object_detection.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py index 91f4284ebee3..a24b7201c844 100644 --- a/src/transformers/pipelines/zero_shot_object_detection.py +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -39,22 +39,15 @@ class ZeroShotObjectDetectionPipeline(Pipeline): >>> from transformers import pipeline >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") - >>> answers = detector( + >>> detector( ... "http://images.cocodataset.org/val2017/000000039769.jpg", ... candidate_labels=["cat", "couch"], ... ) - >>> from transformers.testing_utils import ( - ... nested_simplify, - ... ) # Actual scores might vary slightly depending on PyTorch version or Tensorflow - - >>> nested_simplify(answers) [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]] - >>> nested_simplify( - ... detector( - ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", - ... candidate_labels=["head", "bird"], - ... ) + >>> detector( + ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", + ... candidate_labels=["head", "bird"], ... ) [[{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]] ```