Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor#38148
Merged
tlrmchlsmth merged 3 commits intovllm-project:mainfrom Apr 1, 2026
Merged
Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor#38148tlrmchlsmth merged 3 commits intovllm-project:mainfrom
tlrmchlsmth merged 3 commits intovllm-project:mainfrom
Conversation
Padding rows in the swizzled scale tensor were uninitialized (torch.empty), containing stale NaN from prior GPU allocations. The TRT-LLM mm_fp4 kernel with use_8x4_sf_layout=True reads padding scales and applies them to real rows, contaminating output with NaN. Zero-filling ensures padding scales contribute 0 * data = 0. Fixes: flashinfer-ai/flashinfer#2861 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request modifies vllm/_custom_ops.py to initialize FP4 scale tensors with zeros using torch.zeros instead of torch.empty. This change ensures that the tensors are predictably initialized, preventing potential issues from uninitialized memory. There are no review comments to address.
51ccde0 to
2432520
Compare
tlrmchlsmth
approved these changes
Mar 30, 2026
auto-merge was automatically disabled
March 30, 2026 12:26
Pull Request is not mergeable
tlrmchlsmth
added a commit
to tlrmchlsmth/vllm
that referenced
this pull request
Mar 30, 2026
Cherry-pick d4a41a9: Revert "Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding (vllm-project#37442)" Apply PR vllm-project#38148: Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor Signed-off-by: Travis Stephens <travis@anthropic.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
EricccYang
pushed a commit
to EricccYang/vllm
that referenced
this pull request
Apr 1, 2026
…-project#38148) Signed-off-by: Elvir Crncevic <elvircrn@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
create_fp4_scale_tensor(torch.empty→torch.zeros)Root cause
create_fp4_scale_tensorallocates the swizzled scale tensor withtorch.empty. When the number of rowsmis less thanrounded_m(rounded up to 128 for the tile boundary), the padding rows' scales contain stale GPU memory. If that memory holds FP8 NaN (0x7Finfloat8_e4m3fn), the TRT-LLMmm_fp4kernel withuse_8x4_sf_layout=True(triggered whenm <= 32) reads these padding scales and applies them to real rows, producing NaN output.In practice, this manifests as sporadic NaN in MoE layer outputs during prefill — experts receiving ≤32 tokens (common with 256 experts) hit the
use_8x4_sf_layoutpath. The NaN then cascades through all subsequent layers, corrupts KV cache entries, and propagates to decode servers via KV transfer (NIXL).Fix
Replace
torch.emptywithtorch.zerosfor the scale tensor allocation. Zero scales ensure padding contributes0 × data = 0to real rows' output, neutralizing the kernel bug.Related
Test plan
🤖 Generated with Claude Code