Skip to content

Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor#38148

Merged
tlrmchlsmth merged 3 commits intovllm-project:mainfrom
elvircrn:fix/fp4-scale-padding-nan
Apr 1, 2026
Merged

Fix NaN from stale FP4 scale padding in create_fp4_scale_tensor#38148
tlrmchlsmth merged 3 commits intovllm-project:mainfrom
elvircrn:fix/fp4-scale-padding-nan

Conversation

@elvircrn
Copy link
Copy Markdown
Contributor

Summary

  • Zero-fill FP4 scale tensors in create_fp4_scale_tensor (torch.emptytorch.zeros)
  • Fixes NaN contamination in MoE expert outputs on Blackwell (GB200) with NVFP4 quantization

Root cause

create_fp4_scale_tensor allocates the swizzled scale tensor with torch.empty. When the number of rows m is less than rounded_m (rounded up to 128 for the tile boundary), the padding rows' scales contain stale GPU memory. If that memory holds FP8 NaN (0x7F in float8_e4m3fn), the TRT-LLM mm_fp4 kernel with use_8x4_sf_layout=True (triggered when m <= 32) reads these padding scales and applies them to real rows, producing NaN output.

In practice, this manifests as sporadic NaN in MoE layer outputs during prefill — experts receiving ≤32 tokens (common with 256 experts) hit the use_8x4_sf_layout path. The NaN then cascades through all subsequent layers, corrupts KV cache entries, and propagates to decode servers via KV transfer (NIXL).

Fix

Replace torch.empty with torch.zeros for the scale tensor allocation. Zero scales ensure padding contributes 0 × data = 0 to real rows' output, neutralizing the kernel bug.

Related

Test plan

  • Verified NaN reproduction on GB200 cluster with DeepSeek-R1-0528-NVFP4-v2
  • Confirmed NaN originates at MoE layer 41 on prefill (clean input → NaN after MoE)
  • Confirmed fix neutralizes the padding scale contamination
  • Run existing FP4 unit tests

🤖 Generated with Claude Code

Padding rows in the swizzled scale tensor were uninitialized (torch.empty),
containing stale NaN from prior GPU allocations. The TRT-LLM mm_fp4 kernel
with use_8x4_sf_layout=True reads padding scales and applies them to real
rows, contaminating output with NaN.

Zero-filling ensures padding scales contribute 0 * data = 0.

Fixes: flashinfer-ai/flashinfer#2861

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies vllm/_custom_ops.py to initialize FP4 scale tensors with zeros using torch.zeros instead of torch.empty. This change ensures that the tensors are predictably initialized, preventing potential issues from uninitialized memory. There are no review comments to address.

@mergify mergify bot added the v1 label Mar 28, 2026
@elvircrn elvircrn force-pushed the fix/fp4-scale-padding-nan branch 3 times, most recently from 51ccde0 to 2432520 Compare March 28, 2026 18:25
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 30, 2026
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 30, 2026 11:50
auto-merge was automatically disabled March 30, 2026 12:26

Pull Request is not mergeable

tlrmchlsmth added a commit to tlrmchlsmth/vllm that referenced this pull request Mar 30, 2026
Cherry-pick d4a41a9: Revert "Zero-init MLA attention output buffers
to prevent NaN from CUDA graph padding (vllm-project#37442)"

Apply PR vllm-project#38148: Fix NaN from stale FP4 scale padding in
create_fp4_scale_tensor

Signed-off-by: Travis Stephens <travis@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 31, 2026 19:56
@tlrmchlsmth tlrmchlsmth merged commit 0fab52f into vllm-project:main Apr 1, 2026
48 checks passed
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…-project#38148)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants