[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image#9626
[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image#9626DarkLight1337 merged 1 commit intomainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: mgoin <michael@neuralmagic.com>
d1398b4 to
c82a070
Compare
|
LGTM. Thanks for this bug fix! |
|
From more testing it does seem to decrease performance, but I think this is worth it for sanity at the moment |
…ti-image (vllm-project#9626) Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Alvant <alvasian@yandex.ru>
…ti-image (vllm-project#9626) Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Erkin Sagiroglu <erkin@infra-aipipeline-1-at1-prox-prod-a.ipa.corp.telnyx.com>
…ti-image (vllm-project#9626) Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: qishuai <ferdinandzhong@gmail.com>
…ti-image (vllm-project#9626) Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
…ti-image (vllm-project#9626) Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Mllama could trigger a CUDA error for illegal memory access within the cross-attention SDPA during inference or even profile run when running batch multi-image requests.
When inspecting the inputs to the SDPA, I noticed that the Q/K/V states were not contiguous in all cases (
attention_maskis contiguous). Once forcing all of them to be contiguous, I found this issue went away.vllm/vllm/model_executor/models/mllama.py
Lines 810 to 814 in fd0e2cf
Script used for testing:
Error when running with
CUDA_LAUNCH_BLOCKING=1:Output when running with this PR: