[AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault#23692
Merged
gshtras merged 5 commits intovllm-project:mainfrom Sep 2, 2025
Merged
Conversation
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
…libi Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a critical integer overflow in the _fwd_kernel for AMD GPUs by casting the bn tensor to tl.int64, which prevents a potential GPU segfault. However, as noted in the pull request description, similar vulnerabilities exist in _fwd_kernel_flash_attn_v2 and _fwd_kernel_alibi. The fixes for these functions are currently missing from the patch. It is crucial to include these changes to ensure the bug is fully resolved across all relevant kernels.
jatseng-ai
approved these changes
Aug 27, 2025
SageMoore
approved these changes
Aug 28, 2025
gshtras
approved these changes
Aug 28, 2025
845473182
pushed a commit
to 845473182/vllm
that referenced
this pull request
Sep 3, 2025
* 'main' of https://github.com/845473182/vllm: (457 commits) [BugFix] Fix routed_scaling_factor double mul for dots1 and glm4 MoE models (vllm-project#24132) [Misc] Add check for dual_chunk_attention (vllm-project#24070) [Doc]: fix typos in Python comments (vllm-project#24115) [Doc]: fix typos in Python comments (vllm-project#24093) [Compile] Fix Compile Warning for `w4a8_mm_entry.cu` (vllm-project#23660) fix some typos (vllm-project#24071) [V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (vllm-project#23656) Upgrade xgrammar to 0.1.23 (vllm-project#22988) Update release pipeline post PyTorch 2.8.0 update (vllm-project#24073) [XPU] Fix the bug of LoRA logits on the XPU platform (vllm-project#24081) [CI/Build] Disable SiluMul NVFP4 quant fusion tests (vllm-project#24121) [Bug] R1 Accuracy: Fix `routed_scaling_factor` Double Mul Issue (vllm-project#24119) [AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault (vllm-project#23692) [CI] Enable all hf transformers baselines in test_hybrid (vllm-project#23936) [Log] Only Print Profiler Results on Rank 0 (vllm-project#23370) Fix weights loading for Apertus (vllm-project#24100) [Metrics] Deprecate TPOT in favor of ITL (vllm-project#24110) [Bugfix] Fix packed_factor missing attribute error (vllm-project#23902) Run ruff format on a few files. (vllm-project#24075) [Bugfix] Fix transform_config parsing in Compressed Tensors (vllm-project#23945) ...
eicherseiji
pushed a commit
to eicherseiji/vllm
that referenced
this pull request
Sep 9, 2025
… segfault (vllm-project#23692) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
FeiDaLI
pushed a commit
to FeiDaLI/vllm
that referenced
this pull request
Sep 25, 2025
… segfault (vllm-project#23692) Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The tensor loaded into
bnis multiplied bystride_k_cache_bsin the_fwd_kernelinprefix_prefill.pyand produces an integer overflow resulting in negative offsets which result in a GPU segfault. Changingstride_k_cache_bsto betl.int64in the function signature did not work. Casting thebntensor totl.int64fixes the problem. I added some additional casts into_fwd_kernel_flash_attn_v2and_fwd_kernel_alibias well.