Skip to content

Add list_api script#3341

Merged
aleozlx merged 2 commits into
flashinfer-ai:mainfrom
aleozlx:list_api
May 21, 2026
Merged

Add list_api script#3341
aleozlx merged 2 commits into
flashinfer-ai:mainfrom
aleozlx:list_api

Conversation

@aleozlx
Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx commented May 17, 2026

📌 Description

A utility script that can be used for API review/QA purposes

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Examples

List all @flashinfer_api-decorated APIs (module-level and class methods), grouped by class (or [Global Functions] per file), with full multi-line signatures.

List current API surface

scripts/list_apis.sh

Signatures only (no paths, no line numbers)

scripts/list_apis.sh -p

Class methods only (skip module-level functions)

scripts/list_apis.sh -M

Inspect a single file

scripts/list_apis.sh -p flashinfer/cascade.py

Run against a git tag / branch / SHA

Auto-fetches from upstream or origin if not present locally; uses a throwaway worktree.

scripts/list_apis.sh --ref v0.6.9
scripts/list_apis.sh --ref main -p

Diff the API surface between two revisions

Use -d for stable, byte-identical output so the diff only reflects real API changes.

diff -u \
  <(scripts/list_apis.sh -d -p --ref v0.6.9) \
  <(scripts/list_apis.sh -d -p)

Flags

Flag Effect
-n, --no-lines Omit line numbers
-p, --no-paths Omit paths and line numbers (signatures-only)
-M, --methods-only Skip module-level functions; only show class methods
-d, --deterministic Stable, diff-friendly output (sorts files; slightly slower)
-r, --ref REF Run against a git revision via temp worktree
-h, --help Show help

Summary by CodeRabbit

  • Chores
    • Added a command-line utility to discover and list decorated Python APIs across the codebase.
    • Preserves full multi-line function signatures and offers output options to hide line numbers or file paths.
    • Can limit results to class methods and supports scanning a specific repository revision via a temporary checkout.
    • Includes an option for stable, deterministic ordering of results.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 17, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0a47b92b-9f78-42cd-8d56-ad9d150086c8

📥 Commits

Reviewing files that changed from the base of the PR and between 9debd04 and 3d97460.

📒 Files selected for processing (1)
  • scripts/list_apis.sh
🚧 Files skipped from review as they are similar to previous changes (1)
  • scripts/list_apis.sh

📝 Walkthrough

Walkthrough

Adds scripts/list_apis.sh, a bash utility that finds @flashinfer_api-decorated Python APIs (preserving multi-line signatures), supports output toggles and methods-only filtering, and can scan a specific git revision via a temporary detached worktree.

Changes

API Enumeration Script

Layer / File(s) Summary
Header, CLI parsing, and validation
scripts/list_apis.sh
Script header, usage docs, strict shell settings, ripgrep presence check, option defaults, and CLI argument parsing/validation.
Git --ref resolution and detached worktree
scripts/list_apis.sh
Resolves repo root, validates/fetches the requested ref, creates a temporary detached git worktree for the ref, rewrites scan paths to the worktree, and registers cleanup via trap.
Scan path defaults and ripgrep invocation
scripts/list_apis.sh
Sets default target paths (when none provided), optionally enables deterministic rg ordering, runs multi-line rg regex over targets, and pipes results to awk with format/path options.
AWK emit/flush helpers and state machine
scripts/list_apis.sh
Defines emit/flush helpers and the main awk state machine that parses ripgrep lines, tracks class context, records @flashinfer_api decorators, buffers multi-line def signatures, and emits decorator–definition pairs (final flush at EOF).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐰 I hop through code, with ripgrep in paw,
Finding decorators with awk's gentle law,
Multi-line signatures kept safe and sound,
Worktrees spun up, then quietly unbound,
A tiny script hums — APIs found!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is vague and generic, using non-descriptive phrasing that doesn't convey meaningful information about the script's purpose for API review/QA. Consider a more descriptive title such as 'Add list_apis utility script for API review and QA' to better convey the script's purpose.
✅ Passed checks (4 passed)
Check name Status Explanation
Description check ✅ Passed The description includes comprehensive examples, flag documentation, and usage instructions, though the Description section is minimal. All required sections are present with the pre-commit and testing checklists completed.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@aleozlx
Copy link
Copy Markdown
Collaborator Author

aleozlx commented May 17, 2026

cc @cindyzxq

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new bash script, scripts/list_apis.sh, which extracts and groups Python class methods decorated with @flashinfer_api. The script supports git revisions via temporary worktrees and offers various output formatting options. Review feedback suggested several improvements to enhance robustness, including adding a check for the ripgrep dependency, hardening argument parsing for the --ref flag, refining path parsing in awk to handle filenames with colons, and ensuring that decorated top-level functions are included in the output.

Comment thread scripts/list_apis.sh
Comment thread scripts/list_apis.sh Outdated
Comment thread scripts/list_apis.sh Outdated
Comment thread scripts/list_apis.sh
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@scripts/list_apis.sh`:
- Line 32: The current help extraction in the -h|--help) branch prints lines
2..first blank and thus includes non-comment code; replace the sed pipeline with
a command that only prints contiguous leading comment lines (starting at line 2)
and stops at the first non-comment line. Update the -h|--help) handler to use a
filter like an awk expression that checks NR>=2 and prints lines matching /^`#/`
(stripping the leading "# " via sub) and exits on the first non-# line so only
the leading comment block is shown.
- Line 31: The case branch handling -r|--ref directly assigns ref="$2" and
shifts without verifying a next argument; because set -u is enabled this will
crash if -r/--ref is the last token. Fix the -r|--ref) branch in
scripts/list_apis.sh (the option-parsing case) by first guarding that a next
argument exists (e.g. check $# -ge 2 or that ${2-} is non-empty and not another
option) and if not, print an error/usage and exit non‑zero; only then set
ref="$2" and shift 2.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 97850ad5-a1ea-427a-8684-3bd549746e7a

📥 Commits

Reviewing files that changed from the base of the PR and between ce43023 and 9debd04.

📒 Files selected for processing (1)
  • scripts/list_apis.sh

Comment thread scripts/list_apis.sh Outdated
Comment thread scripts/list_apis.sh
@aleozlx aleozlx merged commit 209c16a into flashinfer-ai:main May 21, 2026
31 checks passed
@aleozlx aleozlx mentioned this pull request May 21, 2026
aleozlx added a commit that referenced this pull request May 22, 2026
## Description

Bump version to 0.6.12 for release.

## Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.12

## Reviewer Notes

**API changes review**

API changes since v0.6.11.post3, using new tool
* #3341

```diff
diff -u \
  <(scripts/list_apis.sh -d -p --ref v0.6.11.post3) \
  <(scripts/list_apis.sh -d -p)

--- /tmp/api_baseline.txt	2026-05-21 16:07:23.252004287 -0700
+++ /tmp/api_head.txt	2026-05-21 16:07:23.316004287 -0700
@@ -251,6 +251,8 @@
     shared_expert_output: Optional[torch.Tensor] = None,
     # ===== Group quant parameters =====
     block_quant_group_size: Optional[int] = None,
+    # ===== RMSNorm variant =====
+    weight_bias: float = 0.0,
 ) -> torch.Tensor:
 [Global Functions]
 @flashinfer_api
@@ -513,6 +515,7 @@
     out_dtype: Optional[torch.dtype] = None,
     is_var_seq: bool = True,
     enable_pdl: Optional[bool] = None,
+    sinks: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
 class BatchPrefillCuteDSLWrapper:
     @flashinfer_api
@@ -759,7 +762,11 @@
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
     uses_shared_paged_kv_idx: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
 @flashinfer_api(trace=xqa_batch_decode_trace)
 def xqa_batch_decode_with_kv_cache(
     query: torch.Tensor,
@@ -898,6 +905,7 @@
     weight_layout: int = WeightLayout.BlockMajorK,
     do_finalize: bool = True,
     enable_pdl: bool = True,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     activation_type: int = ActivationType.Swiglu.value,
     routing_replay_out: Optional[torch.Tensor] = None,
@@ -987,6 +995,7 @@
     weight_layout: int = 0,
     do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
@@ -1034,7 +1043,7 @@
 
 @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
-    topk_ids: torch.Tensor,
+    topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: Optional[torch.Tensor],
@@ -1096,6 +1105,34 @@
     norm_topk_prob: bool = True,
     routing_replay_out: Optional[torch.Tensor] = None,
 ) -> List[torch.Tensor]:
+
+
+@flashinfer_api
+def trtllm_mxint4_block_scale_routed_moe(
+    topk_ids: torch.Tensor,
+    hidden_states: torch.Tensor,
+    gemm1_weights: torch.Tensor,
+    gemm1_weights_scale: torch.Tensor,
+    gemm1_alpha: Optional[torch.Tensor],
+    gemm1_beta: Optional[torch.Tensor],
+    gemm1_clamp_limit: Optional[torch.Tensor],
+    gemm2_weights: torch.Tensor,
+    gemm2_weights_scale: torch.Tensor,
+    num_experts: int,
+    top_k: int,
+    n_group: Optional[int],
+    topk_group: Optional[int],
+    intermediate_size: int,
+    local_expert_offset: int,
+    local_num_experts: int,
+    routed_scaling_factor: Optional[float],
+    routing_method_type: int = 0,
+    do_finalize: bool = True,
+    enable_pdl: Optional[bool] = None,
+    gemm1_lora_delta: Optional[torch.Tensor] = None,
+    output: Optional[torch.Tensor] = None,
+    tune_max_num_tokens: int = 8192,
+) -> List[torch.Tensor]:
 [Global Functions]
 @flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
@@ -1117,8 +1154,6 @@
     output_dtype: torch.dtype = torch.bfloat16,
     activation: str = "silu",
     activation_precision: str = "fp4",
-    quant_mode: Optional[str] = None,
-    source_format: str = "modelopt",
 ) -> torch.Tensor:
 class B12xMoEWrapper:
     @flashinfer_api
@@ -1136,8 +1171,6 @@
         device: str = "cuda",
         activation: str = "silu",
         activation_precision: str = "fp4",
-        quant_mode: Optional[str] = None,
-        source_format: str = "modelopt",
     ):
 
     @flashinfer_api(trace=b12x_moe_wrapper_run_trace)
@@ -1477,8 +1510,6 @@
     out: Optional[torch.Tensor] = None,
     backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
 ):
-
-
 @flashinfer_api(trace=bmm_fp8_trace)
 def bmm_fp8(
     A: torch.Tensor,
@@ -1524,7 +1555,7 @@
     out_dtype: Optional[torch.dtype] = None,
     backend: Literal["cutlass", "trtllm"] = "cutlass",
 ):
-@flashinfer_api
+@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace)
 def gemm_fp8_nt_groupwise(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -1712,8 +1743,17 @@
     sf_dtype: str,
     c_dtype: str,
     sf_vec_size: int,
+    topk_weights: Optional[torch.Tensor] = None,
+    idx_src_info: Optional[torch.Tensor] = None,
+    rank_src_info: Optional[torch.Tensor] = None,
+    out_ptrs: Optional[torch.Tensor] = None,
+    num_ranks: int = 0,
     dst_signals: Optional[torch.Tensor] = None,
     sm_count: Optional[int] = None,
+    barrier_flag_local: Optional[torch.Tensor] = None,
+    barrier_flag_multicast: Optional[torch.Tensor] = None,
+    is_combine_fusion: bool = False,
+    is_swap_ab: bool = False,
     **kwargs,
 ):
 [Global Functions]
@@ -1722,14 +1762,21 @@
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
     out: torch.Tensor,
-    launch_with_pdl: bool = False,
+    launch_with_pdl: bool = True,
 ) -> None:
 @flashinfer_api(trace=mm_M1_16_K7168_N256_trace)
 def mm_M1_16_K7168_N256(
     mat_a: torch.Tensor,
     mat_b: torch.Tensor,
     out: torch.Tensor,
-    launch_with_pdl: bool = False,
+    launch_with_pdl: bool = True,
+) -> None:
+@flashinfer_api(trace=mm_M1_16_K6144_N256_trace)
+def mm_M1_16_K6144_N256(
+    mat_a: torch.Tensor,
+    mat_b: torch.Tensor,
+    out: torch.Tensor,
+    launch_with_pdl: bool = True,
 ) -> None:
 @flashinfer_api(trace=tinygemm_bf16_trace)
 def tinygemm_bf16(
@@ -1826,6 +1873,36 @@
     tactic: int = -1,
 ) -> torch.Tensor:
 [Global Functions]
+@flashinfer_api
+def checkpointing_ssu(
+    state: torch.Tensor,
+    old_x: torch.Tensor,
+    old_B: torch.Tensor,
+    old_dt: torch.Tensor,
+    old_cumAdt: torch.Tensor,
+    cache_buf_idx: torch.Tensor,
+    prev_num_accepted_tokens: torch.Tensor,
+    x: torch.Tensor,
+    dt: torch.Tensor,
+    A: torch.Tensor,
+    B: torch.Tensor,
+    C: torch.Tensor,
+    out: torch.Tensor,
+    D: Optional[torch.Tensor] = None,
+    z: Optional[torch.Tensor] = None,
+    dt_bias: Optional[torch.Tensor] = None,
+    dt_softplus: bool = False,
+    state_batch_indices: Optional[torch.Tensor] = None,
+    pad_slot_id: int = -1,
+    state_scale: Optional[torch.Tensor] = None,
+    rand_seed: Optional[torch.Tensor] = None,
+    philox_rounds: int = 10,
+    d_split: Optional[int] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
+    max_seqlen: Optional[int] = None,
+    enable_pdl: bool = False,
+) -> torch.Tensor:
+[Global Functions]
 @flashinfer_api(trace=selective_state_update_trace)
 def selective_state_update(
     state: torch.Tensor,
@@ -1966,6 +2043,7 @@
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
@@ -1991,7 +2069,10 @@
     backend: str = "auto",
     is_var_seq: bool = True,
     uses_shared_paged_kv_idx: bool = True,
-) -> torch.Tensor:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+    cute_dsl_impl: str = "auto",
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
 @flashinfer_api(trace=xqa_batch_decode_mla_trace)
@@ -2252,6 +2333,44 @@
     norm_out: Optional[torch.Tensor] = None,
     sf_out: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
+    qkv,
+    q_weight,
+    k_weight,
+    **kwargs,
+):
+
+
+@flashinfer_api
+def fused_qk_rmsnorm_rope(
+    qkv: torch.Tensor,
+    q_weight: torch.Tensor,
+    k_weight: torch.Tensor,
+    *,
+    ppf: int,
+    pph: int,
+    ppw: int,
+    num_frame_channels: int,
+    num_height_channels: int,
+    num_width_channels: int,
+    num_heads_q: int,
+    num_heads_k: int,
+    num_heads_v: int,
+    head_dim: int,
+    eps: float = 1e-6,
+    base: float = 10000.0,
+    interleave: bool = True,
+    factor: float = 1.0,
+    low: float = 0.0,
+    high: float = 0.0,
+    attention_factor: float = 1.0,
+    is_qk_norm: bool = True,
+    output_fp8: bool = False,
+    output_quant_scale: float = 1.0,
+    v_quant_scale: float = 1.0,
+    q_out: Optional[torch.Tensor] = None,
+    k_out: Optional[torch.Tensor] = None,
+    v_out: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 [Global Functions]
 @flashinfer_api
 def get_batch_indices_positions(
@@ -2730,7 +2849,11 @@
     skip_softmax_threshold_scale_factor: Optional[float] = None,
     uses_shared_paged_kv_idx: bool = True,
     causal: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+    torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
+]:
 
 
 @flashinfer_api(trace=fmha_v2_prefill_deepseek_trace)
@@ -2942,6 +3065,7 @@
     is_sf_swizzled_layout: bool = True,
     alignment: int = 32,
     enable_pdl: bool | None = None,
+    is_sf_8x4_layout: bool = False,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
```

API changes since v0.6.11.post3 (old approach)

```diff
$ git diff v0.6.11.post3..main -- "*.py" | grep -B5 -A20 "@flashinfer_api"
-def _reconstruct_value(value: Any) -> Any:
+def flush_graph_dumps(synchronize: bool = True) -> int:
+    """Write CUDA-graph-deferred level-10 dumps to disk.
+
+    When ``FLASHINFER_LOGLEVEL=10`` is active inside ``torch.cuda.graph(...)``,
+    each ``@flashinfer_api`` call records input/output tensor references instead
+    of writing immediately or inserting D2H copies into the captured graph.
+    After ``g.replay()`` completes, calling this function materializes current
+    tensor values to CPU and serializes them to two places:
+
+    1. ``inputs.pt``/``outputs.pt`` (or the safetensors equivalents) in the
+       original dump directory, for backwards compatibility. These files
+       always reflect the most recent flush.
+    2. ``graph_flushes/flush_XXXX/`` under the original dump directory. These
+       immutable snapshots preserve every explicit flush, so callers can keep
+       every replay by calling ``flush_graph_dumps()`` after every replay.
+
+    Parameters
+    ----------
+    synchronize : bool, default True
+        Synchronize the current stream first to ensure the most recent
+        ``g.replay()`` has completed before materializing tensors. Set to
+        ``False`` only if you've already synchronized externally.
+
+    Returns
+    -------
--
         routing_logits,
         None,
         None,
@@ -3199,7 +3362,7 @@ def trtllm_fp4_block_scale_moe(
 
 @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace)
 def trtllm_fp4_block_scale_routed_moe(
-    topk_ids: torch.Tensor,
+    topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     hidden_states_scale: Optional[torch.Tensor],
@@ -3231,13 +3394,20 @@ def trtllm_fp4_block_scale_routed_moe(
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
 ) -> List[torch.Tensor]:
-    """FP4 block scale MoE operation.
+    """FP4 block scale MoE operation with pre-computed routing.
+
+    This function supports two pre-computed routing formats:
+    1. Packed format: topk_ids is a single tensor with packed (score << 16 | expert_id)
+    2. Unpacked format: topk_ids is a tuple of (topk_ids, topk_weights) tensors
 
     Args:
-        topk_ids (torch.Tensor): shape [seq_len, top_k]
-            Tensor of top-k indices and expert weights. Dtype must be int32.
--
         norm_topk_prob,
         routing_replay_out,
     )
+
+
+@flashinfer_api
+def trtllm_mxint4_block_scale_routed_moe(
+    topk_ids: torch.Tensor,
+    hidden_states: torch.Tensor,
+    gemm1_weights: torch.Tensor,
+    gemm1_weights_scale: torch.Tensor,
+    gemm1_alpha: Optional[torch.Tensor],
+    gemm1_beta: Optional[torch.Tensor],
+    gemm1_clamp_limit: Optional[torch.Tensor],
+    gemm2_weights: torch.Tensor,
+    gemm2_weights_scale: torch.Tensor,
+    num_experts: int,
+    top_k: int,
+    n_group: Optional[int],
+    topk_group: Optional[int],
+    intermediate_size: int,
+    local_expert_offset: int,
+    local_num_experts: int,
+    routed_scaling_factor: Optional[float],
+    routing_method_type: int = 0,
+    do_finalize: bool = True,
--
-    except Exception:
-        return False
-
-
 @supported_compute_capability([120, 121])
 @flashinfer_api(trace=b12x_fused_moe_trace)
 def b12x_fused_moe(
@@ -74,13 +67,11 @@ def b12x_fused_moe(
     output_dtype: torch.dtype = torch.bfloat16,
     activation: str = "silu",
     activation_precision: str = "fp4",
-    quant_mode: Optional[str] = None,
-    source_format: str = "modelopt",
 ) -> torch.Tensor:
     """Run fused MoE on SM120/SM121 using b12x CuTe DSL kernels.
 
-    The kernel takes bf16 input and runs routing, FC1, activation, FC2,
-    and scatter through the selected backend.
+    The kernel takes bf16 input and fuses quantization + routing +
+    FC1 + activation + FC2 + scatter in a single launch.
     Automatically selects micro (decode), static, or dynamic backend
     based on routed row count.
 
@@ -99,19 +90,16 @@ def b12x_fused_moe(
         w1_alpha: Per-expert global scale for FC1.
         w2_alpha: Per-expert global scale for FC2.
--
 
@@ -6387,7 +6276,7 @@ def _check_gemm_fp8_nt_groupwise_problem_size(
     },
     common_check=_check_gemm_fp8_nt_groupwise_problem_size,
 )
-@flashinfer_api
+@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace)
 def gemm_fp8_nt_groupwise(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -8031,7 +7920,7 @@ def _calculate_block_scale_dims(
 
 
 @functools.lru_cache(maxsize=1024)
-def create_cudnn_execution_plans_mxfp8_gemm(
+def build_cudnn_gemm_mxfp8_graph(
     a_shape,
     a_stride,
     a_type,  # cudnn.data_type, FP8_E4M3 or FP8_E5M2
@@ -8041,7 +7930,11 @@ def create_cudnn_execution_plans_mxfp8_gemm(
     block_size,
     o_type,  # cudnn.data_type, BF16 or FP16
     device,
+    policy=None,
 ):
+    if policy is None:
+        policy = cudnn.build_plan_policy.HEURISTICS_CHOICE
--
@@ -229,6 +264,54 @@ def mm_M1_16_K7168_N256(
     )
 
 
+@backend_requirement({}, common_check=_mm_M1_16_K6144_N256_shape_checks)
+@flashinfer_api(trace=mm_M1_16_K6144_N256_trace)
+def mm_M1_16_K6144_N256(
+    mat_a: torch.Tensor,
+    mat_b: torch.Tensor,
+    out: torch.Tensor,
+    launch_with_pdl: bool = True,
+) -> None:
+    """Optimized GEMM for the router operation in GLM-MoE-DSA.
+
+    This function performs a highly optimized matrix multiplication specifically tailored
+    for the expert routing GEMM in GLM-MoE-DSA's Mixture of Experts (MoE) architecture.
+    It computes out = mat_a @ mat_b where mat_a contains token embeddings and mat_b
+    contains expert routing weights.
+
+    The implementation is optimized for the specific problem dimensions used in GLM-MoE-DSA:
+    - Hidden dimension (K): 6144
+    - Number of experts (N): 256
+    - Number of tokens (M): 1-16
+
+    Args:
+        mat_a (torch.Tensor): Input token embeddings of shape (M, K) where M is the number
--
+) -> None:
+    """Fake implementation for torch.compile() meta tensor propagation."""
+    pass
+
+
+@flashinfer_api
+def checkpointing_ssu(
+    state: torch.Tensor,
+    old_x: torch.Tensor,
+    old_B: torch.Tensor,
+    old_dt: torch.Tensor,
+    old_cumAdt: torch.Tensor,
+    cache_buf_idx: torch.Tensor,
+    prev_num_accepted_tokens: torch.Tensor,
+    x: torch.Tensor,
+    dt: torch.Tensor,
+    A: torch.Tensor,
+    B: torch.Tensor,
+    C: torch.Tensor,
+    out: torch.Tensor,
+    D: Optional[torch.Tensor] = None,
+    z: Optional[torch.Tensor] = None,
+    dt_bias: Optional[torch.Tensor] = None,
+    dt_softplus: bool = False,
+    state_batch_indices: Optional[torch.Tensor] = None,
+    pad_slot_id: int = -1,
--
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
     @flashinfer_api(trace=mla_paged_decode_trace)
@@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Run the MLA attention computation.
 
@@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper:
             ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
         out : Optional[torch.Tensor]
             The output tensor, if not provided, will be allocated internally.
+            When ``o_scale`` is provided, this should be an FP8 tensor.
         lse : Optional[torch.Tensor]
             The log-sum-exp of attention logits, if not provided, will be allocated internally.
         return_lse : bool, optional
@@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper:
             The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``.
         page_table : Optional[torch.Tensor]
             The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
--
+            )
+
+    return True
+
+
+@flashinfer_api
+@backend_requirement(backend_checks={}, common_check=_check_fused_qk_rmsnorm_rope)
+def fused_qk_rmsnorm_rope(
+    qkv: torch.Tensor,
+    q_weight: torch.Tensor,
+    k_weight: torch.Tensor,
+    *,
+    ppf: int,
+    pph: int,
+    ppw: int,
+    num_frame_channels: int,
+    num_height_channels: int,
+    num_width_channels: int,
+    num_heads_q: int,
+    num_heads_k: int,
+    num_heads_v: int,
+    head_dim: int,
+    eps: float = 1e-6,
+    base: float = 10000.0,
+    interleave: bool = True,
+    factor: float = 1.0,```

**Supplemental: class-wrapper overload stub changes (BatchMLAPagedAttentionWrapper.run gained `o_scale`)**

```diff
$ git diff v0.6.11.post3..main -- "flashinfer/mla/_core.py" | grep -B5
-A10 "o_scale"
     mod = gen_trtllm_gen_fmha_module()
@@ -457,6 +881,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> torch.Tensor: ...
 
     @overload
@@ -473,6 +898,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]: ...
 
     @flashinfer_api(trace=mla_paged_decode_trace)
@@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper:
         kv_len: Optional[torch.Tensor] = None,
         page_table: Optional[torch.Tensor] = None,
         return_lse_base_on_e: bool = False,
+        o_scale: Optional[float] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Run the MLA attention computation.
 
@@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper:
             ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models.
         out : Optional[torch.Tensor]
The output tensor, if not provided, will be allocated internally.
+            When ``o_scale`` is provided, this should be an FP8 tensor.
         lse : Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated
internally.
         return_lse : bool, optional
@@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper:
The query length of each request, shape: ``[batch_size]``. Required when
``backend`` is ``cutlass``.
         page_table : Optional[torch.Tensor]
The page table of the paged kv-cache, shape: ``[batch_size,
num_pages]``. Required when ``backend`` is ``cutlass``.
+        o_scale : Optional[float]
+ FP8 output dequantization scale (``real = quantized * o_scale``).
+ When provided, ``out`` must be an FP8 tensor. Only supported with
+            the ``cutlass`` backend.
         """
         if self._backend == "cutlass":
             if return_lse:
@@ -525,7 +957,26 @@ class BatchMLAPagedAttentionWrapper:
"profiler_buffer does not support cutlass backend for now."
                 )
             self._cached_module = get_mla_module()
-            if out is None:
+            output_scale = 1.0
+            if o_scale is not None:
+                output_scale = float(o_scale)
+ if not math.isfinite(output_scale) or output_scale <= 0.0:
+                    raise ValueError(
+ f"o_scale must be a finite positive value, got {o_scale}"
+                    )
+                if out is None:
+                    raise ValueError(
+ "out tensor must be provided when o_scale is used for FP8 output."
+                    )
+                if out.dtype not in (
+                    torch.float8_e4m3fn,
+                    torch.float8_e5m2,
+                ):
+                    raise ValueError(
+ f"out must be an FP8 tensor when o_scale is provided, got {out.dtype}"
+                    )
+ check_shape_dtype_device(out, q_nope.shape, None, q_nope.device,
"out")
+            elif out is None:
                 out = torch.empty_like(q_nope)
             else:
                 check_shape_dtype_device(
@@ -543,9 +994,14 @@ class BatchMLAPagedAttentionWrapper:
                 ckv_kpe_cache,
                 kv_len,
                 page_table,
+                output_scale,
             )
             return out
 
+        if o_scale is not None:
+            raise ValueError(
+ "o_scale is only supported with the cutlass backend for now."
+            )
         if profiler_buffer is None:
             if self._use_profiler:
                 raise ValueError(
@@ -615,7 +1071,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
     backend: str = "auto",
     is_var_seq: bool = True,
     uses_shared_paged_kv_idx: bool = True,
-) -> torch.Tensor:
+    lse: Optional[torch.Tensor] = None,
```

**Supplemental: `trtllm_batch_decode_with_kv_cache` / `trtllm_batch_context_with_kv_cache` gained `lse` and `return_lse` parameters (signature widening — BC)**

```diff
$ git diff v0.6.11.post3..main -- "flashinfer/decode.py"
"flashinfer/prefill.py" | grep -B3 -A6 "return_lse: bool = False"
     uses_shared_paged_kv_idx: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+ torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor],
torch.Tensor]
+]:
     """
     Parameters
     ----------
--
     causal: bool = True,
-) -> Union[torch.Tensor, FP4Tensor]:
+    lse: Optional[torch.Tensor] = None,
+    return_lse: bool = False,
+) -> Union[
+ torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor],
torch.Tensor]
+]:
     """
     Parameters
     ----------
```

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->

## Summary by CodeRabbit

* **Chores**
  * Version bumped to 0.6.12.

<!-- review_stack_entry_start -->

[![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3388?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack)

<!-- review_stack_entry_end -->

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants