Skip to content

[ML3] Optimized Router Gemm#2323

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
dbari:dbariamis/ml3-routergemm
Jan 13, 2026
Merged

[ML3] Optimized Router Gemm#2323
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
dbari:dbariamis/ml3-routergemm

Conversation

@dbari
Copy link
Copy Markdown
Contributor

@dbari dbari commented Jan 9, 2026

📌 Description

This PR extends #2019 by @nvmbreughe for use with Mistral Large 3, which has similar MoE routing to DSV3. The differences are:

  • The number of experts is 128 (DSV3: 256)
  • The output type is bfloat16 (DSV3: float32)

The code has been extended with a template argument for the output type, which is passed to the kernel.

Also, the explicit instantiations have been removed to make the code more concise, because they are handled by the call to the LoopUnroller.

Performance measurement for batch size 16 on a B200:

Method Runtime
torch.nn.functional.linear 45.6us
mm_M1_16_K7168_N128 9.41us

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • New Features

    • Added a bfloat16-capable router GEMM variant (128-expert) alongside the existing float32 256-expert variant; both are available via the public API.
  • Tests

    • Expanded tests to cover multiple data types and both configuration variants.
  • Chores

    • Generalized shape-validation and streamlined public exports to support multiple output dtypes.
  • Benchmarks

    • Added a benchmark script to measure and compare router GEMM performance across configs.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 9, 2026

📝 Walkthrough

Walkthrough

This PR generalizes the DSV3 router GEMM to separate input/output types (Tin, Tout), adds a bfloat16 entrypoint (ml3 / N128) alongside the existing float32 path (dsv3 / N256), updates kernel and host dispatch signatures, exposes the new Python API and shape checks, and updates tests and benchmarks.

Changes

Cohort / File(s) Summary
Core CUDA Kernel Generalization
\include/flashinfer/gemm/dsv3_router_gemm.cuh``
Kernel templated over Tin/Tout; signature/pointer types changed from single T to Tin/Tout (router_gemm_kernel(Tout* out, Tin const* mat_a, Tin const* mat_b)).
Host-side Kernel Dispatch
\csrc/dsv3_router_gemm.cu``
invokeRouterGemm and launch paths templated over Tin/Tout, added use_pdl param, introduced generic_router_gemm_op dispatcher, added ml3_router_gemm_op (bf16) while retaining dsv3_router_gemm_op (float32), removed many explicit instantiations.
Python API & Shape Validation
\flashinfer/gemm/routergemm_dsv3.py``
Consolidated shape checks into _mm_M1_16_K7168_shape_checks(...); added _mm_M1_16_K7168_N256 and new _mm_M1_16_K7168_N128 wrappers; registered "flashinfer::ml3_router_gemm_op" and exposed mm_M1_16_K7168_N128 alongside N256.
Package Exports
\flashinfer/gemm/init.py`, `flashinfer/dsv3_ops/init.py``
Re-exported and added mm_M1_16_K7168_N128 to module exports/__all__.
Bindings/Exports
\csrc/*` (export table updated)`
Export table updated to include ml3_router_gemm_op in addition to dsv3_router_gemm_op.
Tests
\tests/model_optimizations/test_dsv3_router_gemm.py``
Tests parametrized for both N128 (bf16) and N256 (float32) variants; test signatures updated to accept output dtype and function handles/arrays; negative tests iterate over function variants.
Benchmarks
\benchmarks/bench_router_gemm.py``
New benchmark script added to exercise Torch and FlashInfer paths across configurations and report TFLOPs/s.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API\n(mm_M1_16_K7168_N128 / N256)
    participant FFI as FFI Dispatcher
    participant Host as Host Dispatcher\n(generic_router_gemm_op)
    participant Kernel as CUDA Kernel\n(router_gemm_kernel<Tin,Tout>)

    Py->>FFI: call mm_M1_16_K7168_N{128|256}(mat_a, mat_b, out, pdl)
    FFI->>Host: invoke ml3_router_gemm_op / dsv3_router_gemm_op
    Host->>Host: select Tin/Tout, kNumExperts, launch params
    Host->>Kernel: launch router_gemm_kernel<Tin,Tout>(out, mat_a, mat_b)
    Kernel-->>Host: complete
    Host-->>FFI: return
    FFI-->>Py: complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~55 minutes

Possibly related PRs

  • [DSV3] Optimized Router Gemm #2019: Prior PR adding bf16-optimized router GEMM and initial dsv3_router_gemm_op/launch machinery; touches the same kernel/launch areas and is directly related.

Suggested reviewers

  • ttyio
  • aleozlx
  • djmmoss
  • yongwww
  • cyx-6
  • bkryu
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • jiahanc

Poem

🐰 Tin and Tout, I leap and sing,

Kernels learn a bolder thing,
N128 joins N256's tune,
Templates hop beneath the moon,
A happy rabbit cheers—huzzah!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 9.52% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[ML3] Optimized Router Gemm' clearly summarizes the main change: adding optimized router GEMM support for Mistral Large 3 (ML3).
Description check ✅ Passed The description covers all key requirements: it explains what the PR does (extends #2019 for ML3), specifies differences from DSV3 (128 experts vs 256, bfloat16 vs float32 output), documents the template changes, includes performance metrics, and confirms pre-commit and testing checkboxes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @dbari, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for the Mistral Large 3 model's Mixture of Experts (MoE) routing by extending the existing optimized router GEMM implementation. The core changes involve making the CUDA kernel more generic to handle different output data types and expert counts, removing explicit template instantiations, and providing a new Python API tailored for Mistral Large 3's specific parameters. This ensures efficient and correct computation for the new model while maintaining compatibility with the previous DSV3 implementation.

Highlights

  • Mistral Large 3 Support: Extends the router GEMM to support Mistral Large 3, which uses 128 experts and bfloat16 output, differing from DSV3's 256 experts and float32 output.
  • Generic Output Type: The invokeRouterGemm CUDA kernel now uses a template argument (Tout) for the output data type, making it more flexible.
  • Removed Explicit Instantiations: Redundant explicit template instantiations for invokeRouterGemm have been removed, as the LoopUnroller mechanism handles the necessary specializations.
  • Generalized Router GEMM Logic: A new generic_router_gemm_op template function was introduced in the C++ backend to consolidate common logic for different router GEMM configurations.
  • New Python API for ML3: A new Python function mm_M1_16_K7168_N128 and its corresponding C++ binding ml3_router_gemm_op are added to expose the Mistral Large 3 optimized GEMM.
  • Comprehensive Testing: Existing test cases have been updated and expanded to cover both DSV3 and Mistral Large 3 configurations, including positive and negative tests for various input parameters and data types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively extends the optimized router GEMM for Mistral Large 3 by generalizing the existing implementation for DeepSeek-V3. The changes are well-executed, introducing template parameters for the output data type and number of experts in the CUDA/C++ code, which improves code reuse. The corresponding Python bindings and tests are properly updated to support the new model configuration. The removal of explicit template instantiations in favor of relying on the LoopUnroller is a nice cleanup. Overall, this is a high-quality contribution. I have one minor suggestion to add a missing API decorator for consistency.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)

127-173: Missing @flashinfer_api decorator.

The new mm_M1_16_K7168_N128 function is missing the @flashinfer_api decorator that is present on mm_M1_16_K7168_N256 (line 176). Per coding guidelines, use @flashinfer_api decorator for debugging API calls.

Proposed fix
+@flashinfer_api
 @backend_requirement({}, common_check=_mm_M1_16_K7168_N128_shape_checks)
 def mm_M1_16_K7168_N128(
     mat_a: torch.Tensor,
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd2b033 and 5b26f49.

📒 Files selected for processing (6)
  • csrc/dsv3_router_gemm.cu
  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
  • include/flashinfer/gemm/dsv3_router_gemm.cuh
  • tests/model_optimizations/test_dsv3_router_gemm.py
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/gemm/dsv3_router_gemm.cuh
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/model_optimizations/test_dsv3_router_gemm.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/dsv3_router_gemm.cu
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/dsv3_ops/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/dsv3_router_gemm.cu
🧬 Code graph analysis (3)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N128 (101-107)
  • mm_M1_16_K7168_N128 (128-173)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N128 (101-107)
  • mm_M1_16_K7168_N128 (128-173)
flashinfer/gemm/routergemm_dsv3.py (5)
flashinfer/utils.py (2)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (18-24)
flashinfer/jit/core.py (1)
  • build_and_load (303-315)
csrc/tvm_ffi_utils.h (1)
  • Tensor (316-318)
csrc/dsv3_router_gemm.cu (2)
  • ml3_router_gemm_op (95-98)
  • ml3_router_gemm_op (95-95)
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py

14-14: Unused function argument: launch_with_pdl

(ARG001)


62-62: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (17)
flashinfer/dsv3_ops/__init__.py (1)

1-10: LGTM!

The new mm_M1_16_K7168_N128 export is correctly added alongside the existing N256 variant, maintaining consistency in the public API. Based on learnings, this follows the pattern of exporting new operations to make them available as public API.

flashinfer/gemm/__init__.py (2)

19-22: LGTM!

The import and re-export of mm_M1_16_K7168_N128 follows the established pattern in this module.


38-39: LGTM!

The __all__ list is correctly updated to include the new N128 variant.

tests/model_optimizations/test_dsv3_router_gemm.py (4)

10-16: Well-structured test parametrization.

The parametrization cleanly associates each configuration (N256/N128) with its expected output dtype and function reference, enabling comprehensive testing of both variants.


69-92: Good boundary testing for N128 variant.

The negative tests correctly validate that N128 rejects num_experts values outside the expected 128. The shape checks validate num_experts before dtype, so using out_dtype=torch.float32 here doesn't affect the test outcome.


167-178: Correct dtype validation test for N128.

This test correctly verifies that N128 rejects float32 output tensors since it expects bfloat16 output.


252-254: LGTM!

The loop correctly iterates over all functions in fn_array, ensuring each variant is tested against the invalid inputs.

csrc/dsv3_router_gemm.cu (4)

8-28: Clean template generalization.

The invokeRouterGemm function is well-templated over Tin/Tout, and the VPT calculation correctly uses sizeof(Tin) to handle different input types. The kernel launch configuration is correctly parameterized.


30-57: LGTM!

The LoopUnroller correctly propagates the Tout template parameter through the unroll chain, enabling output type flexibility while keeping input type as __nv_bfloat16.


59-89: Well-factored generic dispatch.

The generic_router_gemm_op template cleanly abstracts the configuration differences between DSv3 (N256/float32) and ML3 (N128/bfloat16), avoiding code duplication.


91-103: LGTM!

Both entry points are correctly implemented and exported via TVM-FFI. As per coding guidelines, framework bindings are properly implemented in csrc/ via TVM-FFI.

include/flashinfer/gemm/dsv3_router_gemm.cuh (3)

42-45: LGTM!

The kernel signature is correctly generalized with Tin/Tout template parameters. As per coding guidelines, Torch headers are not included in this include/ directory file, keeping the kernel framework-agnostic.


62-62: LGTM!

The pointer type correctly uses Tin for the B matrix column pointer.


153-153: Implicit float-to-Tout conversion relies on CUDA's built-in conversion.

The accumulation in float and final conversion to Tout at write time is numerically sound. For Tout=__nv_bfloat16, CUDA handles the implicit conversion correctly.

flashinfer/gemm/routergemm_dsv3.py (3)

13-64: Well-factored shape validation.

The generalized _mm_M1_16_K7168_shape_checks cleanly parameterizes the expected number of experts and output dtype, avoiding code duplication between N128 and N256 variants.

Note: The launch_with_pdl parameter (flagged by static analysis) is intentionally passed through for signature consistency with the wrapped operation, even though it's not used in shape validation.


67-90: LGTM!

The specialized shape check wrappers correctly delegate to the generalized function with appropriate parameters for each variant.


97-124: LGTM!

The module correctly registers both N128 and N256 custom ops and exposes them via the SimpleNamespace.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)

13-65: Tighten stride/contiguity validation; current checks may accept unsupported layouts.

For a specialized kernel, allowing padded leading strides (while stride(1)==1) can silently produce wrong results if the kernel assumes contiguous row-major/column-major. Also, the “before dimension checks” comment is stale, and Ruff’s ARG001 can be avoided by naming the arg unused.

Proposed fix
 def _mm_M1_16_K7168_shape_checks(
-    mat_a, mat_b, out, launch_with_pdl, expected_num_experts, expected_out_dtype
+    mat_a, mat_b, out, _launch_with_pdl, expected_num_experts, expected_out_dtype
 ):
     # Dimension checks
     if mat_a.dim() != 2:
         raise ValueError("mat_a must be a 2D tensor")
     if mat_b.dim() != 2:
         raise ValueError("mat_b must be a 2D tensor")
     if out.dim() != 2:
         raise ValueError("out must be a 2D tensor")

-    # Stride checks (check these before dimension checks to give better error messages)
+    # Stride/layout checks
     if mat_a.stride(1) != 1:
         raise ValueError("mat_a must be row-major")
+    if mat_a.stride(0) != mat_a.shape[1]:
+        raise ValueError("mat_a must be contiguous row-major")
     if out.stride(1) != 1:
         raise ValueError("out must be row-major")
+    if out.stride(0) != out.shape[1]:
+        raise ValueError("out must be contiguous row-major")
     if mat_b.stride(0) != 1:
         raise ValueError("mat_b must be column-major")
+    if mat_b.stride(1) != mat_b.shape[0]:
+        raise ValueError("mat_b must be contiguous column-major")
🧹 Nitpick comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)

93-124: Symbol wiring is correct, but consider renaming the inner wrapper function to avoid confusion.

The ml3_router_gemm_op symbol is properly exported from csrc/dsv3_router_gemm.cu (lines 102–103) and correctly called in module.ml3_router_gemm_op(...). However, the inner wrapper function mm_M1_16_K7168_N128 shadows the name of the public API function in the outer scope (around line 155), which can be confusing during debugging. Consider renaming the inner wrapper (e.g., _ml3_router_gemm_wrapper) to clarify the distinction.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b26f49 and 80a7e23.

📒 Files selected for processing (1)
  • flashinfer/gemm/routergemm_dsv3.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/gemm/routergemm_dsv3.py
🧠 Learnings (5)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py

14-14: Unused function argument: launch_with_pdl

(ARG001)


62-62: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/gemm/routergemm_dsv3.py (3)

67-90: Nice: N128/N256 wrappers keep the common shape checks centralized and typed.

This keeps the API surface clean while enabling ML3 vs DSV3 constraints.


127-175: Public ML3 API looks consistent (dtype/shape constraints + cached module).

Docstring is clear, and routing through get_dsv3_router_gemm_module() keeps compilation cached as per guidelines.


177-224: DSV3 path preserved cleanly while reusing the generalized checks.

The N256 wrapper still enforces out as float32, matching the PR objective.

@dbari dbari force-pushed the dbariamis/ml3-routergemm branch from 80a7e23 to 27bb501 Compare January 12, 2026 08:57
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/model_optimizations/test_dsv3_router_gemm.py (1)

167-178: Incorrect expected_error for N128 invalid output dtype test.

This test case uses out_dtype=torch.float32 for mm_M1_16_K7168_N128, which expects bfloat16 output. The expected_error="bfloat16" will match the error message since the kernel expects bfloat16 but receives float32. However, for clarity and consistency with the N256 test at lines 203-214 (which uses expected_error="float32" when float32 is expected), consider using a more specific pattern.

That said, since the error message will contain "bfloat16" when the dtype mismatch is detected, this test should still pass. The logic is correct, just the naming/pattern is less intuitive compared to the N256 case.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 80a7e23 and 27bb501.

📒 Files selected for processing (6)
  • csrc/dsv3_router_gemm.cu
  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
  • include/flashinfer/gemm/dsv3_router_gemm.cuh
  • tests/model_optimizations/test_dsv3_router_gemm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/gemm/dsv3_router_gemm.cuh
🧰 Additional context used
📓 Path-based instructions (3)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/model_optimizations/test_dsv3_router_gemm.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
  • flashinfer/gemm/__init__.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/dsv3_router_gemm.cu
🧠 Learnings (6)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
  • flashinfer/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/dsv3_ops/__init__.py
  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/dsv3_router_gemm.cu
🧬 Code graph analysis (4)
tests/model_optimizations/test_dsv3_router_gemm.py (1)
flashinfer/gemm/routergemm_dsv3.py (4)
  • mm_M1_16_K7168_N128 (101-107)
  • mm_M1_16_K7168_N128 (129-174)
  • mm_M1_16_K7168_N256 (113-119)
  • mm_M1_16_K7168_N256 (179-224)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (4)
  • mm_M1_16_K7168_N128 (101-107)
  • mm_M1_16_K7168_N128 (129-174)
  • mm_M1_16_K7168_N256 (113-119)
  • mm_M1_16_K7168_N256 (179-224)
csrc/dsv3_router_gemm.cu (3)
flashinfer/jit/core.py (1)
  • status (151-155)
flashinfer/comm/cuda_ipc.py (1)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • encode_dlpack_dtype (30-32)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N128 (101-107)
  • mm_M1_16_K7168_N128 (129-174)
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py

14-14: Unused function argument: launch_with_pdl

(ARG001)


62-62: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (12)
csrc/dsv3_router_gemm.cu (4)

5-28: LGTM! Clean generalization of the kernel invocation.

The templated invokeRouterGemm function correctly separates input (Tin) and output (Tout) types, and the VPT calculation properly uses sizeof(Tin) for vector processing. The kernel launch configuration and error handling are appropriate.


30-57: LGTM! LoopUnroller correctly propagates output type.

The template specializations properly unroll over Tout and delegate to invokeRouterGemm with the correct type parameters. The base case at kEnd correctly throws for out-of-range token counts.


59-89: LGTM! Well-structured generic dispatch path.

The generic_router_gemm_op function cleanly parameterizes Tout, tout_code, kNumExperts, and token range bounds. The validation checks for dimensions, strides, and data types are comprehensive.


91-103: LGTM! Entry points and exports are correctly configured.

The dsv3_router_gemm_op and ml3_router_gemm_op functions correctly delegate to the generic path with appropriate template arguments:

  • DSV3: float output, 256 experts
  • ML3: __nv_bfloat16 output, 128 experts

The TVM_FFI exports properly expose both functions.

flashinfer/dsv3_ops/__init__.py (1)

1-10: LGTM! Public API correctly updated.

The new mm_M1_16_K7168_N128 symbol is properly imported from flashinfer.gemm and exposed in __all__, consistent with the existing mm_M1_16_K7168_N256 pattern.

flashinfer/gemm/__init__.py (1)

21-24: LGTM! Import and export correctly added.

The new mm_M1_16_K7168_N128 is properly imported from routergemm_dsv3 and added to __all__, following the same pattern as the existing N256 variant. Based on learnings, this aligns with the guideline to export new operations in module __init__.py files.

Also applies to: 42-43

tests/model_optimizations/test_dsv3_router_gemm.py (2)

9-16: LGTM! Positive test correctly parametrized for both variants.

The test correctly parametrizes over both router GEMM variants with their respective expected output dtypes (float32 for N256, bfloat16 for N128) and function references. The output tensor is correctly allocated with the parametrized output_dtype. Based on coding guidelines, the test appropriately uses get_compute_capability to skip on unsupported architectures.

Also applies to: 19-21, 31-32


230-254: LGTM! Negative test structure is sound.

The test correctly iterates over fn_array to verify that each function raises ValueError with the expected error message pattern. This allows testing both functions for common error conditions while still supporting function-specific tests.

flashinfer/gemm/routergemm_dsv3.py (4)

13-64: LGTM! Well-designed generalized shape validation.

The _mm_M1_16_K7168_shape_checks function cleanly parameterizes expected_num_experts and expected_out_dtype, enabling reuse for both N128 and N256 variants while maintaining comprehensive validation. The launch_with_pdl parameter is intentionally passed through for decorator consistency, even though it's unused in the checks themselves.


67-90: LGTM! Shape check wrappers correctly configured.

Both _mm_M1_16_K7168_N256_shape_checks and _mm_M1_16_K7168_N128_shape_checks correctly delegate to the generalized checker with appropriate expected values:

  • N256: 256 experts, float32 output
  • N128: 128 experts, bfloat16 output

The @supported_compute_capability([100]) decorator appropriately restricts to SM100 (Blackwell).


93-124: LGTM! Module registration correctly exposes both ops.

The get_dsv3_router_gemm_module function properly registers both custom ops with @register_custom_op and returns a namespace containing both functions. The @functools.cache decorator ensures module-level caching per coding guidelines.


127-174: LGTM! New public API function follows established patterns.

The mm_M1_16_K7168_N128 function correctly uses @backend_requirement with the N128 shape check and @flashinfer_api decorator per coding guidelines. The comprehensive docstring accurately describes the Mistral Large 3 router GEMM constraints (128 experts, bfloat16 output).

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !240 has been created, and the CI pipeline #41614696 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should be ready to merge as long as CI passed.

Also @dbari would you mind also adding the number of experts is 128 case in benchmarks?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
benchmarks/bench_router_gemm.py (2)

32-43: Cold L2 benchmarking won't work with closure-captured tensors.

Per the bench_gpu_time_with_cudagraph documentation, GPU tensors must be passed via input_args/input_kwargs (not captured in a closure) for cold L2 cache benchmarking to work. Currently, data is captured in the lambda, so the rotating buffer logic cannot detect or clone the tensors.

Additionally, parameter names reps/warmup_reps are misleading since they represent time in milliseconds, not iteration counts.

♻️ Suggested fix to enable cold L2 benchmarking
-def bench_router_gemm(gemm_fn, data, M, N, K, reps=1000, warmup_reps=1000):
+def bench_router_gemm(gemm_fn, data, M, N, K, repeat_time_ms=1000, warmup_time_ms=1000):
     measurements = bench_gpu_time_with_cudagraph(
-        lambda: gemm_fn(*data),
-        dry_run_time_ms=warmup_reps,
-        repeat_time_ms=reps,
+        lambda *args: gemm_fn(*args),
+        dry_run_time_ms=warmup_time_ms,
+        repeat_time_ms=repeat_time_ms,
+        input_args=data,
     )

41-43: Consider using gemm_fn.__name__ for readable output.

Printing the function object directly will show the memory address (e.g., <function reference_torch at 0x...>). Using gemm_fn.__name__ would produce cleaner output.

♻️ Suggested improvement
     print(
-        f"Router GEMM function {gemm_fn} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}"
+        f"Router GEMM function {gemm_fn.__name__} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}"
     )
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27bb501 and 4befbfa.

📒 Files selected for processing (1)
  • benchmarks/bench_router_gemm.py
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_router_gemm.py (1)
flashinfer/testing/utils.py (1)
  • bench_gpu_time_with_cudagraph (1283-1505)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
benchmarks/bench_router_gemm.py (1)

46-75: LGTM!

The main function correctly iterates over token counts and both expert configurations (N128/bfloat16 for ML3, N256/float32 for DSV3), benchmarking the torch reference and FlashInfer implementations with both PDL launch modes.

@yzh119 yzh119 merged commit 4ccc133 into flashinfer-ai:main Jan 13, 2026
7 checks passed
@dbari dbari deleted the dbariamis/ml3-routergemm branch January 15, 2026 08:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants