Skip to content

fix image-to-text batch incorrect output issue#29342

Merged
amyeroberts merged 3 commits into
huggingface:mainfrom
sywangyi:image-to-text-pipeline
Mar 8, 2024
Merged

fix image-to-text batch incorrect output issue#29342
amyeroberts merged 3 commits into
huggingface:mainfrom
sywangyi:image-to-text-pipeline

Conversation

@sywangyi
Copy link
Copy Markdown
Contributor

fix image to text multi-batch input , but output incorrect issue

@sywangyi
Copy link
Copy Markdown
Contributor Author

before the fix . the output is
image
after the fix , the output is
image

@sywangyi
Copy link
Copy Markdown
Contributor Author

@amyeroberts please help review

@amyeroberts
Copy link
Copy Markdown
Contributor

Hi @sywangyi, thanks for opening a PR!

To show what this PR addresses could you:

  • Provide a code snippet to reproduce?
  • Copy-paste the output of the terminal to show it working, rather than take a screen shot? This is 1) more readable, 2) makes the relevant code searchable and 3) means others can copy the text to be able to test on their own setup

Have you tested with non-batched input to confirm it's equivalent?

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 1, 2024

from transformers import pipeline
import torch
import requests
import PIL.Image
import time

image_url = "https://ankur3107.github.io/assets/images/image-captioning-example.png"
image = []
image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))
image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))

generator = pipeline(
   "image-to-text",
    model="Salesforce/blip-image-captioning-large",
    torch_dtype=torch.bfloat16,
    device="cuda",
)

result = generator(image, batch_size=2)
print(f"{result}")

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 1, 2024

before the fix:

2024-03-01 16:56:47.196759: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-01 16:56:47.231194: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-01 16:56:47.231226: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-01 16:56:47.232068: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-01 16:56:47.237363: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-01 16:56:47.891839: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/mnt/disk1/wangyi/transformers/src/transformers/generation/utils.py:1181: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
[[{'generated_text': ''}, {'generated_text': 'there'}, {'generated_text': 'are'}, {'generated_text': 'two'}, {'generated_text': 'soccer'}, {'generated_text': 'players'}, {'generated_text': 'playing'}, {'generated_text': 'soccer'}, {'generated_text': 'on'}, {'generated_text': 'the'}, {'generated_text': 'field'}, {'generated_text': ''}], [{'generated_text': ''}, {'generated_text': 'there'}, {'generated_text': 'are'}, {'generated_text': 'two'}, {'generated_text': 'soccer'}, {'generated_text': 'players'}, {'generated_text': 'playing'}, {'generated_text': 'soccer'}, {'generated_text': 'on'}, {'generated_text': 'the'}, {'generated_text': 'field'}, {'generated_text': ''}]]

after the fix

2024-03-01 16:58:59.545767: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-01 16:58:59.580669: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-01 16:58:59.580702: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-01 16:58:59.581568: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-01 16:58:59.586970: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-01 16:59:00.239049: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/mnt/disk1/wangyi/transformers/src/transformers/generation/utils.py:1181: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
[[{'generated_text': 'there are two soccer players playing soccer on the field'}], [{'generated_text': 'there are two soccer players playing soccer on the field'}]]

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 1, 2024

I have tested with non-batched input and confirm it's equivalent.

@amyeroberts
Copy link
Copy Markdown
Contributor

Thanks for providing the info @sywangyi. Looking at your example, I think the issue is coming from the argument batch_size=2. If I don't pass this, then my pipeline behaves as expected:

In [9]: from transformers import pipeline
   ...: import torch
   ...: import requests
   ...: import PIL.Image
   ...: import time
   ...:
   ...: image_url = "https://ankur3107.github.io/assets/images/image-captioning-example.png"
   ...: image = []
   ...: image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))
   ...: image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))
   ...:
   ...: generator = pipeline(
   ...:    "image-to-text",
   ...:     model="Salesforce/blip-image-captioning-large",
   ...: )
   ...:
   ...: result = generator(image)
   ...: print(f"{result}")
/Users/amyroberts/code/transformers/src/transformers/generation/utils.py:1181: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
[[{'generated_text': 'arafed image of a soccer player kicking a soccer ball'}], [{'generated_text': 'arafed image of a soccer player kicking a soccer ball'}]]

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 2, 2024

Hi, @amyeroberts with and wo batch_size=2, it goes to different running logic.
with batch_size=2, the running logic is like
preprocess - > forward (bs=2)->postprocess
wo batch_size = 2, the running logic is like
preprocess->forward(bs=1)->postprocess->preprocess->forward(bs=1)->postprocess

@amyeroberts
Copy link
Copy Markdown
Contributor

@sywangyi Right. My understanding is that batch_size is to be used when passing in e.g. a dataset object to the pipeline. Is there a reason you passed it in here, rather than calling the pipeline directly?

My concern is that this is going to change the behaviour for the cases when datasets are passed - it's not obvious from the change or PR description what is being fixed here.

Could you add some tests to make sure the pipeline still have the intended behaviour for the following cases:

  • Passing in a list of two elements with batch_size=2
  • Passing in a list of two elements
  • Passing in a dataset wit batch_size=2 c.f. the docs for an example to use

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 5, 2024

from transformers import pipeline
import torch
import requests
import PIL.Image
import time

image_url = "https://ankur3107.github.io/assets/images/image-captioning-example.png"
image = []
image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))
image.append(PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw))

generator = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large",
torch_dtype=torch.bfloat16,
device="cuda",
)

result = generator(image, batch_size=2)
print(f"{result}")

Hi, @amyeroberts . I write the test for the third case you mentioned

from transformers import pipeline
import torch
import requests
import PIL.Image
import time
from torch.utils.data import Dataset

image2_url = "https://ankur3107.github.io/assets/images/image-captioning-example.png"
image_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
class MyDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, i):
        if i % 2 == 0:
            return PIL.Image.open(requests.get(image_url, stream=True, timeout=3000).raw)
        else:
            return PIL.Image.open(requests.get(image2_url, stream=True, timeout=3000).raw)

dataset = MyDataset()


generator = pipeline(
   "image-to-text",
    model="Salesforce/blip-image-captioning-large",
    torch_dtype=torch.bfloat16,
    device="cuda",
)

result = generator(dataset, batch_size=2)
print(f"{list(result)}")

without the PR. the behavior of the 3 cases you mentioned is

  • Passing in a list of two elements with batch_size=2 (incorrect)
  • Passing in a list of two elements (correct)
  • Passing in a dataset wit batch_size=2 (incorrect)

with the PR.

  • Passing in a list of two elements with batch_size=2 (correct)
  • Passing in a list of two elements (correct)
  • Passing in a dataset wit batch_size=2 (correct)

@amyeroberts
Copy link
Copy Markdown
Contributor

@sywangyi Sorry, what I meant was add a test to our testing suite.

@sywangyi
Copy link
Copy Markdown
Contributor Author

sywangyi commented Mar 6, 2024

I add some test, is it enough? @amyeroberts

Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this fix and tests!

Just a small note on the tests. One addressed I think we'll be good to merge!

Comment thread tests/pipelines/test_pipelines_image_to_text.py Outdated
@sywangyi sywangyi force-pushed the image-to-text-pipeline branch from 38898b6 to 8965fc9 Compare March 8, 2024 10:51
Copy link
Copy Markdown
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great - thanks a lot for working on this and adding tests!

@amyeroberts amyeroberts merged commit 8ee1d47 into huggingface:main Mar 8, 2024
sywangyi added 3 commits March 8, 2024 14:56
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
@sywangyi sywangyi deleted the image-to-text-pipeline branch November 19, 2025 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants