Skip to content

fix: Exclude SM12x (desktop Blackwell) from CUDA_PTX_FP4FP6_CVT_ENABLED#3120

Closed
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm121-e2m1-ptx-exclusion
Closed

fix: Exclude SM12x (desktop Blackwell) from CUDA_PTX_FP4FP6_CVT_ENABLED#3120
RobTand wants to merge 1 commit intoNVIDIA:mainfrom
RobTand:fix/sm121-e2m1-ptx-exclusion

Conversation

@RobTand
Copy link
Copy Markdown

@RobTand RobTand commented Mar 20, 2026

Summary

SM12x GPUs (RTX 5090/5080/PRO 6000 = SM120, DGX Spark GB10 = SM121) have mma.e2m1 tensor cores but lack the cvt.rn.satfinite.e2m1x2.f32 PTX instruction for native FP4/FP6 conversion. This instruction is SM100-family only.

When SM120A/F or SM121A/F is included in the CUDA_PTX_FP4FP6_CVT_ENABLED guard in float_subbyte.h, CUTLASS emits the missing PTX instruction, which produces NaN during all NVFP4 inference on SM12x hardware.

Fix

Remove all SM12x variants (SM120A, SM120F, SM121A, SM121F) from the CUDA_PTX_FP4FP6_CVT_ENABLED preprocessor guard. SM12x falls through to the existing software E2M1 conversion path, which works correctly.

Testing

Tested on DGX Spark (SM121, 128 GB unified LPDDR5X) running:

  • Nemotron-3-Super-120B-A12B-NVFP4 — 24 tok/s via vLLM + FlashInfer CUTLASS MoE
  • Qwen3.5-122B-A10B-NVFP4 — 26 tok/s via vLLM + FlashInfer CUTLASS MoE

Without this fix, both models produce NaN output on SM12x.

Related work

This CUTLASS fix is the root cause. Downstream projects have independent software E2M1 workarounds that complement it:

  • vLLM #35947 by @blake-snc — software E2M1 fallback in vLLM's nvfp4_utils.cuh (same root cause, different code path)
  • FlashInfer — software E2M1 in vendored TRT-LLM quantization_utils.cuh (no PR yet)

cc @blake-snc — your vLLM PR #35947 addresses the same hardware limitation in vLLM's copy of the quantization code. Happy to withdraw this if you'd prefer to upstream the CUTLASS-side fix yourself, otherwise I'll work this through the review process.

Impact

This is a correctness blocker for all NVFP4 inference on desktop Blackwell GPUs (RTX 50-series and DGX Spark). Without this fix, no NVFP4 model produces valid output on SM12x.

SM12x GPUs (RTX 5090/5080/PRO 6000, DGX Spark GB10) have mma.e2m1
tensor cores but lack the cvt.rn.satfinite.e2m1x2.f32 PTX instruction
for native FP4/FP6 conversion — this instruction is SM100-family only.

When SM120A/F or SM121A/F is included in the CUDA_PTX_FP4FP6_CVT_ENABLED
guard, CUTLASS emits the missing PTX instruction, which produces NaN
during NVFP4 inference.

This change removes all SM12x variants from the guard, causing SM12x to
fall through to the existing software E2M1 conversion path.

Tested on DGX Spark (SM121) running Nemotron-3-Super-120B and
Qwen3.5-122B NVFP4 models via vLLM + FlashInfer. Without this fix,
all NVFP4 inference on SM12x produces NaN output.

Signed-off-by: Rob Tand <robert.tand@icloud.com>
@johnnynunez
Copy link
Copy Markdown

cc @depaulmillz

@RobTand
Copy link
Copy Markdown
Author

RobTand commented Mar 20, 2026

Withdrawing this PR. After further investigation prompted by @depaulmillz's comment on vllm-project/vllm#35947, we confirmed that cvt.rn.satfinite.e2m1x2.f32 IS supported on SM12x when compiled with the correct architecture flags (sm_121a or sm_120f).

The root cause is that vLLM's cmake strips the a/f suffix from TORCH_CUDA_ARCH_LIST="12.0a;12.1a", compiling as plain sm_120 instead of sm_120a/sm_121a. Plain sm_120 does not define __CUDA_ARCH_FAMILY_SPECIFIC__, which causes NVIDIA's cuda_fp8.h to disable the PTX instruction path. The CUTLASS guard in float_subbyte.h is correct as-is.

The fix belongs in vLLM's build system, not in CUTLASS. Thank you for the guidance.

@RobTand RobTand closed this Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants