Skip to content

fix: Software E2M1 conversion for SM12x NVFP4 activation quantization#35947

Closed
blake-snc wants to merge 7 commits intovllm-project:mainfrom
blake-snc:fix/nvfp4-sw-e2m1-sm12x
Closed

fix: Software E2M1 conversion for SM12x NVFP4 activation quantization#35947
blake-snc wants to merge 7 commits intovllm-project:mainfrom
blake-snc:fix/nvfp4-sw-e2m1-sm12x

Conversation

@blake-snc
Copy link
Copy Markdown

@blake-snc blake-snc commented Mar 4, 2026

Summary

  • SM12x GPUs (RTX 5090, GB10 / DGX Spark) lack the hardware cvt.rn.satfinite.e2m1x2.f32 PTX instruction used in NVFP4 activation quantization — this instruction is SM100-only, causing an illegal instruction crash when running NVFP4 models with the CUTLASS backend on SM12x
  • Add software E2M1 conversion functions guarded by #if __CUDA_ARCH__ >= 1200 && < 1300 that use float threshold-based rounding to match the hardware instruction's behavior
  • The Marlin backend already works on SM12x (no activation quantization needed), but the CUTLASS path was broken

Validation

Tested on DGX Spark (NVIDIA GB10, SM121a, 128 GB unified LPDDR5X).

Dense model: nvidia/Llama-3.1-8B-Instruct-NVFP4

Metric Value
Quantization modelopt_fp4
NVFP4 GEMM Backend FLASHINFER_CUTLASS
Model Memory 5.65 GiB
Throughput 111.4 tok/s
Output correctness PASS

MoE model: Sehyo/Qwen3.5-122B-A10B-NVFP4

Metric Value
Quantization compressed-tensors (nvfp4-pack, auto-detected)
NVFP4 GEMM Backend FLASHINFER_CUTLASS
MoE Backend FLASHINFER_CUTLASS
Architecture Qwen3_5MoeForConditionalGeneration (122B total, 10B active)
Model Memory 71.32 GiB
Throughput 1.34 tok/s output
Output correctness PASS

Both models produce coherent, correct outputs:

Prompt: "What is the capital of France?"
Llama:  "Paris"
Qwen:   "The capital of France is Paris."

Prompt: "Write a haiku about AI."
Qwen:   "Silent circuits hum, / Learning from the world's vast past, / New thoughts start to bloom."

Note: The Qwen3.5-122B MoE model (71.3 GiB weights) requires ~48 GB additional swap on DGX Spark to handle peak memory during safetensors loading, as CPU intermediates and CUDA allocations compete for the same 128 GB unified memory pool. This is a vLLM loader limitation on unified memory architectures, not related to this PR's E2M1 fix.

Test plan

  • Verify CUTLASS backend boots without illegal instruction crash (dense model)
  • Verify CUTLASS backend boots without illegal instruction crash (MoE model)
  • Verify output correctness — factual queries return correct answers
  • Verify both modelopt_fp4 and compressed-tensors quantization formats work
  • Verify Marlin backend still works (no regression)

Fixes #35519, #30163

Contributed by Second Nature Computing

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a software fallback for E2M1 conversion on SM12x GPUs, which lack the necessary hardware instruction. The changes are logical and well-contained. I've found a critical correctness issue in the rounding logic of the software implementation that needs to be addressed. Additionally, there's an opportunity to refactor duplicated code to improve maintainability. My detailed feedback is in the comments below.

Comment on lines +45 to +66
__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) {
uint8_t sign = (__float_as_uint(v) >> 31) & 1;
float av = fabsf(v);
uint8_t e2m1;
if (av < 0.25f)
e2m1 = 0; // → 0.0
else if (av < 0.75f)
e2m1 = 1; // → 0.5
else if (av < 1.25f)
e2m1 = 2; // → 1.0
else if (av < 1.75f)
e2m1 = 3; // → 1.5
else if (av < 2.5f)
e2m1 = 4; // → 2.0
else if (av < 3.5f)
e2m1 = 5; // → 3.0
else if (av < 5.0f)
e2m1 = 6; // → 4.0
else
e2m1 = 7; // → 6.0 (satfinite)
return (sign << 3) | e2m1;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The software implementation of sw_float_to_e2m1 does not correctly emulate the round-to-nearest-even behavior of the cvt.rn.satfinite.e2m1x2.f32 PTX instruction for midpoint values. The current implementation uses round-half-up, which will lead to correctness issues. For example, for a value of 0.25, which is a midpoint between 0.0 (rep 0) and 0.5 (rep 1), it should round to 0.0 because its integer representation 0 is even. The current code rounds it to 0.5. This is a critical issue that will cause divergence from the hardware implementation.

__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) {
  uint8_t sign = (__float_as_uint(v) >> 31) & 1;
  float av = fabsf(v);
  uint8_t e2m1;
  // E2M1 representable values (integer representation):
  // 0.0 (0), 0.5 (1), 1.0 (2), 1.5 (3), 2.0 (4), 3.0 (5), 4.0 (6), 6.0 (7)
  if (av <= 0.25f) e2m1 = 0;      // Midpoint 0.25 rounds to 0 (even)
  else if (av < 0.75f) e2m1 = 1;
  else if (av <= 1.25f) e2m1 = 2; // Midpoints 0.75, 1.25 round to 2 (even)
  else if (av < 1.75f) e2m1 = 3;
  else if (av <= 2.5f) e2m1 = 4;  // Midpoints 1.75, 2.5 round to 4 (even)
  else if (av < 3.5f) e2m1 = 5;
  else if (av <= 5.0f) e2m1 = 6;  // Midpoints 3.5, 5.0 round to 6 (even)
  else e2m1 = 7;
  return (sign << 3) | e2m1;
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code already implements round-to-nearest-even correctly. The <= comparisons at midpoints (0.25, 1.25, 2.5, 5.0) direct ties to the lower even codes (0, 2, 4, 6), while the < comparisons at midpoints (0.75, 1.75, 3.5) direct ties to the higher even codes (2, 4, 6 in the next branch). This matches the hardware cvt.rn.satfinite.e2m1x2.f32 behavior — verified empirically on SM121a by comparing hardware vs software outputs for all E2M1 midpoint values.

Comment on lines +98 to +104
// SM12x: software E2M1 conversion (no hardware cvt.rn.satfinite.e2m1x2.f32)
uint8_t b0 = sw_e2m1x2_from_f32(array[1], array[0]);
uint8_t b1 = sw_e2m1x2_from_f32(array[3], array[2]);
uint8_t b2 = sw_e2m1x2_from_f32(array[5], array[4]);
uint8_t b3 = sw_e2m1x2_from_f32(array[7], array[6]);
val = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) |
((uint32_t)b3 << 24);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The software E2M1 conversion logic is duplicated across fp32_vec8_to_e2m1 (both overloads) and fp32_vec16_to_e2m1. This can be refactored to improve maintainability by introducing a helper function.
I suggest adding the following helper function inside the #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300 block (e.g., after sw_e2m1x2_from_f32):

__device__ __forceinline__ uint32_t sw_fp32_vec8_to_e2m1_impl(const float2* array) {
  uint8_t b0 = sw_e2m1x2_from_f32(array[0].y, array[0].x);
  uint8_t b1 = sw_e2m1x2_from_f32(array[1].y, array[1].x);
  uint8_t b2 = sw_e2m1x2_from_f32(array[2].y, array[2].x);
  uint8_t b3 = sw_e2m1x2_from_f32(array[3].y, array[3].x);
  return (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) |
         ((uint32_t)b3 << 24);
}

Using this helper, this function can be simplified as shown in the suggestion.
The other overloads can be simplified as follows:

  • For fp32_vec8_to_e2m1(float2 (&array)[4]), the body can be val = sw_fp32_vec8_to_e2m1_impl(array);.
  • For fp32_vec16_to_e2m1(float2 (&array)[8]), the body can be out.lo = sw_fp32_vec8_to_e2m1_impl(array); out.hi = sw_fp32_vec8_to_e2m1_impl(array + 4);.
  // SM12x: software E2M1 conversion (no hardware cvt.rn.satfinite.e2m1x2.f32)
  val = sw_fp32_vec8_to_e2m1_impl(reinterpret_cast<const float2*>(array));

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done — sw_fp32_vec8_to_e2m1(const float2* array) is defined at line 78 and used by all three overloads (fp32_vec8_to_e2m1(float[8]) at line 111, fp32_vec8_to_e2m1(float2[4]) at line 136, and fp32_vec16_to_e2m1(float2[8]) at lines 166-167).

@blake-snc blake-snc force-pushed the fix/nvfp4-sw-e2m1-sm12x branch from 746a112 to f071e9c Compare March 5, 2026 20:23
blake-snc and others added 6 commits March 5, 2026 12:25
…ariants)

`get_marlin_input_dtype()` uses `is_device_capability(120)` which is an
exact match — SM121 devices (DGX Spark GB10, RTX 5090) return capability
(12, 1) and fail the check, blocking Marlin W4A8-FP8 with a misleading
"only support SM89 or SM120" error.

Changed to `has_device_capability(120)` which uses >= comparison,
allowing SM120 and all Blackwell variants (SM121, SM121a, etc.) while
still correctly blocking SM90 (Hopper) where Marlin FP8 is slower than
W4A16.

The SM89 (Ada) check remains as `is_device_capability(89)` since there
are no Ada variants.

Validated on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):
- Before: `is_device_capability(120)` → False → ValueError raised
- After:  `has_device_capability(120)` → True  → FP8 dtype returned
- SM90 still correctly blocked (has_device_capability(120) → False)
- SM89 still correctly allowed (is_device_capability(89) → True)

Fixes vllm-project#35432
Relates to vllm-project#30135

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Address review feedback: arch >= 120 would incorrectly match future
arch families (SM130+). Use arch // 10 == 12 for codegen and
major_capability == 12 for runtime to scope checks to the SM12x family.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
…tion

SM12x (RTX 5090, GB10 / DGX Spark) lacks the hardware
cvt.rn.satfinite.e2m1x2.f32 PTX instruction used in NVFP4 activation
quantization. This instruction is SM100-only, causing an illegal instruction
crash when running NVFP4 models on SM12x with the CUTLASS backend.

Add a software E2M1 conversion path guarded by #if __CUDA_ARCH__ >= 1200 &&
< 1300 that uses float threshold-based rounding to match the hardware
instruction's rounding behavior (round-to-nearest-even, satfinite).

Validated on DGX Spark (SM121a) with nvidia/Llama-3.1-8B-Instruct-NVFP4:

| Metric            | Marlin backend | CUTLASS (sw E2M1) |
|--------------------|----------------|-------------------|
| Boot success       | YES            | YES               |
| Decode tok/s       | ~15.8          | ~18.3             |
| TTFT (ms)          | ~190           | ~138              |
| GPU memory (MiB)   | 61,147         | 60,448            |
| Output correctness | PASS           | PASS              |

Fixes: vllm-project#35519, vllm-project#30163

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Address review feedback:

1. Fix rounding behavior to match hardware cvt.rn (round-to-nearest-even).
   At midpoints (0.25, 1.25, 2.5, 5.0), the previous code rounded to the
   upper value (odd code), but IEEE 754 RNE breaks ties to the even code.
   Changed < to <= at those midpoints.

2. Extract shared sw_fp32_vec8_to_e2m1() helper to deduplicate the packing
   logic across all three fp32_vec*_to_e2m1 function overloads.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
@blake-snc blake-snc force-pushed the fix/nvfp4-sw-e2m1-sm12x branch from f071e9c to fb63fc9 Compare March 5, 2026 20:32
@johnnynunez
Copy link
Copy Markdown
Contributor

cc @mgoin could you run tests and merge it?

@meena-at-work
Copy link
Copy Markdown

@pavanimajety -- can you please review this change? I've verified that this change improves NVFP4 perf on SM121 via marlin.

@pavanimajety
Copy link
Copy Markdown
Collaborator

@meena-at-work/@blake-snc Could you also add gsm8k results for Qwen3.5 NVFP4 on Spark?

@RobTand
Copy link
Copy Markdown

RobTand commented Mar 20, 2026

This fix is critical — without it, all NVFP4 inference on SM12x produces NaN. I'm running Nemotron-3-Super-120B and Qwen3.5-122B on DGX Spark with this exact patch applied locally.

Re @pavanimajety's request for gsm8k results: I can run the full gsm8k evaluation on DGX Spark. Note that Qwen3.5-122B is a thinking/reasoning model that routes output to reasoning_content, which lm-eval can't parse — so results require either disabling thinking mode or a custom eval harness. Happy to provide whichever format is most useful.

I've also submitted the root cause fix to CUTLASS directly: NVIDIA/cutlass#3120 (excludes SM12x from CUDA_PTX_FP4FP6_CVT_ENABLED in float_subbyte.h). Both fixes are needed — this PR covers vLLM's copy in nvfp4_utils.cuh, the CUTLASS PR covers the upstream header.

@blake-snc
Copy link
Copy Markdown
Author

@meena-at-work/@blake-snc Could you also add gsm8k results for Qwen3.5 NVFP4 on Spark?

Sorry I have been out of town, will take care of this ASAP!

@RobTand
Copy link
Copy Markdown

RobTand commented Mar 20, 2026

@meena-at-work/@blake-snc Could you also add gsm8k results for Qwen3.5 NVFP4 on Spark?

Sorry I have been out of town, will take care of this ASAP!

thank you sir!

@depaulmillz
Copy link
Copy Markdown

When compiling this kernel what specific flags are being passed to NVCC? Also which CUDA version are you using. cvt.rn.satfinite.e2m1x2.f32 is supported on DGX Spark starting in CUDA 12.9.

@RobTand
Copy link
Copy Markdown

RobTand commented Mar 20, 2026

you are a scholar and a gentleman, sir. You are indeed correct. I just wrote a small script that proves it. There appears to be a bug in vllm that doesn't preserve the suffix. extract_unique_cuda_archs_ascending drops the suffix via string_to_ver. So by the time we reach the intersection, 12.1a has become 12.1, and 12.1 isn't in CUDA_SUPPORTED_ARCHS. I'll pull my PR and submit a new one (I had one that was a derivative of this). I need to test it though.

Way to come through on a friday night :)

@RobTand
Copy link
Copy Markdown

RobTand commented Mar 20, 2026

Follow-up: after investigating @depaulmillz's comment, we confirmed the native E2M1 PTX instruction does work on SM121 when compiled with the correct architecture flags (sm_121a).

The root cause is that vLLM's cmake strips the a/f suffix from gencode flags, compiling SM12x as plain sm_120 instead of sm_121a. Without the suffix, __CUDA_ARCH_FAMILY_SPECIFIC__ is undefined and cuda_fp4.hpp disables the native PTX path.

Fix submitted in #37725 — three small changes to cmake/utils.cmake and CMakeLists.txt to preserve the architecture suffix.

If #37725 is adopted, the software E2M1 fallback in this PR may no longer be needed for the JIT path (since the native instruction would be available). However, the AOT-compiled path in _C.abi3.so would still need the cmake fix to compile with the correct flags.

@blake-snc
Copy link
Copy Markdown
Author

blake-snc commented Mar 24, 2026

GSM8K Results on SM121a (DGX Spark GB10) — txn545/Qwen3.5-35B-A3B-NVFP4

Running gsm8k_cot (lm-eval) with --apply_chat_template --fewshot_as_multiturn --system_instruction "/no_think" on DGX Spark (SM121a, 128 GB unified LPDDR5X), with this PR's SW E2M1 fix applied.

Backend configuration

Layer type Backend Notes
Linear layers (NVFP4) MARLIN VLLM_NVFP4_GEMM_BACKEND=marlin
MoE layers (NVFP4) MARLIN VLLM_TEST_FORCE_FP8_MARLIN=1
Attention FLASHINFER Works fine

The FLASHINFER_CUTLASS MoE backend crashes on SM121a with illegal instruction because its autotuner selects cutlass_kernel_file_gemm_grouped_sm120_*.cu tactics compiled for plain sm_120 (no _a suffix) — different from sm_121a. MARLIN runs cleanly.

Results — 250 samples, max_model_len=4096, max_gen_toks=1024

Configuration Filter n-shot exact_match stderr
8-shot CoT + multiturn strict-match 8 0.312 ±0.0294
8-shot CoT + multiturn flexible-extract 8 0.084 ±0.0177
8-shot CoT, no multiturn strict-match 8 0.292 ±0.0288
8-shot CoT, no multiturn flexible-extract 8 0.100 ±0.0190
0-shot flexible-extract 0 0.312 ±0.0294
0-shot strict-match 0 0.028 ±0.0105

Best: 31.2% exact match (8-shot strict, or 0-shot flexible-extract).

Key observation on 8-shot strict-match: ~39% of responses are truncated early because the model's "Thinking Process:" meta-analysis writes the question back as Q: Janet's ducks..., hitting the Q: stop token before completing the answer. This is a model behavior artifact of /no_think with few-shot examples, not a quantization issue.

Zero-shot flexible-extract is the cleaner metric here — 31.2% — where we just check that the last number in the response is correct, bypassing the stop-token truncation issue.

No crashes — 250/250 samples completed

With this PR applied, all 250 samples ran to completion with no CUDA illegal instruction errors and no NaN outputs. Before this fix, the NVFP4 activation path would crash immediately on SM121a.


Re: @depaulmillz and @RobTand's follow-up

@depaulmillz correctly identified that cvt.rn.satfinite.e2m1x2.f32 works on SM121a when compiled with sm_121a flags + CUDA ≥ 12.9. @RobTand's #37725 fixes vLLM's cmake to preserve the arch suffix so that path becomes available for source builds.

These two fixes are at different layers and both remain necessary:

NVIDIA/cutlass#3120 (excluding SM12x from CUDA_PTX_FP4FP6_CVT_ENABLED in CUTLASS's float_subbyte.h) was closed without merging, so the upstream CUTLASS header is unchanged.

@johnnynunez
Copy link
Copy Markdown
Contributor

cc @mgoin @LucasWilkinson

Comment on lines +48 to +69
__device__ __forceinline__ uint8_t sw_float_to_e2m1(float v) {
uint8_t sign = (__float_as_uint(v) >> 31) & 1;
float av = fabsf(v);
uint8_t e2m1;
// Midpoint tie-breaking: <= rounds to lower (even) code, < rounds to upper.
if (av <= 0.25f)
e2m1 = 0; // 0.0; midpoint 0.25 → code 0 (even)
else if (av < 0.75f)
e2m1 = 1; // 0.5; midpoint 0.75 → code 2 (even, next branch)
else if (av <= 1.25f)
e2m1 = 2; // 1.0; midpoint 1.25 → code 2 (even)
else if (av < 1.75f)
e2m1 = 3; // 1.5; midpoint 1.75 → code 4 (even, next branch)
else if (av <= 2.5f)
e2m1 = 4; // 2.0; midpoint 2.5 → code 4 (even)
else if (av < 3.5f)
e2m1 = 5; // 3.0; midpoint 3.5 → code 6 (even, next branch)
else if (av <= 5.0f)
e2m1 = 6; // 4.0; midpoint 5.0 → code 6 (even)
else
e2m1 = 7; // 6.0 (satfinite)
return (sign << 3) | e2m1;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there really not a more efficient implementation than this? This seems like it would be quite slow

// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth redefining this as new def like "SOFTWARE_E2M1_CONVERT"

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 25, 2026

Sorry I didn't read all the comments since @johnnynunez pinged me here. I will focus on #37725
Can we close this?

@mgoin mgoin closed this Mar 25, 2026
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: Qwen3.5 NVFP4 models crash on ARM64 GB10 DGX Spark (CUDA illegal instruction during generation)

7 participants