Skip to content

[NVIDIA] fix(jit): enable GDC for CUTLASS fused MoE PDL — prevent random crashes on SM12x#2913

Open
johnnynunez wants to merge 2 commits intoflashinfer-ai:mainfrom
johnnynunez:main
Open

[NVIDIA] fix(jit): enable GDC for CUTLASS fused MoE PDL — prevent random crashes on SM12x#2913
johnnynunez wants to merge 2 commits intoflashinfer-ai:mainfrom
johnnynunez:main

Conversation

@johnnynunez
Copy link
Copy Markdown
Contributor

@johnnynunez johnnynunez commented Mar 29, 2026

Summary

  • Add missing -DCUTLASS_ENABLE_GDC_FOR_SM100=1 compile flag to all CUTLASS fused MoE JIT modules (SM100/SM103/SM120) and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to SM90 modules
  • Sync nv_internal grid_dependency_control.h with upstream CUTLASS to support SM100/SM103/SM110/SM120/SM121 GDC
  • Add -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to FP8 blockscale GEMM SM90 module

Problem

Random cudaErrorIllegalInstruction crashes on DGX Spark (SM121) and RTX 50-series (SM120) when running NVFP4 MoE models (e.g., Nemotron, Qwen3.5-122B) under load. The crashes are intermittent and worsen with longer context lengths and higher concurrency.

Root cause: PR #2780 fixed the missing GDC compile flags for GEMM modules (flashinfer/jit/gemm/core.py), but the CUTLASS fused MoE modules in flashinfer/jit/fused_moe.py and the FP8 blockscale GEMM module were not fixed. This is the exact same class of bug as #2708.

Without -DCUTLASS_ENABLE_GDC_FOR_SM100=1, CUTLASS's grid_dependency_control.h compiles wait_on_dependent_grids() and launch_dependent_grids() as empty no-ops:

CUTLASS_DEVICE void wait_on_dependent_grids() {
#if (defined(CUTLASS_GDC_ENABLED))   // ← not defined without the flag
  asm volatile("griddepcontrol.wait;");
#endif
}

Meanwhile, the host-side code still sets programmaticStreamSerializationAllowed = true (PDL enabled) via device_support_pdl() which returns True for all major >= 9, including SM12x. This means:

  1. Host enables PDL → CUDA runtime overlaps consecutive kernels
  2. Device GDC barriers are no-ops → No synchronization between overlapping kernels
  3. Race condition → Dependent kernel reads stale global memory → corruption → cudaErrorIllegalInstruction

The crash is random because it depends on exact kernel scheduling timing, which varies per request.

Fix

flashinfer/jit/fused_moe.py — Added GDC flags to all CUTLASS fused MoE modules:

Module Flag Architectures Covered
gen_cutlass_fused_moe_sm120_module() -DCUTLASS_ENABLE_GDC_FOR_SM100=1 SM120, SM121
gen_cutlass_fused_moe_sm103_module() -DCUTLASS_ENABLE_GDC_FOR_SM100=1 SM103, SM120, SM121
gen_cutlass_fused_moe_sm100_module() -DCUTLASS_ENABLE_GDC_FOR_SM100=1 SM100, SM110, SM120, SM121
gen_cutlass_fused_moe_sm90_module() -DCUTLASS_ENABLE_GDC_FOR_SM90=1 SM90
gen_trtllm_gen_fused_moe_sm100_module() -DCUTLASS_ENABLE_GDC_FOR_SM100=1 SM100+, SM120, SM121

flashinfer/jit/gemm/fp8_blockscale.py — Added -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to gen_fp8_blockscale_gemm_sm90_module().

csrc/nv_internal/.../grid_dependency_control.h — Synced with upstream CUTLASS (3rdparty/cutlass/include/cutlass/arch/grid_dependency_control.h) to add SM100+ GDC support. Previously only handled SM90, so any nv_internal TensorRT-LLM code compiled for SM12x would have GDC barriers silently compiled as no-ops.

Why -DCUTLASS_ENABLE_GDC_FOR_SM100=1 covers SM12x

CUTLASS uses a single flag for the entire Blackwell family. From grid_dependency_control.h:

#if(CUDA_BARRIER_ENABLED && defined(CUTLASS_ENABLE_GDC_FOR_SM100) && defined(__CUDA_ARCH__) && \
    ((__CUDA_ARCH__ == 1000 && ...) ||   // SM100
     (__CUDA_ARCH__ == 1030 && ...) ||   // SM103
     (__CUDA_ARCH__ == 1100 && ...) ||   // SM110
     (__CUDA_ARCH__ == 1200 && ...) ||   // SM120 (RTX 50-series)
     (__CUDA_ARCH__ == 1210 && ...)))    // SM121 (DGX Spark)
#define CUTLASS_GDC_ENABLED

Why SM90 GDC flag was NOT added to SM100+ modules

PR #2716 attempted to add both -DCUTLASS_ENABLE_GDC_FOR_SM90=1 and -DCUTLASS_ENABLE_GDC_FOR_SM100=1 to all modules. It broke AOT builds because sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp checks CUTLASS_ENABLE_GDC_FOR_SM90 and calls scheduler.is_last_tile() — a method not present on the SM120 scheduler. PR #2780 corrected this by using only the SM100 flag for SM100+ modules. This PR follows the same approach.

Related

Test plan

  • Clear JIT cache: rm -rf ~/.cache/flashinfer/
  • Run NVFP4 MoE model on SM121 (DGX Spark) with 128K context under load — verify no cudaErrorIllegalInstruction
  • Run NVFP4 MoE model on SM120 (RTX 50-series) with concurrent requests — verify no NaN/garbage output
  • Verify CUDA_LAUNCH_BLOCKING=1 workaround is no longer needed
  • AOT build with FLASHINFER_CUDA_ARCH_LIST="12.1a" completes without errors
  • SM90 (Hopper) fused MoE tests pass: pytest tests/moe/
  • SM100 GEMM tests still pass (no regression from existing GDC flags)

Summary by CodeRabbit

  • New Features
    • Expanded GPU kernel compilation support: enabled additional optimizations for NVIDIA SM100 and SM90 GPUs, activating dependency-control optimizations where available.
    • Updated JIT/GEMM build configs to include these architecture-specific compile options, improving performance and compatibility on supported hardware.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

📝 Walkthrough

Walkthrough

Extended CUTLASS Grid Dependency Control (GDC) compile-time enablement to cover additional SM100-family CUDA architectures and added corresponding NVCC defines to JIT build pipelines for fused MoE and FP8 blockscale kernels.

Changes

Cohort / File(s) Summary
CUTLASS GDC Header Extension
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h
Expanded preprocessor checks to enable CUTLASS_GDC_ENABLED when CUDA_BARRIER_ENABLED and CUTLASS_ENABLE_GDC_FOR_SM100 are present for specified SM100-family __CUDA_ARCH__ values; activates inline barrier emission and IsGdcGloballyEnabled.
Fused MoE JIT Build Configuration
flashinfer/jit/fused_moe.py
Added -DCUTLASS_ENABLE_GDC_FOR_SM100=1 to SM100/SM103/SM120 JIT NVCC flags and -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to SM90 variant; also added GDC define to TRT-LLM SM100 module generation.
FP8 Blockscale JIT Build Configuration
flashinfer/jit/gemm/fp8_blockscale.py
Appended -DCUTLASS_ENABLE_GDC_FOR_SM90=1 to SM90 FP8 blockscale GEMM JIT NVCC flags (no API/signature changes).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

  • Issue #2708: Adds/enables CUTLASS GDC compile flags that trigger grid_dependency_control.h to emit device barriers — directly related to these changes.

Possibly related PRs

  • PR #2716: Also adjusts compile-time GDC enablement for SM90/SM100 across header and build flags — strong overlap.
  • PR #2798: Touches CUTLASS GDC flag propagation and SM target enablement — closely related.
  • PR #2780: Adds CUTLASS GDC compile flags for SM100 and ties into header gating — directly connected.

Suggested labels

run-ci, op: moe

Suggested reviewers

  • jimmyzho
  • kahyunnam
  • yzh119
  • nv-yunzheq
  • bkryu
  • cyx-6
  • jiahanc

Poem

🐰 I hopped through code with tiny feet,
Flags set true so barriers meet,
SM1xx hums in careful tune,
Dependencies land—in time, in June. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: enabling GDC for CUTLASS fused MoE to fix random crashes on SM12x architectures.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering problem statement, root cause analysis, specific fixes, rationale, and test plan.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h`:
- Around line 36-54: The new multiline preprocessor block for enabling
CUTLASS_GDC (the `#if` that checks CUDA_BARRIER_ENABLED,
CUTLASS_ENABLE_GDC_FOR_SM100 and various __CUDA_ARCH__ cases including
CUDA_ARCH_FAMILY and CUDA_ARCH_CONDITIONAL_OR_FAMILY) is misformatted and
failing clang-format; run clang-format on
cutlass_extensions/arch/grid_dependency_control.h (or the changed file) and
reformat the `#if/`#endif block so line wrapping and indentation follow the
project's clang-format rules, then commit the formatted file ensuring the
symbols CUTLASS_GDC_ENABLED, CUDA_BARRIER_ENABLED, CUTLASS_ENABLE_GDC_FOR_SM100,
__CUDA_ARCH__, CUDA_ARCH_FAMILY and CUDA_ARCH_CONDITIONAL_OR_FAMILY remain
unchanged semantically.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2e539431-67bf-4479-a371-4d64c698e324

📥 Commits

Reviewing files that changed from the base of the PR and between 779c24d and 0f260d6.

📒 Files selected for processing (3)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h
  • flashinfer/jit/fused_moe.py
  • flashinfer/jit/gemm/fp8_blockscale.py

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 30, 2026

steps to replicate:

Run vLLM on Thor & Spark

Step-by-step guide to building and running vLLM with FlashInfer on NVIDIA Thor (SM110) and Spark (SM121) platforms.

1. Install uv

Clear any stale cache, then install the uv package manager:

sudo rm -rf ~/.cache/
sudo apt install ccache
curl -LsSf https://astral.sh/uv/install.sh | sh

2. Create a Virtual Environment

sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate

3. Install PyTorch

uv pip install --force-reinstall torch torchvision

4. Build and Install vLLM

Note: The build must include vllm-project/vllm#38423, so use the fork below.

git clone --recursive https://github.com/johnnynunez/vllm.git
cd vllm

export VLLM_VERSION=0.18.1
export TORCH_CUDA_ARCH_LIST=12.1a
export USE_CUDNN=1
export VERBOSE=1
export CUDA_HOME=/usr/local/cuda
export PATH="${CUDA_HOME}/bin:$PATH"
export SETUPTOOLS_SCM_PRETEND_VERSION="${VLLM_VERSION}"
export DG_JIT_USE_NVRTC=1  # DeepGEMM NVRTC support — up to 10x compilation speedup

python3 use_existing_torch.py || echo "Skipping use_existing_torch.py"

uv pip install -r requirements/build.txt -v
python3 -m setuptools_scm

# Constrain parallelism on aarch64 to avoid OOM during compilation
ARCH=$(uname -i)
if [ "${ARCH}" = "aarch64" ]; then
    export NVCC_THREADS=1
    export CUDA_NVCC_FLAGS="-Xcudafe --threads=1"
    export MAKEFLAGS='-j2'
    export CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS
    export NINJAFLAGS='-j2'
fi

uv build --wheel --no-build-isolation -v --out-dir ./wheels .
uv pip install ./wheels/vllm*.whl

cd /opt/vllm
uv pip install compressed-tensors

5. Uninstall Pre-built FlashInfer Packages

Remove any pre-compiled FlashInfer packages to avoid conflicts with the editable install:

uv pip uninstall flashinfer-cubin flashinfer-python

6. Install FlashInfer from Source

Note: The build must include flashinfer-ai/flashinfer#2913, so use the fork below.

sudo rm -rf ~/.cache/
git clone --recursive https://github.com/johnnynunez/flashinfer.git
cd flashinfer
uv pip install --force-reinstall --no-build-isolation -e .

7. Export Environment Variables

Set the CUDA architecture target and related paths. Use 12.1a for Spark or 11.0a for Thor:

export TORCH_CUDA_ARCH_LIST=12.1a  # Spark: 12.1a — Thor: 11.0a
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
export CUDA_HOME=/usr/local/cuda
export CPATH=$CUDA_HOME/include:${CPATH}
export C_INCLUDE_PATH=$CUDA_HOME/include:${C_INCLUDE_PATH}
export CPLUS_INCLUDE_PATH=$CUDA_HOME/include:${CPLUS_INCLUDE_PATH}

# Recommended on Jetson platforms
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/lib:${LD_LIBRARY_PATH}
export LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/lib:${LIBRARY_PATH}

8. Clear Memory

Drop filesystem caches to free up memory before serving:

sudo sysctl -w vm.drop_caches=3

9. Serve the Model (Speculative Decoding with MTP)

Launch vLLM with Qwen3.5-122B using 3 speculative tokens via MTP:

vllm serve Sehyo/Qwen3.5-122B-A10B-NVFP4 \
    --port 9000 \
    --max-num-seqs 2 \
    --tensor-parallel-size 1 \
    --max-model-len 131072 \
    --trust-remote-code \
    --gpu-memory-utilization 0.80 \
    --kv-cache-dtype fp8 \
    --speculative_config '{"method":"mtp","num_speculative_tokens":3}'

10. Run a Stress Test (Separate Terminal)

In another terminal with the .vllm environment activated, run the following script to send 10 long-context requests:

python3 -c "
import requests, time, sys, concurrent.futures

MODEL = 'Sehyo/Qwen3.5-122B-A10B-NVFP4'
PORT = 9000

# ~100K tokens — safely under 131072 - 1024 = 130048 limit
parts = []
for i in range(3000):
    parts.append(f'Section {i}: The quick brown fox jumps over the lazy dog. Technology advances rapidly in quantum computing and distributed systems. ')
prompt = 'Write a comprehensive analysis: ' + ' '.join(parts)
print(f'Approx words: {len(prompt.split())}')
sys.stdout.flush()

def send_request(idx):
    t0 = time.time()
    try:
        r = requests.post(f'http://localhost:{PORT}/v1/completions', json={
            'model': MODEL,
            'prompt': prompt,
            'max_tokens': 1024,
            'temperature': 0.7,
        }, timeout=600)
        elapsed = time.time() - t0
        if r.status_code == 200:
            data = r.json()
            text = data['choices'][0]['text']
            usage = data.get('usage', {})
            return f'[{idx}] OK - {len(text)}ch, prompt={usage.get(\"prompt_tokens\",\"?\")}, gen={usage.get(\"completion_tokens\",\"?\")}, {elapsed:.1f}s'
        else:
            err = r.json().get('error',{}).get('message','')[:200]
            return f'[{idx}] FAIL ({r.status_code}): {err}'
    except Exception as e:
        elapsed = time.time() - t0
        return f'[{idx}] CRASH - {type(e).__name__}: {e} ({elapsed:.1f}s)'

# Phase 1: Single ~100K token request
print('=== Phase 1: Single ~100K token request ===')
sys.stdout.flush()
print(send_request(1)); sys.stdout.flush()

# Phase 2: 2 concurrent
print('=== Phase 2: 2 concurrent ===')
sys.stdout.flush()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
    futs = [pool.submit(send_request, i) for i in range(2, 4)]
    for f in concurrent.futures.as_completed(futs):
        print(f.result()); sys.stdout.flush()

# Phase 3: 10 rapid
print('=== Phase 3: 10 rapid sequential ===')
sys.stdout.flush()
for i in range(4, 14):
    r = send_request(i)
    print(r); sys.stdout.flush()
    if 'CRASH' in r: break

print('Done.')
" 2>&1

@johnnynunez
Copy link
Copy Markdown
Contributor Author

johnnynunez commented Mar 30, 2026

Now it is working perfectly and B200 accuracy tests passed for NVFP4.
Related vLLM: vllm-project/vllm#38423

Nemotron Super NVFP4 - DGX Spark

export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
vllm serve nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
--kv-cache-dtype fp8 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--max-model-len 262144 \
--max-num-seqs 10 \
--enable-prefix-caching \
--host 0.0.0.0 \
--port 8000 \
--enable-auto-tool-choice \
--load-format fastsafetensors \
--tool-call-parser qwen3_coder \
--reasoning-parser nemotron_v3 \
--mamba_ssm_cache_dtype float32

Results (Benchmark & Stress Test) --> uvx llama-benchy --base-url http://spark:8000/v1 --depth 0 4096 8192 16384 32768 65535 100000 20000:

Auto-detected HF model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 (served as: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4)
llama-benchy (0.3.5)
Date: 2026-03-30 01:35:34
Benchmarking model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 at http://localhost:8000/v1
Concurrency levels: [1]
Loading text from cache: /home/johnny/.cache/llama-benchy/cc6a0b5782734ee3b9069aa3b64cc62c.txt
Total tokens available in text corpus: 143827
Warming up...
Warmup (User only) complete. Delta: 16 tokens (Server: 38, Local: 22)
Warmup (System+Empty) complete. Delta: 16 tokens (Server: 38, Local: 22)

Running coherence test...
Coherence test PASSED.
Measuring latency using mode: api...
Average latency (api): 1.63 ms
Running test: pp=2048, tg=32, depth=0, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=4096, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=8192, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=16384, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=32768, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=65535, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=100000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Running test: pp=2048, tg=32, depth=200000, concurrency=1
  Run 1/3 (batch size 1)...
  Run 2/3 (batch size 1)...
  Run 3/3 (batch size 1)...
Printing results in MD format:



| model                                          |             test |              t/s |     peak t/s |          ttfr (ms) |       est_ppt (ms) |      e2e_ttft (ms) |
|:-----------------------------------------------|-----------------:|-----------------:|-------------:|-------------------:|-------------------:|-------------------:|
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |           pp2048 | 1722.48 ± 394.11 |              |   1269.76 ± 345.98 |   1268.14 ± 345.98 |   1269.84 ± 345.98 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |             tg32 |     12.76 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d4096 |  1948.06 ± 80.28 |              |   3161.05 ± 134.07 |   3159.43 ± 134.07 |   3161.13 ± 134.05 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d4096 |     12.75 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   pp2048 @ d8192 |   1964.84 ± 4.14 |              |    5213.28 ± 10.99 |    5211.65 ± 10.99 |    5213.35 ± 10.97 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |     tg32 @ d8192 |     12.71 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d16384 |   1934.31 ± 5.53 |              |    9530.67 ± 27.20 |    9529.04 ± 27.20 |    9530.74 ± 27.22 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d16384 |     12.64 ± 0.01 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d32768 |  1857.07 ± 14.17 |              |  18750.32 ± 143.56 |  18748.69 ± 143.56 |  18750.39 ± 143.57 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d32768 |     12.64 ± 0.02 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |  pp2048 @ d65535 |   1759.29 ± 5.89 |              |  38416.91 ± 128.78 |  38415.28 ± 128.78 |  38416.98 ± 128.78 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |    tg32 @ d65535 |     12.64 ± 0.04 | 13.00 ± 0.00 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d100000 |   1656.44 ± 4.33 |              |  61608.98 ± 160.90 |  61607.35 ± 160.90 |  61609.06 ± 160.91 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d100000 |     12.69 ± 0.08 | 13.67 ± 0.47 |                    |                    |                    |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 | pp2048 @ d200000 |   1397.08 ± 7.47 |              | 144626.89 ± 771.10 | 144625.26 ± 771.10 | 144626.94 ± 771.11 |
| nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 |   tg32 @ d200000 |     12.59 ± 0.12 | 14.00 ± 0.00 |                    |                    |                    |

llama-benchy (0.3.5)
date: 2026-03-30 01:35:34 | latency mode: api
(APIServer pid=33932) INFO 03-30 01:50:49 [loggers.py:259] Engine 000: Avg prompt throughput: 20205.7 tokens/s, Avg generation throughput: 3.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=33932) INFO 03-30 01:50:59 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

@johnnynunez
Copy link
Copy Markdown
Contributor Author

cc @aleozlx @yzh119

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 30, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !473 has been created, and the CI pipeline #47268416 is currently running. I'll report back once the pipeline job completes.

@eugr
Copy link
Copy Markdown

eugr commented Mar 30, 2026

Looks like this PR eliminates NVFP4 crashes with flashinfer_cutlass kernel. I build from main with this PR applied on top.

@johnnynunez - I'm getting slightly better numbers than you:

model test t/s peak t/s ttfr (ms) est_ppt (ms) e2e_ttft (ms)
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 pp2048 2205.39 ± 13.67 934.82 ± 5.78 928.67 ± 5.78 934.99 ± 5.79
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 tg32 14.42 ± 0.07 15.00 ± 0.00
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 pp2048 @ d4096 2195.56 ± 7.67 2804.56 ± 9.76 2798.40 ± 9.76 2804.75 ± 9.75
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 tg32 @ d4096 14.36 ± 0.01 15.00 ± 0.00
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 pp2048 @ d8192 2159.83 ± 19.23 4747.63 ± 42.48 4741.48 ± 42.48 4747.77 ± 42.55
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 tg32 @ d8192 14.47 ± 0.11 15.00 ± 0.00
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 pp2048 @ d16384 2122.38 ± 5.93 8690.80 ± 24.31 8684.65 ± 24.31 8690.93 ± 24.28
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 tg32 @ d16384 14.50 ± 0.19 15.33 ± 0.47
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 pp2048 @ d32078 2003.24 ± 7.67 17041.63 ± 65.53 17035.48 ± 65.53 17041.90 ± 65.76
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 tg32 @ d32078 14.33 ± 0.04 15.00 ± 0.00

llama-benchy (0.3.5)
date: 2026-03-30 11:29:12 | latency mode: api | pp basis: ttfr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants