chore: cute dsl nvfp4 moe clean up#2775
Conversation
📝 WalkthroughWalkthroughThreads a per-instance Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant Wrapper as CuteDslMoEWrapper
participant Tuner as CuteDslFusedMoENvfp4Runner
participant GEMM1 as GEMM1 Kernel
participant GEMM2 as GEMM2 Finalize Kernel
User->>Wrapper: cute_dsl_fused_moe_nvfp4(..., enable_pdl=True)
Wrapper->>Tuner: init(..., enable_pdl=True)
Tuner->>Tuner: store self.enable_pdl
Wrapper->>Tuner: forward(inputs, enable_pdl=self.enable_pdl)
Tuner->>Tuner: validate/get_valid_tactics()
Tuner->>GEMM1: launch(gemm1_args..., enable_pdl=self.enable_pdl)
GEMM1-->>Tuner: intermediate output
Tuner->>GEMM2: launch(finalize_args..., enable_pdl=self.enable_pdl)
GEMM2-->>Tuner: final output
Tuner-->>Wrapper: result
Wrapper-->>User: MoE output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.OpenGrep is compatible with Semgrep configurations. Add an |
|
/bot run |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on refining and improving the CuteDSL NVFP4 Mixture-of-Experts (MoE) implementation, primarily targeting the Blackwell architecture. The changes include correcting supported compute capabilities, integrating the MoE kernels into the benchmarking system, and enhancing the autotuning mechanism with more robust tactic validation and explicit parameterization for features like Programmatic Dependent Launch (PDL). These updates aim to improve the correctness, testability, and configurability of the CuteDSL MoE kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request provides a comprehensive cleanup and enhancement for the CuteDSL NVFP4 MoE implementation. The changes are well-structured and improve the codebase in several key areas. Notably, the refactoring to remove environment variable dependencies in favor of function parameters makes the kernels more modular and easier to use. The autotuning strategy has been significantly improved by dynamically filtering valid tactics based on problem size, which increases robustness. The removal of unused code and the addition of thorough unit tests for all valid tactics are excellent for maintainability and correctness. The new benchmarks and compute capability decorators are also valuable additions. Overall, this is a high-quality contribution that significantly refines the feature.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
flashinfer/fused_moe/cute_dsl/tuner.py (2)
299-308:⚠️ Potential issue | 🟠 MajorValidate GEMM2 with the configured output dtype, not hard-coded BF16.
CuteDslFusedMoENvfp4Runnerexposesoutput_dtype, butget_valid_tactics()always asks the finalize kernel aboutcutlass.BFloat16. On a non-BF16 runner this can admit or reject tactics against the wrong alignment rules. Either mapself.output_dtypehere or fail fast when only BF16 is supported.Also applies to: 361-367, 397-403
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 299 - 308, get_valid_tactics() currently validates GEMM2 using a hard-coded cutlass.BFloat16 type which mis-validates tactics when CuteDslFusedMoENvfp4Runner was constructed with a different output_dtype; update get_valid_tactics() to map/convert self.output_dtype to the correct cutlass dtype (or raise a clear error if only BF16 is supported) before querying the finalize kernel and GEMM2 validation so the alignment/packing rules match the runner’s configured output_dtype (references: CuteDslFusedMoENvfp4Runner, get_valid_tactics, self.output_dtype, GEMM2, finalize kernel).
299-308:⚠️ Potential issue | 🟡 MinorInclude
enable_pdlin the runner hash.
enable_pdlchanges the execution path and the downstream JIT cache key, but__hash__()still ignores it. That lets autotune results learned with PDL on leak into a runner created with PDL off, and vice versa.Suggested fix
def __hash__(self): return hash( ( self.num_experts, self.top_k, self.num_local_experts, self.local_expert_offset, self.use_fused_finalize, self.output_dtype, + self.enable_pdl, ) )Also applies to: 310-320
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 299 - 308, The runner's __hash__ currently omits enable_pdl so autotune/JIT cache keys mix PDL and non-PDL runs; update the class's __hash__ implementation to include self.enable_pdl alongside the other fields used (forward_impl, num_experts, top_k, num_local_experts, local_expert_offset, use_fused_finalize, output_dtype) so the hash reflects the PDL flag; do the same change for the other similar class/section referenced (the block around where enable_pdl is set at lines ~310-320) so both runner hashes incorporate enable_pdl.flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
162-191:⚠️ Potential issue | 🟠 MajorThe finalize-kernel cache key is still too weak for
cute.compile().This helper now separates PDL/non-PDL kernels, but it still ignores the typed-pointer parts of the compiled signature. Reusing a kernel first compiled for one combination of
ab_dtype/sf_dtype/out_dtypeor onetoken_final_scales.dtypewith a different combination will dispatch the wrong wrapper specialization.Also applies to: 203-211
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py` around lines 162 - 191, The cache key for _get_compiled_finalize_kernel must include the element dtypes of the typed pointers and token scales so compiled wrappers aren’t reused across incompatible pointer/scale types; update the key construction to incorporate the element types for a_ptr/b_ptr/c_ptr/a_sf_ptr/b_sf_ptr/alpha_ptr and the dtype of token_scales_ptr (and keep the existing enable_pdl flag), and apply the same change to the companion key-building site that creates the finalize-kernel key for the other code path so both PDL and non-PDL kernels are keyed by pointer element types and token_scales dtype.tests/moe/test_cute_dsl_fused_moe.py (1)
38-56:⚠️ Potential issue | 🟠 MajorUse
flashinfer.utilsarchitecture helpers for skip gating.The custom
is_sm100_family()check should useflashinfer.utils.get_compute_capability()instead of directtorch.cudacalls. Additionally, the skip condition mentions SM110 in the reason but only checks formajor == 10; SM110 ismajor == 11. Either correct the reason to exclude SM110, or extend the check to include both SM100 (major 10) and SM110 (major 11) families:from flashinfer.utils import get_compute_capability def is_sm100_family(): """Check for SM100/SM110 family (Blackwell).""" if not torch.cuda.is_available(): return False major, _ = get_compute_capability(torch.device("cuda")) return major in [10, 11] # SM100/SM103 (10) or SM110 (11)Or use the composed helpers:
is_sm100a_supported() or is_sm110a_supported().🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the custom torch.cuda-based is_sm100_family() with a compute-capability check using flashinfer.utils (e.g., call get_compute_capability(torch.device("cuda")) to obtain major) or use the composed helpers is_sm100a_supported() or is_sm110a_supported(); update the function docstring to say SM100/SM110 (Blackwell) and make the return condition accept major in [10, 11], and update the sm100_required pytest reason string to match (or restrict it to only SM100 if you intentionally keep major == 10).
🧹 Nitpick comments (2)
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (1)
395-443: Document the newenable_pdlknob.
enable_pdlis now part of the public constructor, but the constructor docstring still omits it. A short:param enable_pdl:entry would make the launch behavior easier to discover for callers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py` around lines 395 - 443, The constructor __init__ now accepts the enable_pdl parameter but the docstring is missing documentation for it; update the __init__ docstring to add a short ":param enable_pdl: bool" description (and optionally a brief note in the top descriptive section) explaining what enabling/disabling PDL does and its default (True) so callers can discover this knob; reference the enable_pdl attribute set on self and keep wording consistent with existing param entries like vectorized_f32 and topk.flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
357-387: Documentenable_pdlin the constructor docs.This is now a public configuration parameter, but the constructor docstring still stops before describing it. Adding a
:param enable_pdl:entry would make the API surface clearer.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py` around lines 357 - 387, The constructor docstring for __init__ is missing documentation for the new public parameter enable_pdl; update the docstring to add a ":param enable_pdl: Boolean, True to enable PDL (page/packet/precision-dependent logic) in the kernel." and a corresponding ":type enable_pdl: bool" line, placing it alongside the other param entries (near mma_tiler_mn, cluster_shape_mn, use_blkred, raster_along_m) so callers of the class and generated docs clearly see the new configuration option referenced by the enable_pdl attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 71-74: The runtime guard message in main() is out of sync with the
updated docstring: change the error/exit message emitted by main() (the SM
check) to reference SM100/SM103 (Blackwell) instead of SM110, so the runtime
error text matches the docstring and tells users that CuteDSL MoE NVFP4 kernels
are optimized for SM100/SM103 only; update the string used in main()
accordingly.
In `@benchmarks/routines/moe.py`:
- Around line 1176-1179: The test data creator _create_cute_dsl_moe_test_data
currently always stages activations/weights as bfloat16 (variables like x_bf16
and w_bf16) then quantizes, causing reported args.input_dtype/args.weight_dtype
to be incorrect; change the tensor constructions to use args.input_dtype for
input activations and args.weight_dtype for weights (and any similarly staged
tensors) so the initial dtype matches the CLI metadata and the subsequent
quantization path remains unchanged; update all occurrences where tensors are
hardcoded to torch.bfloat16 before quantization so metadata/bandwidth reporting
reflects the actual requested dtypes.
In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 418-429: The current code appends DEFAULT_MOE_TACTIC when
valid_tactics is empty even though every candidate was rejected by
can_implement(), which hides the real incompatibility; instead, remove the
fallback to DEFAULT_MOE_TACTIC and surface a clear error: when valid_tactics is
empty after running can_implement() checks (the block that currently references
valid_tactics and DEFAULT_MOE_TACTIC), log a detailed warning including the
problem dims (num_tokens, hidden_size, intermediate_size, num_local_experts,
self.top_k) and raise an exception (e.g., ValueError or RuntimeError) explaining
that no tactics can implement the problem and include suggestions or the
rejected candidate list if available so callers get an actionable failure
instead of deferring to kernel launch.
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 995-1010: Add a public wrapper method on CuteDslMoEWrapper named
get_valid_tactics_for_inputs(self, inputs: List[torch.Tensor]) that delegates to
the internal runner by calling self._runner.get_valid_tactics(inputs, None) and
returns the filtered tactics; update tests to call
moe.get_valid_tactics_for_inputs(inputs) instead of touching the private
_runner. Ensure the method signature accepts a list of torch.Tensor and simply
forwards the call/result without changing semantics.
---
Outside diff comments:
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 162-191: The cache key for _get_compiled_finalize_kernel must
include the element dtypes of the typed pointers and token scales so compiled
wrappers aren’t reused across incompatible pointer/scale types; update the key
construction to incorporate the element types for
a_ptr/b_ptr/c_ptr/a_sf_ptr/b_sf_ptr/alpha_ptr and the dtype of token_scales_ptr
(and keep the existing enable_pdl flag), and apply the same change to the
companion key-building site that creates the finalize-kernel key for the other
code path so both PDL and non-PDL kernels are keyed by pointer element types and
token_scales dtype.
In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 299-308: get_valid_tactics() currently validates GEMM2 using a
hard-coded cutlass.BFloat16 type which mis-validates tactics when
CuteDslFusedMoENvfp4Runner was constructed with a different output_dtype; update
get_valid_tactics() to map/convert self.output_dtype to the correct cutlass
dtype (or raise a clear error if only BF16 is supported) before querying the
finalize kernel and GEMM2 validation so the alignment/packing rules match the
runner’s configured output_dtype (references: CuteDslFusedMoENvfp4Runner,
get_valid_tactics, self.output_dtype, GEMM2, finalize kernel).
- Around line 299-308: The runner's __hash__ currently omits enable_pdl so
autotune/JIT cache keys mix PDL and non-PDL runs; update the class's __hash__
implementation to include self.enable_pdl alongside the other fields used
(forward_impl, num_experts, top_k, num_local_experts, local_expert_offset,
use_fused_finalize, output_dtype) so the hash reflects the PDL flag; do the same
change for the other similar class/section referenced (the block around where
enable_pdl is set at lines ~310-320) so both runner hashes incorporate
enable_pdl.
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom torch.cuda-based is_sm100_family() with
a compute-capability check using flashinfer.utils (e.g., call
get_compute_capability(torch.device("cuda")) to obtain major) or use the
composed helpers is_sm100a_supported() or is_sm110a_supported(); update the
function docstring to say SM100/SM110 (Blackwell) and make the return condition
accept major in [10, 11], and update the sm100_required pytest reason string to
match (or restrict it to only SM100 if you intentionally keep major == 10).
---
Nitpick comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 395-443: The constructor __init__ now accepts the enable_pdl
parameter but the docstring is missing documentation for it; update the __init__
docstring to add a short ":param enable_pdl: bool" description (and optionally a
brief note in the top descriptive section) explaining what enabling/disabling
PDL does and its default (True) so callers can discover this knob; reference the
enable_pdl attribute set on self and keep wording consistent with existing param
entries like vectorized_f32 and topk.
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 357-387: The constructor docstring for __init__ is missing
documentation for the new public parameter enable_pdl; update the docstring to
add a ":param enable_pdl: Boolean, True to enable PDL
(page/packet/precision-dependent logic) in the kernel." and a corresponding
":type enable_pdl: bool" line, placing it alongside the other param entries
(near mma_tiler_mn, cluster_shape_mn, use_blkred, raster_along_m) so callers of
the class and generated docs clearly see the new configuration option referenced
by the enable_pdl attribute.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4bb78545-d18c-43a4-81f0-b9474ed91c80
📒 Files selected for processing (18)
benchmarks/bench_moe_deepseek.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pybenchmarks/samples/sample_testlist.txtflashinfer/fused_moe/cute_dsl/blackwell/__init__.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/utils.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/fused_moe.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/jit/moe_utils.pytests/moe/test_cute_dsl_fused_moe.py
💤 Files with no reviewable changes (5)
- flashinfer/fused_moe/cute_dsl/blackwell/init.py
- flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm.py
- flashinfer/fused_moe/cute_dsl/blackwell/utils.py
- flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
- flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py
| # Input activations | ||
| x_bf16 = ( | ||
| torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) / 10 | ||
| ) |
There was a problem hiding this comment.
This benchmark can report the wrong input/weight dtypes.
The new CuteDSL route records args.input_dtype and args.weight_dtype, but _create_cute_dsl_moe_test_data() always stages BF16 activations and weights before quantization. Any non-default CLI dtype will benchmark BF16 and then publish misleading metadata/bandwidth numbers.
Also applies to: 1192-1199, 1223-1230, 1292-1293, 1317-1326
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/moe.py` around lines 1176 - 1179, The test data creator
_create_cute_dsl_moe_test_data currently always stages activations/weights as
bfloat16 (variables like x_bf16 and w_bf16) then quantizes, causing reported
args.input_dtype/args.weight_dtype to be incorrect; change the tensor
constructions to use args.input_dtype for input activations and
args.weight_dtype for weights (and any similarly staged tensors) so the initial
dtype matches the CLI metadata and the subsequent quantization path remains
unchanged; update all occurrences where tensors are hardcoded to torch.bfloat16
before quantization so metadata/bandwidth reporting reflects the actual
requested dtypes.
| if not valid_tactics: | ||
| logger.warning( | ||
| "No valid tactics found for problem dims " | ||
| "(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). " | ||
| "Falling back to default tactic.", | ||
| num_tokens, | ||
| hidden_size, | ||
| intermediate_size, | ||
| num_local_experts, | ||
| self.top_k, | ||
| ) | ||
| valid_tactics = [DEFAULT_MOE_TACTIC] |
There was a problem hiding this comment.
Don't fall back to a tactic that can_implement() already rejected.
If valid_tactics is empty, every candidate failed the actual shape/alignment checks. Returning DEFAULT_MOE_TACTIC here just defers the failure to kernel launch with a much less actionable error.
Suggested fix
- if not valid_tactics:
- logger.warning(
- "No valid tactics found for problem dims "
- "(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). "
- "Falling back to default tactic.",
- num_tokens,
- hidden_size,
- intermediate_size,
- num_local_experts,
- self.top_k,
- )
- valid_tactics = [DEFAULT_MOE_TACTIC]
+ if not valid_tactics:
+ raise ValueError(
+ "No valid CuteDSL MoE tactics for the current problem dimensions."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not valid_tactics: | |
| logger.warning( | |
| "No valid tactics found for problem dims " | |
| "(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). " | |
| "Falling back to default tactic.", | |
| num_tokens, | |
| hidden_size, | |
| intermediate_size, | |
| num_local_experts, | |
| self.top_k, | |
| ) | |
| valid_tactics = [DEFAULT_MOE_TACTIC] | |
| if not valid_tactics: | |
| raise ValueError( | |
| "No valid CuteDSL MoE tactics for the current problem dimensions." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 418 - 429, The current
code appends DEFAULT_MOE_TACTIC when valid_tactics is empty even though every
candidate was rejected by can_implement(), which hides the real incompatibility;
instead, remove the fallback to DEFAULT_MOE_TACTIC and surface a clear error:
when valid_tactics is empty after running can_implement() checks (the block that
currently references valid_tactics and DEFAULT_MOE_TACTIC), log a detailed
warning including the problem dims (num_tokens, hidden_size, intermediate_size,
num_local_experts, self.top_k) and raise an exception (e.g., ValueError or
RuntimeError) explaining that no tactics can implement the problem and include
suggestions or the rejected candidate list if available so callers get an
actionable failure instead of deferring to kernel launch.
| # Get the filtered list of valid tactics for this problem size | ||
| inputs = [ | ||
| tensors["x"], | ||
| tensors["x_sf"], | ||
| tensors["token_selected_experts"], | ||
| tensors["token_final_scales"], | ||
| tensors["w1_weight"], | ||
| tensors["w1_weight_sf"], | ||
| tensors["w1_alpha"], | ||
| tensors["fc2_input_scale"], | ||
| tensors["w2_weight"], | ||
| tensors["w2_weight_sf"], | ||
| tensors["w2_alpha"], | ||
| ] | ||
| valid_tactics = moe._runner.get_valid_tactics(inputs, None) | ||
| assert len(valid_tactics) > 0, "No valid tactics found" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether a public tactic-query API exists and where private access is used.
rg -nP --type=py 'class\s+CuteDslMoEWrapper\b|def\s+get_valid_tactics\s*\('
rg -nP --type=py '\._runner\.get_valid_tactics\s*\('Repository: flashinfer-ai/flashinfer
Length of output: 1718
🏁 Script executed:
#!/bin/bash
# Examine the CuteDslMoEWrapper class and get_valid_tactics method
sed -n '301,650p' flashinfer/fused_moe/cute_dsl/fused_moe.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 15680
🏁 Script executed:
#!/bin/bash
# Also check what _runner is and its type
rg -A 5 'self\._runner\s*=' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 339
🏁 Script executed:
#!/bin/bash
# Check the test file to see how moe object is instantiated
sed -n '990,1020p' tests/moe/test_cute_dsl_fused_moe.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1398
🏁 Script executed:
#!/bin/bash
# Find CuteDslFusedMoENvfp4Runner and its get_valid_tactics method
rg -B 3 -A 10 'class CuteDslFusedMoENvfp4Runner' flashinfer/fused_moe/cute_dsl/tuner.pyRepository: flashinfer-ai/flashinfer
Length of output: 507
🏁 Script executed:
#!/bin/bash
# Check the get_valid_tactics implementation in tuner.py
sed -n '320,350p' flashinfer/fused_moe/cute_dsl/tuner.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1384
🏁 Script executed:
#!/bin/bash
# Check if there are other accesses to _runner.get_valid_tactics in the codebase
rg -n '\._runner\.get_valid_tactics' --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 177
🏁 Script executed:
#!/bin/bash
# Check what ALL_MOE_TACTICS contains
rg -B 2 -A 5 'ALL_MOE_TACTICS\s*=' flashinfer/fused_moe/cute_dsl/Repository: flashinfer-ai/flashinfer
Length of output: 656
🏁 Script executed:
#!/bin/bash
# Check the full context around line 612 to see if get_valid_tactics should be updated
sed -n '300,320p' flashinfer/fused_moe/cute_dsl/fused_moe.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1053
Add a public method to CuteDslMoEWrapper for getting filtered tactics instead of accessing private _runner.
The test accesses moe._runner.get_valid_tactics(inputs, None) to retrieve tactics filtered by the actual problem dimensions. This couples the test to private internals. Create a public method like get_valid_tactics_for_inputs(self, inputs: List[torch.Tensor]) in CuteDslMoEWrapper that delegates to self._runner.get_valid_tactics(), allowing tests to retrieve properly filtered tactics through the public API.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 995 - 1010, Add a public
wrapper method on CuteDslMoEWrapper named get_valid_tactics_for_inputs(self,
inputs: List[torch.Tensor]) that delegates to the internal runner by calling
self._runner.get_valid_tactics(inputs, None) and returns the filtered tactics;
update tests to call moe.get_valid_tactics_for_inputs(inputs) instead of
touching the private _runner. Ensure the method signature accepts a list of
torch.Tensor and simply forwards the call/result without changing semantics.
|
/bot run |
|
[FAILED] Pipeline #46013906: 7/20 passed |
|
I cancelled the pr test because the ci won't pass before #2781 lands, and please re-trigger the test after that pr get merged |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
benchmarks/routines/moe.py (1)
1155-1155: Consider adding an error message to the assertion.A bare
assertwithout a message makes debugging harder when the condition fails.💡 Suggested improvement
- assert sizes[dim] % (group_size * 2) == 0 + assert sizes[dim] % (group_size * 2) == 0, ( + f"Dimension {dim} size ({sizes[dim]}) must be divisible by {group_size * 2}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` at line 1155, The assertion "assert sizes[dim] % (group_size * 2) == 0" is bare and should include a helpful message; update it to either an assert with a formatted message (e.g. assert sizes[dim] % (group_size * 2) == 0, f\"sizes[{dim}] ({sizes[dim]}) must be divisible by 2*group_size ({group_size*2})\") or raise a ValueError with the same formatted message to make failures clear; modify the assertion in benchmarks/routines/moe.py where sizes, dim, and group_size are used so the runtime error explains the violated condition.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1402-1415: This routine's autotune block ignores a provided
--autotune_cache; update the autotune invocation to accept and pass the cache
path from args (check getattr(args, "autotune_cache", None) and only pass it
when present) so the autotune context manager is called with the cache option;
specifically, modify the with autotune(True): line to something like with
autotune(True, cache_path=args.autotune_cache) (or the correct parameter name
used by the autotune context manager), leaving run_cute_dsl_moe and
autotune_args unchanged.
---
Nitpick comments:
In `@benchmarks/routines/moe.py`:
- Line 1155: The assertion "assert sizes[dim] % (group_size * 2) == 0" is bare
and should include a helpful message; update it to either an assert with a
formatted message (e.g. assert sizes[dim] % (group_size * 2) == 0,
f\"sizes[{dim}] ({sizes[dim]}) must be divisible by 2*group_size
({group_size*2})\") or raise a ValueError with the same formatted message to
make failures clear; modify the assertion in benchmarks/routines/moe.py where
sizes, dim, and group_size are used so the runtime error explains the violated
condition.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ee3c9014-4684-4e17-a8aa-beb50311e03e
📒 Files selected for processing (1)
benchmarks/routines/moe.py
| if getattr(args, "autotune", False): | ||
| warmup_iters = ( | ||
| args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 | ||
| ) | ||
| backend = "cute-dsl_autotune" | ||
| if args.verbose >= 1: | ||
| print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters") | ||
| autotune_args = tuple( | ||
| t.clone() if isinstance(t, torch.Tensor) else t for t in input_args | ||
| ) | ||
| with autotune(True): | ||
| for _ in range(warmup_iters): | ||
| run_cute_dsl_moe(*autotune_args) | ||
| del autotune_args |
There was a problem hiding this comment.
Autotune cache path not supported.
Unlike other test functions (testTrtllmFp4BlockScaleMoe at line 671, testCutlassFusedMoe at line 1068), this function doesn't support the autotune_cache argument. Users who pass --autotune_cache will have their cache ignored for this routine.
🔧 Proposed fix to add cache support
+ cache_path = getattr(args, "autotune_cache", None)
+
# Optional autotune warmup.
# Clone input_args so autotune tactic exploration doesn't corrupt the
# original tensors used by the subsequent benchmark.
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
backend = "cute-dsl_autotune"
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters")
autotune_args = tuple(
t.clone() if isinstance(t, torch.Tensor) else t for t in input_args
)
- with autotune(True):
+ with autotune(True, cache=cache_path):
for _ in range(warmup_iters):
run_cute_dsl_moe(*autotune_args)
del autotune_args
+ elif cache_path:
+ with autotune(False, cache=cache_path):
+ pass🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/moe.py` around lines 1402 - 1415, This routine's autotune
block ignores a provided --autotune_cache; update the autotune invocation to
accept and pass the cache path from args (check getattr(args, "autotune_cache",
None) and only pass it when present) so the autotune context manager is called
with the cache option; specifically, modify the with autotune(True): line to
something like with autotune(True, cache_path=args.autotune_cache) (or the
correct parameter name used by the autotune context manager), leaving
run_cute_dsl_moe and autotune_args unchanged.
bkryu
left a comment
There was a problem hiding this comment.
Approved. What LGTM:
- Benchmark changes look correct
- Removal of unused less-fused kernel variants -- vanilla grouped gemm and grouped gemm + swiglu -- make sense to me
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 342-343: Replace the `@supported_compute_capability` decorator on
CuteDslMoEWrapper.__init__ with `@backend_requirement` so the public API exposes
both is_compute_capability_supported(cc) and is_backend_supported(); to do this,
create a separate small check function (e.g., cute_dsl_moe_compute_check)
decorated with `@supported_compute_capability`([100, 103]) that returns the
compute-capability check, then apply `@backend_requirement`({},
common_check=cute_dsl_moe_compute_check) and `@flashinfer_api` to
CuteDslMoEWrapper.__init__ (mirroring the nvfp4_kv_dequantize pattern) so
backend discovery and compute-capability checks are both available.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a24feca8-7413-4278-98c4-4ade10a80466
📒 Files selected for processing (6)
benchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/fused_moe.pyflashinfer/jit/moe_utils.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/jit/moe_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
- flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
| @supported_compute_capability([100, 103]) | ||
| @flashinfer_api |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Decorator definitions ==="
rg -n -C5 'def supported_compute_capability\b|def backend_requirement\b' flashinfer
echo
echo "=== Backend metadata helpers exposed in flashinfer/ ==="
rg -n -C2 'is_compute_capability_supported|is_backend_supported' flashinfer -g '!**/tests/**'
echo
echo "=== Class-based API decoration patterns ==="
rg -n -C3 'class .*Wrapper|@backend_requirement|@supported_compute_capability' flashinfer -g '!**/tests/**'Repository: flashinfer-ai/flashinfer
Length of output: 40762
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== CuteDslMoEWrapper class definition and methods ==="
sed -n '302,700p' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -150
echo
echo "=== Check if CuteDslMoEWrapper or __init__ expose helper methods ==="
rg -n -A30 'class CuteDslMoEWrapper' flashinfer/fused_moe/cute_dsl/fused_moe.py | grep -E 'is_compute_capability_supported|is_backend_supported|@property|def ' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 6159
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if CuteDslMoEWrapper has is_compute_capability_supported or is_backend_supported as instance/class methods ==="
grep -n 'is_compute_capability_supported\|is_backend_supported' flashinfer/fused_moe/cute_dsl/fused_moe.py
echo
echo "=== Check how other Wrapper classes are decorated (if at all) ==="
rg -n 'class.*Wrapper' flashinfer -g '!**/tests/**' | head -10
echo
echo "=== Example: Check BatchDecodeWithPagedKVCacheWrapper for comparison ==="
sed -n '593,650p' flashinfer/decode.py
echo
echo "=== Check if functions (not classes) are the primary target for compute capability decorators ==="
rg -B2 '@backend_requirement' flashinfer -g '!**/tests/**' | grep -E '(def |@backend_requirement)' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 188
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Example: Check BatchDecodeWithPagedKVCacheWrapper for comparison ==="
sed -n '593,650p' flashinfer/decode.py
echo
echo "=== Check if functions (not classes) are the primary target for compute capability decorators ==="
rg -B2 '@backend_requirement' flashinfer -g '!**/tests/**' | grep -E '(def |@backend_requirement)' | head -30
echo
echo "=== Check CuteDslMoEWrapper.run method ==="
rg -n -A10 'def run' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 4388
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if `@flashinfer_api` handles the compute capability check for methods ==="
rg -n -B5 -A15 'def flashinfer_api' flashinfer/utils.py | head -50
echo
echo "=== Verify if CuteDslMoEWrapper.__init__ actually has the helper methods attached ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
# Try to import and check if __init__ has the helper
try:
from flashinfer.fused_moe.cute_dsl.fused_moe import CuteDslMoEWrapper
init_method = CuteDslMoEWrapper.__init__
print("Checking __init__ attributes:")
print(f" has _supported_ccs: {hasattr(init_method, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(init_method, 'is_compute_capability_supported')}")
print(f" has is_backend_supported: {hasattr(init_method, 'is_backend_supported')}")
print("\nChecking CuteDslMoEWrapper class attributes:")
print(f" has _supported_ccs: {hasattr(CuteDslMoEWrapper, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(CuteDslMoEWrapper, 'is_compute_capability_supported')}")
print(f" has is_backend_supported: {hasattr(CuteDslMoEWrapper, 'is_backend_supported')}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
EOFRepository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Search for flashinfer_api decorator definition ==="
rg -n 'flashinfer_api' flashinfer/utils.py | head -5
echo
echo "=== Get decorator definition ==="
rg -n -A20 'def flashinfer_api' flashinfer
echo
echo "=== Verify class instance does NOT have helper methods at runtime ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
try:
# Check if __init__ method itself has the helpers
from flashinfer.utils import supported_compute_capability
import inspect
`@supported_compute_capability`([100, 103])
def example_func():
pass
print("Function decorated with `@supported_compute_capability`:")
print(f" has _supported_ccs: {hasattr(example_func, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(example_func, 'is_compute_capability_supported')}")
# Now check a class method
class ExampleClass:
`@supported_compute_capability`([100, 103])
def __init__(self):
pass
print("\nClass method decorated with `@supported_compute_capability`:")
print(f" __init__ has _supported_ccs: {hasattr(ExampleClass.__init__, '_supported_ccs')}")
print(f" __init__ has is_compute_capability_supported: {hasattr(ExampleClass.__init__, 'is_compute_capability_supported')}")
instance = ExampleClass()
print(f" instance has is_compute_capability_supported: {hasattr(instance, 'is_compute_capability_supported')}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
EOFRepository: flashinfer-ai/flashinfer
Length of output: 122
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Find flashinfer_api definition ==="
fd flashinfer_api flashinfer --type f
echo
echo "=== Search more broadly for flashinfer_api ==="
rg 'def flashinfer_api' flashinfer
echo
echo "=== Check what `@flashinfer_api` is in the imports/exports ==="
rg -n 'flashinfer_api' flashinfer/__init__.py
echo
echo "=== Verify actual decorator stacking on CuteDslMoEWrapper.__init__ ==="
python3 -c "
import sys
sys.path.insert(0, '/repo')
from flashinfer.fused_moe.cute_dsl.fused_moe import CuteDslMoEWrapper
# Check what the __init__ actually is
init_fn = CuteDslMoEWrapper.__init__
print('Type of __init__:', type(init_fn))
print('__init__ has _supported_ccs:', hasattr(init_fn, '_supported_ccs'))
print('__init__ has is_compute_capability_supported:', hasattr(init_fn, 'is_compute_capability_supported'))
print('__init__ attributes:', [a for a in dir(init_fn) if not a.startswith('_')])
" 2>&1 || echo "(import failed)"
echo
echo "=== Compare with a decorated standalone function ==="
rg -A5 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.py
echo
echo "=== Check if standalone functions use backend_requirement or supported_compute_capability directly ==="
rg -B3 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 374
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check flashinfer_api definition in api_logging.py ==="
rg -n -A20 'def flashinfer_api' flashinfer/api_logging.py
echo
echo "=== Check how `@backend_requirement` and `@flashinfer_api` work together ==="
rg -B5 -A10 '@backend_requirement.*\n.*@flashinfer_api' flashinfer/quantization/fp4_quantization.py | head -30
echo
echo "=== Review the decorated function pattern to confirm backend_requirement is the standard ==="
sed -n '1269,1280p' flashinfer/quantization/fp4_quantization.py
echo
echo "=== Now check CuteDslMoEWrapper pattern more carefully ==="
sed -n '342,365p' flashinfer/fused_moe/cute_dsl/fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 1621
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check decorator pattern for standalone function ==="
sed -n '1269,1276p' flashinfer/quantization/fp4_quantization.py
echo
echo "=== Check CuteDslMoEWrapper.__init__ decoration ==="
sed -n '342,344p' flashinfer/fused_moe/cute_dsl/fused_moe.py
echo
echo "=== Verify the key difference: backend_requirement on function vs supported_compute_capability on method ==="
rg -B2 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -10
echo
echo "=== Check if there are ANY Wrapper classes with backend_requirement on __init__ ==="
rg -B3 'class.*Wrapper' flashinfer -g '!**/tests/**' -A50 | rg '@backend_requirement' | head -5
echo
echo "=== Confirm: standalone functions use `@backend_requirement`, not class methods ==="
rg '@backend_requirement' flashinfer -B3 -g '!**/tests/**' | grep -E '(def |@backend)' | head -15Repository: flashinfer-ai/flashinfer
Length of output: 824
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check flashinfer_api full implementation ==="
sed -n '1420,1500p' flashinfer/api_logging.py
echo
echo "=== Does flashinfer_api handle backend requirement checking ==="
rg -A50 'def flashinfer_api' flashinfer/api_logging.py | grep -E '(is_backend|is_compute|backend_requirement)' | head -10
echo
echo "=== Examine if `@supported_compute_capability` on standalone functions provides is_backend_supported ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
from flashinfer.fused_moe.cute_dsl.fused_moe import cute_dsl_fused_moe_nvfp4
print("cute_dsl_fused_moe_nvfp4 attributes:")
print(f" has _supported_ccs: {hasattr(cute_dsl_fused_moe_nvfp4, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported')}")
print(f" has is_backend_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_backend_supported')}")
if hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported'):
print(f" is_compute_capability_supported(100): {cute_dsl_fused_moe_nvfp4.is_compute_capability_supported(100)}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 3899
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if cute_dsl_fused_moe_nvfp4 has is_backend_supported ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
from flashinfer.fused_moe.cute_dsl.fused_moe import cute_dsl_fused_moe_nvfp4
print("cute_dsl_fused_moe_nvfp4 attributes:")
print(f" has _supported_ccs: {hasattr(cute_dsl_fused_moe_nvfp4, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported')}")
print(f" has is_backend_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_backend_supported')}")
EOF
echo
echo "=== Compare with a function that uses backend_requirement ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
from flashinfer.quantization.fp4_quantization import nvfp4_kv_dequantize
print("nvfp4_kv_dequantize (uses `@backend_requirement`) attributes:")
print(f" has _supported_ccs: {hasattr(nvfp4_kv_dequantize, '_supported_ccs')}")
print(f" has is_compute_capability_supported: {hasattr(nvfp4_kv_dequantize, 'is_compute_capability_supported')}")
print(f" has is_backend_supported: {hasattr(nvfp4_kv_dequantize, 'is_backend_supported')}")
EOF
echo
echo "=== Summary: Check decorator stacking on functions in fused_moe.py ==="
rg -B1 'def cute_dsl_fused_moe_nvfp4\|def cute_dsl_fused_moe_fp4_block' flashinfer/fused_moe/cute_dsl/fused_moe.py | grep -E '(@|def )' | head -10Repository: flashinfer-ai/flashinfer
Length of output: 583
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Direct code examination: Check standalone function decorators ==="
sed -n '683,690p' flashinfer/fused_moe/cute_dsl/fused_moe.py
echo
echo "=== Direct code examination: Check function with backend_requirement for comparison ==="
sed -n '1274,1280p' flashinfer/quantization/fp4_quantization.py
echo
echo "=== Key finding: Verify learning requirement one more time ==="
cat << 'EOF'
From the retrieval learning:
"Use `backend_requirement` decorator on APIs that have compute capability requirements
and provide `is_compute_capability_supported(cc)` and `is_backend_supported()` methods"
Current code patterns observed:
1. Standalone function (cute_dsl_fused_moe_nvfp4): Uses `@supported_compute_capability` directly
2. Wrapper function (nvfp4_kv_dequantize): Uses `@backend_requirement` with common_check
CuteDslMoEWrapper.__init__ currently uses:
`@supported_compute_capability`([100, 103])
`@flashinfer_api`
This does NOT match the learning requirement for APIs to use `@backend_requirement`
EOF
echo
echo "=== Conclusion: Review the original concern ==="
echo "The original review comment states:"
echo " 'Wrapping CuteDslMoEWrapper.__init__ with the compute-capability decorator'"
echo " 'only guards construction; it may not make CuteDslMoEWrapper itself expose'"
echo " 'the standard backend discovery helpers'"
echo ""
echo "This is CORRECT. The `@supported_compute_capability` decorator adds helpers to"
echo "the __init__ function object, not to the class instance. Following the learning,"
echo "this should use `@backend_requirement` instead."Repository: flashinfer-ai/flashinfer
Length of output: 1874
Use @backend_requirement instead of @supported_compute_capability on CuteDslMoEWrapper.__init__.
The @supported_compute_capability decorator provides only is_compute_capability_supported on the method object. Public APIs should use @backend_requirement to expose both is_compute_capability_supported(cc) and is_backend_supported() methods, following the library's standard backend-discovery pattern.
Move the compute capability constraint into a separate check function decorated with @supported_compute_capability, then apply @backend_requirement({}, common_check=...) to __init__, mirroring the pattern used by nvfp4_kv_dequantize and other public APIs.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 342 - 343, Replace
the `@supported_compute_capability` decorator on CuteDslMoEWrapper.__init__ with
`@backend_requirement` so the public API exposes both
is_compute_capability_supported(cc) and is_backend_supported(); to do this,
create a separate small check function (e.g., cute_dsl_moe_compute_check)
decorated with `@supported_compute_capability`([100, 103]) that returns the
compute-capability check, then apply `@backend_requirement`({},
common_check=cute_dsl_moe_compute_check) and `@flashinfer_api` to
CuteDslMoEWrapper.__init__ (mirroring the nvfp4_kv_dequantize pattern) so
backend discovery and compute-capability checks are both available.
<!-- .github/pull_request_template.md --> ## 📌 Description fix api breaking changes for 0.6.7 release ## 🔍 Related Issues (Gated-by PRs) https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes **API changes review** API changes since v0.6.6 PR #2520 + commit e35c19e (fixed to be compatible) Function: xqa() Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params (after *). Backward-compatible. PR #2618 (has PR #2730 to fix it) Function: gated_delta_rule_mtp() Change: disable_state_update: bool = True → Optional[bool] = None. Still defaults to True at runtime but emits a deprecation warning; will flip to False in 0.7.0. PR #2775 (expected — cute DSL MoE cleanup) Function: blockscaled_contiguous_grouped_gemm_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: CuteDslMoEWrapper.__init__() Change: Added enable_pdl: bool = True param. Backward-compatible. Function: cute_dsl_fused_moe_nvfp4() Change: Added enable_pdl: bool = True param. Backward-compatible. PR #2428 Function: rmsnorm_quant() Change: scale: float → scale: Union[float, torch.Tensor]; return type torch.Tensor → None. Function: fused_add_rmsnorm_quant() Change: scale: float → scale: Union[float, torch.Tensor]. Quantization functions (relocated, not removed) All quantization APIs (fp4_quantize, block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a, nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host, mxfp8_quantize, mxfp8_dequantize_host) were moved from flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to flashinfer/quantization/. Signatures, @flashinfer_api decorators, and __init__.py exports are preserved. No breakage. ```diff $ git diff v0.6.6 | grep -A20 "@flashinfer_api" @flashinfer_api @@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper: sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, + kv_block_scales: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + # Unpack kv_block_scales + key_block_scales = None + value_block_scales = None + if kv_block_scales is not None: + if isinstance(kv_block_scales, tuple): + key_block_scales, value_block_scales = kv_block_scales -- -@flashinfer_api -def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - This function implements FP4 quantization that converts input tensors to a compressed FP4 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- -@flashinfer_api -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - This function swizzles the block scale tensor to optimize memory access patterns - for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. - - Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: Swizzled tensor with the same shape as input. - - Raises: - AssertionError: If input dtype is not uint8 or bfloat16. - """ - # TODO(shuw): check input dtype is uint8 - assert ( - unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 - ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" - -- -@flashinfer_api -def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- -@flashinfer_api -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: - """ - PyTorch equivalent of trtllm-gen `shuffleMatrixA` - """ - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] - - -@flashinfer_api -def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, -): - """ - Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. - `shuffleMatrixSfA` expects the input to be in 128x4 layout and then - apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 - layout. - This function expects the input to be in linear layout. It's done this - way because the scaling factors in the NVFP4 checkpoints are quantized - and are in linear layout. - This function doesn't add padding. - """ - - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - -- -@flashinfer_api -def nvfp4_quantize( - a, - a_global_sf, - sfLayout=SfLayout.layout_128x4, - do_shuffle=False, - sf_vec_size=16, - enable_pdl=None, -): - """ - Quantize input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. - do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - -- -@flashinfer_api -def mxfp4_quantize(a): - """ - Quantize input tensor to MXFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() - a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_dequantize(a_fp4, a_sf): - """ - Dequantize input tensor from MXFP4 format. - - Parameters: - a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), - 32, - 0, - True, - ) - -- -@flashinfer_api -def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Dequantize input tensor from MXFP4 format on host. - - Parameters: - weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - group_size (int, optional): Group size for dequantization. Defaults to 32. - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # use any cuda device to get a compute capability -- -@flashinfer_api -def nvfp4_batched_quantize( - a, - a_global_sf, - sf_vec_size=16, -): - """ - Quantize batched input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" -- -@flashinfer_api -def scaled_fp4_grouped_quantize( - a, - mask, - a_global_sf, -): - """ - quantize batched input tensor to NVFP4 format with mask. - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module( - device_arch -- -@flashinfer_api -def mxfp8_quantize( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 -- -@flashinfer_api -def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - This function performs dequantization by converting a packed FP8 tensor in MxFP8 format - back to float values using the associated scale factors. - - Args: - input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - -- -@flashinfer_api def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( vectorized_f32: bool = True, raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads. @@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( major, minor = get_compute_capability(a.device) if major != 10: raise ValueError( - f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). " + f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). " f"Got SM{major}{minor}." ) -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (128, 128), - cluster_shape_mn: Tuple[int, int] = (1, 1), - sm_count: Optional[int] = None, -) -> torch.Tensor: - """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization. - -- -@flashinfer_api def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( cluster_shape_mn: Tuple[int, int] = (2, 1), raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads. @@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1. token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16 out: Optional output tensor, shape (seq_len, n). Created if None. - This tensor is used for atomic accumulation, so it should be zero-initialized. + This tensor is used for atomic accumulation. If `out` is + provided, it must already be zero-initialized by the caller. + If `out` is None, this function allocates a zero-initialized + output tensor. Passing a non-zeroed `out` buffer will silently -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - out_scale: Optional[torch.Tensor] = None, - global_scale: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (256, 128), - cluster_shape_mn: Tuple[int, int] = (2, 1), - vectorized_f32: bool = True, - sm_count: Optional[int] = None, -- @flashinfer_api def __init__( self, @@ -347,6 +355,7 @@ class CuteDslMoEWrapper: sf_vec_size: int = 16, output_dtype: torch.dtype = torch.bfloat16, device: str = "cuda", + enable_pdl: bool = True, ): """Initialize the MoE wrapper. @@ -363,6 +372,7 @@ class CuteDslMoEWrapper: sf_vec_size: Scale factor vector size. Default: 16. output_dtype: Output data type. Default: torch.bfloat16. device: Device for buffer allocation. Default: "cuda". + enable_pdl: Enable Programmatic Dependent Launch. Default: True. """ self.num_experts = num_experts self.top_k = top_k @@ -376,6 +386,7 @@ class CuteDslMoEWrapper: self.sf_vec_size = sf_vec_size -- @flashinfer_api @@ -550,9 +570,10 @@ class CuteDslMoEWrapper: f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" ) - # Allocate output buffer if not using pre-allocated one + # Slice the pre-allocated buffer to the active batch so that + # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. if self.use_cuda_graph: - moe_output = self._moe_output + moe_output = self._moe_output[:num_tokens] else: moe_output = torch.empty( (num_tokens, self.hidden_size), @@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Internal implementation called by auto-tuner for functional API.""" -- @flashinfer_api def cute_dsl_fused_moe_nvfp4( x: torch.Tensor, @@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Run fused MoE computation using CuteDSL NVFP4 kernels. + Supported architectures: SM100, SM103. + This is the simple functional API. For CUDA graph support, use `CuteDslMoEWrapper` instead. @@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4( local_expert_offset=local_expert_offset, use_fused_finalize=use_fused_finalize, output_dtype=output_dtype, + enable_pdl=enable_pdl, -- @flashinfer_api def gated_delta_rule_decode_pretranspose( q: torch.Tensor, @@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose( - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used (supports both the direct ``state`` path and the pool+indices path). - - pool+indices (``initial_state``/``initial_state_indices``) only supported - via the bf16 fast path; float32 state raises an error. + - pool+indices (``initial_state``/``initial_state_indices``) supported on + both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes @@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( -- @flashinfer_api def gated_delta_rule_mtp( q: torch.Tensor, @@ -2427,7 +489,7 @@ def gated_delta_rule_mtp( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, intermediate_states_buffer: Optional[torch.Tensor] = None, - disable_state_update: bool = True, + disable_state_update: Optional[bool] = None, use_qk_l2norm: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -2463,8 +525,15 @@ def gated_delta_rule_mtp( intermediate_states_buffer (Optional[torch.Tensor]): Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``. If None, intermediate states are not cached. - disable_state_update (bool): - If True, the initial state is not updated. Default: ``True``. + disable_state_update (Optional[bool]): + If True, the initial state is not updated. Currently defaults to ``True``. + Please pass this argument explicitly — the default will change to ``False`` -- @flashinfer_api @@ -60,16 +120,14 @@ def rmsnorm( output: torch.Tensor Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, enable_pdl) + _rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) -def _rmsnorm( +def _rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -78,11 +136,21 @@ def _rmsnorm( -- @flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, @@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek( If return_lse is False, the output will be a single tensor. """ if not is_sm12x_supported(query.device): - major, minor = get_compute_capability(query.device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} " - f"for SM12{minor}x GPUs." - ) raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) - module = get_trtllm_fmha_v2_module() + module = get_trtllm_fmha_v2_sm120_module() is_e4m3 = query.dtype == torch.float8_e4m3fn -- +@flashinfer_api +def trtllm_fmha_v2_prefill( + qkv: Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ], + input_layout: str, + workspace_buffer: torch.Tensor, + seq_lens: torch.Tensor, + max_q_len: int, + max_kv_len: int, + bmm1_scale: float, + bmm2_scale: float, + batch_size: int, + cum_seq_lens_q: torch.Tensor, + cum_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[Union[torch.dtype, str]] = None, + sinks: Optional[List[torch.Tensor]] = None, -- +@flashinfer_api +def fp4_quantize( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + This function implements FP4 quantization that converts input tensors to a compressed FP4 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- +@flashinfer_api +def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + This function swizzles the block scale tensor to optimize memory access patterns + for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. + + Args: + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: Swizzled tensor with the same shape as input. + + Raises: + AssertionError: If input dtype is not uint8 or bfloat16. + """ + # TODO(shuw): check input dtype is uint8 + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" + -- +@flashinfer_api +def e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- +@flashinfer_api +def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: + """ + PyTorch equivalent of trtllm-gen `shuffleMatrixA` + """ + row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) + + return input_tensor[row_indices.to(input_tensor.device)] + + +@flashinfer_api +def shuffle_matrix_sf_a( + input_tensor: torch.Tensor, + epilogue_tile_m: int, + num_elts_per_sf: int = 16, +): + """ + Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. + `shuffleMatrixSfA` expects the input to be in 128x4 layout and then + apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 + layout. + This function expects the input to be in linear layout. It's done this + way because the scaling factors in the NVFP4 checkpoints are quantized + and are in linear layout. + This function doesn't add padding. + """ + + row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) + + w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + -- +@flashinfer_api +def nvfp4_quantize( + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, +): + """ + Quantize input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. + do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + -- +@flashinfer_api +def mxfp4_quantize( + a: torch.Tensor, + backend: str = "cuda", + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic + Dependent Launch). Only used when backend="cute-dsl". + If None, automatically detects based on device capability. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) -- +@flashinfer_api +def mxfp4_dequantize(a_fp4, a_sf): + """ + Dequantize input tensor from MXFP4 format. + + Parameters: + a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + return e2m1_and_ufp8sf_scale_to_float( + a_fp4.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0], device=a_fp4.device), + 32, + 0, + True, + ) + -- +@flashinfer_api +def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Dequantize input tensor from MXFP4 format on host. + + Parameters: + weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + group_size (int, optional): Group size for dequantization. Defaults to 32. + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # use any cuda device to get a compute capability -- +@flashinfer_api +def nvfp4_batched_quantize( + a, + a_global_sf, + sf_vec_size=16, +): + """ + Quantize batched input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" -- +@flashinfer_api +def nvfp4_quantize_paged_kv_cache( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_layout: str = "HND", + k_global_sf: Optional[torch.Tensor] = None, + v_global_sf: Optional[torch.Tensor] = None, +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + float, + float, +]: + """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA. + + Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling + (global FP32 + per-block FP8), and swizzles scale factors + for the SM100 trtllm-gen MHA kernel layout. + + Args: + k_cache: Key cache tensor. -- +@flashinfer_api +def scaled_fp4_grouped_quantize( + a, + mask, + a_global_sf, +): + """ + quantize batched input tensor to NVFP4 format with mask. + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + mask (torch.Tensor): Mask tensor to apply before quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module( + device_arch -- +@flashinfer_api +def nvfp4_kv_dequantize( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """GPU dequantization of NVFP4 KV cache data with linear block scale layout. + + Requires SM80+. + + Args: + fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` + with dtype uint8. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as fp4_data. + output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. + + Returns: + torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. -- +@flashinfer_api +def nvfp4_kv_quantize( + input: torch.Tensor, + global_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """GPU quantization to NVFP4 KV cache format with linear block scale layout. + + Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. + K must be divisible by 16. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as input. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. + """ + M, K = input.shape -- +@flashinfer_api +def mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + backend: Literal["cuda", "cute-dsl"] = "cuda", + sf_swizzle_layout: Optional[SfLayout] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. + backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are: -- +@flashinfer_api +def mxfp8_dequantize_host( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: Optional[SfLayout] = None, +) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + This function performs dequantization by converting a packed FP8 tensor in MxFP8 format + back to float values using the associated scale factors. + + Args: + input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear. + + Returns: -- +@flashinfer_api +def mxfp4_quantize_cute_dsl( + input: torch.Tensor, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: + - Global scale computed as (448 * 6) / max(|input|) + - UE8M0 scale factors + - E2M1 output format (4-bit, 2 values per byte) + - Swizzled (128x4) scale factor layout + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). -- +@flashinfer_api +def mxfp8_quantize_cute_dsl( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP8 format using CuTe-DSL kernel. + + This is a GPU implementation with dual-path optimization: + - LINEAR layout: SF-block based iteration (fast) + - SWIZZLED layout: Row-based iteration with padding fast path (optimized) + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False) + alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Enhancements** * Normalization now accepts scale as either a float or tensor; passing a float emits a deprecation warning and is auto-converted for compatibility. * Attention/decoding API: cache-scale parameters are now optional keyword-only arguments with sensible defaults, simplifying common call patterns. * **Tests** * Tests updated to match the adjusted attention/decoding call signature. * **Chores** * Release version bumped to 0.6.7. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This is to clean up cute dsl nvfp4 moe
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Tests