Skip to content

[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image#9626

Merged
DarkLight1337 merged 1 commit intomainfrom
fix-mllama-multi-image
Oct 24, 2024
Merged

[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image#9626
DarkLight1337 merged 1 commit intomainfrom
fix-mllama-multi-image

Conversation

@mgoin
Copy link
Copy Markdown
Member

@mgoin mgoin commented Oct 23, 2024

Mllama could trigger a CUDA error for illegal memory access within the cross-attention SDPA during inference or even profile run when running batch multi-image requests.

When inspecting the inputs to the SDPA, I noticed that the Q/K/V states were not contiguous in all cases (attention_mask is contiguous). Once forcing all of them to be contiguous, I found this issue went away.

output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)

Script used for testing:

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

sampling_params = SamplingParams(temperature=0.0, max_tokens=100)

model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
llm = LLM(model=model_name,
          max_model_len=4096,
          max_num_seqs=16,
          enforce_eager=True,
          limit_mm_per_prompt={"image": 4},
)

image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image2 = ImageAsset("stop_sign").pil_image.convert("RGB")
input1 = {
    "prompt": "<|image|><|begin_of_text|>Describe the image.",
    "multi_modal_data": {"image": image1},
}
input2 = {
    "prompt": "<|image|><|image|><|image|><|image|><|begin_of_text|>How many duplicated images are there?",
    "multi_modal_data": {"image": [image1, image2, image1, image1]},
}
input3 = {
    "prompt": "<|image|><|image|><|begin_of_text|>Are the images the same?",
    "multi_modal_data": {"image": [image1, image1]},
}
outputs = llm.generate([input1, input2, input3], sampling_params=sampling_params)

for i, output in enumerate(outputs):
    print(f"\nOutput #{i}:", output.outputs[0].text)

Error when running with CUDA_LAUNCH_BLOCKING=1:

[rank0]:   File "/home/mgoin/code/vllm/vllm/model_executor/models/mllama.py", line 870, in forward
[rank0]:     hidden_states = self.cross_attn(
[rank0]:   File "/home/mgoin/venvs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/mgoin/venvs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/mgoin/code/vllm/vllm/model_executor/models/mllama.py", line 755, in forward
[rank0]:     output = self.attention_with_mask(q, k, v, kv_cache,
[rank0]:   File "/home/mgoin/code/vllm/vllm/model_executor/models/mllama.py", line 811, in attention_with_mask
[rank0]:     output = F.scaled_dot_product_attention(q,
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Output when running with this PR:

Output #0: The image depicts a serene scene of a white tower, likely a skyscraper or monument, set against a backdrop of vibrant pink cherry blossoms and a clear blue sky. The purpose of the image is to showcase the beauty of nature and architecture in harmony.

* A white tower:
        + The tower is tall and slender, with a distinctive shape that tapers towards the top.
        + It has a series of windows and a pointed roof, giving it a futuristic appearance.
        + The tower is

Output #1:  There are 2 duplicated images in the image. One is the image of the tower and the other is the image of the cherry blossoms. The image of the tower is duplicated in the background, and the image of the cherry blossoms is duplicated in the foreground. The duplicated images are not identical, as the tower is in the background and the cherry blossoms are in the foreground. The duplicated images are not symmetrical, as the tower is on the left side of the image and the cherry

Output #2:  The first image shows a white tower with a round top, surrounded by pink flowers. The second image shows a white tower with a round top, surrounded by pink flowers. The images are similar, but not identical. The first image has a more vibrant color scheme, while the second image has a more muted tone. The first image also has a more prominent shadow on the left side of the tower, while the second image has a more even lighting. Overall, the two images are similar, but with

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mgoin mgoin changed the title Fix Mllama SDPA illegal memory access for batched multi-image [Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image Oct 23, 2024
Signed-off-by: mgoin <michael@neuralmagic.com>
@mgoin mgoin force-pushed the fix-mllama-multi-image branch from d1398b4 to c82a070 Compare October 23, 2024 17:23
@heheda12345
Copy link
Copy Markdown
Collaborator

LGTM. Thanks for this bug fix!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2024
@mgoin
Copy link
Copy Markdown
Member Author

mgoin commented Oct 23, 2024

From more testing it does seem to decrease performance, but I think this is worth it for sanity at the moment

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@DarkLight1337 DarkLight1337 merged commit bb01f29 into main Oct 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…ti-image (vllm-project#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
…ti-image (vllm-project#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
@simon-mo simon-mo deleted the fix-mllama-multi-image branch October 28, 2024 16:50
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
…ti-image (vllm-project#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…ti-image (vllm-project#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…ti-image (vllm-project#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants