Skip to content

Conversation

@nv-yunzheq
Copy link
Contributor

@nv-yunzheq nv-yunzheq commented Nov 17, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a GPU Top‑K reduction utility and fast warp‑level Top‑K paths plus a high‑performance fused MoE NoAuxTc routing op with grouped/hierarchical top‑k, multi‑dtype support (fp32/fp16/bf16), deterministic tie handling, and a public Python entry point with JIT build helper.
  • Tests

    • Added an end‑to‑end CUDA test validating fused routing outputs against a reference implementation across dtypes and configurations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 17, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a CUDA Top‑K reduction utility and a DSV3 deepseek fused‑routing CUDA kernel with launcher and header, Python JIT build + bindings and package exports, a public NoAuxTc FFI entry, and a parameterized CUDA unit test validating the fused routing path.

Changes

Cohort / File(s) Summary
CUDA Top‑K Reduction Framework
csrc/fused_moe/moeTopKFuncs.cuh
New header defining tensorrt_llm::kernels::reduce_topk utilities: TopKRedType<T> (value/index packing, comparison key, warp reduce), TopKIdx, Sort<N> specializations (N=1..4), TOPK_SWAP macro, and device APIs reduceTopK / reduceTopKFunc supporting N up to 16 with architecture-guarded fast-paths.
NoAuxTc Kernel Implementation
csrc/fused_moe/noAuxTcKernels.cu
Adds deepseek_v3_topk_kernel (grouped/ungrouped), warp-level reductions, sigmoid+bias scoring, normalization, invokeNoAuxTc launcher with dtype dispatch and instantiations, and public NoAuxTc FFI entry with input validation and dtype permutations.
Public Kernel Header
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
New declaration of tensorrt_llm::kernels::invokeNoAuxTc template (InputT, BiasT, OutputT, IdxT) for kernel launcher.
Python fused routing module
flashinfer/fused_moe/fused_routing_dsv3.py
Adds lazy builder get_dsv3_fused_routing_module(), registers custom op flashinfer::NoAuxTc, exposes a SimpleNamespace with NoAuxTc, and a Python wrapper that forwards arguments to the compiled module.
Python package exports
flashinfer/fused_moe/__init__.py, flashinfer/dsv3_ops/__init__.py
Re-exports NoAuxTc in package public APIs (imports from fused_routing_dsv3 / fused_moe) and adds it to __all__.
JIT spec for build
flashinfer/jit/dsv3_optimizations.py, flashinfer/jit/__init__.py
Adds gen_dsv3_fused_routing_module() returning a JitSpec with source files and include paths, and re-exports it via flashinfer.jit.__init__.
Unit test
tests/model_optimizations/test_dsv3_fused_routing.py
New comprehensive, parameterized CUDA test that builds a ground-truth sigmoid+bias hierarchical top‑k reference, runs NoAuxTc, and validates selected groups/experts and output values across fp32/fp16/bf16 with tie-aware comparisons and detailed diagnostics.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python caller
    participant PyModule as fused_routing_dsv3.py
    participant JIT as JIT build (gen_dsv3_fused_routing_module)
    participant FFI as NoAuxTc FFI
    participant Launcher as invokeNoAuxTc
    participant Kernel as deepseek_v3_topk_kernel
    participant TopK as reduce_topk (moeTopKFuncs.cuh)

    Py->>PyModule: NoAuxTc(scores,bias,n_group,topk_group,topk,...)
    alt module not built
        PyModule->>JIT: build/compile JitSpec
        JIT-->>PyModule: compiled module
    end
    PyModule->>FFI: call NoAuxTc (mutates outputs)
    FFI->>Launcher: dtype dispatch -> invokeNoAuxTc<...>(...)
    Launcher->>Kernel: launch kernel on CUDA stream
    rect rgb(235,245,255)
      Kernel->>Kernel: compute sigmoid(scores + bias)
      Kernel->>TopK: call reduceTopK / reduceTopKFunc (warp-level)
      TopK-->>Kernel: top-k values & indices
      Kernel->>Kernel: normalize & store outputs
    end
    Kernel-->>FFI: outputs ready
    FFI-->>Py: return (tensors mutated)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay special attention to:
    • TopKRedType packing/compare semantics and deterministic tie-breaking.
    • Warp reduction correctness, lane shuffles, and arch-conditional fast-paths in reduceTopK variants.
    • Kernel launch config, shared-memory sizing, dtype instantiations, and error checks in noAuxTcKernels.cu.
    • Python JIT lazy-build thread-safety and custom op registration in fused_routing_dsv3.py.
    • Coverage and stability of tests/model_optimizations/test_dsv3_fused_routing.py across device/SM and dtypes.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • cyx-6
  • yongwww
  • wenscarl
  • aleozlx
  • joker-eph
  • nvmbreughe
  • yzh119

Poem

🐰 I hopped through kernels, bits in tow,

Packed values, indices β€” row by row.
Warp by warp I chased the best,
Sorted carrots, passed the test.
Routes now hum β€” a rabbit’s glow.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description check ⚠️ Warning Description is entirely the template with all sections empty (placeholders only); no concrete information about changes, rationale, or implementation details is provided. Fill in the Description section with specifics about what the PR adds (fused routing kernel for DSV3), why it's needed, and complete the Required sections. Verify pre-commit checks and tests are passing.
Docstring Coverage ⚠️ Warning Docstring coverage is 58.82% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive Title is vague and generic, using non-descriptive terms that don't clearly convey what specific optimization or routing kernel feature is being implemented. Revise title to be more specific about the core change, e.g., 'Add DSV3 fused top-K routing kernel' or 'Implement optimized DeepSeek V3 MoE routing'.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the routing capabilities for DeepSeek-V3 Mixture-of-Experts (MoE) models within the FlashInfer library. It introduces highly optimized CUDA kernels for Top-K selection and expert routing, designed to improve performance by leveraging GPU-specific features and efficient reduction strategies. The changes include new C++ CUDA files for core logic, a utility for architecture-specific compilation, and Python bindings to expose this functionality, alongside a dedicated test suite to ensure accuracy.

Highlights

  • New Top-K Reduction Kernel: Introduced moeTopKFuncs.cuh which provides highly optimized CUDA kernels for Top-K reduction, essential for efficient Mixture-of-Experts (MoE) routing.
  • DeepSeek-V3 Routing Kernel: Added noAuxTcKernels.cu implementing a specialized deepseek_v3_topk_kernel that handles both single and multi-group expert selection, improving performance for DeepSeek-V3 models.
  • CUDA Architecture Conditionals: Included archCondition.h to enable architecture-specific optimizations and checks, ensuring the kernels leverage the full capabilities of modern NVIDIA GPUs.
  • Python Integration: The new routing functionality is exposed to Python through flashinfer.fused_moe.fused_routing_dsv3.py and integrated into the flashinfer.dsv3_ops module, making it accessible for users.
  • Comprehensive Testing: A new test file test_dsv3_fused_routing.py has been added to validate the correctness of the fused routing operation against a reference implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimized CUDA kernels for DeepSeek V3-style MoE routing, including helper functions for Top-K reduction and a new JIT-compiled operator NoAuxTc. The changes are well-structured and include new Python wrappers and tests. My review focuses on potential issues in the CUDA kernel logic, leftover development comments, and opportunities to improve test coverage. I've identified a potential bug in a loop bound calculation, some leftover debug code, and suggested expanding the test suite to cover more execution paths.

RedType topK{value, idx};
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This //@todo comment seems to be a leftover from development. If the logic for actualK has been verified, this comment should be removed to improve code clarity.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (1)

1-69: Test logic matches kernel behavior; consider a few robustness and style tweaks

  • The reference dsv3_ref_check closely mirrors the CUDA kernel’s sigmoid β†’ bias β†’ group top‑K β†’ per‑expert top‑K β†’ normalization flow and looks correct for the tested configuration.

  • To avoid failures on environments without CUDA, you may want to guard the test with torch.cuda.is_available() and pytest.skip before calling get_compute_capability(torch.device("cuda")).

  • Currently only n_group=1, topk_group=1, topk=1 is exercised. If you expect multi‑group usage, adding a second parametrization (e.g., n_group > 1, topk_group > 1) would give better coverage of the is_multi_group path.

  • Ruff’s RUF005 suggestions around scores_shape[:-1] + [...] and similar concatenations are purely stylistic; if you want to quiet the linter, you can switch to iterable unpacking, e.g.:

    scores_with_bias.view(*scores_shape[:-1], n_group, scores_shape[-1] // n_group)

Overall, the test looks solid; these are optional cleanups/robustness improvements.

csrc/fused_moe/moeTopKFuncs.cuh (1)

47-69: Clarify/index-guard the 16-bit index packing (kMaxIdx = 65535).

TopKRedType encodes the index into the lower 16 bits of TypeCmp:

static constexpr int kMaxIdx = 65535;
// ...
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
// ...
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));

This silently assumes idx is in [0, 65535]. If a caller ever passes an index beyond 65535, the packed index will wrap and unpack will return an incorrect index without any diagnostics.

I suggest either:

  • enforcing the assumption with a debug-time check, or
  • documenting it clearly next to kMaxIdx, so future callers know the constraint.

For example:

-    static constexpr int kMaxIdx = 65535;
+    // Indices must be in [0, 65535]; stored in the lower 16 bits of compValIdx.
+    static constexpr int kMaxIdx = 65535;

and, optionally, in a debug build:

#ifdef DEBUG
    assert(idx >= 0 && idx <= kMaxIdx);
#endif
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 4aed50c and 5c7bbd8.

πŸ“’ Files selected for processing (10)
  • csrc/fused_moe/moeTopKFuncs.cuh (1 hunks)
  • csrc/fused_moe/noAuxTcKernels.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1 hunks)
  • flashinfer/dsv3_ops/__init__.py (1 hunks)
  • flashinfer/fused_moe/__init__.py (2 hunks)
  • flashinfer/fused_moe/fused_routing_dsv3.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/dsv3_optimizations.py (1 hunks)
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1 hunks)
  • tests/model_optimizations/test_dsv3_fused_routing.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
flashinfer/fused_moe/__init__.py (2)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (327-447)
  • NoAuxTc (327-327)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-20)
  • NoAuxTc (26-27)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (29-31)
  • get_stream (272-274)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-20)
  • NoAuxTc (26-27)
flashinfer/dsv3_ops/__init__.py (2)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (327-447)
  • NoAuxTc (327-327)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-20)
  • NoAuxTc (26-27)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
tests/model_optimizations/test_dsv3_fused_routing.py (3)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (327-447)
  • NoAuxTc (327-327)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-20)
  • NoAuxTc (26-27)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)
csrc/fused_moe/noAuxTcKernels.cu (3)
  • void (28-242)
  • invokeNoAuxTc (245-298)
  • invokeNoAuxTc (245-247)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_fused_routing_module (14-45)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1)
  • tensorrt_llm (25-33)
flashinfer/fused_moe/fused_routing_dsv3.py (3)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_fused_routing_module (14-45)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (327-447)
  • NoAuxTc (327-327)
πŸͺ› Clang (14.0.6)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h

[error] 20-20: 'cuda_bf16.h' file not found

(clang-diagnostic-error)

csrc/nv_internal/tensorrt_llm/kernels/archCondition.h

[error] 19-19: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 19-19: expected ';' after top level declarator

(clang-diagnostic-error)


[error] 19-19: expected identifier or '('

(clang-diagnostic-error)

πŸͺ› GitHub Actions: pre-commit
csrc/fused_moe/noAuxTcKernels.cu

[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

tests/model_optimizations/test_dsv3_fused_routing.py

[warning] 1-1: pre-commit: mixed line ending detected; hooks may modify files on re-run

include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h

[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.


[warning] 1-1: pre-commit: mixed line ending detected; hooks may modify files on re-run

csrc/nv_internal/tensorrt_llm/kernels/archCondition.h

[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

flashinfer/fused_moe/fused_routing_dsv3.py

[error] 4-4: ruff: F401 'torch' imported but unused


[error] 5-5: ruff: F401 'supported_compute_capability' imported but unused


[error] 1-1: ruff: F401 'torch' imported but unused


[error] 1-1: ruff: F401 'supported_compute_capability' imported but unused


[error] 1-1: end-of-file-fixer hook: files modified


[error] 1-1: Trailing whitespace: files modified


[error] 1-1: clang-format: some files were reformatted by clang-format hook

csrc/fused_moe/moeTopKFuncs.cuh

[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

πŸͺ› Ruff (0.14.5)
flashinfer/fused_moe/__init__.py

35-35: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tests/model_optimizations/test_dsv3_fused_routing.py

13-14: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


28-29: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (6)
flashinfer/jit/dsv3_optimizations.py (1)

14-45: JIT spec wiring for dsv3_fused_routing looks consistent; just verify includes

The new gen_dsv3_fused_routing_module spec and source list line up with the CUDA/C++ implementation and nv_internal helpers, and the include roots look reasonable for the tensorrt_llm headers.

Please just double‑check that your default JIT includes plus this extra_include_paths set are sufficient for all of:

  • tensorrt_llm/common/*
  • flashinfer/trtllm/fused_moe/noAuxTcKernels.h
  • any cutlass kernels headers

so NVCC/clang don’t fail with missing headers depending on the environment.

flashinfer/dsv3_ops/__init__.py (1)

2-7: NoAuxTc re‑export from fused_moe looks good

Re‑exporting NoAuxTc in flashinfer.dsv3_ops is consistent with the fused‑routing API surface and the tests that import from this namespace. No issues from the snippet.

flashinfer/jit/__init__.py (1)

79-84: JIT API re‑export is consistent

Re‑exporting gen_dsv3_fused_routing_module here matches the existing pattern for other JIT specs (e.g., gen_dsv3_router_gemm_module) and cleanly exposes the new fused routing module.

flashinfer/fused_moe/fused_routing_dsv3.py (1)

1-27: Remove unused imports to fix Ruff F401 errors

The lazy JIT build via get_dsv3_fused_routing_module and register_custom_op("flashinfer::NoAuxTc", mutates_args=...) registration are solid. However, verification confirms the unused imports are blocking pre-commit:

  • torch (line 4) is imported but never used
  • supported_compute_capability and backend_requirement (lines 7-8) are imported but never used

Remove these three unused imports:

-from flashinfer.jit import gen_dsv3_fused_routing_module
-import functools
-from types import SimpleNamespace
-import torch
-from flashinfer.utils import (
-    register_custom_op,
-    supported_compute_capability,
-    backend_requirement,
-)
+from flashinfer.jit import gen_dsv3_fused_routing_module
+import functools
+from types import SimpleNamespace
+from flashinfer.utils import register_custom_op

Then re-run pre-commit to resolve remaining formatting issues.

csrc/fused_moe/moeTopKFuncs.cuh (1)

1-286: The clang-format fix has been successfully applied.

The file csrc/fused_moe/moeTopKFuncs.cuh now passes clang-format validation (exit code 0 confirms no formatting errors). The pre-commit style check requirement has been satisfied.

csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)

1-114: Clang-format check passes.

The file has been successfully formatted and now conforms to clang-format style requirements. The pre-commit formatting check should no longer fail.

Comment on lines 17 to 23
#pragma once

namespace tensorrt_llm::kernels
{

namespace detail
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ› οΈ Refactor suggestion | 🟠 Major

Make the header self-contained by including <type_traits>.

std::bool_constant (used from Line 75 onward) is defined in <type_traits>, but this header does not include it directly. That makes archCondition.h fragile and dependent on transitive includes from its consumers.

I recommend explicitly including <type_traits> here so this header can be used in isolation without relying on include order.

-#pragma once
-
- namespace tensorrt_llm::kernels
+#pragma once
+
+#include <type_traits>
+
+namespace tensorrt_llm::kernels
 {

Also applies to: 75-96

🧰 Tools
πŸͺ› Clang (14.0.6)

[error] 19-19: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 19-19: expected ';' after top level declarator

(clang-diagnostic-error)


[error] 19-19: expected identifier or '('

(clang-diagnostic-error)

πŸ€– Prompt for AI Agents
In csrc/nv_internal/tensorrt_llm/kernels/archCondition.h around lines 17-23 (and
affecting lines 75-96), the header uses std::bool_constant but doesn't include
<type_traits>, making it rely on transitive includes; add a direct #include
<type_traits> near the top of the file (with other includes/pragma once) so the
header is self-contained and rebuilds cleanly without depending on include
order.

Comment on lines 83 to 96
template <int Arch>
struct is_match : std::bool_constant<is_device::value && detail::arch_info::mArch == Arch>
{
};

template <int Major>
struct is_major : std::bool_constant<is_device::value && detail::arch_info::mMajor == Major>
{
};

template <int Arch>
struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch>
{
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

is_compatible trait likely has incorrect semantics.

As written:

template <int Arch>
struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {};

you are treating the template parameter Arch both as a major (for is_major<Arch>) and as the mArch β€œSM code” (80, 90, 100, ...). For example, on SM90 (mMajor = 9, mArch = 90), is_compatible<90> will always be false because is_major<90> can never be true.

If the intent is β€œcompilation device SM is at least Arch (80, 90, 100, ...) and we are on device”, a cleaner definition is:

-template <int Arch>
-struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch>
-{
-};
+template <int Arch>
+struct is_compatible
+    : std::bool_constant<is_device::value && detail::arch_info::mArch >= Arch>
+{
+};

This makes is_compatible_v<90> read as β€œcompiled for device and mArch >= 90”, which matches the naming.

Comment on lines +35 to +37
from .fused_routing_dsv3 import ( # noqa: F401
NoAuxTc as NoAuxTc,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

🧩 Analysis chain

Remove unused noqa and simplify NoAuxTc re‑export

Ruff reports the # noqa: F401 as unused. You can simplify the import and satisfy the linter without changing behavior:

-from .fused_routing_dsv3 import (  # noqa: F401
-    NoAuxTc as NoAuxTc,
-)
+from .fused_routing_dsv3 import NoAuxTc

NoAuxTc is still exported via __all__, so there’s no functional change.


🏁 Script executed:

#!/bin/bash
# Check the current state of __init__.py around lines 35-37
cat flashinfer/fused_moe/__init__.py | head -50

Length of output: 1504


🏁 Script executed:

#!/bin/bash
# Search for __all__ definition in the file
grep -n "__all__" flashinfer/fused_moe/__init__.py

Length of output: 14


🏁 Script executed:

#!/bin/bash
# Search for NoAuxTc usage in the codebase
rg -n "NoAuxTc" --type py

Length of output: 1381


🏁 Script executed:

#!/bin/bash
# Check if fused_routing_dsv3 module exists and what it exports
fd "fused_routing_dsv3" --type py

Length of output: 169


🏁 Script executed:

#!/bin/bash
# Get the full __all__ list
sed -n '39,60p' flashinfer/fused_moe/__init__.py

Length of output: 529


Simplify NoAuxTc re-export by removing redundant pattern

The refactoring is correct. NoAuxTc is confirmed in __all__ at line 55, so the simplified import form maintains the same export behavior while removing unnecessary redundancy:

-from .fused_routing_dsv3 import (  # noqa: F401
-    NoAuxTc as NoAuxTc,
-)
+from .fused_routing_dsv3 import NoAuxTc

This change is safe and recommended.

πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .fused_routing_dsv3 import ( # noqa: F401
NoAuxTc as NoAuxTc,
)
from .fused_routing_dsv3 import NoAuxTc
🧰 Tools
πŸͺ› Ruff (0.14.5)

35-35: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

πŸ€– Prompt for AI Agents
In flashinfer/fused_moe/__init__.py around lines 35 to 37, the current re-export
unnecessarily aliases NoAuxTc via "NoAuxTc as NoAuxTc"; replace that with a
direct import of NoAuxTc from .fused_routing_dsv3 (e.g., "from
.fused_routing_dsv3 import NoAuxTc") to simplify the code while preserving the
existing export behavior already declared in __all__.

@jiahanc
Copy link
Collaborator

jiahanc commented Nov 17, 2025

cc @ChristinaZ for vis

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (8)
csrc/fused_moe/noAuxTcKernels.cu (3)

337-340: Clarify the topk_indices dtype error message.

You enforce topk_indices to be int32:

TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
    << "topk_indices must have the same dtype as scores";

but the message says β€œsame dtype as scores”, which is misleading (scores are float/bfloat16). Updating the message improves debuggability:

-  TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
-      << "topk_indices must have the same dtype as scores";
+  TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
+      << "topk_indices must have dtype int32";

94-103: Enforce topk ≀ MaxNumTopExperts to avoid out‑of‑bounds accesses.

topScores/topExperts and related buffers are sized with MaxNumTopExperts = 8:

float topScores[MaxNumTopExperts];      // bound of topk
int32_t topExperts[MaxNumTopExperts];

but are indexed with laneIdx < topk and kk < topk (via reduceTopK(..., topk)) and written into shared memory when laneIdx < topk. The host only checks topk <= 32, so topk > MaxNumTopExperts will:

  • Overrun topScores/topExperts inside reduceTopK (loop kk < actualK), and
  • Overrun them again in the final writeback (expertIdx = laneIdx < topk ? topExperts[laneIdx] : ...).

At minimum, the host entry should enforce the tighter bound:

-  TVM_FFI_ICHECK(topk <= 32)
-      << "topk should be smaller than or equal to 32 for now";  //@todo: remove this restriction
-                                                               // later
+  TVM_FFI_ICHECK(topk <= tensorrt_llm::kernels::MaxNumTopExperts)
+      << "topk must not exceed " << tensorrt_llm::kernels::MaxNumTopExperts
+      << " for the optimized DSv3 fused routing kernel";

If you intend to support larger topk, you’ll need to increase MaxNumTopExperts (and adjust shared/register usage and the reduceTopK template K) instead of only loosening the host check.

Also applies to: 138-141, 162-165, 201-212


169-182: Fix NumInterTopKPerThread and initialise intermediate buffers to avoid using uninitialized values.

In the multi‑warp, no‑groups path you currently have:

int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1;
float intermidiateScore[NumInterTopKPerThread];
int32_t intermidiateExpert[NumInterTopKPerThread];

for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) {
  int ii = i / WARP_SIZE;
  if (i < NumInterTopK) {
    intermidiateScore[ii] = smemInterTopScores[i];
    intermidiateExpert[ii] = smemInterTopExperts[i];
  } else {
    intermidiateScore[ii] = invalidScoreFloat;
    intermidiateExpert[ii] = MaxNumExperts - 1;
  }
}

Given NumInterTopK = NumExpertWarps * MaxNumTopExperts and MaxNumExperts ≀ 384, this formula makes NumInterTopKPerThread larger than ceil(NumInterTopK / WARP_SIZE), so some intermidiateScore[ii] / intermidiateExpert[ii] entries are never touched on some lanes, yet reduceTopKFunc is called with N = NumInterTopKPerThread and will read all of them, leading to undefined behaviour.

A safer and simpler formulation is:

-      int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1;
-      float intermidiateScore[NumInterTopKPerThread];
-      int32_t intermidiateExpert[NumInterTopKPerThread];
+      int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
+      float intermidiateScore[NumInterTopKPerThread];
+      int32_t intermidiateExpert[NumInterTopKPerThread];
+      #pragma unroll
+      for (int ii = 0; ii < NumInterTopKPerThread; ++ii) {
+        intermidiateScore[ii] = invalidScoreFloat;
+        intermidiateExpert[ii] = MaxNumExperts - 1;
+      }

This matches the intended β€œceil(NumInterTopK / WARP_SIZE)” per‑thread budget and ensures every slot has a well‑defined sentinel value before it is used in the subsequent reduction.

csrc/fused_moe/moeTopKFuncs.cuh (1)

142-159: Clean up TODOs in reduceTopK and document the invariants instead.

There are a couple of lingering TODOs:

for (int kk = 0; kk < actualK; ++kk)  //@todo: check if actualK is correct
...
topKBufferIdx[ii] = ii * kWARP_SIZE - 1;  //@todo: check if this is correct

The surrounding logic for actualK and the sentinel ii * kWARP_SIZE - 1 looks intentional and stable now. Leaving these TODOs in place suggests the implementation is still suspect, which makes maintenance harder.

Either remove these TODOs or replace them with brief explanatory comments (e.g., why actualK can be less than K, and how the -1 index acts as a safe sentinel that can never collide with a valid index) so future readers don’t have to re‑audit the algorithm.

Also applies to: 219-222

csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (2)

17-21: Include <type_traits> so the trait types are defined and the header is self‑contained.

This header uses std::bool_constant in the arch traits but does not include <type_traits>, so it relies on transitive includes and may fail to compile when included first in a TU. Please add an explicit include near the top:

-#pragma once
-
-namespace tensorrt_llm::kernels {
+#pragma once
+
+#include <type_traits>
+
+namespace tensorrt_llm::kernels {

This keeps the header robust and independent of include order.

Also applies to: 71-74


81-84: is_compatible mixes major and SM codes; semantics are likely incorrect.

is_compatible currently does:

template <int Arch>
struct is_compatible
    : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {};

Here Arch is used both as a major for is_major<Arch> and as an SM code (e.g., 80, 90, 100) for mArch >= Arch. On SM90 (mMajor = 9, mArch = 90), is_compatible<90> will always be false because is_major<90> can never be true.

If the intended meaning is β€œcompiled for device and mArch >= Arch (where Arch is 80/90/100...)”, consider:

 template <int Arch>
-struct is_compatible
-    : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {};
+struct is_compatible
+    : std::bool_constant<is_device::value && detail::arch_info::mArch >= Arch> {};

This makes is_compatible_v<90> read as β€œdevice build and SM >= 90”, which matches the name.

include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)

27-31: Fix invokeNoAuxTc declaration to match the .cu definition (missing launch_with_pdl).

The implementation in csrc/fused_moe/noAuxTcKernels.cu takes an extra bool launch_with_pdl parameter, but the header omits it. This will cause conflicting declarations/ODR issues once both are seen.

Please align the declaration with the definition:

 template <typename InputT, typename BiasT, typename OutputT, typename IdxT>
-void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices,
-                   int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
-                   int64_t const topk_group, int64_t const topk, double const routed_scaling_factor,
-                   cudaStream_t const stream = 0);
+void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices,
+                   int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
+                   int64_t const topk_group, int64_t const topk,
+                   double const routed_scaling_factor, bool const launch_with_pdl,
+                   cudaStream_t const stream = 0);

(The default for stream can stay only in the declaration.)


20-21: Remove unnecessary <cuda_bf16.h> from this public header and include it only where needed.

This header itself never references __nv_bfloat16, but including <cuda_bf16.h> here causes clang errors in environments without CUDA headers and increases header dependencies.

Since the BF16 specializations are instantiated in csrc/fused_moe/noAuxTcKernels.cu, it’s cleaner to:

-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
+#include <cuda_fp16.h>

and add:

#include <cuda_bf16.h>

near the top of csrc/fused_moe/noAuxTcKernels.cu before the BF16 instantiations.

This keeps the public header lighter and avoids toolchain issues when CUDA headers are not configured for host-only builds.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 5c7bbd8 and 521fe01.

πŸ“’ Files selected for processing (5)
  • csrc/fused_moe/moeTopKFuncs.cuh (1 hunks)
  • csrc/fused_moe/noAuxTcKernels.cu (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1 hunks)
  • flashinfer/fused_moe/fused_routing_dsv3.py (1 hunks)
  • include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/fused_routing_dsv3.py
🧰 Additional context used
🧠 Learnings (1)
πŸ“š Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fused_moe/noAuxTcKernels.cu
🧬 Code graph analysis (3)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (29-31)
  • get_stream (272-274)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-40)
  • NoAuxTc (47-68)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1)
  • tensorrt_llm (25-33)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)
csrc/fused_moe/noAuxTcKernels.cu (3)
  • void (26-218)
  • invokeNoAuxTc (221-272)
  • invokeNoAuxTc (221-224)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
πŸͺ› Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h

[error] 19-19: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 19-19: expected ';' after top level declarator

(clang-diagnostic-error)


[error] 19-19: expected identifier or '('

(clang-diagnostic-error)

include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h

[error] 20-20: 'cuda_bf16.h' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
csrc/fused_moe/moeTopKFuncs.cuh (1)

152-152: Resolve or replace TODOs with explicit invariants / comments

There are two //@todo notes that look like leftover uncertainty:

  • Line 152: //@todo: check if actualK is correct in the scalar reduceTopK loop.
  • Line 221: //@todo: check if this is correct for initializing topKBufferIdx[ii] = ii * kWARP_SIZE - 1.

If these code paths have now been validated (e.g., by parity with the reference implementation and tests), consider:

  • Removing the //@todo markers, and
  • Replacing them with brief comments documenting the intended invariants (e.g., actualK <= K and why the -1 sentinel index is safe when packing/unpacking).

That keeps future readers from wondering whether there is still an open correctness question here.

Also applies to: 221-221

🧹 Nitpick comments (4)
csrc/fused_moe/moeTopKFuncs.cuh (4)

35-84: TopKRedType packing/reduction logic looks solid, with only minor aliasing caveats

The value/index packing, deterministic tie-breaking via kMaxIdx - idx, and warp reduction path (including the fast redux.sync.max.u32 specialization) all look coherent and in line with typical Top‑K implementations. The only nit is that the reinterpret_cast-based conversions in makeCmpVal/unpack rely on CUB’s Traits<T> patterns and are a bit aggressive from a strict C++ aliasing perspective; if this ever needs to be made more portable, using memcpy or an explicitly-sized intermediate would be safer, but it’s probably fine given CUDA/CUB constraints.


142-159: Guard against actualK > K or document the invariant for reduceTopK (scalar input)

reduceTopK writes to out[kk]/outIdx[kk] for kk < actualK, but the arrays are sized as K. If actualK can ever exceed K at runtime, this will lead to out‑of‑bounds writes on the stack. If the intended contract is 0 < actualK <= K, consider either:

  • Clamping: int kEff = min(actualK, K); and looping to kEff, or
  • Adding an explicit runtime check / debug assert, or at least documenting clearly that callers must enforce actualK <= K.

This also applies to the higher‑level overloads that forward actualK unchanged.


161-193: Clarify N constraint message and reuse actualK invariant for reduceTopKFunc

The reduceTopKFunc<K, Type, N, IsSorted> implementation looks consistent with the scalar reduceTopK (shift-out of the previous max via update, and incremental warp reductions). Two minor points:

  • The static assert static_assert(N < 5, "Only support candidates number less than or equal to 128"); is slightly confusing because the template constraint is per-thread (N <= 4), while the message refers to the per-warp total (≀ 4 * 32). Consider tightening the message to reflect that ("N per thread must be <= 4 (<= 128 per warp)") to avoid ambiguity.
  • The same potential actualK > K hazard exists here as in the scalar version, since out/outIdx are sized by K but indexed up to actualK - 1. Any solution you adopt for the scalar overload (clamping, assert, or documenting a hard invariant) should be mirrored here for consistency.

195-247: Multi-chunk (N > 4) path: lane indexing and numResults usage could be tightened

The multi-chunk overload reduceTopK<K, Type, N> generally makes sense (processing 4-candidate blocks and then reducing partial Top‑K results), but a few details are worth tightening:

  • Lane index source: laneIdx is computed as threadIdx.x % kWARP_SIZE. Since you already have a cg::thread_block_tile<kWARP_SIZE> warp, it would be more robust to use warp.thread_rank() instead of assuming the tile corresponds to threadIdx.x modulo 32. This avoids surprises if the warp tile is ever created from a nontrivial block layout.
  • Unused numResults slots: topKBufferValue/topKBufferIdx are sized as numResults, but only indices 0 and (in the last loop) 1 are ever written. For combinations where numResults > 2 (e.g., larger K with N = 12 or 16), the extra slots remain at the sentinel minValue/ii * kWARP_SIZE - 1. That’s functionally safe but suggests either:
    • You never intend to support such (N, K) combinations and could enforce it with a static_assert on numLoops * K (or directly on K), or
    • The distribution logic could be extended to actually populate all numResults entries so that no partial Top‑K candidates are dropped when numLoops * K > 64.

Clarifying which of these is intended (and possibly adding a static_assert or brief comment) would make the behavior easier to reason about.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 521fe01 and 51812f1.

πŸ“’ Files selected for processing (1)
  • csrc/fused_moe/moeTopKFuncs.cuh (1 hunks)

@yzh119
Copy link
Collaborator

yzh119 commented Nov 18, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !143 has been created, and the CI pipeline #38685617 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/model_optimizations/test_dsv3_fused_routing.py (1)

47-49: Narrow test coverage for multi-group routing.

The test currently only exercises the single-group case (n_group=1). As noted in previous reviews, coverage should be expanded to include multi-group routing paths, different topk values, and varying num_experts configurations.

🧹 Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (2)

8-44: Consider documenting the DSv3-specific constant.

The hardcoded k=2 on line 17 appears to be a DSv3-specific constant for the group scoring mechanism. Adding a comment or constant definition would improve clarity for future maintainers.

Optionally, you can also address the static analysis hints by replacing list concatenation with iterable unpacking on lines 15 and 31:

-            scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]
+            [*scores_shape[:-1], n_group, scores_shape[-1] // n_group]

61-62: Consider using torch.empty for output buffers.

Since topk_values and topk_indices are output tensors that will be fully overwritten by the kernel, using torch.empty instead of torch.randn would better communicate intent and avoid unnecessary random initialization.

-    topk_values = torch.randn(num_tokens, topk, device="cuda", dtype=torch.float32)
-    topk_indices = torch.randn(num_tokens, topk, device="cuda").to(torch.int32)
+    topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=torch.float32)
+    topk_indices = torch.empty(num_tokens, topk, device="cuda", dtype=torch.int32)
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 51812f1 and b4af548.

πŸ“’ Files selected for processing (1)
  • tests/model_optimizations/test_dsv3_fused_routing.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/model_optimizations/test_dsv3_fused_routing.py (3)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (19-40)
  • NoAuxTc (47-68)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (303-445)
  • NoAuxTc (303-305)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
πŸͺ› Ruff (0.14.5)
tests/model_optimizations/test_dsv3_fused_routing.py

15-15: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


31-31: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

πŸ”‡ Additional comments (4)
tests/model_optimizations/test_dsv3_fused_routing.py (4)

1-5: LGTM!

All imports are necessary and correctly used throughout the test.


50-53: LGTM!

The compute capability check correctly restricts the test to SM100 hardware as required for DSv3 fused routing.


64-74: LGTM!

The kernel call correctly passes all required parameters and uses the PDL launch path as intended.


76-80: LGTM!

The validation correctly compares kernel outputs against the reference implementation with appropriate tolerances for floating-point values and proper type casting for index comparison.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38685617: 4/18 passed

@nv-yunzheq nv-yunzheq changed the title Optimized routing kernels dskv3 [DSV3] Optimized routing kernels dsv3 Nov 19, 2025
@nv-yunzheq
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !143 has been updated with latest changes, and the CI pipeline #38749791 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
csrc/fused_moe/noAuxTcKernels.cu (2)

16-23: Enforce topk and n_group bounds at the C++ boundary to avoid kernel OOB

The kernel hardcodes several capacity limits:

  • static constexpr int MaxNumTopExperts = 8;
  • static constexpr int MaxNumTopGroups = 4;
  • int constexpr NumWarps = MaxNumExperts / WARP_SIZE; with MaxNumExperts instantiated as NumDeepseekExperts = 256 for the grouped path, so NumWarps == 8.
  • float topScores[MaxNumTopExperts]; int32_t topExperts[MaxNumTopExperts];
  • __shared__ float smemGroupScores[NumWarps];

But the host FFI entry only checks:

TVM_FFI_ICHECK(n_group <= 32) << "...";
TVM_FFI_ICHECK(topk <= 32) << "...";

Two concrete issues follow:

  1. Grouped path n_group > 8
    For the grouped instantiation (UseGroups = true with MaxNumExperts = NumDeepseekExperts):

    • smemGroupScores has length 8, but group selection at warp 0 reads smemGroupScores[laneIdx] for laneIdx < numGroup. With n_group > 8, this reads past the shared array.
    • Only 8 warps exist, so groups beyond 8 also never get a proper group score.
  2. topk > MaxNumTopExperts
    The final selection uses topScores[MaxNumTopExperts] / topExperts[MaxNumTopExperts] and laneIdx < topk to index and write outputs. Allowing topk up to 32 while the arrays are sized to 8 is inconsistent with the kernel’s capacity and risks undefined behavior, depending on reduce_topk’s implementation.

Python’s _check_dsv3_fused_routing_supported currently restricts DSv3 to topk <= 8, but this C++ NoAuxTc can be reached from other frontends and should enforce the same invariants.

Consider tightening the host checks along these lines:

-  TVM_FFI_ICHECK(n_group <= 32)
-      << "n_group should be smaller than or equal to 32 for now";
-  TVM_FFI_ICHECK(topk <= 32)
-      << "topk should be smaller than or equal to 32 for now";
+  TVM_FFI_ICHECK(n_group <= NumDeepseekExperts / WARP_SIZE)
+      << "n_group must be <= " << (NumDeepseekExperts / WARP_SIZE)
+      << " for the optimized DSv3 fused routing kernel";
+  TVM_FFI_ICHECK(topk <= MaxNumTopExperts)
+      << "topk must be <= " << MaxNumTopExperts
+      << " for the optimized DSv3 fused routing kernel";

and then keeping the more relaxed 32 limit only if/when a generic fallback path (without these static buffers) is implemented.

Also applies to: 96-105, 120-145, 204-215, 323-329


340-343: Clarify topk_indices dtype error message to match the check

Here:

TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
    << "topk_indices must have the same dtype as scores";

the check enforces that topk_indices is int32, not β€œsame dtype as scores”. The mismatched message can make debugging confusing (especially since the Python docstring also mentions int64 as an option).

Suggest tightening the message:

-  TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
-      << "topk_indices must have the same dtype as scores";
+  TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
+      << "topk_indices must have dtype int32";
🧹 Nitpick comments (3)
flashinfer/fused_moe/fused_routing_dsv3.py (1)

12-23: Tighten Python precheck to mirror C++ kernel invariants and silence unused-arg lint

The configuration checks are aligned with the optimized kernel’s assumptions (group product limits, topk ≀ 8, experts/group ≀ 32), but two small gaps remain:

  • The C++ entry point enforces num_experts % n_group == 0; encoding the same constraint here (e.g., via scores.shape[1] % n_group == 0) would fail fast in Python instead of surfacing as a TVM assertion later.
  • For multi‑group configurations, the kernel effectively only supports up to NumDeepseekExperts / WARP_SIZE groups (8 for 256 experts). If you intend to keep that invariant, consider also enforcing n_group <= 8 here so Python callers can’t create configs that the CUDA path can’t handle.

Also, _check_dsv3_fused_routing_supported must accept the full NoAuxTc signature for backend_requirement, but bias, routed_scaling_factor, topk_values, topk_indices, and launch_with_pdl are unused. If Ruff’s ARG001 is noisy, you can explicitly mark them as unused, e.g.:

    # Unused but required by backend_requirement signature
    _ = (bias, routed_scaling_factor, topk_values, topk_indices, launch_with_pdl)

Also applies to: 40-79

tests/model_optimizations/test_dsv3_fused_routing.py (2)

373-449: Remove or use the unused topk_values_kernel parameter in validate_and_debug

validate_and_debug takes topk_values_kernel but never uses it; all logic is based on topk_indices_kernel and the ground-truth object. That’s triggering Ruff’s ARG001 and can confuse readers.

You can either:

  • Drop the parameter entirely:
-def validate_and_debug(ground_truth, topk_indices_kernel, topk_values_kernel):
+def validate_and_debug(ground_truth, topk_indices_kernel):
@@
-    all_valid, tokens_with_different_experts = validate_and_debug(
-        ground_truth, topk_indices, sorted_vals
-    )
+    all_valid, tokens_with_different_experts = validate_and_debug(
+        ground_truth, topk_indices
+    )

or

  • Actually use topk_values_kernel in the debug printout (e.g., printing kernel values alongside indices for failing tokens).

Either approach will resolve the unused-argument warning and make the intent clearer.


513-593: Optionally gate the test on supported compute capability to avoid hard failures

NoAuxTc is decorated with backend_requirement and a supported_compute_capability([89, 90, 100, 103, 120, 121]) common check, so calling it on GPUs outside that set will raise a BackendSupportedError. Right now, the test unconditionally constructs CUDA tensors and calls NoAuxTc, which will cause a test failure rather than a skip on unsupported hardware.

Given you already have a commented import for get_compute_capability, you could do something like:

-import pytest
-from flashinfer.dsv3_ops import NoAuxTc
-# from flashinfer.utils import get_compute_capability
+import pytest
+from flashinfer.dsv3_ops import NoAuxTc
+from flashinfer.utils import get_compute_capability
@@
 def test_dsv3_fused_routing_op(
     num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type
 ):
+    # Skip on unsupported compute capability
+    cc = get_compute_capability()
+    if not NoAuxTc.is_compute_capability_supported(cc):
+        pytest.skip(f"NoAuxTc not supported on compute capability {cc}")

This keeps the test suite green on older or different GPUs while still enforcing correctness where the kernel is intended to run.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between b4af548 and cd0a0ef.

πŸ“’ Files selected for processing (3)
  • csrc/fused_moe/noAuxTcKernels.cu (1 hunks)
  • flashinfer/fused_moe/fused_routing_dsv3.py (1 hunks)
  • tests/model_optimizations/test_dsv3_fused_routing.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
πŸ“š Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fused_moe/noAuxTcKernels.cu
🧬 Code graph analysis (3)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
csrc/tvm_ffi_utils.h (2)
  • encode_dlpack_dtype (29-31)
  • get_stream (272-274)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (90-111)
  • NoAuxTc (119-194)
flashinfer/fused_moe/fused_routing_dsv3.py (4)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_fused_routing_module (14-45)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • backend_requirement (856-1131)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (306-448)
  • NoAuxTc (306-308)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (90-111)
  • NoAuxTc (119-194)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (306-448)
  • NoAuxTc (306-308)
πŸͺ› Ruff (0.14.5)
flashinfer/fused_moe/fused_routing_dsv3.py

15-15: Unused function argument: bias

(ARG001)


19-19: Unused function argument: routed_scaling_factor

(ARG001)


20-20: Unused function argument: topk_values

(ARG001)


21-21: Unused function argument: topk_indices

(ARG001)


22-22: Unused function argument: launch_with_pdl

(ARG001)


45-48: Avoid specifying long messages outside the exception class

(TRY003)


56-58: Avoid specifying long messages outside the exception class

(TRY003)


60-63: Avoid specifying long messages outside the exception class

(TRY003)


65-68: Avoid specifying long messages outside the exception class

(TRY003)


71-73: Avoid specifying long messages outside the exception class

(TRY003)


75-77: Avoid specifying long messages outside the exception class

(TRY003)

tests/model_optimizations/test_dsv3_fused_routing.py

373-373: Unused function argument: topk_values_kernel

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (3)
flashinfer/fused_moe/fused_routing_dsv3.py (1)

82-115: JIT build + custom-op registration pattern looks solid

The @functools.cache around get_dsv3_fused_routing_module() plus the nested, registered NoAuxTc custom op that forwards to module.NoAuxTc gives you a single build per process and a clear mutation contract on topk_values / topk_indices. Argument order and types match the C++ NoAuxTc entry point, so the wiring looks correct.

csrc/fused_moe/noAuxTcKernels.cu (1)

223-275: Dispatch heuristics for single-/multi-group kernels look consistent

The invokeNoAuxTc logic around is_single_group vs is_multi_group and the choice of MaxNumExperts (128, 256, 384) matches the Python precheck and the test ranges (num_experts ∈ {256, 384}, topk ≀ 8, group-wise capacity constraints). The use of cudaLaunchKernelEx with programmatic stream serialization gated by launch_with_pdl also looks correct for DSv3 usage.

tests/model_optimizations/test_dsv3_fused_routing.py (1)

125-224: Ground-truth implementation closely matches the kernel’s routing spec

The DSv3RoutingGroundTruth class mirrors the documented algorithm (sigmoid + bias, per-group top-2 sums, group top‑k, masked expert top‑k, normalization, and final sorting) in float32 and incorporates per‑dtype tie thresholds. This gives strong coverage of both selection and normalization behavior and is a good reference for future kernel changes.

Comment on lines +118 to +193
@backend_requirement({}, common_check=_check_dsv3_fused_routing_supported)
def NoAuxTc(
scores: torch.Tensor,
bias: torch.Tensor,
n_group: int,
topk_group: int,
topk: int,
routed_scaling_factor: float,
topk_values: torch.Tensor,
topk_indices: torch.Tensor,
launch_with_pdl: bool = True,
) -> None:
"""Fused expert routing with top-k selection for DeepSeek-V3.
This function performs a highly optimized fused routing operation specifically
designed for DeepSeek-V3's Mixture of Experts (MoE) architecture with grouped
expert routing and no auxiliary loss. It combines score computation, expert
selection, and normalization into a single kernel operation.
The routing algorithm consists of the following steps:
1. Compute biased scores: sigmoid(scores) + bias for each expert
2. Group experts and compute group scores (sum of top-2 experts per group)
3. Select top-k groups based on group scores
4. From selected groups, select top-k experts based on biased scores
5. Normalize selected expert weights: sigmoid_scores / sum(sigmoid_scores) * scale
Args:
scores (torch.Tensor): Input routing scores of shape (num_tokens, num_experts).
The logits produced by the router network before activation. Supports
bfloat16, float16, or float32.
bias (torch.Tensor): Per-expert routing bias of shape (num_experts,). Added to
sigmoid-activated scores to produce biased scores for expert selection.
Must match the dtype of scores.
n_group (int): Number of expert groups. Experts are divided into groups for
hierarchical selection. Typical value is 8 for DeepSeek-V3 with 256 experts
(32 experts per group).
topk_group (int): Number of top groups to select. Must be <= n_group. Typical
value is 4, meaning the top 4 groups are selected from 8 groups.
topk (int): Number of top experts to select per token. Must be <= num_experts.
Typical value is 8, meaning 8 experts are routed per token.
routed_scaling_factor (float): Scaling factor applied to normalized expert
weights. The final output weights are:
sigmoid_scores / sum(sigmoid_scores) * routed_scaling_factor.
topk_values (torch.Tensor): Pre-allocated output tensor of shape
(num_tokens, topk) for the normalized expert weights. Must be float32.
This tensor is mutated in-place.
topk_indices (torch.Tensor): Pre-allocated output tensor of shape
(num_tokens, topk) for the selected expert indices. Must be int32 or int64.
This tensor is mutated in-place.
launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent
Device-side Launch. Defaults to True.
Returns:
None: Results are written directly to `topk_values` and `topk_indices` tensors.
Note:
- The kernel uses float32 internally for all computations to ensure numerical
precision, even when inputs are float16 or bfloat16.
- This implementation is optimized for Hopper (compute capability 90, 100),
Ada (compute capability 89), and Blackwell (compute capability 120, 121)
architectures.
- The "NoAux" prefix indicates this variant does not compute auxiliary losses
(e.g., load balancing loss) during routing.
- The "Tc" suffix indicates the use of Tensor Core optimizations in the
underlying CUDA kernel.
"""
get_dsv3_fused_routing_module().NoAuxTc(
scores,
bias,
n_group,
topk_group,
topk,
routed_scaling_factor,
topk_values,
topk_indices,
launch_with_pdl,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Docstring is inconsistent with actual dtype requirements of the kernel

The Python docstring states that:

  • topk_values β€œMust be float32”.
  • topk_indices β€œMust be int32 or int64”.

However, the C++ entry point (flashinfer::trtllm_dsv3_fused_routing::NoAuxTc) enforces:

  • topk_values.dtype() == scores.dtype() (via TVM_FFI_ICHECK(topk_values.dtype() == data_type)).
  • topk_indices dtype encoded as int32_code (only int32 is accepted).

The tests also construct topk_values with dtype=data_type and topk_indices as torch.int32, matching the C++ checks, not the docstring.

To avoid confusing users and spurious runtime errors, please update the docstring to match the real contract, e.g.:

-        topk_values (torch.Tensor): Pre-allocated output tensor of shape
-            (num_tokens, topk) for the normalized expert weights. Must be float32.
+        topk_values (torch.Tensor): Pre-allocated output tensor of shape
+            (num_tokens, topk) for the normalized expert weights. Must have the
+            same dtype as `scores`. This tensor is mutated in-place.
@@
-        topk_indices (torch.Tensor): Pre-allocated output tensor of shape
-            (num_tokens, topk) for the selected expert indices. Must be int32 or int64.
-            This tensor is mutated in-place.
+        topk_indices (torch.Tensor): Pre-allocated output tensor of shape
+            (num_tokens, topk) for the selected expert indices. Must be int32.
+            This tensor is mutated in-place.
πŸ€– Prompt for AI Agents
In flashinfer/fused_moe/fused_routing_dsv3.py around lines 118 to 193, the
docstring incorrectly states that topk_values must be float32 and topk_indices
can be int32 or int64; update it to reflect the C++ entry-point contract:
topk_values must have the same dtype as scores (i.e., use scores.dtype) and
topk_indices must be int32 only (torch.int32). Change the Args section lines for
topk_values and topk_indices to state these exact requirements and note that the
kernel checks enforce them.

@aleozlx
Copy link
Collaborator

aleozlx commented Nov 19, 2025

there's also alternative path csrc/nv_internal/cpp/kernels to store cpp/h files, for the sake of minimizing difference, worth considering for imported kernels

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only posted comments that are more of nitpicking
adding my approval as vote up

Copy link
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for the contribution

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
tests/model_optimizations/test_dsv3_fused_routing.py (1)

455-486: Add CUDA availability guard and consider parametrizing launch modes

The test lacks a @pytest.mark.skipif decorator for CUDA availability. Currently, it will error in non-CUDA environments before pytest can properly report a skip. Add:

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for DSv3 NoAuxTc tests")

immediately before the existing @pytest.mark.parametrize decorators at line 421.

Additionally, launch_with_pdl=True is hardcoded at line 484. Consider parametrizing it over [True, False] to exercise both launch paths:

@pytest.mark.parametrize("launch_with_pdl", [True, False])

This ensures test coverage of both runtime modes.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between cd0a0ef and f45018b.

πŸ“’ Files selected for processing (1)
  • tests/model_optimizations/test_dsv3_fused_routing.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
  • NoAuxTc (90-111)
  • NoAuxTc (119-194)
csrc/fused_moe/noAuxTcKernels.cu (2)
  • NoAuxTc (306-448)
  • NoAuxTc (306-308)
πŸͺ› Ruff (0.14.5)
tests/model_optimizations/test_dsv3_fused_routing.py

333-333: Unused function argument: topk_values_kernel

(ARG001)


350-350: Unpacked variable reason is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

πŸ”‡ Additional comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (2)

1-117: High-level test documentation is excellent

The top-level docstring is clear and accurately describes both the DSv3 routing algorithm and the two-stage validation strategy, including dtype-dependent thresholds and tie semantics. This will be very helpful for future maintainers when debugging failures.


488-501: Overall test structure and coverage look strong

Aside from the issues noted above, the test structure is solid:

  • Sorting outputs before comparison gives deterministic expert ordering.
  • Two-stage validation (selection, then values) with dtype-specific tolerances is well thought out.
  • Parameterization over num_tokens, num_experts, topk, n_group, topk_group, data_type, and bias_type provides broad coverage, especially for multi-group routing.

Once the masking, tie-checking, device, and skip-condition issues are addressed, this should be a very robust regression test for the NoAuxTc path.

Comment on lines +131 to +223
def __init__(
self, scores, bias, n_group, topk_group, topk, routed_scaling_factor, data_type
):
self.num_tokens = scores.shape[0]
self.num_experts = scores.shape[1]
self.n_group = n_group
self.topk_group = topk_group
self.topk = topk
self.routed_scaling_factor = routed_scaling_factor
self.experts_per_group = self.num_experts // n_group
self.device = scores.device

# Set thresholds based on data type
if data_type == torch.bfloat16:
self.expert_tie_threshold = 1.0
self.group_tie_threshold = 0.05
elif data_type == torch.float16:
self.expert_tie_threshold = 0.5
self.group_tie_threshold = 0.02
else: # float32
self.expert_tie_threshold = 0.2
self.group_tie_threshold = 0.01

# Convert to float32 to match kernel's internal computation
scores_f32 = scores.to(torch.float32)
bias_f32 = bias.to(torch.float32)

# Compute sigmoid and biased scores
self.sigmoid_scores = torch.sigmoid(scores_f32)
self.biased_scores = self.sigmoid_scores + bias_f32

# Reshape for group-wise operations
scores_reshaped = self.biased_scores.view(
self.num_tokens, n_group, self.experts_per_group
)

# Compute group scores (sum of top-2 experts per group)
top2_per_group = torch.topk(
scores_reshaped, k=2, dim=-1, largest=True, sorted=True
)[0]
self.group_scores = torch.sum(top2_per_group, dim=-1)

# Reference group selection
_, self.ref_group_indices = torch.topk(
self.group_scores, k=topk_group, dim=-1, largest=True, sorted=True
)

# Identify tied groups for each token
self.tied_group_sets = []
for token_idx in range(self.num_tokens):
tied_groups = set()
group_scores_token = self.group_scores[token_idx]

for g1 in range(n_group):
for g2 in range(g1 + 1, n_group):
score_diff = abs(group_scores_token[g1] - group_scores_token[g2])
if score_diff < self.group_tie_threshold:
tied_groups.add(g1)
tied_groups.add(g2)

self.tied_group_sets.append(tied_groups)

# Compute reference expert selection and normalization
self.ref_expert_indices = torch.zeros(
self.num_tokens, topk, dtype=torch.long, device=self.device
)
self.ref_expert_values = torch.zeros(
self.num_tokens, topk, dtype=torch.float32, device=self.device
)

for token_idx in range(self.num_tokens):
# Create mask for selected groups
group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device)
group_mask[self.ref_group_indices[token_idx]] = 1.0
expert_mask = group_mask.repeat_interleave(self.experts_per_group)

# Mask and select top-k experts
masked_biased_scores = self.biased_scores[token_idx] * expert_mask
_, topk_idx = torch.topk(
masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True
)

# Normalize selected experts
selected_sigmoid_scores = self.sigmoid_scores[token_idx][topk_idx]
score_sum = selected_sigmoid_scores.sum() + 1e-20
normalized_scores = (
selected_sigmoid_scores / score_sum * routed_scaling_factor
)

# Sort by normalized scores
sorted_vals, sorted_idx = torch.sort(normalized_scores, descending=True)
self.ref_expert_values[token_idx] = sorted_vals
self.ref_expert_indices[token_idx] = topk_idx[sorted_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Masking non-selected groups with 0 can pick experts from disallowed groups

In both the reference implementation and _get_topk_experts_from_groups, experts from non-selected groups are masked by multiplying with 0:

  • Lines 201–211: masked_biased_scores = self.biased_scores[token_idx] * expert_mask
  • Lines 325–327: same pattern

Because biased_scores can be negative (sigmoid in (0,1) plus possibly negative bias), zeroing out non-selected groups can make those β€œmasked” experts larger than valid but negative scores in selected groups. That lets torch.topk choose experts from groups that are supposed to be excluded, which breaks the algorithm described in the docstring.

To enforce the β€œonly from selected groups” constraint regardless of sign, it’s safer to mask with -inf instead of 0, e.g.:

@@
-            # Create mask for selected groups
-            group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device)
-            group_mask[self.ref_group_indices[token_idx]] = 1.0
-            expert_mask = group_mask.repeat_interleave(self.experts_per_group)
-
-            # Mask and select top-k experts
-            masked_biased_scores = self.biased_scores[token_idx] * expert_mask
-            _, topk_idx = torch.topk(
-                masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True
-            )
+            # Create mask for selected groups
+            group_mask = torch.zeros(n_group, dtype=torch.bool, device=self.device)
+            group_mask[self.ref_group_indices[token_idx]] = True
+            expert_mask = group_mask.repeat_interleave(self.experts_per_group)
+
+            # Mask and select top-k experts: force non-selected groups to -inf
+            masked_biased_scores = self.biased_scores[token_idx].clone()
+            masked_biased_scores[~expert_mask] = float("-inf")
+            _, topk_idx = torch.topk(
+                masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True
+            )
@@
-        # Create mask for specified groups
-        group_mask = torch.zeros(self.n_group, dtype=torch.float32, device=self.device)
-        for g in groups:
-            group_mask[g] = 1.0
-        expert_mask = group_mask.repeat_interleave(self.experts_per_group)
-
-        # Mask and select top-k experts
-        masked_biased_scores = self.biased_scores[token_idx] * expert_mask
-        _, topk_idx = torch.topk(
-            masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True
-        )
-
-        return set(topk_idx.tolist())
+        # Create mask for specified groups
+        group_mask = torch.zeros(self.n_group, dtype=torch.bool, device=self.device)
+        for g in groups:
+            group_mask[g] = True
+        expert_mask = group_mask.repeat_interleave(self.experts_per_group)
+
+        # Mask and select top-k experts, restricting strictly to these groups
+        masked_biased_scores = self.biased_scores[token_idx].clone()
+        masked_biased_scores[~expert_mask] = float("-inf")
+        _, topk_idx = torch.topk(
+            masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True
+        )
+
+        return set(topk_idx.tolist())

This keeps the ground truth aligned with the routing algorithm even when biased scores are negative.

Also applies to: 313-328

Comment on lines +178 to +191
# Identify tied groups for each token
self.tied_group_sets = []
for token_idx in range(self.num_tokens):
tied_groups = set()
group_scores_token = self.group_scores[token_idx]

for g1 in range(n_group):
for g2 in range(g1 + 1, n_group):
score_diff = abs(group_scores_token[g1] - group_scores_token[g2])
if score_diff < self.group_tie_threshold:
tied_groups.add(g1)
tied_groups.add(g2)

self.tied_group_sets.append(tied_groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Group tie detection can accept invalid group selections as β€œtied”

The current group tie logic precomputes a flat tied_group_sets[token_idx] by unioning all pairwise ties (lines 178–191), and then considers any symmetric difference of groups that lies in this union as β€œtied” (lines 237–243). This can incorrectly accept cases where:

  • ref groups = {A}, kernel groups = {C}, and
  • A is only tied with B, while C is only tied with D.

All of {A, B, C, D} end up in tied_groups, so {A, C} is seen as β€œtied groups” even though A and C themselves may have very different scores. That can allow completely wrong group selections to pass as valid and propagate into the expert-level checks in is_valid_expert_selection.

You can tighten this by checking ties only among the groups actually involved (ref βˆͺ selected) using self.group_scores instead of the global precomputed set; e.g.:

-    def is_valid_group_selection(self, token_idx, selected_groups):
-        """Check if a set of selected groups is valid (exact match or tied)."""
-        ref_groups = set(self.ref_group_indices[token_idx].tolist())
-        selected_groups_set = set(selected_groups)
-
-        if selected_groups_set == ref_groups:
-            return True, "exact"
-
-        if self.n_group > 1:
-            diff_groups = selected_groups_set.symmetric_difference(ref_groups)
-            tied_groups = self.tied_group_sets[token_idx]
-
-            if diff_groups and diff_groups.issubset(tied_groups):
-                return True, "tied_groups"
-
-        return False, "different_groups"
+    def is_valid_group_selection(self, token_idx, selected_groups):
+        """Check if a set of selected groups is valid (exact match or tied)."""
+        ref_groups = set(self.ref_group_indices[token_idx].tolist())
+        selected_groups_set = set(selected_groups)
+
+        if selected_groups_set == ref_groups:
+            return True, "exact"
+
+        if self.n_group > 1:
+            # Only consider groups actually involved in this comparison
+            groups_union = sorted(selected_groups_set | ref_groups)
+            group_scores_token = self.group_scores[token_idx, groups_union]
+            score_range = group_scores_token.max() - group_scores_token.min()
+            if score_range < self.group_tie_threshold:
+                return True, "tied_groups"
+
+        return False, "different_groups"

With this, is_valid_expert_selection’s group branch (lines 259–293) remains the same conceptually but will no longer accept unrelated groups just because each participated in some tie with some other group.

Also applies to: 229-245, 259-293

πŸ€– Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 178-191 and
the related logic at 229-245 and 259-293, the current approach unions all
pairwise ties into tied_group_sets and later treats any symmetric difference
inside that union as "tied", letting unrelated groups appear tied; instead,
remove/stop relying on the global tied_group_sets and perform tie checks only
among the actually involved groups (ref βˆͺ selected) for each token using
self.group_scores[token_idx] and self.group_tie_threshold: compute pairwise
abs(score_i - score_j) < threshold for groups in the involved set (on-the-fly)
and use that result to decide if the selection is tied/valid, so unrelated ties
elsewhere won't make two unrelated groups pass as tied.

Comment on lines +333 to +357
def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel):
"""Validate kernel outputs and provide detailed debug info for failures."""
num_tokens = topk_indices_kernel.shape[0]
tokens_with_different_experts = set()

for token_idx in range(num_tokens):
kernel_experts = topk_indices_kernel[token_idx].tolist()
ref_experts = ground_truth.ref_expert_indices[token_idx].tolist()

# Same experts - valid
if set(kernel_experts) == set(ref_experts):
continue

# Different experts - mark for value comparison skip
tokens_with_different_experts.add(token_idx)

# Validate the selection
is_valid, reason = ground_truth.is_valid_expert_selection(
token_idx, kernel_experts
)

if not is_valid:
return False, tokens_with_different_experts

return True, tokens_with_different_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Address Ruff warnings: unused parameter and unused reason

Static analysis flags two small issues here:

  • topk_values_kernel is never used.
  • reason returned from ground_truth.is_valid_expert_selection is unused.

If you don’t plan to use topk_values_kernel and reason for debug reporting, you can explicitly mark them as intentionally unused to satisfy Ruff:

-def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel):
+def validate_expert_selection(ground_truth, topk_indices_kernel, _topk_values_kernel):
@@
-        # Validate the selection
-        is_valid, reason = ground_truth.is_valid_expert_selection(
+        # Validate the selection
+        is_valid, _reason = ground_truth.is_valid_expert_selection(
             token_idx, kernel_experts
         )

Alternatively, if you want richer failure messages, you could propagate reason back to the test and include it in the pytest.fail message; but the above is the minimal, low-noise fix.

πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel):
"""Validate kernel outputs and provide detailed debug info for failures."""
num_tokens = topk_indices_kernel.shape[0]
tokens_with_different_experts = set()
for token_idx in range(num_tokens):
kernel_experts = topk_indices_kernel[token_idx].tolist()
ref_experts = ground_truth.ref_expert_indices[token_idx].tolist()
# Same experts - valid
if set(kernel_experts) == set(ref_experts):
continue
# Different experts - mark for value comparison skip
tokens_with_different_experts.add(token_idx)
# Validate the selection
is_valid, reason = ground_truth.is_valid_expert_selection(
token_idx, kernel_experts
)
if not is_valid:
return False, tokens_with_different_experts
return True, tokens_with_different_experts
def validate_expert_selection(ground_truth, topk_indices_kernel, _topk_values_kernel):
"""Validate kernel outputs and provide detailed debug info for failures."""
num_tokens = topk_indices_kernel.shape[0]
tokens_with_different_experts = set()
for token_idx in range(num_tokens):
kernel_experts = topk_indices_kernel[token_idx].tolist()
ref_experts = ground_truth.ref_expert_indices[token_idx].tolist()
# Same experts - valid
if set(kernel_experts) == set(ref_experts):
continue
# Different experts - mark for value comparison skip
tokens_with_different_experts.add(token_idx)
# Validate the selection
is_valid, _reason = ground_truth.is_valid_expert_selection(
token_idx, kernel_experts
)
if not is_valid:
return False, tokens_with_different_experts
return True, tokens_with_different_experts
🧰 Tools
πŸͺ› Ruff (0.14.5)

333-333: Unused function argument: topk_values_kernel

(ARG001)


350-350: Unpacked variable reason is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

πŸ€– Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 333–357,
the function validate_expert_selection has an unused parameter
topk_values_kernel and the local variable reason is never used; to quiet Ruff,
mark the parameter and variable as intentionally unused by renaming the
parameter to _topk_values_kernel (or prefix with an underscore) and change the
call to ground_truth.is_valid_expert_selection to capture the second return as _
or _reason (e.g., is_valid, _ = ...), or alternatively explicitly del
topk_values_kernel and del reason after they are created β€” make one of these
minimal changes so the linter no longer flags unused names.

Comment on lines +360 to +419
def validate_values(ground_truth, topk_values_kernel, tokens_to_skip, data_type):
"""Validate that output values match reference within tolerance."""
# Set tolerance based on data type
if data_type == torch.bfloat16:
rtol, atol = 0.1, 0.1
elif data_type == torch.float16:
rtol, atol = 0.05, 0.05
else: # float32
rtol, atol = 0.01, 0.01

num_tokens = topk_values_kernel.shape[0]

# Create mask for tokens to check
tokens_to_check = torch.ones(num_tokens, dtype=torch.bool)
for token_idx in tokens_to_skip:
tokens_to_check[token_idx] = False

if not tokens_to_check.any():
return

# Compare values
ref_values = ground_truth.ref_expert_values[tokens_to_check].float()
kernel_values = topk_values_kernel[tokens_to_check].float()

try:
torch.testing.assert_close(
ref_values,
kernel_values,
rtol=rtol,
atol=atol,
)
except AssertionError:
# Find and report first mismatch
for token_idx in range(num_tokens):
if not tokens_to_check[token_idx]:
continue

ref_vals = ground_truth.ref_expert_values[token_idx].float()
kernel_vals = topk_values_kernel[token_idx].float()

if not torch.allclose(ref_vals, kernel_vals, rtol=rtol, atol=atol):
diff = (kernel_vals - ref_vals).abs()
max_diff = diff.max().item()
max_diff_idx = diff.argmax().item()

print(f"\n{'=' * 80}")
print(f"VALUE MISMATCH - Token {token_idx}")
print(f"{'=' * 80}")
print(f"Tolerance: rtol={rtol}, atol={atol}")
print(f"Max difference: {max_diff:.6f} at position {max_diff_idx}")
print(f"\nReference values: {ref_vals.tolist()}")
print(f"Kernel values: {kernel_vals.tolist()}")
print(f"Absolute diff: {diff.tolist()}")
print(
f"Expert indices: {ground_truth.ref_expert_indices[token_idx].tolist()}"
)
break

raise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

🧩 Analysis chain

Device mismatch in value validation when using boolean mask with CUDA tensors

In validate_values, tokens_to_check is a CPU bool tensor, while ground_truth.ref_expert_values and topk_values_kernel are CUDA tensors. Indexing CUDA tensors with a CPU boolean mask (lines 381–382) will fail at runtime. A simple fix is to move the values to CPU before applying the mask and doing comparisons:

-    # Compare values
-    ref_values = ground_truth.ref_expert_values[tokens_to_check].float()
-    kernel_values = topk_values_kernel[tokens_to_check].float()
+    # Compare values on CPU to avoid device mismatch with the boolean mask
+    ref_values = ground_truth.ref_expert_values.cpu()[tokens_to_check].float()
+    kernel_values = topk_values_kernel.cpu()[tokens_to_check].float()
@@
-            ref_vals = ground_truth.ref_expert_values[token_idx].float()
-            kernel_vals = topk_values_kernel[token_idx].float()
+            ref_vals = ground_truth.ref_expert_values[token_idx].float().cpu()
+            kernel_vals = topk_values_kernel[token_idx].float().cpu()

This keeps the mask on CPU and avoids CUDA/CPU indexing incompatibilities while still validating the numeric accuracy.


🏁 Script executed:

#!/bin/bash
# Quick grep to confirm no other places index CUDA tensors with CPU masks
rg -n "tokens_to_check" -C3

Length of output: 2053





🏁 Script executed:

#!/bin/bash
# Find where ground_truth is created and used in the test
rg -n "ground_truth\s*=" tests/model_optimizations/test_dsv3_fused_routing.py -A2

Length of output: 100


🏁 Script executed:

#!/bin/bash
# Find where validate_values is called to see what tensors are passed
rg -n "validate_values" tests/model_optimizations/test_dsv3_fused_routing.py -B3 -A3

Length of output: 512





🏁 Script executed:

#!/bin/bash
# Find DSv3RoutingGroundTruth class definition
rg -n "class DSv3RoutingGroundTruth" -A10

Length of output: 1002


🏁 Script executed:

#!/bin/bash
# Find where sorted_vals is created (lines around 501)
sed -n '450,510p' tests/model_optimizations/test_dsv3_fused_routing.py

Length of output: 1629





🏁 Script executed:

#!/bin/bash
# Find ref_expert_values definition in DSv3RoutingGroundTruth class
sed -n '125,250p' tests/model_optimizations/test_dsv3_fused_routing.py | grep -n "ref_expert_values" -B2 -A2

Length of output: 502


🏁 Script executed:

#!/bin/bash
# Get the full init method of DSv3RoutingGroundTruth to see tensor allocation
sed -n '131,220p' tests/model_optimizations/test_dsv3_fused_routing.py

Length of output: 3622


Device mismatch in value validation when using boolean mask with CUDA tensors

In validate_values, tokens_to_check is a CPU bool tensor (line 373), while ground_truth.ref_expert_values and topk_values_kernel are CUDA tensors. Advanced indexing CUDA tensors with a CPU boolean mask (lines 381–382) will fail at runtime with a device mismatch error.

Fix by moving the tensors to CPU before applying the mask:

-    # Compare values
-    ref_values = ground_truth.ref_expert_values[tokens_to_check].float()
-    kernel_values = topk_values_kernel[tokens_to_check].float()
+    # Compare values on CPU to avoid device mismatch with the boolean mask
+    ref_values = ground_truth.ref_expert_values.cpu()[tokens_to_check].float()
+    kernel_values = topk_values_kernel.cpu()[tokens_to_check].float()

Also update lines 397–398 in the exception handler for consistency:

-            ref_vals = ground_truth.ref_expert_values[token_idx].float()
-            kernel_vals = topk_values_kernel[token_idx].float()
+            ref_vals = ground_truth.ref_expert_values[token_idx].float().cpu()
+            kernel_vals = topk_values_kernel[token_idx].float().cpu()
πŸ€– Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 360–419,
the boolean mask tokens_to_check is created on CPU but used to index CUDA
tensors, causing a device mismatch; create the mask on the same device as the
tensors (e.g., device = topk_values_kernel.device; tokens_to_check =
torch.ones(num_tokens, dtype=torch.bool, device=device)) so indexing works, and
ensure values used for comparison/printing are moved to CPU for readability in
the exception handler (e.g., call .cpu() before .tolist() or printing) for lines
~397–398.

Comment on lines +421 to +454
@pytest.mark.parametrize("num_tokens", [1, 8, 16, 64])
@pytest.mark.parametrize("num_experts", [256, 384])
@pytest.mark.parametrize("topk", [1, 2, 4, 8])
@pytest.mark.parametrize("n_group", [1, 2, 4, 8])
@pytest.mark.parametrize("topk_group", [1, 2, 4, 8])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("bias_type", [torch.float32, torch.float16, torch.bfloat16])
def test_dsv3_fused_routing_op(
num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type
):
"""
Test NoAuxTc kernel against reference implementation.
Validates:
1. Expert selection equivalence (allowing for ties)
2. Value correctness within numerical precision tolerance
"""

# Skip invalid configurations
if topk_group * n_group < topk or topk_group > n_group:
pytest.skip(
"Invalid configuration: topk_group * n_group < topk or topk_group > n_group"
)
if n_group > 1:
if (
topk > 8
or num_experts / n_group > 32
or num_experts / n_group * topk_group > 128
):
pytest.skip("Invalid configuration: exceeds kernel limits for n_group > 1")
else:
if num_experts > 384 or topk > 8:
pytest.skip("Invalid configuration: exceeds kernel limits for n_group = 1")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Skip-condition logic may be stricter than necessary for valid configurations

The initial skip condition:

if topk_group * n_group < topk or topk_group > n_group:
    pytest.skip("Invalid configuration: topk_group * n_group < topk or topk_group > n_group")

uses topk_group * n_group to bound topk. Given the algorithm, the natural upper bound on the number of routable experts from the selected groups is experts_per_group * topk_group, not topk_group * n_group. As written, this skips configurations that seem valid from the algorithm’s perspective (e.g., n_group=1, topk_group=1, topk>1), thereby reducing coverage for β€œsingle-group, multi-expert” routing.

Consider tightening this to match the actual capacity constraint instead:

-    if topk_group * n_group < topk or topk_group > n_group:
-        pytest.skip(
-            "Invalid configuration: topk_group * n_group < topk or topk_group > n_group"
-        )
+    experts_per_group = num_experts // n_group
+    if topk_group * experts_per_group < topk or topk_group > n_group:
+        pytest.skip(
+            "Invalid configuration: not enough experts in selected groups "
+            "or topk_group > n_group"
+        )

This preserves the topk_group <= n_group constraint while not discarding otherwise valid and interesting configurations.


🏁 Script executed:

#!/bin/bash
# Inspect other uses of topk_group / n_group to confirm intended constraints.
rg -n "topk_group" -C3

Length of output: 93651


Correct and apply the skip-condition logic to match kernel constraints

The test's skip condition uses topk_group * n_group as an upper bound, but the kernel constraint requires topk < topk_group * (num_experts / n_group). The current logic rejects valid single-group multi-expert configurations (e.g., n_group=1, topk_group=1, topk>1) that the kernel actually supports.

Update line 440 to:

experts_per_group = num_experts // n_group
if topk_group * experts_per_group < topk or topk_group > n_group:
    pytest.skip(
        "Invalid configuration: not enough experts in selected groups "
        "or topk_group > n_group"
    )

This preserves the topk_group <= n_group constraint while allowing all algorithmically valid configurations and improving test coverage.

πŸ€– Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 421 to 454,
the skip logic incorrectly uses topk_group * n_group to validate capacity and
rejects valid configs; compute experts_per_group = num_experts // n_group and
replace the current check with: if topk_group * experts_per_group < topk or
topk_group > n_group: pytest.skip("Invalid configuration: not enough experts in
selected groups or topk_group > n_group"); this preserves the topk_group <=
n_group rule and ensures the kernel constraint topk < topk_group * (num_experts
/ n_group) is enforced.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38749791: 13/18 passed

@yzh119 yzh119 merged commit 3a23405 into flashinfer-ai:main Nov 19, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants