[ML3] Optimized Router Gemm#2323
Conversation
📝 WalkthroughWalkthroughThis PR generalizes the DSV3 router GEMM to separate input/output types (Tin, Tout), adds a bfloat16 entrypoint (ml3 / N128) alongside the existing float32 path (dsv3 / N256), updates kernel and host dispatch signatures, exposes the new Python API and shape checks, and updates tests and benchmarks. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API\n(mm_M1_16_K7168_N128 / N256)
participant FFI as FFI Dispatcher
participant Host as Host Dispatcher\n(generic_router_gemm_op)
participant Kernel as CUDA Kernel\n(router_gemm_kernel<Tin,Tout>)
Py->>FFI: call mm_M1_16_K7168_N{128|256}(mat_a, mat_b, out, pdl)
FFI->>Host: invoke ml3_router_gemm_op / dsv3_router_gemm_op
Host->>Host: select Tin/Tout, kNumExperts, launch params
Host->>Kernel: launch router_gemm_kernel<Tin,Tout>(out, mat_a, mat_b)
Kernel-->>Host: complete
Host-->>FFI: return
FFI-->>Py: complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~55 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @dbari, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for the Mistral Large 3 model's Mixture of Experts (MoE) routing by extending the existing optimized router GEMM implementation. The core changes involve making the CUDA kernel more generic to handle different output data types and expert counts, removing explicit template instantiations, and providing a new Python API tailored for Mistral Large 3's specific parameters. This ensures efficient and correct computation for the new model while maintaining compatibility with the previous DSV3 implementation. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively extends the optimized router GEMM for Mistral Large 3 by generalizing the existing implementation for DeepSeek-V3. The changes are well-executed, introducing template parameters for the output data type and number of experts in the CUDA/C++ code, which improves code reuse. The corresponding Python bindings and tests are properly updated to support the new model configuration. The removal of explicit template instantiations in favor of relying on the LoopUnroller is a nice cleanup. Overall, this is a high-quality contribution. I have one minor suggestion to add a missing API decorator for consistency.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)
127-173: Missing@flashinfer_apidecorator.The new
mm_M1_16_K7168_N128function is missing the@flashinfer_apidecorator that is present onmm_M1_16_K7168_N256(line 176). Per coding guidelines, use@flashinfer_apidecorator for debugging API calls.Proposed fix
+@flashinfer_api @backend_requirement({}, common_check=_mm_M1_16_K7168_N128_shape_checks) def mm_M1_16_K7168_N128( mat_a: torch.Tensor,
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/dsv3_router_gemm.cuflashinfer/dsv3_ops/__init__.pyflashinfer/gemm/__init__.pyflashinfer/gemm/routergemm_dsv3.pyinclude/flashinfer/gemm/dsv3_router_gemm.cuhtests/model_optimizations/test_dsv3_router_gemm.py
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/dsv3_ops/__init__.pyflashinfer/gemm/__init__.pyflashinfer/gemm/routergemm_dsv3.py
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/gemm/dsv3_router_gemm.cuh
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/model_optimizations/test_dsv3_router_gemm.py
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/dsv3_router_gemm.cu
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/dsv3_ops/__init__.pyflashinfer/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/dsv3_ops/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/dsv3_router_gemm.cu
🧬 Code graph analysis (3)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N128(101-107)mm_M1_16_K7168_N128(128-173)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N128(101-107)mm_M1_16_K7168_N128(128-173)
flashinfer/gemm/routergemm_dsv3.py (5)
flashinfer/utils.py (2)
supported_compute_capability(819-899)backend_requirement(902-1184)flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(18-24)flashinfer/jit/core.py (1)
build_and_load(303-315)csrc/tvm_ffi_utils.h (1)
Tensor(316-318)csrc/dsv3_router_gemm.cu (2)
ml3_router_gemm_op(95-98)ml3_router_gemm_op(95-95)
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py
14-14: Unused function argument: launch_with_pdl
(ARG001)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (17)
flashinfer/dsv3_ops/__init__.py (1)
1-10: LGTM!The new
mm_M1_16_K7168_N128export is correctly added alongside the existing N256 variant, maintaining consistency in the public API. Based on learnings, this follows the pattern of exporting new operations to make them available as public API.flashinfer/gemm/__init__.py (2)
19-22: LGTM!The import and re-export of
mm_M1_16_K7168_N128follows the established pattern in this module.
38-39: LGTM!The
__all__list is correctly updated to include the new N128 variant.tests/model_optimizations/test_dsv3_router_gemm.py (4)
10-16: Well-structured test parametrization.The parametrization cleanly associates each configuration (N256/N128) with its expected output dtype and function reference, enabling comprehensive testing of both variants.
69-92: Good boundary testing for N128 variant.The negative tests correctly validate that N128 rejects num_experts values outside the expected 128. The shape checks validate
num_expertsbeforedtype, so usingout_dtype=torch.float32here doesn't affect the test outcome.
167-178: Correct dtype validation test for N128.This test correctly verifies that N128 rejects
float32output tensors since it expectsbfloat16output.
252-254: LGTM!The loop correctly iterates over all functions in
fn_array, ensuring each variant is tested against the invalid inputs.csrc/dsv3_router_gemm.cu (4)
8-28: Clean template generalization.The
invokeRouterGemmfunction is well-templated overTin/Tout, and theVPTcalculation correctly usessizeof(Tin)to handle different input types. The kernel launch configuration is correctly parameterized.
30-57: LGTM!The
LoopUnrollercorrectly propagates theTouttemplate parameter through the unroll chain, enabling output type flexibility while keeping input type as__nv_bfloat16.
59-89: Well-factored generic dispatch.The
generic_router_gemm_optemplate cleanly abstracts the configuration differences between DSv3 (N256/float32) and ML3 (N128/bfloat16), avoiding code duplication.
91-103: LGTM!Both entry points are correctly implemented and exported via TVM-FFI. As per coding guidelines, framework bindings are properly implemented in
csrc/via TVM-FFI.include/flashinfer/gemm/dsv3_router_gemm.cuh (3)
42-45: LGTM!The kernel signature is correctly generalized with
Tin/Touttemplate parameters. As per coding guidelines, Torch headers are not included in thisinclude/directory file, keeping the kernel framework-agnostic.
62-62: LGTM!The pointer type correctly uses
Tinfor the B matrix column pointer.
153-153: Implicit float-to-Tout conversion relies on CUDA's built-in conversion.The accumulation in
floatand final conversion toToutat write time is numerically sound. ForTout=__nv_bfloat16, CUDA handles the implicit conversion correctly.flashinfer/gemm/routergemm_dsv3.py (3)
13-64: Well-factored shape validation.The generalized
_mm_M1_16_K7168_shape_checkscleanly parameterizes the expected number of experts and output dtype, avoiding code duplication between N128 and N256 variants.Note: The
launch_with_pdlparameter (flagged by static analysis) is intentionally passed through for signature consistency with the wrapped operation, even though it's not used in shape validation.
67-90: LGTM!The specialized shape check wrappers correctly delegate to the generalized function with appropriate parameters for each variant.
97-124: LGTM!The module correctly registers both N128 and N256 custom ops and exposes them via the
SimpleNamespace.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)
13-65: Tighten stride/contiguity validation; current checks may accept unsupported layouts.For a specialized kernel, allowing padded leading strides (while
stride(1)==1) can silently produce wrong results if the kernel assumes contiguous row-major/column-major. Also, the “before dimension checks” comment is stale, and Ruff’s ARG001 can be avoided by naming the arg unused.Proposed fix
def _mm_M1_16_K7168_shape_checks( - mat_a, mat_b, out, launch_with_pdl, expected_num_experts, expected_out_dtype + mat_a, mat_b, out, _launch_with_pdl, expected_num_experts, expected_out_dtype ): # Dimension checks if mat_a.dim() != 2: raise ValueError("mat_a must be a 2D tensor") if mat_b.dim() != 2: raise ValueError("mat_b must be a 2D tensor") if out.dim() != 2: raise ValueError("out must be a 2D tensor") - # Stride checks (check these before dimension checks to give better error messages) + # Stride/layout checks if mat_a.stride(1) != 1: raise ValueError("mat_a must be row-major") + if mat_a.stride(0) != mat_a.shape[1]: + raise ValueError("mat_a must be contiguous row-major") if out.stride(1) != 1: raise ValueError("out must be row-major") + if out.stride(0) != out.shape[1]: + raise ValueError("out must be contiguous row-major") if mat_b.stride(0) != 1: raise ValueError("mat_b must be column-major") + if mat_b.stride(1) != mat_b.shape[0]: + raise ValueError("mat_b must be contiguous column-major")
🧹 Nitpick comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)
93-124: Symbol wiring is correct, but consider renaming the inner wrapper function to avoid confusion.The
ml3_router_gemm_opsymbol is properly exported fromcsrc/dsv3_router_gemm.cu(lines 102–103) and correctly called inmodule.ml3_router_gemm_op(...). However, the inner wrapper functionmm_M1_16_K7168_N128shadows the name of the public API function in the outer scope (around line 155), which can be confusing during debugging. Consider renaming the inner wrapper (e.g.,_ml3_router_gemm_wrapper) to clarify the distinction.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/routergemm_dsv3.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/gemm/routergemm_dsv3.py
🧠 Learnings (5)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py
14-14: Unused function argument: launch_with_pdl
(ARG001)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/gemm/routergemm_dsv3.py (3)
67-90: Nice: N128/N256 wrappers keep the common shape checks centralized and typed.This keeps the API surface clean while enabling ML3 vs DSV3 constraints.
127-175: Public ML3 API looks consistent (dtype/shape constraints + cached module).Docstring is clear, and routing through
get_dsv3_router_gemm_module()keeps compilation cached as per guidelines.
177-224: DSV3 path preserved cleanly while reusing the generalized checks.The N256 wrapper still enforces
outasfloat32, matching the PR objective.
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
80a7e23 to
27bb501
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/model_optimizations/test_dsv3_router_gemm.py (1)
167-178: Incorrect expected_error for N128 invalid output dtype test.This test case uses
out_dtype=torch.float32formm_M1_16_K7168_N128, which expectsbfloat16output. Theexpected_error="bfloat16"will match the error message since the kernel expects bfloat16 but receives float32. However, for clarity and consistency with the N256 test at lines 203-214 (which usesexpected_error="float32"when float32 is expected), consider using a more specific pattern.That said, since the error message will contain "bfloat16" when the dtype mismatch is detected, this test should still pass. The logic is correct, just the naming/pattern is less intuitive compared to the N256 case.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/dsv3_router_gemm.cuflashinfer/dsv3_ops/__init__.pyflashinfer/gemm/__init__.pyflashinfer/gemm/routergemm_dsv3.pyinclude/flashinfer/gemm/dsv3_router_gemm.cuhtests/model_optimizations/test_dsv3_router_gemm.py
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/gemm/dsv3_router_gemm.cuh
🧰 Additional context used
📓 Path-based instructions (3)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/model_optimizations/test_dsv3_router_gemm.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/dsv3_ops/__init__.pyflashinfer/gemm/routergemm_dsv3.pyflashinfer/gemm/__init__.py
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/dsv3_router_gemm.cu
🧠 Learnings (6)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/dsv3_ops/__init__.pyflashinfer/gemm/routergemm_dsv3.pyflashinfer/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/dsv3_ops/__init__.pyflashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
flashinfer/gemm/routergemm_dsv3.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/dsv3_router_gemm.cu
🧬 Code graph analysis (4)
tests/model_optimizations/test_dsv3_router_gemm.py (1)
flashinfer/gemm/routergemm_dsv3.py (4)
mm_M1_16_K7168_N128(101-107)mm_M1_16_K7168_N128(129-174)mm_M1_16_K7168_N256(113-119)mm_M1_16_K7168_N256(179-224)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (4)
mm_M1_16_K7168_N128(101-107)mm_M1_16_K7168_N128(129-174)mm_M1_16_K7168_N256(113-119)mm_M1_16_K7168_N256(179-224)
csrc/dsv3_router_gemm.cu (3)
flashinfer/jit/core.py (1)
status(151-155)flashinfer/comm/cuda_ipc.py (1)
cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
encode_dlpack_dtype(30-32)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N128(101-107)mm_M1_16_K7168_N128(129-174)
🪛 Ruff (0.14.10)
flashinfer/gemm/routergemm_dsv3.py
14-14: Unused function argument: launch_with_pdl
(ARG001)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (12)
csrc/dsv3_router_gemm.cu (4)
5-28: LGTM! Clean generalization of the kernel invocation.The templated
invokeRouterGemmfunction correctly separates input (Tin) and output (Tout) types, and theVPTcalculation properly usessizeof(Tin)for vector processing. The kernel launch configuration and error handling are appropriate.
30-57: LGTM! LoopUnroller correctly propagates output type.The template specializations properly unroll over
Toutand delegate toinvokeRouterGemmwith the correct type parameters. The base case atkEndcorrectly throws for out-of-range token counts.
59-89: LGTM! Well-structured generic dispatch path.The
generic_router_gemm_opfunction cleanly parameterizesTout,tout_code,kNumExperts, and token range bounds. The validation checks for dimensions, strides, and data types are comprehensive.
91-103: LGTM! Entry points and exports are correctly configured.The
dsv3_router_gemm_opandml3_router_gemm_opfunctions correctly delegate to the generic path with appropriate template arguments:
- DSV3:
floatoutput, 256 experts- ML3:
__nv_bfloat16output, 128 expertsThe TVM_FFI exports properly expose both functions.
flashinfer/dsv3_ops/__init__.py (1)
1-10: LGTM! Public API correctly updated.The new
mm_M1_16_K7168_N128symbol is properly imported fromflashinfer.gemmand exposed in__all__, consistent with the existingmm_M1_16_K7168_N256pattern.flashinfer/gemm/__init__.py (1)
21-24: LGTM! Import and export correctly added.The new
mm_M1_16_K7168_N128is properly imported fromroutergemm_dsv3and added to__all__, following the same pattern as the existing N256 variant. Based on learnings, this aligns with the guideline to export new operations in module__init__.pyfiles.Also applies to: 42-43
tests/model_optimizations/test_dsv3_router_gemm.py (2)
9-16: LGTM! Positive test correctly parametrized for both variants.The test correctly parametrizes over both router GEMM variants with their respective expected output dtypes (float32 for N256, bfloat16 for N128) and function references. The output tensor is correctly allocated with the parametrized
output_dtype. Based on coding guidelines, the test appropriately usesget_compute_capabilityto skip on unsupported architectures.Also applies to: 19-21, 31-32
230-254: LGTM! Negative test structure is sound.The test correctly iterates over
fn_arrayto verify that each function raisesValueErrorwith the expected error message pattern. This allows testing both functions for common error conditions while still supporting function-specific tests.flashinfer/gemm/routergemm_dsv3.py (4)
13-64: LGTM! Well-designed generalized shape validation.The
_mm_M1_16_K7168_shape_checksfunction cleanly parameterizesexpected_num_expertsandexpected_out_dtype, enabling reuse for both N128 and N256 variants while maintaining comprehensive validation. Thelaunch_with_pdlparameter is intentionally passed through for decorator consistency, even though it's unused in the checks themselves.
67-90: LGTM! Shape check wrappers correctly configured.Both
_mm_M1_16_K7168_N256_shape_checksand_mm_M1_16_K7168_N128_shape_checkscorrectly delegate to the generalized checker with appropriate expected values:
- N256: 256 experts, float32 output
- N128: 128 experts, bfloat16 output
The
@supported_compute_capability([100])decorator appropriately restricts to SM100 (Blackwell).
93-124: LGTM! Module registration correctly exposes both ops.The
get_dsv3_router_gemm_modulefunction properly registers both custom ops with@register_custom_opand returns a namespace containing both functions. The@functools.cachedecorator ensures module-level caching per coding guidelines.
127-174: LGTM! New public API function follows established patterns.The
mm_M1_16_K7168_N128function correctly uses@backend_requirementwith the N128 shape check and@flashinfer_apidecorator per coding guidelines. The comprehensive docstring accurately describes the Mistral Large 3 router GEMM constraints (128 experts, bfloat16 output).
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
benchmarks/bench_router_gemm.py (2)
32-43: Cold L2 benchmarking won't work with closure-captured tensors.Per the
bench_gpu_time_with_cudagraphdocumentation, GPU tensors must be passed viainput_args/input_kwargs(not captured in a closure) for cold L2 cache benchmarking to work. Currently,datais captured in the lambda, so the rotating buffer logic cannot detect or clone the tensors.Additionally, parameter names
reps/warmup_repsare misleading since they represent time in milliseconds, not iteration counts.♻️ Suggested fix to enable cold L2 benchmarking
-def bench_router_gemm(gemm_fn, data, M, N, K, reps=1000, warmup_reps=1000): +def bench_router_gemm(gemm_fn, data, M, N, K, repeat_time_ms=1000, warmup_time_ms=1000): measurements = bench_gpu_time_with_cudagraph( - lambda: gemm_fn(*data), - dry_run_time_ms=warmup_reps, - repeat_time_ms=reps, + lambda *args: gemm_fn(*args), + dry_run_time_ms=warmup_time_ms, + repeat_time_ms=repeat_time_ms, + input_args=data, )
41-43: Consider usinggemm_fn.__name__for readable output.Printing the function object directly will show the memory address (e.g.,
<function reference_torch at 0x...>). Usinggemm_fn.__name__would produce cleaner output.♻️ Suggested improvement
print( - f"Router GEMM function {gemm_fn} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}" + f"Router GEMM function {gemm_fn.__name__} | num_tokens={M}, num_experts={N}{add_desc} | Median execution time: {1000 * ms:.3f} us | TFLOPs/s: {flops:.3f}" )
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_router_gemm.py
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_router_gemm.py (1)
flashinfer/testing/utils.py (1)
bench_gpu_time_with_cudagraph(1283-1505)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
benchmarks/bench_router_gemm.py (1)
46-75: LGTM!The main function correctly iterates over token counts and both expert configurations (N128/bfloat16 for ML3, N256/float32 for DSV3), benchmarking the torch reference and FlashInfer implementations with both PDL launch modes.
📌 Description
This PR extends #2019 by @nvmbreughe for use with Mistral Large 3, which has similar MoE routing to DSV3. The differences are:
bfloat16(DSV3:float32)The code has been extended with a template argument for the output type, which is passed to the kernel.
Also, the explicit instantiations have been removed to make the code more concise, because they are handled by the call to the
LoopUnroller.Performance measurement for batch size 16 on a B200:
torch.nn.functional.linearmm_M1_16_K7168_N128🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
New Features
Tests
Chores
Benchmarks
✏️ Tip: You can customize this high-level summary in your review settings.