Skip to content

[NVIDIA] Bugfix NVFP4 DGX Spark and RTX50#38423

Open
johnnynunez wants to merge 18 commits intovllm-project:mainfrom
johnnynunez:main
Open

[NVIDIA] Bugfix NVFP4 DGX Spark and RTX50#38423
johnnynunez wants to merge 18 commits intovllm-project:mainfrom
johnnynunez:main

Conversation

@johnnynunez
Copy link
Copy Markdown
Contributor

@johnnynunez johnnynunez commented Mar 28, 2026

Summary

Fix cudaErrorIllegalInstruction when running NVFP4 models (e.g. nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4) on SM12x GPUs (RTX 50 series SM120, DGX Spark SM121).

Root causes

  1. CUTLASS v4.2.2 lacks SM12x NVFP4 tile constraints — The bundled CUTLASS was missing SM120f family-level compilation support for NVFP4/MX Grouped GEMM and SM121-specific tile configurations (DGX Spark). This caused IllegalInstruction during decode when small-M tile variants were selected. Related upstream: NVIDIA/cutlass#3038.

  2. FlashInfer 0.6.6 bundles CUTLASS 4.2.1 — The FlashInfer CUTLASS MoE backend failed on SM12x with Failed to initialize cutlass TMA WS grouped gemm due to the same missing tile constraints. Fixed upstream in flashinfer-ai/flashinfer#2798.

  3. cutlass_scaled_mm_supports_fp4() reported false availability — Only checked CUDA runtime version (>= 12080), not whether the SM-specific kernel was actually compiled. On a build with only ENABLE_NVFP4_SM100, it incorrectly reported CUTLASS as available for SM12x, then failed at dispatch.

  4. Quantization kernels had no SM runtime guard — The scaled_fp4_quant, silu_and_mul_nvfp4_quant, and expert quant entry points dispatched to _sm1xxa kernels if any SM1xx was compiled, with no runtime check. If only SM100 SASS existed, CUDA would JIT-compile SM100 PTX for SM120 (different major arch), producing illegal instructions asynchronously — surfacing later at synchronize() as an opaque CUDA error.

  5. FlashInfer CUTLASS backend bypassed quant kernel checksselect_nvfp4_linear_backend() selected FlashInfer CUTLASS solely on has_device_capability(100), without verifying the vLLM quantization kernels (used by all non-Marlin backends) were compiled for the current SM.

Changes

File Change
CMakeLists.txt Bump CUTLASS from v4.2.2 to v4.4.2 — enables SM120f (family) compilation for NVFP4/MX Grouped GEMM, covering RTX 50 (SM120) and DGX Spark (SM121)
docker/Dockerfile Bump FlashInfer from 0.6.6 to 0.6.7 (includes CUTLASS 4.4.2, fixes TMA grouped GEMM on SM12x)
docker/Dockerfile.nightly_torch Same FlashInfer bump (source build)
docker/versions.json FLASHINFER_VERSION: 0.6.60.6.7
nvfp4_scaled_mm_entry.cu cutlass_scaled_mm_supports_fp4() now checks compile-time ENABLE_NVFP4_SM100/ENABLE_NVFP4_SM120 guards per SM range instead of a blanket >= 100 check
nvfp4_quant_entry.cu Added nvfp4_quant_sm_supported() runtime guard to all four quant entry points (scaled_fp4_quant, scaled_fp4_experts_quant, silu_and_mul_nvfp4_quant, silu_and_mul_scaled_fp4_experts_quant)
nvfp4_utils.py select_nvfp4_linear_backend() gates FlashInfer CUTLASS on cutlass_fp4_supported() + adds validation assert for all FlashInfer backends

What is NOT changed

Marlin remains a valid fallback on SM12x. Marlin FP4 uses weight-only dequantization to BF16 — it does not use native FP4 tensor core instructions and works correctly on all Blackwell architectures including DGX Spark. Benchmarks confirm Marlin is stable on SM121 (~558 tok/s, on par with vLLM CUTLASS at ~562 tok/s). The Marlin path (apply_fp4_marlin_linear) bypasses the vLLM quant kernels entirely, so the SM guards in nvfp4_quant_entry.cu do not affect it.

Behavior on SM12x after this PR

Scenario Before After
Build includes ENABLE_NVFP4_SM120 + CUTLASS v4.4.2 IllegalInstruction Native CUTLASS backend selected, works correctly
Build lacks ENABLE_NVFP4_SM120 IllegalInstruction (SM100 PTX JIT to SM120) Native CUTLASS correctly reports unavailable; Marlin selected as fallback — works correctly
FlashInfer CUTLASS MoE on SM12x Failed to initialize cutlass TMA WS grouped gemm (CUTLASS 4.2.1 in FlashInfer 0.6.6) Works correctly with FlashInfer 0.6.7 (CUTLASS 4.4.2)

Follow-up: FlashInfer 0.6.8

flashinfer-ai/flashinfer#2738 (merged March 28, 2026) adds native NVFP4 and MXFP4 group GEMM support for SM120/SM121 (RTX 50 / DGX Spark) directly in FlashInfer. This will land in FlashInfer 0.6.8. Once released, FLASHINFER_VERSION should be bumped in docker/Dockerfile, docker/Dockerfile.nightly_torch, and docker/versions.json to unlock FlashInfer's own SM12x NVFP4/MXFP4 kernels (including GDC unguarding and PDL group GEMM fixes). TODO comments have been added to both Dockerfiles tracking this.

Test plan

  • Build with CUDA_ARCHS="12.0a;12.1a" on DGX Spark (SM121), verify NVFP4 model serves with vLLM CUTLASS backend (VLLM_NVFP4_GEMM_BACKEND=cutlass --moe-backend=cutlass)
  • Verify FlashInfer CUTLASS MoE on SM12x no longer hits TMA init error
  • Build with CUDA_ARCHS="12.0a;12.1a", verify Marlin fallback still works (VLLM_NVFP4_GEMM_BACKEND=marlin --moe-backend=marlin)
  • Build with CUDA_ARCHS="10.0a" only, verify Marlin fallback on SM12x (no IllegalInstruction)
  • Verify SM100 (B200) still works with native CUTLASS (no regression from CUTLASS bump)
  • Verify SM89/SM90 still works (pre-Blackwell unaffected)
  • Run tests/models/quantization/test_nvfp4.py on SM120+
  • Docker build completes with FlashInfer 0.6.7 for both Dockerfile and Dockerfile.nightly_torch

Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@johnnynunez johnnynunez changed the title fix NVFP4 DGX Spark and RTX50 [NVIDIA] Bugfix NVFP4 DGX Spark and RTX50 Mar 28, 2026
@mergify mergify bot added ci/build nvidia bug Something isn't working labels Mar 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the CUTLASS revision to v4.4.2 and upgrades FlashInfer to version 0.6.7 across the Dockerfiles and requirement files. It also introduces runtime checks to verify that NVFP4 quantization kernels are compiled for the current GPU's SM version (SM100 or SM120) before use, preventing invalid backend selection or runtime failures. I have no feedback to provide.

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 28, 2026

@johnnynunez
Copy link
Copy Markdown
Contributor Author

Could a maintainer please add the ready label so CI can run? I have 3 merged PRs but need 4 to bypass the label requirement. Thank you!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 28, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs labels Mar 28, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 28, 2026

Hi @johnnynunez, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

CUTLASS v4.4.2 added ArchTag to DispatchPolicy in
sm90_gemm_tma_warpspecialized_cooperative.hpp to distinguish SM90 from
SM120 kernel paths. Machete's custom MacheteCollectiveMma defines its
own DispatchPolicy but was missing this field, causing all 18 Machete
template instantiations to fail with "has no member ArchTag".
Also reformats nvfp4_scaled_mm_entry.cu to satisfy pre-commit linter.

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@eugr
Copy link
Copy Markdown

eugr commented Mar 29, 2026

Getting consistent Illegal Instruction crashes with this PR.

Building Flashinfer from main with FLASHINFER_CUDA_ARCH_LIST=12.1a
vLLM from main with this PR applied with TORCH_CUDA_ARCH_LIST=12.1a

Exception raised from currentStreamCaptureStatusMayInitCtx at /pytorch/c10/cuda/CUDAGraphsC10Utils.h:71 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xc8 (0xf152462b6778 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) + 0x224 (0xf152463e4714 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0xf1d388 (0xf15246f8d388 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x477e40 (0xf15246297e40 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #4: c10::TensorImpl::~TensorImpl() + 0x14 (0xf15246256d84 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x5fa548 (0xf1526c7ea548 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xb46d1c (0xf1526cd36d1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #7: VLLM::EngineCore() [0x524b64]
frame #8: _PyObject_ClearManagedDict + 0x10c (0x4fd240 in VLLM::EngineCore)
frame #9: VLLM::EngineCore() [0x527adc]
frame #10: VLLM::EngineCore() [0x5b3cac]
frame #11: VLLM::EngineCore() [0x5b2fec]
frame #12: VLLM::EngineCore() [0x58cf5c]
frame #13: _PyEval_EvalFrameDefault + 0x8fdc (0x56cf40 in VLLM::EngineCore)
frame #14: VLLM::EngineCore() [0x4c4d74]
frame #15: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #16: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #17: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #18: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #19: VLLM::EngineCore() [0x55fd48]
frame #20: VLLM::EngineCore() [0x5045cc]
frame #21: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #22: VLLM::EngineCore() [0x4c4d74]
frame #23: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #24: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #25: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #26: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #27: VLLM::EngineCore() [0x55fd48]
frame #28: VLLM::EngineCore() [0x5045cc]
frame #29: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #30: VLLM::EngineCore() [0x4c4d74]
frame #31: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #32: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #33: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #34: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #35: VLLM::EngineCore() [0x55fd48]
frame #36: VLLM::EngineCore() [0x5045cc]
frame #37: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #38: VLLM::EngineCore() [0x4c4d74]
frame #39: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #40: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #41: VLLM::EngineCore() [0x560238]
frame #42: VLLM::EngineCore() [0x5045cc]
frame #43: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #44: VLLM::EngineCore() [0x4c4d74]
frame #45: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #46: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #47: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #48: VLLM::EngineCore() [0x6c14a8]
frame #49: Py_FinalizeEx + 0x58 (0x67b088 in VLLM::EngineCore)
frame #50: Py_Exit + 0x18 (0x67c518 in VLLM::EngineCore)
frame #51: VLLM::EngineCore() [0x6811d0]
frame #52: VLLM::EngineCore() [0x680f04]
frame #53: PyRun_SimpleStringFlags + 0x7c (0x67ef1c in VLLM::EngineCore)
frame #54: Py_RunMain + 0x390 (0x68b690 in VLLM::EngineCore)
frame #55: Py_BytesMain + 0x28 (0x68b198 in VLLM::EngineCore)
frame #56: <unknown function> + 0x284c4 (0xf152f23d84c4 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #57: __libc_start_main + 0x98 (0xf152f23d8598 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #58: _start + 0x30 (0x5f66f0 in VLLM::EngineCore)

@johnnynunez
Copy link
Copy Markdown
Contributor Author

Getting consistent Illegal Instruction crashes with this PR.

Building Flashinfer from main with FLASHINFER_CUDA_ARCH_LIST=12.1a vLLM from main with this PR applied with TORCH_CUDA_ARCH_LIST=12.1a

Exception raised from currentStreamCaptureStatusMayInitCtx at /pytorch/c10/cuda/CUDAGraphsC10Utils.h:71 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xc8 (0xf152462b6778 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, unsigned int, bool) + 0x224 (0xf152463e4714 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0xf1d388 (0xf15246f8d388 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x477e40 (0xf15246297e40 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #4: c10::TensorImpl::~TensorImpl() + 0x14 (0xf15246256d84 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x5fa548 (0xf1526c7ea548 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xb46d1c (0xf1526cd36d1c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #7: VLLM::EngineCore() [0x524b64]
frame #8: _PyObject_ClearManagedDict + 0x10c (0x4fd240 in VLLM::EngineCore)
frame #9: VLLM::EngineCore() [0x527adc]
frame #10: VLLM::EngineCore() [0x5b3cac]
frame #11: VLLM::EngineCore() [0x5b2fec]
frame #12: VLLM::EngineCore() [0x58cf5c]
frame #13: _PyEval_EvalFrameDefault + 0x8fdc (0x56cf40 in VLLM::EngineCore)
frame #14: VLLM::EngineCore() [0x4c4d74]
frame #15: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #16: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #17: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #18: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #19: VLLM::EngineCore() [0x55fd48]
frame #20: VLLM::EngineCore() [0x5045cc]
frame #21: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #22: VLLM::EngineCore() [0x4c4d74]
frame #23: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #24: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #25: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #26: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #27: VLLM::EngineCore() [0x55fd48]
frame #28: VLLM::EngineCore() [0x5045cc]
frame #29: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #30: VLLM::EngineCore() [0x4c4d74]
frame #31: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #32: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #33: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #34: PyEval_EvalCode + 0x130 (0x562b54 in VLLM::EngineCore)
frame #35: VLLM::EngineCore() [0x55fd48]
frame #36: VLLM::EngineCore() [0x5045cc]
frame #37: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #38: VLLM::EngineCore() [0x4c4d74]
frame #39: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #40: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #41: VLLM::EngineCore() [0x560238]
frame #42: VLLM::EngineCore() [0x5045cc]
frame #43: _PyEval_EvalFrameDefault + 0x3e54 (0x567db8 in VLLM::EngineCore)
frame #44: VLLM::EngineCore() [0x4c4d74]
frame #45: PyObject_CallMethodObjArgs + 0xa8 (0x4c6958 in VLLM::EngineCore)
frame #46: PyImport_ImportModuleLevelObject + 0x36c (0x58eaa0 in VLLM::EngineCore)
frame #47: _PyEval_EvalFrameDefault + 0x4cc0 (0x568c24 in VLLM::EngineCore)
frame #48: VLLM::EngineCore() [0x6c14a8]
frame #49: Py_FinalizeEx + 0x58 (0x67b088 in VLLM::EngineCore)
frame #50: Py_Exit + 0x18 (0x67c518 in VLLM::EngineCore)
frame #51: VLLM::EngineCore() [0x6811d0]
frame #52: VLLM::EngineCore() [0x680f04]
frame #53: PyRun_SimpleStringFlags + 0x7c (0x67ef1c in VLLM::EngineCore)
frame #54: Py_RunMain + 0x390 (0x68b690 in VLLM::EngineCore)
frame #55: Py_BytesMain + 0x28 (0x68b198 in VLLM::EngineCore)
frame #56: <unknown function> + 0x284c4 (0xf152f23d84c4 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #57: __libc_start_main + 0x98 (0xf152f23d8598 in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #58: _start + 0x30 (0x5f66f0 in VLLM::EngineCore)

Look if you applied correctly the PR and cutlass version

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 29, 2026

Hi @johnnynunez, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
@johnnynunez johnnynunez requested a review from WoosukKwon as a code owner March 29, 2026 02:47
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 29, 2026

Hi @johnnynunez, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 29, 2026

I've applied these fixes from the folks at NVIDIA:
#38188
#38215

@johnnynunez
Copy link
Copy Markdown
Contributor Author

Related bug some models:
NVIDIA/cutlass#3121
SM120/SM121 (DGX Spark, RTX 50) has only 99KB SMEM vs 228KB on SM100. The K=128 block-scaled MoE GEMM tiles compile but overflow SMEM at runtime on SM120. And K=64 tiles that would fit can’t compile yet due to two unfixed CUTLASS bugs.
So the real problem isn’t instruction incompatibility, it’s that SM120 has only 99KB SMEM (vs 228KB on SM100), and the K=128 block-scaled MoE GEMM tiles overflow it at runtime.

@johnnynunez johnnynunez mentioned this pull request Mar 29, 2026
6 tasks
@gbanyan
Copy link
Copy Markdown

gbanyan commented Mar 29, 2026

Test Report: PR #38423 on DGX Spark SM121 with Qwen3.5-122B-A10B-NVFP4

Hardware: Single DGX Spark GB10 (SM121, 128GB UMA)
Model: Sehyo/Qwen3.5-122B-A10B-NVFP4 (compressed-tensors)
vLLM: 0.18.1rc1 from main + this PR
FlashInfer: main branch (CUTLASS 4.4.2 bundled)

Results

Config 128K context 32K context
No MTP PASS (15.3 tok/s, 1024 tokens) PASS (16.2 tok/s)
MTP n=1 FAIL (cudaErrorIllegalInstruction) PASS (23.2 tok/s)
MTP n=3 Not tested PASS (32.7 tok/s)

What this PR fixed

Previously (before this PR), 128K context crashed with cudaErrorIllegalInstruction even without MTP. Now 128K works without MTP — the CUTLASS 4.4.2 upgrade resolved that.

What still crashes

MTP + 128K context. Short MTP requests (64 tokens) succeed, but longer decode (1024 tokens) hits cudaErrorIllegalInstruction. This is likely the SM121 99KB SMEM limitation you noted — MTP's additional forward passes may select tile configurations that overflow SMEM.

32K context + MTP works perfectly at all token lengths.

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 29, 2026

Update: 128K without MTP is also intermittently unstable. It passed initial tests (1024 tokens completed) but later crashes with cudaErrorIllegalInstruction on subsequent requests. The SMEM overflow may trigger non-deterministically depending on prompt length or batch scheduling.

32K context (with or without MTP) remains stable across all our tests.

Thank you for all tests. Yes, some users found that. Reported to cutlass team. Thank you

Signed-off-by: Johnny <johnnynuca14@gmail.com>
Fix fp8 trtllm gen routing bias dtype

Signed-off-by: Johnny <johnnynuca14@gmail.com>
@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 29, 2026

Test Report: PR #38423 on DGX Spark SM121 with Qwen3.5-122B-A10B-NVFP4

Hardware: Single DGX Spark GB10 (SM121, 128GB UMA) Model: Sehyo/Qwen3.5-122B-A10B-NVFP4 (compressed-tensors) vLLM: 0.18.1rc1 from main + this PR FlashInfer: main branch (CUTLASS 4.4.2 bundled)

Results

Config 128K context 32K context
No MTP PASS (15.3 tok/s, 1024 tokens) PASS (16.2 tok/s)
MTP n=1 FAIL (cudaErrorIllegalInstruction) PASS (23.2 tok/s)
MTP n=3 Not tested PASS (32.7 tok/s)

What this PR fixed

Previously (before this PR), 128K context crashed with cudaErrorIllegalInstruction even without MTP. Now 128K works without MTP — the CUTLASS 4.4.2 upgrade resolved that.

What still crashes

MTP + 128K context. Short MTP requests (64 tokens) succeed, but longer decode (1024 tokens) hits cudaErrorIllegalInstruction. This is likely the SM121 99KB SMEM limitation you noted — MTP's additional forward passes may select tile configurations that overflow SMEM.

32K context + MTP works perfectly at all token lengths.

could you try? To me now it is working... you were right, it was race condition...

Run vLLM on Thor & Spark

Step-by-step guide to building and running vLLM with FlashInfer on NVIDIA Thor (SM110) and Spark (SM121) platforms.

1. Install uv

Clear any stale cache, then install the uv package manager:

sudo rm -rf ~/.cache/
sudo apt install ccache
curl -LsSf https://astral.sh/uv/install.sh | sh

2. Create a Virtual Environment

sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate

3. Install PyTorch

uv pip install --force-reinstall torch torchvision

4. Build and Install vLLM

Note: The build must include vllm-project/vllm#38423, so use the fork below.

git clone --recursive https://github.com/johnnynunez/vllm.git
cd vllm

export VLLM_VERSION=0.18.1
export TORCH_CUDA_ARCH_LIST=12.1a
export USE_CUDNN=1
export VERBOSE=1
export CUDA_HOME=/usr/local/cuda
export PATH="${CUDA_HOME}/bin:$PATH"
export SETUPTOOLS_SCM_PRETEND_VERSION="${VLLM_VERSION}"
export DG_JIT_USE_NVRTC=1  # DeepGEMM NVRTC support — up to 10x compilation speedup

python3 use_existing_torch.py || echo "Skipping use_existing_torch.py"

uv pip install -r requirements/build.txt -v
python3 -m setuptools_scm

# Constrain parallelism on aarch64 to avoid OOM during compilation
ARCH=$(uname -i)
if [ "${ARCH}" = "aarch64" ]; then
    export NVCC_THREADS=1
    export CUDA_NVCC_FLAGS="-Xcudafe --threads=1"
    export MAKEFLAGS='-j2'
    export CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS
    export NINJAFLAGS='-j2'
fi

uv build --wheel --no-build-isolation -v --out-dir ./wheels .
uv pip install ./wheels/vllm*.whl

cd /opt/vllm
uv pip install compressed-tensors

5. Uninstall Pre-built FlashInfer Packages

Remove any pre-compiled FlashInfer packages to avoid conflicts with the editable install:

uv pip uninstall flashinfer-cubin flashinfer-python

6. Install FlashInfer from Source

Note: The build must include flashinfer-ai/flashinfer#2913, so use the fork below.

sudo rm -rf ~/.cache/
git clone --recursive https://github.com/johnnynunez/flashinfer.git
cd flashinfer
uv pip install --force-reinstall --no-build-isolation -e .

7. Export Environment Variables

Set the CUDA architecture target and related paths. Use 12.1a for Spark or 11.0a for Thor:

export TORCH_CUDA_ARCH_LIST=12.1a  # Spark: 12.1a — Thor: 11.0a
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
export CUDA_HOME=/usr/local/cuda
export CPATH=$CUDA_HOME/include:${CPATH}
export C_INCLUDE_PATH=$CUDA_HOME/include:${C_INCLUDE_PATH}
export CPLUS_INCLUDE_PATH=$CUDA_HOME/include:${CPLUS_INCLUDE_PATH}

# Recommended on Jetson platforms
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/lib:${LD_LIBRARY_PATH}
export LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/lib:${LIBRARY_PATH}

8. Clear Memory

Drop filesystem caches to free up memory before serving:

sudo sysctl -w vm.drop_caches=3

9. Serve the Model (Speculative Decoding with MTP)

Launch vLLM with Qwen3.5-122B using 3 speculative tokens via MTP:

vllm serve Sehyo/Qwen3.5-122B-A10B-NVFP4 \
    --port 9000 \
    --max-num-seqs 2 \
    --tensor-parallel-size 1 \
    --max-model-len 131072 \
    --trust-remote-code \
    --gpu-memory-utilization 0.80 \
    --kv-cache-dtype fp8 \
    --speculative_config '{"method":"mtp","num_speculative_tokens":3}'

10. Run a Stress Test (Separate Terminal)

In another terminal with the .vllm environment activated, run the following script to send 10 long-context requests:

python3 -c "
import requests, time, sys, concurrent.futures

MODEL = 'Sehyo/Qwen3.5-122B-A10B-NVFP4'
PORT = 9000

# ~100K tokens — safely under 131072 - 1024 = 130048 limit
parts = []
for i in range(3000):
    parts.append(f'Section {i}: The quick brown fox jumps over the lazy dog. Technology advances rapidly in quantum computing and distributed systems. ')
prompt = 'Write a comprehensive analysis: ' + ' '.join(parts)
print(f'Approx words: {len(prompt.split())}')
sys.stdout.flush()

def send_request(idx):
    t0 = time.time()
    try:
        r = requests.post(f'http://localhost:{PORT}/v1/completions', json={
            'model': MODEL,
            'prompt': prompt,
            'max_tokens': 1024,
            'temperature': 0.7,
        }, timeout=600)
        elapsed = time.time() - t0
        if r.status_code == 200:
            data = r.json()
            text = data['choices'][0]['text']
            usage = data.get('usage', {})
            return f'[{idx}] OK - {len(text)}ch, prompt={usage.get(\"prompt_tokens\",\"?\")}, gen={usage.get(\"completion_tokens\",\"?\")}, {elapsed:.1f}s'
        else:
            err = r.json().get('error',{}).get('message','')[:200]
            return f'[{idx}] FAIL ({r.status_code}): {err}'
    except Exception as e:
        elapsed = time.time() - t0
        return f'[{idx}] CRASH - {type(e).__name__}: {e} ({elapsed:.1f}s)'

# Phase 1: Single ~100K token request
print('=== Phase 1: Single ~100K token request ===')
sys.stdout.flush()
print(send_request(1)); sys.stdout.flush()

# Phase 2: 2 concurrent
print('=== Phase 2: 2 concurrent ===')
sys.stdout.flush()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
    futs = [pool.submit(send_request, i) for i in range(2, 4)]
    for f in concurrent.futures.as_completed(futs):
        print(f.result()); sys.stdout.flush()

# Phase 3: 10 rapid
print('=== Phase 3: 10 rapid sequential ===')
sys.stdout.flush()
for i in range(4, 14):
    r = send_request(i)
    print(r); sys.stdout.flush()
    if 'CRASH' in r: break

print('Done.')
" 2>&1

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 30, 2026

ready to merge! @mgoin

Now it is working perfectly and B200 accuracy tests passed for NVFP4

Nemotron Super NVFP4 - DGX Spark

export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
vllm serve nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
--kv-cache-dtype fp8 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--max-model-len 262144 \
--max-num-seqs 10 \
--enable-prefix-caching \
--host 0.0.0.0 \
--port 8000 \
--enable-auto-tool-choice \
--load-format fastsafetensors \
--tool-call-parser qwen3_coder \
--reasoning-parser nemotron_v3 \
--mamba_ssm_cache_dtype float32

Results (Benchmark & Stress Test) --> uvx llama-benchy --base-url http://spark:8000/v1 --depth 0 4096 8192 16384 32768 65535 100000 20000:

Auto-detected HF model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 (served as: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4)
llama-benchy (0.3.5)
Date: 2026-03-30 01:35:34
Benchmarking model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 at http://localhost:8000/v1
Concurrency levels: [1]
Loading text from cache: /home/johnny/.cache/llama-benchy/cc6a0b5782734ee3b9069aa3b64cc62c.txt
Total tokens available in text corpus: 143827
Warming up...
Warmup (User only) complete. Delta: 16 tokens (Server: 38, Local: 22)
Warmup (System+Empty) complete. Delta: 16 tokens (Server: 38, Local: 22)

Running coherence test...
Coherence test PASSED.
Measuring latency using mode: api...
Average latency (api): 1.63 ms
Running test: pp=2048, tg=32, depth=0, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=4096, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=8192, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=16384, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=32768, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=65535, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=100000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=200000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Printing results in MD format:



| model                                          |             test |              t/s |     peak t/s |          ttfr (ms) |       est_ppt (ms) |      e2e_ttft (ms) |
|:-----------------------------------------------|-----------------:|-----------------:|-------------:|-------------------:|-------------------:|-------------------:|
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |           pp2048 | 1722.48 ± 394.11 |              |   1269.76 ± 345.98 |   1268.14 ± 345.98 |   1269.84 ± 345.98 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |             tg32 |     12.76 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d4096 |  1948.06 ± 80.28 |              |   3161.05 ± 134.07 |   3159.43 ± 134.07 |   3161.13 ± 134.05 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d4096 |     12.75 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d8192 |   1964.84 ± 4.14 |              |    5213.28 ± 10.99 |    5211.65 ± 10.99 |    5213.35 ± 10.97 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d8192 |     12.71 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d16384 |   1934.31 ± 5.53 |              |    9530.67 ± 27.20 |    9529.04 ± 27.20 |    9530.74 ± 27.22 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d16384 |     12.64 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d32768 |  1857.07 ± 14.17 |              |  18750.32 ± 143.56 |  18748.69 ± 143.56 |  18750.39 ± 143.57 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d32768 |     12.64 ± 0.02 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d65535 |   1759.29 ± 5.89 |              |  38416.91 ± 128.78 |  38415.28 ± 128.78 |  38416.98 ± 128.78 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d65535 |     12.64 ± 0.04 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d100000 |   1656.44 ± 4.33 |              |  61608.98 ± 160.90 |  61607.35 ± 160.90 |  61609.06 ± 160.91 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d100000 |     12.69 ± 0.08 | 13.67 ± 0.47 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d200000 |   1397.08 ± 7.47 |              | 144626.89 ± 771.10 | 144625.26 ± 771.10 | 144626.94 ± 771.11 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d200000 |     12.59 ± 0.12 | 14.00 ± 0.00 |                    |                    |                    |

llama-benchy (0.3.5)
date: 2026-03-30 01:35:34 | latency mode: api
(APIServer pid=33932) INFO 03-30 01:50:49 [loggers.py:259] Engine 000: Avg prompt throughput: 20205.7 tokens/s, Avg generation throughput: 3.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=33932) INFO 03-30 01:50:59 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

Comment on lines +402 to +405
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pavanimajety do you know if this is right? I thought we fixed this issue for trtllm MoE across the board

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix is done by @wzhao18

Comment on lines +315 to +319
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same @wzhao18

Signed-off-by: Johnny <johnnynuca14@gmail.com>
@johnnynunez johnnynunez requested a review from noooop as a code owner March 30, 2026 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

Status: Ready

Development

Successfully merging this pull request may close these issues.

4 participants