-
Notifications
You must be signed in to change notification settings - Fork 582
[DSV3] Optimized routing kernels dsv3 #2099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DSV3] Optimized routing kernels dsv3 #2099
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a CUDA TopβK reduction utility and a DSV3 deepseek fusedβrouting CUDA kernel with launcher and header, Python JIT build + bindings and package exports, a public NoAuxTc FFI entry, and a parameterized CUDA unit test validating the fused routing path. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python caller
participant PyModule as fused_routing_dsv3.py
participant JIT as JIT build (gen_dsv3_fused_routing_module)
participant FFI as NoAuxTc FFI
participant Launcher as invokeNoAuxTc
participant Kernel as deepseek_v3_topk_kernel
participant TopK as reduce_topk (moeTopKFuncs.cuh)
Py->>PyModule: NoAuxTc(scores,bias,n_group,topk_group,topk,...)
alt module not built
PyModule->>JIT: build/compile JitSpec
JIT-->>PyModule: compiled module
end
PyModule->>FFI: call NoAuxTc (mutates outputs)
FFI->>Launcher: dtype dispatch -> invokeNoAuxTc<...>(...)
Launcher->>Kernel: launch kernel on CUDA stream
rect rgb(235,245,255)
Kernel->>Kernel: compute sigmoid(scores + bias)
Kernel->>TopK: call reduceTopK / reduceTopKFunc (warp-level)
TopK-->>Kernel: top-k values & indices
Kernel->>Kernel: normalize & store outputs
end
Kernel-->>FFI: outputs ready
FFI-->>Py: return (tensors mutated)
Estimated code review effortπ― 4 (Complex) | β±οΈ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touchesβ Failed checks (2 warnings, 1 inconclusive)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the routing capabilities for DeepSeek-V3 Mixture-of-Experts (MoE) models within the FlashInfer library. It introduces highly optimized CUDA kernels for Top-K selection and expert routing, designed to improve performance by leveraging GPU-specific features and efficient reduction strategies. The changes include new C++ CUDA files for core logic, a utility for architecture-specific compilation, and Python bindings to expose this functionality, alongside a dedicated test suite to ensure accuracy. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces optimized CUDA kernels for DeepSeek V3-style MoE routing, including helper functions for Top-K reduction and a new JIT-compiled operator NoAuxTc. The changes are well-structured and include new Python wrappers and tests. My review focuses on potential issues in the CUDA kernel logic, leftover development comments, and opportunities to improve test coverage. I've identified a potential bug in a loop bound calculation, some leftover debug code, and suggested expanding the test suite to cover more execution paths.
csrc/fused_moe/moeTopKFuncs.cuh
Outdated
| RedType topK{value, idx}; | ||
| typename RedType::TypeCmp packedMax{}; | ||
| #pragma unroll | ||
| for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
π§Ή Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (1)
1-69: Test logic matches kernel behavior; consider a few robustness and style tweaks
The reference
dsv3_ref_checkclosely mirrors the CUDA kernelβs sigmoid β bias β group topβK β perβexpert topβK β normalization flow and looks correct for the tested configuration.To avoid failures on environments without CUDA, you may want to guard the test with
torch.cuda.is_available()andpytest.skipbefore callingget_compute_capability(torch.device("cuda")).Currently only
n_group=1,topk_group=1,topk=1is exercised. If you expect multiβgroup usage, adding a second parametrization (e.g.,n_group > 1,topk_group > 1) would give better coverage of theis_multi_grouppath.Ruffβs RUF005 suggestions around
scores_shape[:-1] + [...]and similar concatenations are purely stylistic; if you want to quiet the linter, you can switch to iterable unpacking, e.g.:scores_with_bias.view(*scores_shape[:-1], n_group, scores_shape[-1] // n_group)Overall, the test looks solid; these are optional cleanups/robustness improvements.
csrc/fused_moe/moeTopKFuncs.cuh (1)
47-69: Clarify/index-guard the 16-bit index packing (kMaxIdx = 65535).
TopKRedTypeencodes the index into the lower 16 bits ofTypeCmp:static constexpr int kMaxIdx = 65535; // ... compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); // ... index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));This silently assumes
idxis in[0, 65535]. If a caller ever passes an index beyond 65535, the packed index will wrap andunpackwill return an incorrect index without any diagnostics.I suggest either:
- enforcing the assumption with a debug-time check, or
- documenting it clearly next to
kMaxIdx, so future callers know the constraint.For example:
- static constexpr int kMaxIdx = 65535; + // Indices must be in [0, 65535]; stored in the lower 16 bits of compValIdx. + static constexpr int kMaxIdx = 65535;and, optionally, in a debug build:
#ifdef DEBUG assert(idx >= 0 && idx <= kMaxIdx); #endif
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (10)
csrc/fused_moe/moeTopKFuncs.cuh(1 hunks)csrc/fused_moe/noAuxTcKernels.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/archCondition.h(1 hunks)flashinfer/dsv3_ops/__init__.py(1 hunks)flashinfer/fused_moe/__init__.py(2 hunks)flashinfer/fused_moe/fused_routing_dsv3.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/dsv3_optimizations.py(1 hunks)include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h(1 hunks)tests/model_optimizations/test_dsv3_fused_routing.py(1 hunks)
π§° Additional context used
𧬠Code graph analysis (9)
flashinfer/fused_moe/__init__.py (2)
csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(327-447)NoAuxTc(327-327)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-20)NoAuxTc(26-27)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
num_experts(263-263)n_group(271-271)topk_group(273-273)csrc/tvm_ffi_utils.h (2)
encode_dlpack_dtype(29-31)get_stream(272-274)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-20)NoAuxTc(26-27)
flashinfer/dsv3_ops/__init__.py (2)
csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(327-447)NoAuxTc(327-327)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-20)NoAuxTc(26-27)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)
tests/model_optimizations/test_dsv3_fused_routing.py (3)
csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(327-447)NoAuxTc(327-327)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-20)NoAuxTc(26-27)flashinfer/utils.py (1)
get_compute_capability(252-255)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)
csrc/fused_moe/noAuxTcKernels.cu (3)
void(28-242)invokeNoAuxTc(245-298)invokeNoAuxTc(245-247)include/flashinfer/trtllm/fused_moe/runner.h (3)
num_experts(263-263)n_group(271-271)topk_group(273-273)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_fused_routing_module(14-45)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1)
tensorrt_llm(25-33)
flashinfer/fused_moe/fused_routing_dsv3.py (3)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_fused_routing_module(14-45)flashinfer/jit/core.py (1)
build_and_load(300-312)csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(327-447)NoAuxTc(327-327)
πͺ Clang (14.0.6)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
[error] 20-20: 'cuda_bf16.h' file not found
(clang-diagnostic-error)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h
[error] 19-19: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 19-19: expected ';' after top level declarator
(clang-diagnostic-error)
[error] 19-19: expected identifier or '('
(clang-diagnostic-error)
πͺ GitHub Actions: pre-commit
csrc/fused_moe/noAuxTcKernels.cu
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
tests/model_optimizations/test_dsv3_fused_routing.py
[warning] 1-1: pre-commit: mixed line ending detected; hooks may modify files on re-run
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
[warning] 1-1: pre-commit: mixed line ending detected; hooks may modify files on re-run
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
flashinfer/fused_moe/fused_routing_dsv3.py
[error] 4-4: ruff: F401 'torch' imported but unused
[error] 5-5: ruff: F401 'supported_compute_capability' imported but unused
[error] 1-1: ruff: F401 'torch' imported but unused
[error] 1-1: ruff: F401 'supported_compute_capability' imported but unused
[error] 1-1: end-of-file-fixer hook: files modified
[error] 1-1: Trailing whitespace: files modified
[error] 1-1: clang-format: some files were reformatted by clang-format hook
csrc/fused_moe/moeTopKFuncs.cuh
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
πͺ Ruff (0.14.5)
flashinfer/fused_moe/__init__.py
35-35: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tests/model_optimizations/test_dsv3_fused_routing.py
13-14: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
28-29: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (6)
flashinfer/jit/dsv3_optimizations.py (1)
14-45: JIT spec wiring fordsv3_fused_routinglooks consistent; just verify includesThe new
gen_dsv3_fused_routing_modulespec and source list line up with the CUDA/C++ implementation and nv_internal helpers, and the include roots look reasonable for thetensorrt_llmheaders.Please just doubleβcheck that your default JIT includes plus this
extra_include_pathsset are sufficient for all of:
tensorrt_llm/common/*flashinfer/trtllm/fused_moe/noAuxTcKernels.h- any cutlass kernels headers
so NVCC/clang donβt fail with missing headers depending on the environment.
flashinfer/dsv3_ops/__init__.py (1)
2-7:NoAuxTcreβexport fromfused_moelooks goodReβexporting
NoAuxTcinflashinfer.dsv3_opsis consistent with the fusedβrouting API surface and the tests that import from this namespace. No issues from the snippet.flashinfer/jit/__init__.py (1)
79-84: JIT API reβexport is consistentReβexporting
gen_dsv3_fused_routing_modulehere matches the existing pattern for other JIT specs (e.g.,gen_dsv3_router_gemm_module) and cleanly exposes the new fused routing module.flashinfer/fused_moe/fused_routing_dsv3.py (1)
1-27: Remove unused imports to fix Ruff F401 errorsThe lazy JIT build via
get_dsv3_fused_routing_moduleandregister_custom_op("flashinfer::NoAuxTc", mutates_args=...)registration are solid. However, verification confirms the unused imports are blocking pre-commit:
torch(line 4) is imported but never usedsupported_compute_capabilityandbackend_requirement(lines 7-8) are imported but never usedRemove these three unused imports:
-from flashinfer.jit import gen_dsv3_fused_routing_module -import functools -from types import SimpleNamespace -import torch -from flashinfer.utils import ( - register_custom_op, - supported_compute_capability, - backend_requirement, -) +from flashinfer.jit import gen_dsv3_fused_routing_module +import functools +from types import SimpleNamespace +from flashinfer.utils import register_custom_opThen re-run pre-commit to resolve remaining formatting issues.
csrc/fused_moe/moeTopKFuncs.cuh (1)
1-286: The clang-format fix has been successfully applied.The file
csrc/fused_moe/moeTopKFuncs.cuhnow passes clang-format validation (exit code 0 confirms no formatting errors). The pre-commit style check requirement has been satisfied.csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)
1-114: Clang-format check passes.The file has been successfully formatted and now conforms to clang-format style requirements. The pre-commit formatting check should no longer fail.
| #pragma once | ||
|
|
||
| namespace tensorrt_llm::kernels | ||
| { | ||
|
|
||
| namespace detail | ||
| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π οΈ Refactor suggestion | π Major
Make the header self-contained by including <type_traits>.
std::bool_constant (used from Line 75 onward) is defined in <type_traits>, but this header does not include it directly. That makes archCondition.h fragile and dependent on transitive includes from its consumers.
I recommend explicitly including <type_traits> here so this header can be used in isolation without relying on include order.
-#pragma once
-
- namespace tensorrt_llm::kernels
+#pragma once
+
+#include <type_traits>
+
+namespace tensorrt_llm::kernels
{Also applies to: 75-96
π§° Tools
πͺ Clang (14.0.6)
[error] 19-19: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 19-19: expected ';' after top level declarator
(clang-diagnostic-error)
[error] 19-19: expected identifier or '('
(clang-diagnostic-error)
π€ Prompt for AI Agents
In csrc/nv_internal/tensorrt_llm/kernels/archCondition.h around lines 17-23 (and
affecting lines 75-96), the header uses std::bool_constant but doesn't include
<type_traits>, making it rely on transitive includes; add a direct #include
<type_traits> near the top of the file (with other includes/pragma once) so the
header is self-contained and rebuilds cleanly without depending on include
order.
| template <int Arch> | ||
| struct is_match : std::bool_constant<is_device::value && detail::arch_info::mArch == Arch> | ||
| { | ||
| }; | ||
|
|
||
| template <int Major> | ||
| struct is_major : std::bool_constant<is_device::value && detail::arch_info::mMajor == Major> | ||
| { | ||
| }; | ||
|
|
||
| template <int Arch> | ||
| struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> | ||
| { | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_compatible trait likely has incorrect semantics.
As written:
template <int Arch>
struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {};you are treating the template parameter Arch both as a major (for is_major<Arch>) and as the mArch βSM codeβ (80, 90, 100, ...). For example, on SM90 (mMajor = 9, mArch = 90), is_compatible<90> will always be false because is_major<90> can never be true.
If the intent is βcompilation device SM is at least Arch (80, 90, 100, ...) and we are on deviceβ, a cleaner definition is:
-template <int Arch>
-struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch>
-{
-};
+template <int Arch>
+struct is_compatible
+ : std::bool_constant<is_device::value && detail::arch_info::mArch >= Arch>
+{
+};This makes is_compatible_v<90> read as βcompiled for device and mArch >= 90β, which matches the naming.
| from .fused_routing_dsv3 import ( # noqa: F401 | ||
| NoAuxTc as NoAuxTc, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
Remove unused noqa and simplify NoAuxTc reβexport
Ruff reports the # noqa: F401 as unused. You can simplify the import and satisfy the linter without changing behavior:
-from .fused_routing_dsv3 import ( # noqa: F401
- NoAuxTc as NoAuxTc,
-)
+from .fused_routing_dsv3 import NoAuxTcNoAuxTc is still exported via __all__, so thereβs no functional change.
π Script executed:
#!/bin/bash
# Check the current state of __init__.py around lines 35-37
cat flashinfer/fused_moe/__init__.py | head -50Length of output: 1504
π Script executed:
#!/bin/bash
# Search for __all__ definition in the file
grep -n "__all__" flashinfer/fused_moe/__init__.pyLength of output: 14
π Script executed:
#!/bin/bash
# Search for NoAuxTc usage in the codebase
rg -n "NoAuxTc" --type pyLength of output: 1381
π Script executed:
#!/bin/bash
# Check if fused_routing_dsv3 module exists and what it exports
fd "fused_routing_dsv3" --type pyLength of output: 169
π Script executed:
#!/bin/bash
# Get the full __all__ list
sed -n '39,60p' flashinfer/fused_moe/__init__.pyLength of output: 529
Simplify NoAuxTc re-export by removing redundant pattern
The refactoring is correct. NoAuxTc is confirmed in __all__ at line 55, so the simplified import form maintains the same export behavior while removing unnecessary redundancy:
-from .fused_routing_dsv3 import ( # noqa: F401
- NoAuxTc as NoAuxTc,
-)
+from .fused_routing_dsv3 import NoAuxTcThis change is safe and recommended.
π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .fused_routing_dsv3 import ( # noqa: F401 | |
| NoAuxTc as NoAuxTc, | |
| ) | |
| from .fused_routing_dsv3 import NoAuxTc |
π§° Tools
πͺ Ruff (0.14.5)
35-35: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
π€ Prompt for AI Agents
In flashinfer/fused_moe/__init__.py around lines 35 to 37, the current re-export
unnecessarily aliases NoAuxTc via "NoAuxTc as NoAuxTc"; replace that with a
direct import of NoAuxTc from .fused_routing_dsv3 (e.g., "from
.fused_routing_dsv3 import NoAuxTc") to simplify the code while preserving the
existing export behavior already declared in __all__.
|
cc @ChristinaZ for vis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
β»οΈ Duplicate comments (8)
csrc/fused_moe/noAuxTcKernels.cu (3)
337-340: Clarify thetopk_indicesdtype error message.You enforce
topk_indicesto be int32:TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) << "topk_indices must have the same dtype as scores";but the message says βsame dtype as scoresβ, which is misleading (scores are float/bfloat16). Updating the message improves debuggability:
- TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) - << "topk_indices must have the same dtype as scores"; + TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) + << "topk_indices must have dtype int32";
94-103: Enforcetopkβ€MaxNumTopExpertsto avoid outβofβbounds accesses.
topScores/topExpertsand related buffers are sized withMaxNumTopExperts = 8:float topScores[MaxNumTopExperts]; // bound of topk int32_t topExperts[MaxNumTopExperts];but are indexed with
laneIdx < topkandkk < topk(viareduceTopK(..., topk)) and written into shared memory whenlaneIdx < topk. The host only checkstopk <= 32, sotopk > MaxNumTopExpertswill:
- Overrun
topScores/topExpertsinsidereduceTopK(loopkk < actualK), and- Overrun them again in the final writeback (
expertIdx = laneIdx < topk ? topExperts[laneIdx] : ...).At minimum, the host entry should enforce the tighter bound:
- TVM_FFI_ICHECK(topk <= 32) - << "topk should be smaller than or equal to 32 for now"; //@todo: remove this restriction - // later + TVM_FFI_ICHECK(topk <= tensorrt_llm::kernels::MaxNumTopExperts) + << "topk must not exceed " << tensorrt_llm::kernels::MaxNumTopExperts + << " for the optimized DSv3 fused routing kernel";If you intend to support larger
topk, youβll need to increaseMaxNumTopExperts(and adjust shared/register usage and thereduceTopKtemplateK) instead of only loosening the host check.Also applies to: 138-141, 162-165, 201-212
169-182: FixNumInterTopKPerThreadand initialise intermediate buffers to avoid using uninitialized values.In the multiβwarp, noβgroups path you currently have:
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) { int ii = i / WARP_SIZE; if (i < NumInterTopK) { intermidiateScore[ii] = smemInterTopScores[i]; intermidiateExpert[ii] = smemInterTopExperts[i]; } else { intermidiateScore[ii] = invalidScoreFloat; intermidiateExpert[ii] = MaxNumExperts - 1; } }Given
NumInterTopK = NumExpertWarps * MaxNumTopExpertsandMaxNumExperts β€ 384, this formula makesNumInterTopKPerThreadlarger thanceil(NumInterTopK / WARP_SIZE), so someintermidiateScore[ii]/intermidiateExpert[ii]entries are never touched on some lanes, yetreduceTopKFuncis called withN = NumInterTopKPerThreadand will read all of them, leading to undefined behaviour.A safer and simpler formulation is:
- int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1; - float intermidiateScore[NumInterTopKPerThread]; - int32_t intermidiateExpert[NumInterTopKPerThread]; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; + float intermidiateScore[NumInterTopKPerThread]; + int32_t intermidiateExpert[NumInterTopKPerThread]; + #pragma unroll + for (int ii = 0; ii < NumInterTopKPerThread; ++ii) { + intermidiateScore[ii] = invalidScoreFloat; + intermidiateExpert[ii] = MaxNumExperts - 1; + }This matches the intended βceil(NumInterTopK / WARP_SIZE)β perβthread budget and ensures every slot has a wellβdefined sentinel value before it is used in the subsequent reduction.
csrc/fused_moe/moeTopKFuncs.cuh (1)
142-159: Clean up TODOs inreduceTopKand document the invariants instead.There are a couple of lingering TODOs:
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct ... topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correctThe surrounding logic for
actualKand the sentinelii * kWARP_SIZE - 1looks intentional and stable now. Leaving these TODOs in place suggests the implementation is still suspect, which makes maintenance harder.Either remove these TODOs or replace them with brief explanatory comments (e.g., why
actualKcan be less thanK, and how the-1index acts as a safe sentinel that can never collide with a valid index) so future readers donβt have to reβaudit the algorithm.Also applies to: 219-222
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (2)
17-21: Include<type_traits>so the trait types are defined and the header is selfβcontained.This header uses
std::bool_constantin thearchtraits but does not include<type_traits>, so it relies on transitive includes and may fail to compile when included first in a TU. Please add an explicit include near the top:-#pragma once - -namespace tensorrt_llm::kernels { +#pragma once + +#include <type_traits> + +namespace tensorrt_llm::kernels {This keeps the header robust and independent of include order.
Also applies to: 71-74
81-84:is_compatiblemixes major and SM codes; semantics are likely incorrect.
is_compatiblecurrently does:template <int Arch> struct is_compatible : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {};Here
Archis used both as a major foris_major<Arch>and as an SM code (e.g., 80, 90, 100) formArch >= Arch. On SM90 (mMajor = 9,mArch = 90),is_compatible<90>will always be false becauseis_major<90>can never be true.If the intended meaning is βcompiled for device and
mArch >= Arch(whereArchis 80/90/100...)β, consider:template <int Arch> -struct is_compatible - : std::bool_constant<is_major<Arch>::value && detail::arch_info::mArch >= Arch> {}; +struct is_compatible + : std::bool_constant<is_device::value && detail::arch_info::mArch >= Arch> {};This makes
is_compatible_v<90>read as βdevice build and SM >= 90β, which matches the name.include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)
27-31: FixinvokeNoAuxTcdeclaration to match the.cudefinition (missinglaunch_with_pdl).The implementation in
csrc/fused_moe/noAuxTcKernels.cutakes an extrabool launch_with_pdlparameter, but the header omits it. This will cause conflicting declarations/ODR issues once both are seen.Please align the declaration with the definition:
template <typename InputT, typename BiasT, typename OutputT, typename IdxT> -void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, - int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, - int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, - cudaStream_t const stream = 0); +void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices, + int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + double const routed_scaling_factor, bool const launch_with_pdl, + cudaStream_t const stream = 0);(The default for
streamcan stay only in the declaration.)
20-21: Remove unnecessary<cuda_bf16.h>from this public header and include it only where needed.This header itself never references
__nv_bfloat16, but including<cuda_bf16.h>here causes clang errors in environments without CUDA headers and increases header dependencies.Since the BF16 specializations are instantiated in
csrc/fused_moe/noAuxTcKernels.cu, itβs cleaner to:-#include <cuda_bf16.h> -#include <cuda_fp16.h> +#include <cuda_fp16.h>and add:
#include <cuda_bf16.h>near the top of
csrc/fused_moe/noAuxTcKernels.cubefore the BF16 instantiations.This keeps the public header lighter and avoids toolchain issues when CUDA headers are not configured for host-only builds.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (5)
csrc/fused_moe/moeTopKFuncs.cuh(1 hunks)csrc/fused_moe/noAuxTcKernels.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/archCondition.h(1 hunks)flashinfer/fused_moe/fused_routing_dsv3.py(1 hunks)include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h(1 hunks)
π§ Files skipped from review as they are similar to previous changes (1)
- flashinfer/fused_moe/fused_routing_dsv3.py
π§° Additional context used
π§ Learnings (1)
π Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fused_moe/noAuxTcKernels.cu
𧬠Code graph analysis (3)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
num_experts(263-263)n_group(271-271)topk_group(273-273)csrc/tvm_ffi_utils.h (2)
encode_dlpack_dtype(29-31)get_stream(272-274)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-40)NoAuxTc(47-68)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h (1)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (1)
tensorrt_llm(25-33)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h (2)
csrc/fused_moe/noAuxTcKernels.cu (3)
void(26-218)invokeNoAuxTc(221-272)invokeNoAuxTc(221-224)include/flashinfer/trtllm/fused_moe/runner.h (3)
num_experts(263-263)n_group(271-271)topk_group(273-273)
πͺ Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/kernels/archCondition.h
[error] 19-19: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 19-19: expected ';' after top level declarator
(clang-diagnostic-error)
[error] 19-19: expected identifier or '('
(clang-diagnostic-error)
include/flashinfer/trtllm/fused_moe/noAuxTcKernels.h
[error] 20-20: 'cuda_bf16.h' file not found
(clang-diagnostic-error)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
β»οΈ Duplicate comments (1)
csrc/fused_moe/moeTopKFuncs.cuh (1)
152-152: Resolve or replace TODOs with explicit invariants / commentsThere are two
//@todonotes that look like leftover uncertainty:
- Line 152:
//@todo: check if actualK is correctin the scalarreduceTopKloop.- Line 221:
//@todo: check if this is correctfor initializingtopKBufferIdx[ii] = ii * kWARP_SIZE - 1.If these code paths have now been validated (e.g., by parity with the reference implementation and tests), consider:
- Removing the
//@todomarkers, and- Replacing them with brief comments documenting the intended invariants (e.g.,
actualK <= Kand why the-1sentinel index is safe when packing/unpacking).That keeps future readers from wondering whether there is still an open correctness question here.
Also applies to: 221-221
π§Ή Nitpick comments (4)
csrc/fused_moe/moeTopKFuncs.cuh (4)
35-84: TopKRedType packing/reduction logic looks solid, with only minor aliasing caveatsThe value/index packing, deterministic tie-breaking via
kMaxIdx - idx, and warp reduction path (including the fastredux.sync.max.u32specialization) all look coherent and in line with typical TopβK implementations. The only nit is that thereinterpret_cast-based conversions inmakeCmpVal/unpackrely on CUBβsTraits<T>patterns and are a bit aggressive from a strict C++ aliasing perspective; if this ever needs to be made more portable, usingmemcpyor an explicitly-sized intermediate would be safer, but itβs probably fine given CUDA/CUB constraints.
142-159: Guard againstactualK > Kor document the invariant for reduceTopK (scalar input)
reduceTopKwrites toout[kk]/outIdx[kk]forkk < actualK, but the arrays are sized asK. IfactualKcan ever exceedKat runtime, this will lead to outβofβbounds writes on the stack. If the intended contract is0 < actualK <= K, consider either:
- Clamping:
int kEff = min(actualK, K);and looping tokEff, or- Adding an explicit runtime check / debug assert, or at least documenting clearly that callers must enforce
actualK <= K.This also applies to the higherβlevel overloads that forward
actualKunchanged.
161-193: Clarify N constraint message and reuseactualKinvariant for reduceTopKFuncThe
reduceTopKFunc<K, Type, N, IsSorted>implementation looks consistent with the scalarreduceTopK(shift-out of the previous max viaupdate, and incremental warp reductions). Two minor points:
- The static assert
static_assert(N < 5, "Only support candidates number less than or equal to 128");is slightly confusing because the template constraint is per-thread (N <= 4), while the message refers to the per-warp total (β€ 4 * 32). Consider tightening the message to reflect that ("N per thread must be <= 4 (<= 128 per warp)") to avoid ambiguity.- The same potential
actualK > Khazard exists here as in the scalar version, sinceout/outIdxare sized byKbut indexed up toactualK - 1. Any solution you adopt for the scalar overload (clamping, assert, or documenting a hard invariant) should be mirrored here for consistency.
195-247: Multi-chunk (N > 4) path: lane indexing andnumResultsusage could be tightenedThe multi-chunk overload
reduceTopK<K, Type, N>generally makes sense (processing 4-candidate blocks and then reducing partial TopβK results), but a few details are worth tightening:
- Lane index source:
laneIdxis computed asthreadIdx.x % kWARP_SIZE. Since you already have acg::thread_block_tile<kWARP_SIZE> warp, it would be more robust to usewarp.thread_rank()instead of assuming the tile corresponds tothreadIdx.xmodulo 32. This avoids surprises if the warp tile is ever created from a nontrivial block layout.- Unused
numResultsslots:topKBufferValue/topKBufferIdxare sized asnumResults, but only indices0and (in the last loop)1are ever written. For combinations wherenumResults > 2(e.g., largerKwithN = 12 or 16), the extra slots remain at the sentinelminValue/ii * kWARP_SIZE - 1. Thatβs functionally safe but suggests either:
- You never intend to support such
(N, K)combinations and could enforce it with a static_assert onnumLoops * K(or directly onK), or- The distribution logic could be extended to actually populate all
numResultsentries so that no partial TopβK candidates are dropped whennumLoops * K > 64.Clarifying which of these is intended (and possibly adding a static_assert or brief comment) would make the behavior easier to reason about.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
β»οΈ Duplicate comments (1)
tests/model_optimizations/test_dsv3_fused_routing.py (1)
47-49: Narrow test coverage for multi-group routing.The test currently only exercises the single-group case (
n_group=1). As noted in previous reviews, coverage should be expanded to include multi-group routing paths, differenttopkvalues, and varyingnum_expertsconfigurations.
π§Ή Nitpick comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
8-44: Consider documenting the DSv3-specific constant.The hardcoded
k=2on line 17 appears to be a DSv3-specific constant for the group scoring mechanism. Adding a comment or constant definition would improve clarity for future maintainers.Optionally, you can also address the static analysis hints by replacing list concatenation with iterable unpacking on lines 15 and 31:
- scores_shape[:-1] + [n_group, scores_shape[-1] // n_group] + [*scores_shape[:-1], n_group, scores_shape[-1] // n_group]
61-62: Consider usingtorch.emptyfor output buffers.Since
topk_valuesandtopk_indicesare output tensors that will be fully overwritten by the kernel, usingtorch.emptyinstead oftorch.randnwould better communicate intent and avoid unnecessary random initialization.- topk_values = torch.randn(num_tokens, topk, device="cuda", dtype=torch.float32) - topk_indices = torch.randn(num_tokens, topk, device="cuda").to(torch.int32) + topk_values = torch.empty(num_tokens, topk, device="cuda", dtype=torch.float32) + topk_indices = torch.empty(num_tokens, topk, device="cuda", dtype=torch.int32)
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (1)
tests/model_optimizations/test_dsv3_fused_routing.py(1 hunks)
π§° Additional context used
𧬠Code graph analysis (1)
tests/model_optimizations/test_dsv3_fused_routing.py (3)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(19-40)NoAuxTc(47-68)csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(303-445)NoAuxTc(303-305)flashinfer/utils.py (1)
get_compute_capability(252-255)
πͺ Ruff (0.14.5)
tests/model_optimizations/test_dsv3_fused_routing.py
15-15: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
31-31: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
π Additional comments (4)
tests/model_optimizations/test_dsv3_fused_routing.py (4)
1-5: LGTM!All imports are necessary and correctly used throughout the test.
50-53: LGTM!The compute capability check correctly restricts the test to SM100 hardware as required for DSv3 fused routing.
64-74: LGTM!The kernel call correctly passes all required parameters and uses the PDL launch path as intended.
76-80: LGTM!The validation correctly compares kernel outputs against the reference implementation with appropriate tolerances for floating-point values and proper type casting for index comparison.
|
[FAILED] Pipeline #38685617: 4/18 passed |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
β»οΈ Duplicate comments (2)
csrc/fused_moe/noAuxTcKernels.cu (2)
16-23: Enforcetopkandn_groupbounds at the C++ boundary to avoid kernel OOBThe kernel hardcodes several capacity limits:
static constexpr int MaxNumTopExperts = 8;static constexpr int MaxNumTopGroups = 4;int constexpr NumWarps = MaxNumExperts / WARP_SIZE;withMaxNumExpertsinstantiated asNumDeepseekExperts = 256for the grouped path, soNumWarps == 8.float topScores[MaxNumTopExperts]; int32_t topExperts[MaxNumTopExperts];__shared__ float smemGroupScores[NumWarps];But the host FFI entry only checks:
TVM_FFI_ICHECK(n_group <= 32) << "..."; TVM_FFI_ICHECK(topk <= 32) << "...";Two concrete issues follow:
Grouped path
n_group > 8
For the grouped instantiation (UseGroups = truewithMaxNumExperts = NumDeepseekExperts):
smemGroupScoreshas length 8, but group selection at warp 0 readssmemGroupScores[laneIdx]forlaneIdx < numGroup. Withn_group > 8, this reads past the shared array.- Only 8 warps exist, so groups beyond 8 also never get a proper group score.
topk > MaxNumTopExperts
The final selection usestopScores[MaxNumTopExperts]/topExperts[MaxNumTopExperts]andlaneIdx < topkto index and write outputs. Allowingtopkup to 32 while the arrays are sized to 8 is inconsistent with the kernelβs capacity and risks undefined behavior, depending onreduce_topkβs implementation.Pythonβs
_check_dsv3_fused_routing_supportedcurrently restricts DSv3 totopk <= 8, but this C++NoAuxTccan be reached from other frontends and should enforce the same invariants.Consider tightening the host checks along these lines:
- TVM_FFI_ICHECK(n_group <= 32) - << "n_group should be smaller than or equal to 32 for now"; - TVM_FFI_ICHECK(topk <= 32) - << "topk should be smaller than or equal to 32 for now"; + TVM_FFI_ICHECK(n_group <= NumDeepseekExperts / WARP_SIZE) + << "n_group must be <= " << (NumDeepseekExperts / WARP_SIZE) + << " for the optimized DSv3 fused routing kernel"; + TVM_FFI_ICHECK(topk <= MaxNumTopExperts) + << "topk must be <= " << MaxNumTopExperts + << " for the optimized DSv3 fused routing kernel";and then keeping the more relaxed 32 limit only if/when a generic fallback path (without these static buffers) is implemented.
Also applies to: 96-105, 120-145, 204-215, 323-329
340-343: Clarifytopk_indicesdtype error message to match the checkHere:
TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) << "topk_indices must have the same dtype as scores";the check enforces that
topk_indicesis int32, not βsame dtype as scoresβ. The mismatched message can make debugging confusing (especially since the Python docstring also mentions int64 as an option).Suggest tightening the message:
- TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) - << "topk_indices must have the same dtype as scores"; + TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code) + << "topk_indices must have dtype int32";
π§Ή Nitpick comments (3)
flashinfer/fused_moe/fused_routing_dsv3.py (1)
12-23: Tighten Python precheck to mirror C++ kernel invariants and silence unused-arg lintThe configuration checks are aligned with the optimized kernelβs assumptions (group product limits, topk β€ 8, experts/group β€ 32), but two small gaps remain:
- The C++ entry point enforces
num_experts % n_group == 0; encoding the same constraint here (e.g., viascores.shape[1] % n_group == 0) would fail fast in Python instead of surfacing as a TVM assertion later.- For multiβgroup configurations, the kernel effectively only supports up to
NumDeepseekExperts / WARP_SIZEgroups (8 for 256 experts). If you intend to keep that invariant, consider also enforcingn_group <= 8here so Python callers canβt create configs that the CUDA path canβt handle.Also,
_check_dsv3_fused_routing_supportedmust accept the fullNoAuxTcsignature forbackend_requirement, butbias,routed_scaling_factor,topk_values,topk_indices, andlaunch_with_pdlare unused. If Ruffβs ARG001 is noisy, you can explicitly mark them as unused, e.g.:# Unused but required by backend_requirement signature _ = (bias, routed_scaling_factor, topk_values, topk_indices, launch_with_pdl)Also applies to: 40-79
tests/model_optimizations/test_dsv3_fused_routing.py (2)
373-449: Remove or use the unusedtopk_values_kernelparameter invalidate_and_debug
validate_and_debugtakestopk_values_kernelbut never uses it; all logic is based ontopk_indices_kerneland the ground-truth object. Thatβs triggering Ruffβs ARG001 and can confuse readers.You can either:
- Drop the parameter entirely:
-def validate_and_debug(ground_truth, topk_indices_kernel, topk_values_kernel): +def validate_and_debug(ground_truth, topk_indices_kernel): @@ - all_valid, tokens_with_different_experts = validate_and_debug( - ground_truth, topk_indices, sorted_vals - ) + all_valid, tokens_with_different_experts = validate_and_debug( + ground_truth, topk_indices + )or
- Actually use
topk_values_kernelin the debug printout (e.g., printing kernel values alongside indices for failing tokens).Either approach will resolve the unused-argument warning and make the intent clearer.
513-593: Optionally gate the test on supported compute capability to avoid hard failures
NoAuxTcis decorated withbackend_requirementand asupported_compute_capability([89, 90, 100, 103, 120, 121])common check, so calling it on GPUs outside that set will raise aBackendSupportedError. Right now, the test unconditionally constructs CUDA tensors and callsNoAuxTc, which will cause a test failure rather than a skip on unsupported hardware.Given you already have a commented import for
get_compute_capability, you could do something like:-import pytest -from flashinfer.dsv3_ops import NoAuxTc -# from flashinfer.utils import get_compute_capability +import pytest +from flashinfer.dsv3_ops import NoAuxTc +from flashinfer.utils import get_compute_capability @@ def test_dsv3_fused_routing_op( num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type ): + # Skip on unsupported compute capability + cc = get_compute_capability() + if not NoAuxTc.is_compute_capability_supported(cc): + pytest.skip(f"NoAuxTc not supported on compute capability {cc}")This keeps the test suite green on older or different GPUs while still enforcing correctness where the kernel is intended to run.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (3)
csrc/fused_moe/noAuxTcKernels.cu(1 hunks)flashinfer/fused_moe/fused_routing_dsv3.py(1 hunks)tests/model_optimizations/test_dsv3_fused_routing.py(1 hunks)
π§° Additional context used
π§ Learnings (1)
π Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fused_moe/noAuxTcKernels.cu
𧬠Code graph analysis (3)
csrc/fused_moe/noAuxTcKernels.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (3)
num_experts(263-263)n_group(271-271)topk_group(273-273)csrc/tvm_ffi_utils.h (2)
encode_dlpack_dtype(29-31)get_stream(272-274)flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(90-111)NoAuxTc(119-194)
flashinfer/fused_moe/fused_routing_dsv3.py (4)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_fused_routing_module(14-45)flashinfer/utils.py (2)
supported_compute_capability(773-853)backend_requirement(856-1131)flashinfer/jit/core.py (1)
build_and_load(300-312)csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(306-448)NoAuxTc(306-308)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(90-111)NoAuxTc(119-194)csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(306-448)NoAuxTc(306-308)
πͺ Ruff (0.14.5)
flashinfer/fused_moe/fused_routing_dsv3.py
15-15: Unused function argument: bias
(ARG001)
19-19: Unused function argument: routed_scaling_factor
(ARG001)
20-20: Unused function argument: topk_values
(ARG001)
21-21: Unused function argument: topk_indices
(ARG001)
22-22: Unused function argument: launch_with_pdl
(ARG001)
45-48: Avoid specifying long messages outside the exception class
(TRY003)
56-58: Avoid specifying long messages outside the exception class
(TRY003)
60-63: Avoid specifying long messages outside the exception class
(TRY003)
65-68: Avoid specifying long messages outside the exception class
(TRY003)
71-73: Avoid specifying long messages outside the exception class
(TRY003)
75-77: Avoid specifying long messages outside the exception class
(TRY003)
tests/model_optimizations/test_dsv3_fused_routing.py
373-373: Unused function argument: topk_values_kernel
(ARG001)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (3)
flashinfer/fused_moe/fused_routing_dsv3.py (1)
82-115: JIT build + custom-op registration pattern looks solidThe
@functools.cachearoundget_dsv3_fused_routing_module()plus the nested, registeredNoAuxTccustom op that forwards tomodule.NoAuxTcgives you a single build per process and a clear mutation contract ontopk_values/topk_indices. Argument order and types match the C++NoAuxTcentry point, so the wiring looks correct.csrc/fused_moe/noAuxTcKernels.cu (1)
223-275: Dispatch heuristics for single-/multi-group kernels look consistentThe
invokeNoAuxTclogic aroundis_single_groupvsis_multi_groupand the choice ofMaxNumExperts(128, 256, 384) matches the Python precheck and the test ranges (num_experts β {256, 384}, topk β€ 8, group-wise capacity constraints). The use ofcudaLaunchKernelExwith programmatic stream serialization gated bylaunch_with_pdlalso looks correct for DSv3 usage.tests/model_optimizations/test_dsv3_fused_routing.py (1)
125-224: Ground-truth implementation closely matches the kernelβs routing specThe
DSv3RoutingGroundTruthclass mirrors the documented algorithm (sigmoid + bias, per-group top-2 sums, group topβk, masked expert topβk, normalization, and final sorting) in float32 and incorporates perβdtype tie thresholds. This gives strong coverage of both selection and normalization behavior and is a good reference for future kernel changes.
| @backend_requirement({}, common_check=_check_dsv3_fused_routing_supported) | ||
| def NoAuxTc( | ||
| scores: torch.Tensor, | ||
| bias: torch.Tensor, | ||
| n_group: int, | ||
| topk_group: int, | ||
| topk: int, | ||
| routed_scaling_factor: float, | ||
| topk_values: torch.Tensor, | ||
| topk_indices: torch.Tensor, | ||
| launch_with_pdl: bool = True, | ||
| ) -> None: | ||
| """Fused expert routing with top-k selection for DeepSeek-V3. | ||
| This function performs a highly optimized fused routing operation specifically | ||
| designed for DeepSeek-V3's Mixture of Experts (MoE) architecture with grouped | ||
| expert routing and no auxiliary loss. It combines score computation, expert | ||
| selection, and normalization into a single kernel operation. | ||
| The routing algorithm consists of the following steps: | ||
| 1. Compute biased scores: sigmoid(scores) + bias for each expert | ||
| 2. Group experts and compute group scores (sum of top-2 experts per group) | ||
| 3. Select top-k groups based on group scores | ||
| 4. From selected groups, select top-k experts based on biased scores | ||
| 5. Normalize selected expert weights: sigmoid_scores / sum(sigmoid_scores) * scale | ||
| Args: | ||
| scores (torch.Tensor): Input routing scores of shape (num_tokens, num_experts). | ||
| The logits produced by the router network before activation. Supports | ||
| bfloat16, float16, or float32. | ||
| bias (torch.Tensor): Per-expert routing bias of shape (num_experts,). Added to | ||
| sigmoid-activated scores to produce biased scores for expert selection. | ||
| Must match the dtype of scores. | ||
| n_group (int): Number of expert groups. Experts are divided into groups for | ||
| hierarchical selection. Typical value is 8 for DeepSeek-V3 with 256 experts | ||
| (32 experts per group). | ||
| topk_group (int): Number of top groups to select. Must be <= n_group. Typical | ||
| value is 4, meaning the top 4 groups are selected from 8 groups. | ||
| topk (int): Number of top experts to select per token. Must be <= num_experts. | ||
| Typical value is 8, meaning 8 experts are routed per token. | ||
| routed_scaling_factor (float): Scaling factor applied to normalized expert | ||
| weights. The final output weights are: | ||
| sigmoid_scores / sum(sigmoid_scores) * routed_scaling_factor. | ||
| topk_values (torch.Tensor): Pre-allocated output tensor of shape | ||
| (num_tokens, topk) for the normalized expert weights. Must be float32. | ||
| This tensor is mutated in-place. | ||
| topk_indices (torch.Tensor): Pre-allocated output tensor of shape | ||
| (num_tokens, topk) for the selected expert indices. Must be int32 or int64. | ||
| This tensor is mutated in-place. | ||
| launch_with_pdl (bool, optional): Whether to launch the kernel using Persistent | ||
| Device-side Launch. Defaults to True. | ||
| Returns: | ||
| None: Results are written directly to `topk_values` and `topk_indices` tensors. | ||
| Note: | ||
| - The kernel uses float32 internally for all computations to ensure numerical | ||
| precision, even when inputs are float16 or bfloat16. | ||
| - This implementation is optimized for Hopper (compute capability 90, 100), | ||
| Ada (compute capability 89), and Blackwell (compute capability 120, 121) | ||
| architectures. | ||
| - The "NoAux" prefix indicates this variant does not compute auxiliary losses | ||
| (e.g., load balancing loss) during routing. | ||
| - The "Tc" suffix indicates the use of Tensor Core optimizations in the | ||
| underlying CUDA kernel. | ||
| """ | ||
| get_dsv3_fused_routing_module().NoAuxTc( | ||
| scores, | ||
| bias, | ||
| n_group, | ||
| topk_group, | ||
| topk, | ||
| routed_scaling_factor, | ||
| topk_values, | ||
| topk_indices, | ||
| launch_with_pdl, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring is inconsistent with actual dtype requirements of the kernel
The Python docstring states that:
topk_valuesβMust be float32β.topk_indicesβMust be int32 or int64β.
However, the C++ entry point (flashinfer::trtllm_dsv3_fused_routing::NoAuxTc) enforces:
topk_values.dtype() == scores.dtype()(viaTVM_FFI_ICHECK(topk_values.dtype() == data_type)).topk_indicesdtype encoded asint32_code(only int32 is accepted).
The tests also construct topk_values with dtype=data_type and topk_indices as torch.int32, matching the C++ checks, not the docstring.
To avoid confusing users and spurious runtime errors, please update the docstring to match the real contract, e.g.:
- topk_values (torch.Tensor): Pre-allocated output tensor of shape
- (num_tokens, topk) for the normalized expert weights. Must be float32.
+ topk_values (torch.Tensor): Pre-allocated output tensor of shape
+ (num_tokens, topk) for the normalized expert weights. Must have the
+ same dtype as `scores`. This tensor is mutated in-place.
@@
- topk_indices (torch.Tensor): Pre-allocated output tensor of shape
- (num_tokens, topk) for the selected expert indices. Must be int32 or int64.
- This tensor is mutated in-place.
+ topk_indices (torch.Tensor): Pre-allocated output tensor of shape
+ (num_tokens, topk) for the selected expert indices. Must be int32.
+ This tensor is mutated in-place.π€ Prompt for AI Agents
In flashinfer/fused_moe/fused_routing_dsv3.py around lines 118 to 193, the
docstring incorrectly states that topk_values must be float32 and topk_indices
can be int32 or int64; update it to reflect the C++ entry-point contract:
topk_values must have the same dtype as scores (i.e., use scores.dtype) and
topk_indices must be int32 only (torch.int32). Change the Args section lines for
topk_values and topk_indices to state these exact requirements and note that the
kernel checks enforce them.
|
there's also alternative path csrc/nv_internal/cpp/kernels to store cpp/h files, for the sake of minimizing difference, worth considering for imported kernels |
aleozlx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only posted comments that are more of nitpicking
adding my approval as vote up
jiahanc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for the contribution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
π§Ή Nitpick comments (1)
tests/model_optimizations/test_dsv3_fused_routing.py (1)
455-486: Add CUDA availability guard and consider parametrizing launch modesThe test lacks a
@pytest.mark.skipifdecorator for CUDA availability. Currently, it will error in non-CUDA environments before pytest can properly report a skip. Add:@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for DSv3 NoAuxTc tests")immediately before the existing
@pytest.mark.parametrizedecorators at line 421.Additionally,
launch_with_pdl=Trueis hardcoded at line 484. Consider parametrizing it over[True, False]to exercise both launch paths:@pytest.mark.parametrize("launch_with_pdl", [True, False])This ensures test coverage of both runtime modes.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (1)
tests/model_optimizations/test_dsv3_fused_routing.py(1 hunks)
π§° Additional context used
𧬠Code graph analysis (1)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
flashinfer/fused_moe/fused_routing_dsv3.py (2)
NoAuxTc(90-111)NoAuxTc(119-194)csrc/fused_moe/noAuxTcKernels.cu (2)
NoAuxTc(306-448)NoAuxTc(306-308)
πͺ Ruff (0.14.5)
tests/model_optimizations/test_dsv3_fused_routing.py
333-333: Unused function argument: topk_values_kernel
(ARG001)
350-350: Unpacked variable reason is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
π Additional comments (2)
tests/model_optimizations/test_dsv3_fused_routing.py (2)
1-117: High-level test documentation is excellentThe top-level docstring is clear and accurately describes both the DSv3 routing algorithm and the two-stage validation strategy, including dtype-dependent thresholds and tie semantics. This will be very helpful for future maintainers when debugging failures.
488-501: Overall test structure and coverage look strongAside from the issues noted above, the test structure is solid:
- Sorting outputs before comparison gives deterministic expert ordering.
- Two-stage validation (selection, then values) with dtype-specific tolerances is well thought out.
- Parameterization over
num_tokens,num_experts,topk,n_group,topk_group,data_type, andbias_typeprovides broad coverage, especially for multi-group routing.Once the masking, tie-checking, device, and skip-condition issues are addressed, this should be a very robust regression test for the NoAuxTc path.
| def __init__( | ||
| self, scores, bias, n_group, topk_group, topk, routed_scaling_factor, data_type | ||
| ): | ||
| self.num_tokens = scores.shape[0] | ||
| self.num_experts = scores.shape[1] | ||
| self.n_group = n_group | ||
| self.topk_group = topk_group | ||
| self.topk = topk | ||
| self.routed_scaling_factor = routed_scaling_factor | ||
| self.experts_per_group = self.num_experts // n_group | ||
| self.device = scores.device | ||
|
|
||
| # Set thresholds based on data type | ||
| if data_type == torch.bfloat16: | ||
| self.expert_tie_threshold = 1.0 | ||
| self.group_tie_threshold = 0.05 | ||
| elif data_type == torch.float16: | ||
| self.expert_tie_threshold = 0.5 | ||
| self.group_tie_threshold = 0.02 | ||
| else: # float32 | ||
| self.expert_tie_threshold = 0.2 | ||
| self.group_tie_threshold = 0.01 | ||
|
|
||
| # Convert to float32 to match kernel's internal computation | ||
| scores_f32 = scores.to(torch.float32) | ||
| bias_f32 = bias.to(torch.float32) | ||
|
|
||
| # Compute sigmoid and biased scores | ||
| self.sigmoid_scores = torch.sigmoid(scores_f32) | ||
| self.biased_scores = self.sigmoid_scores + bias_f32 | ||
|
|
||
| # Reshape for group-wise operations | ||
| scores_reshaped = self.biased_scores.view( | ||
| self.num_tokens, n_group, self.experts_per_group | ||
| ) | ||
|
|
||
| # Compute group scores (sum of top-2 experts per group) | ||
| top2_per_group = torch.topk( | ||
| scores_reshaped, k=2, dim=-1, largest=True, sorted=True | ||
| )[0] | ||
| self.group_scores = torch.sum(top2_per_group, dim=-1) | ||
|
|
||
| # Reference group selection | ||
| _, self.ref_group_indices = torch.topk( | ||
| self.group_scores, k=topk_group, dim=-1, largest=True, sorted=True | ||
| ) | ||
|
|
||
| # Identify tied groups for each token | ||
| self.tied_group_sets = [] | ||
| for token_idx in range(self.num_tokens): | ||
| tied_groups = set() | ||
| group_scores_token = self.group_scores[token_idx] | ||
|
|
||
| for g1 in range(n_group): | ||
| for g2 in range(g1 + 1, n_group): | ||
| score_diff = abs(group_scores_token[g1] - group_scores_token[g2]) | ||
| if score_diff < self.group_tie_threshold: | ||
| tied_groups.add(g1) | ||
| tied_groups.add(g2) | ||
|
|
||
| self.tied_group_sets.append(tied_groups) | ||
|
|
||
| # Compute reference expert selection and normalization | ||
| self.ref_expert_indices = torch.zeros( | ||
| self.num_tokens, topk, dtype=torch.long, device=self.device | ||
| ) | ||
| self.ref_expert_values = torch.zeros( | ||
| self.num_tokens, topk, dtype=torch.float32, device=self.device | ||
| ) | ||
|
|
||
| for token_idx in range(self.num_tokens): | ||
| # Create mask for selected groups | ||
| group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device) | ||
| group_mask[self.ref_group_indices[token_idx]] = 1.0 | ||
| expert_mask = group_mask.repeat_interleave(self.experts_per_group) | ||
|
|
||
| # Mask and select top-k experts | ||
| masked_biased_scores = self.biased_scores[token_idx] * expert_mask | ||
| _, topk_idx = torch.topk( | ||
| masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True | ||
| ) | ||
|
|
||
| # Normalize selected experts | ||
| selected_sigmoid_scores = self.sigmoid_scores[token_idx][topk_idx] | ||
| score_sum = selected_sigmoid_scores.sum() + 1e-20 | ||
| normalized_scores = ( | ||
| selected_sigmoid_scores / score_sum * routed_scaling_factor | ||
| ) | ||
|
|
||
| # Sort by normalized scores | ||
| sorted_vals, sorted_idx = torch.sort(normalized_scores, descending=True) | ||
| self.ref_expert_values[token_idx] = sorted_vals | ||
| self.ref_expert_indices[token_idx] = topk_idx[sorted_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Masking non-selected groups with 0 can pick experts from disallowed groups
In both the reference implementation and _get_topk_experts_from_groups, experts from non-selected groups are masked by multiplying with 0:
- Lines 201β211:
masked_biased_scores = self.biased_scores[token_idx] * expert_mask - Lines 325β327: same pattern
Because biased_scores can be negative (sigmoid in (0,1) plus possibly negative bias), zeroing out non-selected groups can make those βmaskedβ experts larger than valid but negative scores in selected groups. That lets torch.topk choose experts from groups that are supposed to be excluded, which breaks the algorithm described in the docstring.
To enforce the βonly from selected groupsβ constraint regardless of sign, itβs safer to mask with -inf instead of 0, e.g.:
@@
- # Create mask for selected groups
- group_mask = torch.zeros(n_group, dtype=torch.float32, device=self.device)
- group_mask[self.ref_group_indices[token_idx]] = 1.0
- expert_mask = group_mask.repeat_interleave(self.experts_per_group)
-
- # Mask and select top-k experts
- masked_biased_scores = self.biased_scores[token_idx] * expert_mask
- _, topk_idx = torch.topk(
- masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True
- )
+ # Create mask for selected groups
+ group_mask = torch.zeros(n_group, dtype=torch.bool, device=self.device)
+ group_mask[self.ref_group_indices[token_idx]] = True
+ expert_mask = group_mask.repeat_interleave(self.experts_per_group)
+
+ # Mask and select top-k experts: force non-selected groups to -inf
+ masked_biased_scores = self.biased_scores[token_idx].clone()
+ masked_biased_scores[~expert_mask] = float("-inf")
+ _, topk_idx = torch.topk(
+ masked_biased_scores, k=topk, dim=-1, largest=True, sorted=True
+ )
@@
- # Create mask for specified groups
- group_mask = torch.zeros(self.n_group, dtype=torch.float32, device=self.device)
- for g in groups:
- group_mask[g] = 1.0
- expert_mask = group_mask.repeat_interleave(self.experts_per_group)
-
- # Mask and select top-k experts
- masked_biased_scores = self.biased_scores[token_idx] * expert_mask
- _, topk_idx = torch.topk(
- masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True
- )
-
- return set(topk_idx.tolist())
+ # Create mask for specified groups
+ group_mask = torch.zeros(self.n_group, dtype=torch.bool, device=self.device)
+ for g in groups:
+ group_mask[g] = True
+ expert_mask = group_mask.repeat_interleave(self.experts_per_group)
+
+ # Mask and select top-k experts, restricting strictly to these groups
+ masked_biased_scores = self.biased_scores[token_idx].clone()
+ masked_biased_scores[~expert_mask] = float("-inf")
+ _, topk_idx = torch.topk(
+ masked_biased_scores, k=self.topk, dim=-1, largest=True, sorted=True
+ )
+
+ return set(topk_idx.tolist())This keeps the ground truth aligned with the routing algorithm even when biased scores are negative.
Also applies to: 313-328
| # Identify tied groups for each token | ||
| self.tied_group_sets = [] | ||
| for token_idx in range(self.num_tokens): | ||
| tied_groups = set() | ||
| group_scores_token = self.group_scores[token_idx] | ||
|
|
||
| for g1 in range(n_group): | ||
| for g2 in range(g1 + 1, n_group): | ||
| score_diff = abs(group_scores_token[g1] - group_scores_token[g2]) | ||
| if score_diff < self.group_tie_threshold: | ||
| tied_groups.add(g1) | ||
| tied_groups.add(g2) | ||
|
|
||
| self.tied_group_sets.append(tied_groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Group tie detection can accept invalid group selections as βtiedβ
The current group tie logic precomputes a flat tied_group_sets[token_idx] by unioning all pairwise ties (lines 178β191), and then considers any symmetric difference of groups that lies in this union as βtiedβ (lines 237β243). This can incorrectly accept cases where:
- ref groups =
{A}, kernel groups ={C}, and Ais only tied withB, whileCis only tied withD.
All of {A, B, C, D} end up in tied_groups, so {A, C} is seen as βtied groupsβ even though A and C themselves may have very different scores. That can allow completely wrong group selections to pass as valid and propagate into the expert-level checks in is_valid_expert_selection.
You can tighten this by checking ties only among the groups actually involved (ref βͺ selected) using self.group_scores instead of the global precomputed set; e.g.:
- def is_valid_group_selection(self, token_idx, selected_groups):
- """Check if a set of selected groups is valid (exact match or tied)."""
- ref_groups = set(self.ref_group_indices[token_idx].tolist())
- selected_groups_set = set(selected_groups)
-
- if selected_groups_set == ref_groups:
- return True, "exact"
-
- if self.n_group > 1:
- diff_groups = selected_groups_set.symmetric_difference(ref_groups)
- tied_groups = self.tied_group_sets[token_idx]
-
- if diff_groups and diff_groups.issubset(tied_groups):
- return True, "tied_groups"
-
- return False, "different_groups"
+ def is_valid_group_selection(self, token_idx, selected_groups):
+ """Check if a set of selected groups is valid (exact match or tied)."""
+ ref_groups = set(self.ref_group_indices[token_idx].tolist())
+ selected_groups_set = set(selected_groups)
+
+ if selected_groups_set == ref_groups:
+ return True, "exact"
+
+ if self.n_group > 1:
+ # Only consider groups actually involved in this comparison
+ groups_union = sorted(selected_groups_set | ref_groups)
+ group_scores_token = self.group_scores[token_idx, groups_union]
+ score_range = group_scores_token.max() - group_scores_token.min()
+ if score_range < self.group_tie_threshold:
+ return True, "tied_groups"
+
+ return False, "different_groups"With this, is_valid_expert_selectionβs group branch (lines 259β293) remains the same conceptually but will no longer accept unrelated groups just because each participated in some tie with some other group.
Also applies to: 229-245, 259-293
π€ Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 178-191 and
the related logic at 229-245 and 259-293, the current approach unions all
pairwise ties into tied_group_sets and later treats any symmetric difference
inside that union as "tied", letting unrelated groups appear tied; instead,
remove/stop relying on the global tied_group_sets and perform tie checks only
among the actually involved groups (ref βͺ selected) for each token using
self.group_scores[token_idx] and self.group_tie_threshold: compute pairwise
abs(score_i - score_j) < threshold for groups in the involved set (on-the-fly)
and use that result to decide if the selection is tied/valid, so unrelated ties
elsewhere won't make two unrelated groups pass as tied.
| def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel): | ||
| """Validate kernel outputs and provide detailed debug info for failures.""" | ||
| num_tokens = topk_indices_kernel.shape[0] | ||
| tokens_with_different_experts = set() | ||
|
|
||
| for token_idx in range(num_tokens): | ||
| kernel_experts = topk_indices_kernel[token_idx].tolist() | ||
| ref_experts = ground_truth.ref_expert_indices[token_idx].tolist() | ||
|
|
||
| # Same experts - valid | ||
| if set(kernel_experts) == set(ref_experts): | ||
| continue | ||
|
|
||
| # Different experts - mark for value comparison skip | ||
| tokens_with_different_experts.add(token_idx) | ||
|
|
||
| # Validate the selection | ||
| is_valid, reason = ground_truth.is_valid_expert_selection( | ||
| token_idx, kernel_experts | ||
| ) | ||
|
|
||
| if not is_valid: | ||
| return False, tokens_with_different_experts | ||
|
|
||
| return True, tokens_with_different_experts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address Ruff warnings: unused parameter and unused reason
Static analysis flags two small issues here:
topk_values_kernelis never used.reasonreturned fromground_truth.is_valid_expert_selectionis unused.
If you donβt plan to use topk_values_kernel and reason for debug reporting, you can explicitly mark them as intentionally unused to satisfy Ruff:
-def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel):
+def validate_expert_selection(ground_truth, topk_indices_kernel, _topk_values_kernel):
@@
- # Validate the selection
- is_valid, reason = ground_truth.is_valid_expert_selection(
+ # Validate the selection
+ is_valid, _reason = ground_truth.is_valid_expert_selection(
token_idx, kernel_experts
)Alternatively, if you want richer failure messages, you could propagate reason back to the test and include it in the pytest.fail message; but the above is the minimal, low-noise fix.
π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def validate_expert_selection(ground_truth, topk_indices_kernel, topk_values_kernel): | |
| """Validate kernel outputs and provide detailed debug info for failures.""" | |
| num_tokens = topk_indices_kernel.shape[0] | |
| tokens_with_different_experts = set() | |
| for token_idx in range(num_tokens): | |
| kernel_experts = topk_indices_kernel[token_idx].tolist() | |
| ref_experts = ground_truth.ref_expert_indices[token_idx].tolist() | |
| # Same experts - valid | |
| if set(kernel_experts) == set(ref_experts): | |
| continue | |
| # Different experts - mark for value comparison skip | |
| tokens_with_different_experts.add(token_idx) | |
| # Validate the selection | |
| is_valid, reason = ground_truth.is_valid_expert_selection( | |
| token_idx, kernel_experts | |
| ) | |
| if not is_valid: | |
| return False, tokens_with_different_experts | |
| return True, tokens_with_different_experts | |
| def validate_expert_selection(ground_truth, topk_indices_kernel, _topk_values_kernel): | |
| """Validate kernel outputs and provide detailed debug info for failures.""" | |
| num_tokens = topk_indices_kernel.shape[0] | |
| tokens_with_different_experts = set() | |
| for token_idx in range(num_tokens): | |
| kernel_experts = topk_indices_kernel[token_idx].tolist() | |
| ref_experts = ground_truth.ref_expert_indices[token_idx].tolist() | |
| # Same experts - valid | |
| if set(kernel_experts) == set(ref_experts): | |
| continue | |
| # Different experts - mark for value comparison skip | |
| tokens_with_different_experts.add(token_idx) | |
| # Validate the selection | |
| is_valid, _reason = ground_truth.is_valid_expert_selection( | |
| token_idx, kernel_experts | |
| ) | |
| if not is_valid: | |
| return False, tokens_with_different_experts | |
| return True, tokens_with_different_experts |
π§° Tools
πͺ Ruff (0.14.5)
333-333: Unused function argument: topk_values_kernel
(ARG001)
350-350: Unpacked variable reason is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
π€ Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 333β357,
the function validate_expert_selection has an unused parameter
topk_values_kernel and the local variable reason is never used; to quiet Ruff,
mark the parameter and variable as intentionally unused by renaming the
parameter to _topk_values_kernel (or prefix with an underscore) and change the
call to ground_truth.is_valid_expert_selection to capture the second return as _
or _reason (e.g., is_valid, _ = ...), or alternatively explicitly del
topk_values_kernel and del reason after they are created β make one of these
minimal changes so the linter no longer flags unused names.
| def validate_values(ground_truth, topk_values_kernel, tokens_to_skip, data_type): | ||
| """Validate that output values match reference within tolerance.""" | ||
| # Set tolerance based on data type | ||
| if data_type == torch.bfloat16: | ||
| rtol, atol = 0.1, 0.1 | ||
| elif data_type == torch.float16: | ||
| rtol, atol = 0.05, 0.05 | ||
| else: # float32 | ||
| rtol, atol = 0.01, 0.01 | ||
|
|
||
| num_tokens = topk_values_kernel.shape[0] | ||
|
|
||
| # Create mask for tokens to check | ||
| tokens_to_check = torch.ones(num_tokens, dtype=torch.bool) | ||
| for token_idx in tokens_to_skip: | ||
| tokens_to_check[token_idx] = False | ||
|
|
||
| if not tokens_to_check.any(): | ||
| return | ||
|
|
||
| # Compare values | ||
| ref_values = ground_truth.ref_expert_values[tokens_to_check].float() | ||
| kernel_values = topk_values_kernel[tokens_to_check].float() | ||
|
|
||
| try: | ||
| torch.testing.assert_close( | ||
| ref_values, | ||
| kernel_values, | ||
| rtol=rtol, | ||
| atol=atol, | ||
| ) | ||
| except AssertionError: | ||
| # Find and report first mismatch | ||
| for token_idx in range(num_tokens): | ||
| if not tokens_to_check[token_idx]: | ||
| continue | ||
|
|
||
| ref_vals = ground_truth.ref_expert_values[token_idx].float() | ||
| kernel_vals = topk_values_kernel[token_idx].float() | ||
|
|
||
| if not torch.allclose(ref_vals, kernel_vals, rtol=rtol, atol=atol): | ||
| diff = (kernel_vals - ref_vals).abs() | ||
| max_diff = diff.max().item() | ||
| max_diff_idx = diff.argmax().item() | ||
|
|
||
| print(f"\n{'=' * 80}") | ||
| print(f"VALUE MISMATCH - Token {token_idx}") | ||
| print(f"{'=' * 80}") | ||
| print(f"Tolerance: rtol={rtol}, atol={atol}") | ||
| print(f"Max difference: {max_diff:.6f} at position {max_diff_idx}") | ||
| print(f"\nReference values: {ref_vals.tolist()}") | ||
| print(f"Kernel values: {kernel_vals.tolist()}") | ||
| print(f"Absolute diff: {diff.tolist()}") | ||
| print( | ||
| f"Expert indices: {ground_truth.ref_expert_indices[token_idx].tolist()}" | ||
| ) | ||
| break | ||
|
|
||
| raise | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
Device mismatch in value validation when using boolean mask with CUDA tensors
In validate_values, tokens_to_check is a CPU bool tensor, while ground_truth.ref_expert_values and topk_values_kernel are CUDA tensors. Indexing CUDA tensors with a CPU boolean mask (lines 381β382) will fail at runtime. A simple fix is to move the values to CPU before applying the mask and doing comparisons:
- # Compare values
- ref_values = ground_truth.ref_expert_values[tokens_to_check].float()
- kernel_values = topk_values_kernel[tokens_to_check].float()
+ # Compare values on CPU to avoid device mismatch with the boolean mask
+ ref_values = ground_truth.ref_expert_values.cpu()[tokens_to_check].float()
+ kernel_values = topk_values_kernel.cpu()[tokens_to_check].float()
@@
- ref_vals = ground_truth.ref_expert_values[token_idx].float()
- kernel_vals = topk_values_kernel[token_idx].float()
+ ref_vals = ground_truth.ref_expert_values[token_idx].float().cpu()
+ kernel_vals = topk_values_kernel[token_idx].float().cpu()This keeps the mask on CPU and avoids CUDA/CPU indexing incompatibilities while still validating the numeric accuracy.
π Script executed:
#!/bin/bash
# Quick grep to confirm no other places index CUDA tensors with CPU masks
rg -n "tokens_to_check" -C3Length of output: 2053
π Script executed:
#!/bin/bash
# Find where ground_truth is created and used in the test
rg -n "ground_truth\s*=" tests/model_optimizations/test_dsv3_fused_routing.py -A2Length of output: 100
π Script executed:
#!/bin/bash
# Find where validate_values is called to see what tensors are passed
rg -n "validate_values" tests/model_optimizations/test_dsv3_fused_routing.py -B3 -A3Length of output: 512
π Script executed:
#!/bin/bash
# Find DSv3RoutingGroundTruth class definition
rg -n "class DSv3RoutingGroundTruth" -A10Length of output: 1002
π Script executed:
#!/bin/bash
# Find where sorted_vals is created (lines around 501)
sed -n '450,510p' tests/model_optimizations/test_dsv3_fused_routing.pyLength of output: 1629
π Script executed:
#!/bin/bash
# Find ref_expert_values definition in DSv3RoutingGroundTruth class
sed -n '125,250p' tests/model_optimizations/test_dsv3_fused_routing.py | grep -n "ref_expert_values" -B2 -A2Length of output: 502
π Script executed:
#!/bin/bash
# Get the full init method of DSv3RoutingGroundTruth to see tensor allocation
sed -n '131,220p' tests/model_optimizations/test_dsv3_fused_routing.pyLength of output: 3622
Device mismatch in value validation when using boolean mask with CUDA tensors
In validate_values, tokens_to_check is a CPU bool tensor (line 373), while ground_truth.ref_expert_values and topk_values_kernel are CUDA tensors. Advanced indexing CUDA tensors with a CPU boolean mask (lines 381β382) will fail at runtime with a device mismatch error.
Fix by moving the tensors to CPU before applying the mask:
- # Compare values
- ref_values = ground_truth.ref_expert_values[tokens_to_check].float()
- kernel_values = topk_values_kernel[tokens_to_check].float()
+ # Compare values on CPU to avoid device mismatch with the boolean mask
+ ref_values = ground_truth.ref_expert_values.cpu()[tokens_to_check].float()
+ kernel_values = topk_values_kernel.cpu()[tokens_to_check].float()Also update lines 397β398 in the exception handler for consistency:
- ref_vals = ground_truth.ref_expert_values[token_idx].float()
- kernel_vals = topk_values_kernel[token_idx].float()
+ ref_vals = ground_truth.ref_expert_values[token_idx].float().cpu()
+ kernel_vals = topk_values_kernel[token_idx].float().cpu()π€ Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 360β419,
the boolean mask tokens_to_check is created on CPU but used to index CUDA
tensors, causing a device mismatch; create the mask on the same device as the
tensors (e.g., device = topk_values_kernel.device; tokens_to_check =
torch.ones(num_tokens, dtype=torch.bool, device=device)) so indexing works, and
ensure values used for comparison/printing are moved to CPU for readability in
the exception handler (e.g., call .cpu() before .tolist() or printing) for lines
~397β398.
| @pytest.mark.parametrize("num_tokens", [1, 8, 16, 64]) | ||
| @pytest.mark.parametrize("num_experts", [256, 384]) | ||
| @pytest.mark.parametrize("topk", [1, 2, 4, 8]) | ||
| @pytest.mark.parametrize("n_group", [1, 2, 4, 8]) | ||
| @pytest.mark.parametrize("topk_group", [1, 2, 4, 8]) | ||
| @pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("bias_type", [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_dsv3_fused_routing_op( | ||
| num_tokens, num_experts, topk, n_group, topk_group, data_type, bias_type | ||
| ): | ||
| """ | ||
| Test NoAuxTc kernel against reference implementation. | ||
| Validates: | ||
| 1. Expert selection equivalence (allowing for ties) | ||
| 2. Value correctness within numerical precision tolerance | ||
| """ | ||
|
|
||
| # Skip invalid configurations | ||
| if topk_group * n_group < topk or topk_group > n_group: | ||
| pytest.skip( | ||
| "Invalid configuration: topk_group * n_group < topk or topk_group > n_group" | ||
| ) | ||
| if n_group > 1: | ||
| if ( | ||
| topk > 8 | ||
| or num_experts / n_group > 32 | ||
| or num_experts / n_group * topk_group > 128 | ||
| ): | ||
| pytest.skip("Invalid configuration: exceeds kernel limits for n_group > 1") | ||
| else: | ||
| if num_experts > 384 or topk > 8: | ||
| pytest.skip("Invalid configuration: exceeds kernel limits for n_group = 1") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
Skip-condition logic may be stricter than necessary for valid configurations
The initial skip condition:
if topk_group * n_group < topk or topk_group > n_group:
pytest.skip("Invalid configuration: topk_group * n_group < topk or topk_group > n_group")uses topk_group * n_group to bound topk. Given the algorithm, the natural upper bound on the number of routable experts from the selected groups is experts_per_group * topk_group, not topk_group * n_group. As written, this skips configurations that seem valid from the algorithmβs perspective (e.g., n_group=1, topk_group=1, topk>1), thereby reducing coverage for βsingle-group, multi-expertβ routing.
Consider tightening this to match the actual capacity constraint instead:
- if topk_group * n_group < topk or topk_group > n_group:
- pytest.skip(
- "Invalid configuration: topk_group * n_group < topk or topk_group > n_group"
- )
+ experts_per_group = num_experts // n_group
+ if topk_group * experts_per_group < topk or topk_group > n_group:
+ pytest.skip(
+ "Invalid configuration: not enough experts in selected groups "
+ "or topk_group > n_group"
+ )This preserves the topk_group <= n_group constraint while not discarding otherwise valid and interesting configurations.
π Script executed:
#!/bin/bash
# Inspect other uses of topk_group / n_group to confirm intended constraints.
rg -n "topk_group" -C3Length of output: 93651
Correct and apply the skip-condition logic to match kernel constraints
The test's skip condition uses topk_group * n_group as an upper bound, but the kernel constraint requires topk < topk_group * (num_experts / n_group). The current logic rejects valid single-group multi-expert configurations (e.g., n_group=1, topk_group=1, topk>1) that the kernel actually supports.
Update line 440 to:
experts_per_group = num_experts // n_group
if topk_group * experts_per_group < topk or topk_group > n_group:
pytest.skip(
"Invalid configuration: not enough experts in selected groups "
"or topk_group > n_group"
)This preserves the topk_group <= n_group constraint while allowing all algorithmically valid configurations and improving test coverage.
π€ Prompt for AI Agents
In tests/model_optimizations/test_dsv3_fused_routing.py around lines 421 to 454,
the skip logic incorrectly uses topk_group * n_group to validate capacity and
rejects valid configs; compute experts_per_group = num_experts // n_group and
replace the current check with: if topk_group * experts_per_group < topk or
topk_group > n_group: pytest.skip("Invalid configuration: not enough experts in
selected groups or topk_group > n_group"); this preserves the topk_group <=
n_group rule and ensures the kernel constraint topk < topk_group * (num_experts
/ n_group) is enforced.
|
[FAILED] Pipeline #38749791: 13/18 passed |
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests