[None][feat] FlashInfer NVFP4 MoE backend (SM120/SM121) for Nemotron …#13773
[None][feat] FlashInfer NVFP4 MoE backend (SM120/SM121) for Nemotron …#13773farazkh80 wants to merge 12 commits into
Conversation
📝 WalkthroughWalkthroughThis PR adds a new MoE backend backend option ( ChangesFlashInfer MoE Backend Addition
Sequence DiagramsequenceDiagram
participant User
participant MoeFactory as MoE Factory<br/>(create_moe)
participant FI as FlashInferFusedMoE
participant B12x as B12xMoEWrapper<br/>(FlashInfer)
participant Weights as Weight Storage
User->>MoeFactory: get_moe_cls(backend="FLASHINFER", quant_config)
MoeFactory->>MoeFactory: Validate NVFP4 & SM versions
MoeFactory-->>User: Return FlashInferFusedMoE class
User->>FI: __init__(model_config)
FI->>FI: Validate ep_size==1, no alltoall
FI-->>User: Instance ready
User->>Weights: Load model weights
User->>FI: post_load_weights()
FI->>FI: Import B12xMoEWrapper
FI->>FI: Convert scales to MMA layout
FI->>FI: Build FP4 uint8 weight views
FI->>B12x: B12xMoEWrapper(experts, weights)
B12x-->>FI: Wrapper instantiated
FI-->>User: Initialization complete
loop Forward Pass
User->>FI: quantize_input(x)
FI-->>User: (x, None) — passthrough
User->>FI: run_moe(x, routing_tensors)
FI->>B12x: run(x, token_selected_experts, ...)
B12x->>B12x: Compute FP4 MoE with internal quant
B12x-->>FI: Output tensor
FI-->>User: MoE result
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1)
216-238:⚠️ Potential issue | 🔴 CriticalThe
issubclass()dispatch at line 216 makes theCuteDslFusedMoEandDeepGemmFusedMoEbranches unreachable and will cause runtime errors.Both
CuteDslFusedMoEandDeepGemmFusedMoEinherit fromCutlassFusedMoE, so they are now captured by the broaderissubclass(moe_cls, CutlassFusedMoE)check before their exact-class branches (lines 269 and 285) can execute. The generic Cutlass path then attempts to pass arguments likeswiglu_alpha,swiglu_beta, andswiglu_limitthat these subclasses' narrower constructors do not accept, resulting in unexpected keyword argument errors. OnlyFlashInferFusedMoEis compatible because its__init__(self, *args, **kwargs)forwards all arguments.Either revert to exact-class checks (
moe_cls ==) for all three subclasses, or verify that the constructors accept the full argument set being passed.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py` around lines 216 - 238, The dispatch using issubclass(moe_cls, CutlassFusedMoE) wrongly catches CuteDslFusedMoE and DeepGemmFusedMoE (both subclassing CutlassFusedMoE) and passes unsupported kwargs (swiglu_alpha, swiglu_beta, swiglu_limit), causing runtime errors; fix by making the dispatch check exact-class comparisons for CuteDslFusedMoE and DeepGemmFusedMoE (i.e. moe_cls == CuteDslFusedMoE and moe_cls == DeepGemmFusedMoE) or move those subclass-specific branches before the generic issubclass(CutlassFusedMoE) branch so their narrower constructors run, ensuring only FlashInferFusedMoE (which accepts arbitrary kwargs) is handled by the broad issubclass(CutlassFusedMoE) path.
🧹 Nitpick comments (1)
tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py (1)
34-131: ⚡ Quick winCover the remaining no-GPU guard rails too.
This file exercises selection-time gating, but the backend's other hard rejects are still untested:
ep_size != 1, alltoall enabled,Fp4QuantizedTensorinput, andx_sf is not None. Those are pure-Python validation paths, so adding them here would catch the runtime guard rails most likely to regress during refactors.As per coding guidelines, "Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py` around lines 34 - 131, Add unit tests in tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py that exercise the remaining pure-Python guard rails on selection/creation: call FlashInferFusedMoE.can_implement (and get_moe_cls where appropriate) with ep_size != 1, with alltoall enabled, with an input type simulated as Fp4QuantizedTensor, and with x_sf set (not None) to assert they return (or raise) the expected hard rejects; reference FlashInferFusedMoE.can_implement and get_moe_cls to locate validation logic, patch get_sm_version to a supported SM (e.g., 120) so only these specific guards are hit, and assert the returned ok is False and/or get_moe_cls raises ValueError matching the relevant reason text for each case.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md`:
- Around line 145-163: Add the runtime dependency floor for the FlashInfer
backend to this section: state the required minimum versions of the external
packages (e.g., flashinfer and the CUTLASS/DSL package) and any toolchain
constraints so users know the exact package combo needed before selecting
FLASHINFER; place this note alongside the "FlashInferFusedMoE — additional
constraints" paragraph and reference the selection point
(get_moe_cls("FLASHINFER", ...)) and the lazy wrapper initialization in
post_load_weights() so developers see the dependency requirement when reading
the backend constraints.
In `@tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py`:
- Around line 15-21: Add integration test definitions that exercise the
FlashInferFusedMoE backend on SM120/SM121 hardware: create a perf test entry
named with the l0_* pattern (so it runs in pre-merge CI) that sets
moe_backend=FLASHINFER and targets the SM120/SM121 GPU variant, and add
corresponding scheduled QA entries following the llm_perf_* naming convention
that also specify moe_backend=FLASHINFER/FlashInferFusedMoE and the same GPU
targets; ensure the new YAML entries include the job name, test selector,
hardware requirements, and any perf thresholds so the backend is executed in
both pre-merge CI and scheduled QA runs.
---
Outside diff comments:
In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py`:
- Around line 216-238: The dispatch using issubclass(moe_cls, CutlassFusedMoE)
wrongly catches CuteDslFusedMoE and DeepGemmFusedMoE (both subclassing
CutlassFusedMoE) and passes unsupported kwargs (swiglu_alpha, swiglu_beta,
swiglu_limit), causing runtime errors; fix by making the dispatch check
exact-class comparisons for CuteDslFusedMoE and DeepGemmFusedMoE (i.e. moe_cls
== CuteDslFusedMoE and moe_cls == DeepGemmFusedMoE) or move those
subclass-specific branches before the generic issubclass(CutlassFusedMoE) branch
so their narrower constructors run, ensuring only FlashInferFusedMoE (which
accepts arbitrary kwargs) is handled by the broad issubclass(CutlassFusedMoE)
path.
---
Nitpick comments:
In `@tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py`:
- Around line 34-131: Add unit tests in
tests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py that exercise
the remaining pure-Python guard rails on selection/creation: call
FlashInferFusedMoE.can_implement (and get_moe_cls where appropriate) with
ep_size != 1, with alltoall enabled, with an input type simulated as
Fp4QuantizedTensor, and with x_sf set (not None) to assert they return (or
raise) the expected hard rejects; reference FlashInferFusedMoE.can_implement and
get_moe_cls to locate validation logic, patch get_sm_version to a supported SM
(e.g., 120) so only these specific guards are hit, and assert the returned ok is
False and/or get_moe_cls raises ValueError matching the relevant reason text for
each case.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8dc4a17c-d10b-410b-8630-d20e1b303ed2
📒 Files selected for processing (6)
tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.mdtensorrt_llm/_torch/modules/fused_moe/__init__.pytensorrt_llm/_torch/modules/fused_moe/create_moe.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/modules/moe/test_flashinfer_moe_backend.py
… diff, helper scripts Captures the investigation findings + reproducible artifacts for the ``B12xLukeFusedMoE`` backend committed in 374b483. Lives under .claude_docs/ rather than docs/ since the artifacts are working-files (bench logs, container scripts, PR body) rather than user-facing docs. Files: - ``B12X_LUKE_RESULTS.md``: bench numbers vs the FI hybrid baseline (TPOT 12.97 ms vs FI 11.49 ms = +12.8% regression) + token-parity result + success-criteria table + root-cause writeup + follow-up probes. - ``FI_VS_LUKE_DELTA.md``: side-by-side architectural comparison of flashinfer-vendored b12x and lukealonso/b12x master HEAD. Includes bisect data (986a405a / c9cc90ec / 1378cea7 all ~13.5 ms TPOT, gap to FI predates published luke history), trace-probe evidence (luke's MoEMicroKernel does fire; the slowdown is intrinsic), and the kernel- source diff that locates the root cause: FI uses Blackwell's warp- specialized producer/consumer pattern (5 warps/CTA: 4 MMA + 1 dedicated TMA-load with cute.arch.setmaxregister_increase/_decrease register repartitioning), luke uses a flat 16-warp design with no producer/ consumer split. Concludes that closing the gap requires either an upstream rewrite or hand-porting FI's kernel (~12k lines of CuTe DSL cascade — out of scope here). - ``PR_BODY_b12x_luke.md``: PR description used when opening the stacked PR against ``faraz/b12x-flashinfer-moe-pr`` on the farazkh80/TensorRT-LLM fork. - ``start_runtime_container_b12x_luke.sh``: docker run helper that installs flashinfer + lukealonso/b12x @ 1378cea7 + cutlass-dsl 4.4.2 trio + cache_dit (rc14 dep absent in rc12 base) + LD_LIBRARY_PATH fix-up so docker exec inherits libnvonnxparser. - ``sync_b12x_luke_files.sh``: syncs edited fused_moe submodule files from the host source tree into the container's wheel-installed site-packages (the rc12 base image's tensorrt_llm has new imports like cache_dit that block PYTHONPATH overlay; targeted file copy is safer). - ``bench_kvoff_b12x_luke.yml``: bench yaml clone of bench_kvoff_flashinfer.yml with moe_config.backend swapped to B12X_LUKE. - ``parity_check_b12x_luke.py``: token-parity script with --moe-backend flag for FLASHINFER vs B12X_LUKE A/B (skipped this run, kept for future use). - ``_patch_tp_moe_trace.py``: idempotent patch that injects [trace-luke] prints into b12x.integration.tp_moe._launch_compact_static, used to prove luke's micro path actually fires (not a fall-through bug). Not needed at runtime; kept for reproducibility. The bench logs themselves and parent-PR (NVIDIA#13773) artifacts (HYBRID_DOC.md, HYBRID_RESULTS.md, etc.) are intentionally NOT committed: bench logs live under /home/farazkh_scratch/logs/ and parent-PR docs belong on the parent branch. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…ide Draft PR The fork-side PR_BODY_b12x_luke.md targets the GitHub UI on farazkh80/TensorRT-LLM and assumes a stacked base of `b12x-hybrid`. The NVIDIA-side PR (filed as Draft against NVIDIA:main) carries the 3 NVIDIA#13773 commits as overlap, so its body needs: - A prominent DRAFT-blocked-on-NVIDIA#13773 header at the top. - A condensed framing that names the perf regression upfront. - Same bench data + warp-spec swap evidence + recommendation. Open URL: https://github.com/NVIDIA/TensorRT-LLM/compare/main...farazkh80:b12x-luke-decode?expand=1 Use the "Create draft pull request" dropdown. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
b4c7031 to
1570171
Compare
|
Disclosure: I work on Atlas. We've been running NVFP4 MoE on Two things that bit us on FlashInfer + NVFP4 on consumer Blackwell that may or may not already be on your radar: 1. E2M1 conversion PTX is Hopper/SM10-only. FlashInfer's CUTLASS headers gate
About 30 lines. Disables the hardware E2M1 path for SM121 specifically and falls back to the software conversion. Roughly 32x speedup over the broken-PTX path on a 35B baseline (1.1 to 35 tok/s for us) so worth catching on the SM120 side too. 2. NVFP4 MoE backend dispatch bug. Upstream FlashInfer's
If FlashInfer's gating already handles this on SM120 cleanly in your branch, ignore. If not, this is the call site that was lying to callers. For end-to-end NVFP4 numbers on
Glad to see Nemotron NVFP4 landing for SM120/121. Happy to dig into either of the patches above if you want to confirm whether your backend wiring handles them differently. |
1 similar comment
|
Disclosure: I work on Atlas. We've been running NVFP4 MoE on Two things that bit us on FlashInfer + NVFP4 on consumer Blackwell that may or may not already be on your radar: 1. E2M1 conversion PTX is Hopper/SM10-only. FlashInfer's CUTLASS headers gate
About 30 lines. Disables the hardware E2M1 path for SM121 specifically and falls back to the software conversion. Roughly 32x speedup over the broken-PTX path on a 35B baseline (1.1 to 35 tok/s for us) so worth catching on the SM120 side too. 2. NVFP4 MoE backend dispatch bug. Upstream FlashInfer's
If FlashInfer's gating already handles this on SM120 cleanly in your branch, ignore. If not, this is the call site that was lying to callers. For end-to-end NVFP4 numbers on
Glad to see Nemotron NVFP4 landing for SM120/121. Happy to dig into either of the patches above if you want to confirm whether your backend wiring handles them differently. |
1570171 to
11dcf4a
Compare
The b12x MoE kernel introduced by PR NVIDIA#13773 (FLASHINFER_NVFP4SM12X) JIT-compiles via nvidia-cutlass-dsl, whose CUDA 13 runtime libraries ship as a separate optional wheel (nvidia-cutlass-dsl-libs-cu13) and are NOT pulled automatically by the main nvidia-cutlass-dsl wheel. Without this wheel, executor initialization on SM120/SM121 hosts dies with ptxas "Unexpected instruction types specified for '_mma'" because the chip->compute_target conversion falls back to a path that strips the 'a' suffix (sm_120a -> sm_120), and ptxas is then invoked with -opt-arch=sm_120 against PTX that has .target sm_120a with sm_120a-only mma instruction forms. The runtime requirement was documented in the PR body but never made binding via requirements.txt. Pin it explicitly at the same version as the main wheel so fresh builds reproduce the same working environment. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
8f589ea to
bbd21f2
Compare
The b12x MoE kernel introduced by PR NVIDIA#13773 (FLASHINFER_NVFP4SM12X) JIT-compiles via nvidia-cutlass-dsl, whose CUDA 13 runtime libraries ship as a separate optional wheel (nvidia-cutlass-dsl-libs-cu13) and are NOT pulled automatically by the main nvidia-cutlass-dsl wheel. Without this wheel, executor initialization on SM120/SM121 hosts dies with ptxas "Unexpected instruction types specified for '_mma'" because the chip->compute_target conversion falls back to a path that strips the 'a' suffix (sm_120a -> sm_120), and ptxas is then invoked with -opt-arch=sm_120 against PTX that has .target sm_120a with sm_120a-only mma instruction forms. The runtime requirement was documented in the PR body but never made binding via requirements.txt. Pin it explicitly at the same version as the main wheel so fresh builds reproduce the same working environment. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #48025 [ run ] triggered by Bot. Commit: |
| } | ||
|
|
||
|
|
||
| class FlashInferNvfp4Sm12xFusedMoE(CutlassFusedMoE): |
There was a problem hiding this comment.
Could you help to explain why we have to introduce a new backend? Do we have a another way to do this, something like heuristic selection?
What's the relation between FlashInfer and CutlassFusedMoE?
May I know the motivation here to inherit from CutlassFusedMoE?
It is highly recommended to do a refactor with the skills https://gitlab-master.nvidia.com/ftp/trtllm-agent-toolkit/-/merge_requests/41#143127cc66c0ca5221561b39ccb27e7766f76cf3.
There was a problem hiding this comment.
A following question: do we have any chance to add more quantization in the FalshInfer path, if yes, then I would like to remove the quantization type and platform info from the FlashInferNvfp4Sm12xFusedMoE.
There was a problem hiding this comment.
Thanks for the comment @xxi-nv, and for the moe-develop skills. I used them and made some design changes.
On FI precision support: as of flashinfer-python 0.6.8/0.6.10, B12xMoEWrapper is NVFP4-only and SM120/SM121-only. Because of that, B12x is no longer user-facing: moe_backend: CUTLASS auto-promotes to FlashInferNvfp4Sm12xFusedMoE on SM120/121 + NVFP4 + importable FlashInfer, otherwise it falls back to CutlassFusedMoE, following the existing can_implement-gated promotion pattern (similar to MEGAMOE_DEEPGEMM-style can_implement)
I also removed FLASHINFER_NVFP4SM12X from MoeConfig.backend, so there is no user-facing FLASHINFER_* config value anymore. The class name is internal only, so I’d prefer to defer renaming until we have a second FlashInfer-backed MoE path. FlashInferNvfp4Sm12xFusedMoE still subclasses CutlassFusedMoE because it reuses the NVFP4 weight-loading lifecycle and only overrides quantize_input, run_moe, and post_load_weights; we also keep CutlassFusedMoE for prefill since it performs better.
There was a problem hiding this comment.
Thanks for the refactoring.
May I confirm that the kernel implementation of FLASHINFER_NVFP4SM12X is also Cutlass C++?
There was a problem hiding this comment.
it is actually in cuteDSL. but for prefill we use cutlass so I thought we could put it here.
Cutlass is the suggested MoE backend in sm120/121. so by adding this special handling, we essentially let the user get better output tp/s without changing their config.
There was a problem hiding this comment.
Wow, if we want to make it consistent with the current design of backend. I will recommend to integrate the newly added FLASHINFER_NVFP4SM12X to CuteDSLFusedMoE, because they both implement the kernel by cuteDSL.
And you can update the heuristic in 'AUTO' mode in MOE, but the AUTO mode in MOE is not so ready yet.
Please check https://github.com/xxi-nv/TensorRT-LLM/blob/bfcf0f31a52335622341c820270d0695bac1feaf/tensorrt_llm/_torch/model_config.py#L270-L293.
But I have to say, it is kind of complicated and not a good scalable way to add so many kernels in a easy way.
I need to take some time to rethink the definition of backends. @litaotju Any suggestions?
…to-promote Replace the user-facing `moe_backend: FLASHINFER_NVFP4SM12X` knob with transparent heuristic auto-promotion on the `CUTLASS` path. When the user selects `moe_backend: CUTLASS` (the default), `get_moe_cls()` now returns `FlashInferNvfp4Sm12xFusedMoE` automatically when: - quant_config has NVFP4 - SM version is 120 or 121 - `import flashinfer` succeeds Otherwise it returns `CutlassFusedMoE` (the pre-PR behaviour). The class itself, its weight lifecycle, and its hybrid `m >= 64` decode dispatch are unchanged — only the selection plumbing moves. This responds to xxi-nv's review comment on PR NVIDIA#13773 asking whether the b12x backend could be selected via a heuristic rather than an explicit name. Mirrors the existing `MEGAMOE_DEEPGEMM` pattern of `can_implement`-gated promotion with a CUTLASS fallback. Drops `"FLASHINFER_NVFP4SM12X"` from `MoeConfig.backend` Literal — the class stays importable as an internal API for tests and for direct construction, but is no longer a valid user-facing config string. Tests in `test_flashinfer_nvfp4_sm12x_moe_backend.py` flipped from "explicit name raises on bad config" to "heuristic auto-promotes vs falls back to CutlassFusedMoE". Internal `MoeBackendType` entry kept so `test_moe_backend.py` parametrization continues to cover the backend; `create_test_backend` routes the enum through `moe_backend="CUTLASS"` to exercise the same code path users hit. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…MOE guide test_moe_module.py: register MoeBackendType.FLASHINFER_NVFP4SM12X in BACKEND_TYPES so the unified ConfigurableMoE matrix exercises it. _create_model_config maps the internal enum value to moe_backend="CUTLASS" before passing into ModelConfig — the enum is internal-only after the heuristic auto-promote landed; users reach the backend via the CUTLASS path. MOE_DEVELOPER_GUIDE.md: remove the dedicated FlashInferNvfp4Sm12xFusedMoE section (composition / dispatch policy / weight-conversion algebra / hard-reject list) and drop the Nvfp4Sm12x matrix column. The class's NVFP4 support on SM120/121 is already covered by the CUTLASS row in the matrix (auto-promote target). Only the single inventory-table entry under "Backends" remains, pointing at the backend file for anyone who wants the details. Both changes respond to xxi-nv's review comments on PR NVIDIA#13773 asking that test_moe_module.py / test_moe_backend.py cover the new backend and that the MoE guide stay high-level. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…to-promote Replace the user-facing `moe_backend: FLASHINFER_NVFP4SM12X` knob with transparent heuristic auto-promotion on the `CUTLASS` path. When the user selects `moe_backend: CUTLASS` (the default), `get_moe_cls()` now returns `FlashInferNvfp4Sm12xFusedMoE` automatically when: - quant_config has NVFP4 - SM version is 120 or 121 - `import flashinfer` succeeds Otherwise it returns `CutlassFusedMoE` (the pre-PR behaviour). The class itself, its weight lifecycle, and its hybrid `m >= 64` decode dispatch are unchanged — only the selection plumbing moves. This responds to xxi-nv's review comment on PR NVIDIA#13773 asking whether the b12x backend could be selected via a heuristic rather than an explicit name. Mirrors the existing `MEGAMOE_DEEPGEMM` pattern of `can_implement`-gated promotion with a CUTLASS fallback. Drops `"FLASHINFER_NVFP4SM12X"` from `MoeConfig.backend` Literal — the class stays importable as an internal API for tests and for direct construction, but is no longer a valid user-facing config string. Tests in `test_flashinfer_nvfp4_sm12x_moe_backend.py` flipped from "explicit name raises on bad config" to "heuristic auto-promotes vs falls back to CutlassFusedMoE". Internal `MoeBackendType` entry kept so `test_moe_backend.py` parametrization continues to cover the backend; `create_test_backend` routes the enum through `moe_backend="CUTLASS"` to exercise the same code path users hit. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…MOE guide test_moe_module.py: register MoeBackendType.FLASHINFER_NVFP4SM12X in BACKEND_TYPES so the unified ConfigurableMoE matrix exercises it. _create_model_config maps the internal enum value to moe_backend="CUTLASS" before passing into ModelConfig — the enum is internal-only after the heuristic auto-promote landed; users reach the backend via the CUTLASS path. MOE_DEVELOPER_GUIDE.md: remove the dedicated FlashInferNvfp4Sm12xFusedMoE section (composition / dispatch policy / weight-conversion algebra / hard-reject list) and drop the Nvfp4Sm12x matrix column. The class's NVFP4 support on SM120/121 is already covered by the CUTLASS row in the matrix (auto-promote target). Only the single inventory-table entry under "Backends" remains, pointing at the backend file for anyone who wants the details. Both changes respond to xxi-nv's review comments on PR NVIDIA#13773 asking that test_moe_module.py / test_moe_backend.py cover the new backend and that the MoE guide stay high-level. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
df2234a to
cfc5d8b
Compare
…brid CUTLASS-prefill / b12x-decode)
Adds the FLASHINFER_NVFP4SM12X MoE backend, selectable via
moe_config.backend: FLASHINFER_NVFP4SM12X. Targets Nemotron-Super-120B-NVFP4
on SM120 (RTX PRO 6000 / GB202) and SM121 (DGX Spark / GB10).
Composition (see MOE_DEVELOPER_GUIDE.md for the full explainer):
- Prefill (m >= 64) routes through the inherited CutlassFusedMoE NVFP4
GroupGEMM. The b12x kernel's 12-CTA-per-token MMA pattern is suboptimal
at large m.
- Decode (m < 64) dispatches to FlashInfer's B12xMoEWrapper.run, a kernel
purpose-built for m=1 / small routed-row counts.
NVFP4 weights are loaded once via the inherited NVFP4 quant method;
post_load_weights then prepares the b12x-shaped weight tensors alongside
the existing CUTLASS layout (un-normalize FP8 block scales, apply
convert_sf_to_mma_layout, prep w*_alpha for b12x's dual-use convention).
Both layouts coexist; the dispatcher picks per call based on x.shape[0].
CUDA graph capture only covers decode in TRT-LLM, so captured graphs
always replay the b12x path; eager prefill always runs CUTLASS — no
graph-capture conflict.
Hard-rejects EP, MoE alltoall, Fp4QuantizedTensor input on the decode
path, swiglu_gptoss_style biased SwiGLU, and activations outside
{Relu2, Swiglu}. Misconfigured selection raises at get_moe_cls time
rather than silently falling back to CUTLASS.
Replaces the prior FLASHINFER backend identifier (which exposed only the
pure-FlashInfer / b12x path with a +48.6% TTFT regression at prefill).
The hybrid composition eliminates that regression and beats CUTLASS by
+21.7% throughput / -17.6% TPOT on Nemotron-Super-120B-NVFP4 at conc=1.
Bench numbers and full investigation in
.claude_docs/nemo-fp4-moe-b12x-mr/HYBRID_RESULTS.md.
Tests: 23 unit tests in
tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py
(19 negative-path can_implement / get_moe_cls tests + 4 hybrid dispatch
shape-predicate tests). All pass on container (no GPU required).
Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
The b12x MoE kernel introduced by PR NVIDIA#13773 (FLASHINFER_NVFP4SM12X) JIT-compiles via nvidia-cutlass-dsl, whose CUDA 13 runtime libraries ship as a separate optional wheel (nvidia-cutlass-dsl-libs-cu13) and are NOT pulled automatically by the main nvidia-cutlass-dsl wheel. Without this wheel, executor initialization on SM120/SM121 hosts dies with ptxas "Unexpected instruction types specified for '_mma'" because the chip->compute_target conversion falls back to a path that strips the 'a' suffix (sm_120a -> sm_120), and ptxas is then invoked with -opt-arch=sm_120 against PTX that has .target sm_120a with sm_120a-only mma instruction forms. The runtime requirement was documented in the PR body but never made binding via requirements.txt. Pin it explicitly at the same version as the main wheel so fresh builds reproduce the same working environment. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
Adds test_flashinfer_nvfp4_sm12x_moe_backend.py (23 tests covering can_implement gating, get_moe_cls error paths, and the hybrid CUTLASS-prefill / b12x-decode dispatch predicate) to the pre-merge 1-GPU PyTorch section of the SM120 (RTX PRO 6000) test list. No GPU required, ~12s run time, so it stays in the pre-merge tier. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
The `elif issubclass(moe_cls, CutlassFusedMoE)` dispatch added with this backend captured `CuteDslFusedMoE` and `DeepGemmFusedMoE` (both subclass `CutlassFusedMoE`) before their dedicated branches at L326 / L343 could match, leaking `swiglu_alpha` / `swiglu_beta` / `swiglu_limit` into constructors that don't accept them. Restrict the branch to an explicit allowlist (`CutlassFusedMoE`, `FlashInferNvfp4Sm12xFusedMoE`) so the existing exact-class branches keep working. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…framework Per review feedback, register the new backend in the shared MoE test framework so it participates in the same parametrized harness as the other backends: - Add `FLASHINFER_NVFP4SM12X` to `MoeBackendType` enum and `get_backend_class()` map. - Add `should_skip_flashinfer_nvfp4_sm12x()` covering the EP / alltoall hard rejects from `__init__` (SM, quant, dtype and gptoss are already handled by `can_implement()`). - Exclude the backend from `supports_autotuner_capture()` (b12x decode does not go through the autotuner). - Hook the new helper into `get_quick_skip_reason()` skip chain. - Add the new enum value to `test_moe_backend.py::BACKEND_TYPES_TO_TEST` so the existing parametrization picks it up on SM120/SM121 + NVFP4 configs and skips elsewhere via `can_implement()`. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…to-promote Replace the user-facing `moe_backend: FLASHINFER_NVFP4SM12X` knob with transparent heuristic auto-promotion on the `CUTLASS` path. When the user selects `moe_backend: CUTLASS` (the default), `get_moe_cls()` now returns `FlashInferNvfp4Sm12xFusedMoE` automatically when: - quant_config has NVFP4 - SM version is 120 or 121 - `import flashinfer` succeeds Otherwise it returns `CutlassFusedMoE` (the pre-PR behaviour). The class itself, its weight lifecycle, and its hybrid `m >= 64` decode dispatch are unchanged — only the selection plumbing moves. This responds to xxi-nv's review comment on PR NVIDIA#13773 asking whether the b12x backend could be selected via a heuristic rather than an explicit name. Mirrors the existing `MEGAMOE_DEEPGEMM` pattern of `can_implement`-gated promotion with a CUTLASS fallback. Drops `"FLASHINFER_NVFP4SM12X"` from `MoeConfig.backend` Literal — the class stays importable as an internal API for tests and for direct construction, but is no longer a valid user-facing config string. Tests in `test_flashinfer_nvfp4_sm12x_moe_backend.py` flipped from "explicit name raises on bad config" to "heuristic auto-promotes vs falls back to CutlassFusedMoE". Internal `MoeBackendType` entry kept so `test_moe_backend.py` parametrization continues to cover the backend; `create_test_backend` routes the enum through `moe_backend="CUTLASS"` to exercise the same code path users hit. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
…MOE guide test_moe_module.py: register MoeBackendType.FLASHINFER_NVFP4SM12X in BACKEND_TYPES so the unified ConfigurableMoE matrix exercises it. _create_model_config maps the internal enum value to moe_backend="CUTLASS" before passing into ModelConfig — the enum is internal-only after the heuristic auto-promote landed; users reach the backend via the CUTLASS path. MOE_DEVELOPER_GUIDE.md: remove the dedicated FlashInferNvfp4Sm12xFusedMoE section (composition / dispatch policy / weight-conversion algebra / hard-reject list) and drop the Nvfp4Sm12x matrix column. The class's NVFP4 support on SM120/121 is already covered by the CUTLASS row in the matrix (auto-promote target). Only the single inventory-table entry under "Backends" remains, pointing at the backend file for anyone who wants the details. Both changes respond to xxi-nv's review comments on PR NVIDIA#13773 asking that test_moe_module.py / test_moe_backend.py cover the new backend and that the MoE guide stay high-level. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
cfc5d8b to
60a2634
Compare
| blobfile | ||
| openai-harmony==0.0.4 | ||
| nvidia-cutlass-dsl==4.5.0; python_version >= "3.10" | ||
| nvidia-cutlass-dsl-libs-cu13==4.5.0; python_version >= "3.10" |
There was a problem hiding this comment.
CUDA 13 runtime libs for cutlass-dsl; not pulled automatically by the main wheel, but required on SM120/SM121 to avoid a ptxas "Unexpected Instruction types specified for '_mma'" ICE during b12x kernel JIT.
Pre-commit hooks flagged 3 cosmetic formatting tweaks (collapse multi-line ternary/f-string/blank-line) in the MoE test files added/edited earlier in this PR. No behaviour change. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com>
xxi-nv
left a comment
There was a problem hiding this comment.
Thanks for the refactoring. It looks much better. However, there are still some tasks that could make it more consistent with our design.
Approved and please address the comment.
| if sm_version in FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS: | ||
| try: | ||
| import flashinfer # noqa: F401 | ||
| logger.info( |
There was a problem hiding this comment.
May I know whether the Prefill and Decode has different platforms?
If Prefill and Decode run both on SM120, then the heuristic here will not works, right?
There was a problem hiding this comment.
It looks that you are using _route_to_cutlass as the runtime switch.
| return isinstance(x, torch.Tensor) and x.shape[0] >= self._PREFILL_VIA_CUTLASS_THRESHOLD | ||
|
|
||
| def post_load_weights(self): | ||
| """Build the b12x weight dict and instantiate ``B12xMoEWrapper``. |
There was a problem hiding this comment.
It seems that the skills work partially.
It is suggested to apply the post_load_weights to a dedicated quantization_method.
The principle here is to leave all the implementation details into the quantization_method.
E.G.
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/modules/fused_moe/quantization.py#L877
Results doc for the b12x W4A16 dense GEMM work split across v1 (Triton, production), v2 (CuTe-DSL bit-exact 5-step build-up), and v3.1 (CuTe-DSL gmem widening). * v1 Triton: bit-exact across Nano35 dense shapes, ~7x behind TRT-LLM. * v2 CuTe-DSL: 5 commits (5373882..ec85630) for smem alloc, gmem->smem copy, FP4 dequant, full-K-loop MMA, smem-staged epilogue. Bit-exact on every M=16 Nano35 shape (cos=1.0). * v3.1: 128-bit gmem copies via Uint32 recast (dd7ead9). Bit-exact; side win is JIT compile time dropping from >3 min to ~5 s on the canonical probe. * v3.2 LdMatrix/StMatrix attempt reverted (7aaab43) -- needs swizzled smem layouts; tracked as the next biggest perf lever. TRT-LLM + Triton M=16 baselines captured; cute v3.1 perf on q_proj M=16 deferred (long JIT compile on this host). Signed-off-by: Faraz Khoubsirat <fkhoubsirat@nvidia.com>
Captured fresh M=16 Nano35 dense-shape latencies for the in-flight W4A16 work (b12x master @ 7aaab43). | Shape | K | N | TRT-LLM us | b12x.triton us | ratio | |--------------------|------|--------|------------|----------------|-------| | q_proj | 2688 | 4096 | 25.0 | 149.0 | 6.0x | | k_proj | 2688 | 256 | 19.1 | 149.0 | 7.8x | | v_proj | 2688 | 256 | 19.0 | 149.0 | 7.8x | | o_proj | 4096 | 2688 | 23.1 | 216.9 | 9.4x | | shared_expert.up | 2688 | 3712 | 23.9 | 147.8 | 6.2x | | shared_expert.down | 3712 | 2688 | 24.0 | 199.3 | 8.3x | | lm_head | 2688 | 131072 | - | 966.1 | n/a | TRT-LLM rejects lm_head (N=131072 out of cuda_core_nvfp4_gemm dispatch envelope); b12x is the only working kernel there. Cute v3.1 perf not captured (~5 min JIT compile per shape on this host); rough Triton parity expected until v3.2 LdMatrix lands. Companion: W4A16_DENSE_RESULTS.md (full v1+v2+v3.1 narrative). Signed-off-by: Faraz Khoubsirat <fkhoubsirat@nvidia.com>
…to RESULTS Three companion docs for the b12x W4A16 dense GEMM work alongside the already-tracked W4A16_DENSE_RESULTS.md: * W4A16_DENSE_DESIGN.md -- architecture, file layout, kernel API, target MNK from nsys, accuracy/perf gates, build flow. * W4A16_DENSE_PLAN.md -- 13-task TDD implementation plan that drove the v1 Triton landing. * W4A16_DENSE_PERF_SNAPSHOT.md -- fresh M=16 Nano35 medians (TRT-LLM cuda_core_nvfp4_gemm vs b12x.triton), with reproduction snippet. Signed-off-by: Faraz Khoubsirat <fkhoubsirat@nvidia.com>
…arlin Captures the v4 forked kernel (b12x/gemm/w4a16/_cute_dense_kernel.py in the b12x repo, master @ 3d2025f) and its head-to-head against the silicon Marlin baseline on p4242-0053 (NVIDIA GB10, SM121). Headline: v4 lands at 56-2040 us across the Nano3.5 dense decode shapes at M=16 — 100x faster than the previous v3 minimal kernel. v4 beats Marlin on shared_fc2 (0.97x ratio), and sits 1.5-2.3x off on the other shapes. Includes tuning sweep tables for (tile_n, tile_k, ab_stage) and the n_per_cta A-reuse experiment; both confirmed the default config (32, 64, 64, ab_stage=2, n_per_cta=1) is the right point in the performance space for the current architecture.
DSV3-Lite NVFP4 / B300 sm_103 spike for TRTLLM-12510. Captures the hands-on evaluation that produced: - ~10% MLA decode kernel-level speedup (TokenSpeed CuTe DSL vs FlashInfer / trtllm-gen) at q_len_per_req=1, 32-token decode - Parity unit-test result: 5 of 7 runnable cases pass; 2 spec-decode (BS=8 / q_len=4) cases diverge (max abs 0.33, max rel ~1166x) - exactly the regime where TokenSpeed claims its 2x headline win - Finding that the default DSV3-Lite NVFP4 + sm_103 path dispatches MLA decode through C++ thop direct cubin launch, bypassing the Python flashinfer wrapper the spike env-var swap was patched into - rc14 trtllm-bench upstream bug: q_len_per_req computed as (1 - input_length) under the TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION=1 path when run under trtllm-bench multi-request scheduling - tokenspeed-mla 0.1.2 LSE wrapper bug (workaround in patches/) - Design draft for TokenSpeedMLAAttention(TrtllmAttention) following PR NVIDIA#13773's MoE backend integration pattern step-1-env through step-7-bench document each phase; summary.md is the cross-cutting writeup. design-tokenspeed-attn-backend.md is the follow-up integration design. .gitignore in this directory excludes the ~1.2 GB of nsys traces (*.nsys-rep, *.sqlite) the spike produced - findings are cited in the markdown reports; raw traces stay on local disk. Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Description
Adds a new MoE backend,
FlashInferNvfp4Sm12xFusedMoE, for Nemotron-Super-120B-NVFP4 on SM120 (RTX 5090 / RTX PRO 6000 / GB202) and SM121 (DGX Spark / GB10). It is auto-selected on theCUTLASSpath whenquant_algo == NVFP4, SM is 120/121, andflashinferis importable — otherwise falls back to plainCutlassFusedMoE. There is no new user-facingmoe_backendvalue; the heuristic inget_moe_cls()mirrors the existingMEGAMOE_DEEPGEMMcan_implement-gated pattern.Hybrid composition:
x.shape[0] >= 64) routes through the inheritedCutlassFusedMoENVFP4 GroupGEMM. b12x's 12-CTA-per-token MMA pattern is suboptimal at largem.x.shape[0] < 64) dispatches to FlashInfer'sB12xMoEWrapper.run. Beats CUTLASS by ~17.6 % TPOT.Hardware constraints (rejected by
can_implementor__init__)NVFP4 only; SM120/121 only; bf16/fp16 activation only;
ep_size == 1only; no MoE alltoall; activations limited toRelu2andSwiglu; noswiglu_gptoss_style.Fp4QuantizedTensorinput rejected on the decode path (b12x quantizes activations internally).Performance
Single RTX PRO 6000 Blackwell (SM120, 97 GB), Nemotron-Super-120B-NVFP4, TRT-LLM 1.3.0rc14, FlashInfer 0.6.8, cutlass-dsl 4.5.0; ISL=2048, OSL=1024, 5 reqs, conc=1, KV reuse off,
cuda_graph_config: {batch_sizes: [1]},max_num_tokens=2048.moe_backend: CUTLASS(pre-PR)moe_backend: CUTLASS(this PR, auto-promoted)Matches CUTLASS on TTFT (prefill on CUTLASS), matches pure b12x on TPOT (decode on b12x), beats both on total throughput. Tokens/Watt +20.4 % vs CUTLASS.
GSM8K accuracy parity (1319 samples, 8-shot CoT, greedy)
CUTLASS(baseline)CUTLASS(auto-promoted, this PR)Statistically indistinguishable (≈ 2 questions out of 1319; well within 95 % binomial CI of ±1.5 pp).
PR Checklist
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
New Features
Documentation