Add list_api script#3341
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds ChangesAPI Enumeration Script
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
cc @cindyzxq |
There was a problem hiding this comment.
Code Review
This pull request introduces a new bash script, scripts/list_apis.sh, which extracts and groups Python class methods decorated with @flashinfer_api. The script supports git revisions via temporary worktrees and offers various output formatting options. Review feedback suggested several improvements to enhance robustness, including adding a check for the ripgrep dependency, hardening argument parsing for the --ref flag, refining path parsing in awk to handle filenames with colons, and ensuring that decorated top-level functions are included in the output.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/list_apis.sh`:
- Line 32: The current help extraction in the -h|--help) branch prints lines
2..first blank and thus includes non-comment code; replace the sed pipeline with
a command that only prints contiguous leading comment lines (starting at line 2)
and stops at the first non-comment line. Update the -h|--help) handler to use a
filter like an awk expression that checks NR>=2 and prints lines matching /^`#/`
(stripping the leading "# " via sub) and exits on the first non-# line so only
the leading comment block is shown.
- Line 31: The case branch handling -r|--ref directly assigns ref="$2" and
shifts without verifying a next argument; because set -u is enabled this will
crash if -r/--ref is the last token. Fix the -r|--ref) branch in
scripts/list_apis.sh (the option-parsing case) by first guarding that a next
argument exists (e.g. check $# -ge 2 or that ${2-} is non-empty and not another
option) and if not, print an error/usage and exit non‑zero; only then set
ref="$2" and shift 2.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
## Description Bump version to 0.6.12 for release. ## Related Issues (Gated-by PRs) https://github.com/flashinfer-ai/flashinfer/issues?q=is%3Aopen+label%3Av0.6.12 ## Reviewer Notes **API changes review** API changes since v0.6.11.post3, using new tool * #3341 ```diff diff -u \ <(scripts/list_apis.sh -d -p --ref v0.6.11.post3) \ <(scripts/list_apis.sh -d -p) --- /tmp/api_baseline.txt 2026-05-21 16:07:23.252004287 -0700 +++ /tmp/api_head.txt 2026-05-21 16:07:23.316004287 -0700 @@ -251,6 +251,8 @@ shared_expert_output: Optional[torch.Tensor] = None, # ===== Group quant parameters ===== block_quant_group_size: Optional[int] = None, + # ===== RMSNorm variant ===== + weight_bias: float = 0.0, ) -> torch.Tensor: [Global Functions] @flashinfer_api @@ -513,6 +515,7 @@ out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, enable_pdl: Optional[bool] = None, + sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: class BatchPrefillCuteDSLWrapper: @flashinfer_api @@ -759,7 +762,11 @@ skip_softmax_threshold_scale_factor: Optional[float] = None, kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: @flashinfer_api(trace=xqa_batch_decode_trace) def xqa_batch_decode_with_kv_cache( query: torch.Tensor, @@ -898,6 +905,7 @@ weight_layout: int = WeightLayout.BlockMajorK, do_finalize: bool = True, enable_pdl: bool = True, + gemm1_lora_delta: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, activation_type: int = ActivationType.Swiglu.value, routing_replay_out: Optional[torch.Tensor] = None, @@ -987,6 +995,7 @@ weight_layout: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, + gemm1_lora_delta: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, @@ -1034,7 +1043,7 @@ @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace) def trtllm_fp4_block_scale_routed_moe( - topk_ids: torch.Tensor, + topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], @@ -1096,6 +1105,34 @@ norm_topk_prob: bool = True, routing_replay_out: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: + + +@flashinfer_api +def trtllm_mxint4_block_scale_routed_moe( + topk_ids: torch.Tensor, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + gemm1_lora_delta: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 8192, +) -> List[torch.Tensor]: [Global Functions] @flashinfer_api(trace=b12x_fused_moe_trace) def b12x_fused_moe( @@ -1117,8 +1154,6 @@ output_dtype: torch.dtype = torch.bfloat16, activation: str = "silu", activation_precision: str = "fp4", - quant_mode: Optional[str] = None, - source_format: str = "modelopt", ) -> torch.Tensor: class B12xMoEWrapper: @flashinfer_api @@ -1136,8 +1171,6 @@ device: str = "cuda", activation: str = "silu", activation_precision: str = "fp4", - quant_mode: Optional[str] = None, - source_format: str = "modelopt", ): @flashinfer_api(trace=b12x_moe_wrapper_run_trace) @@ -1477,8 +1510,6 @@ out: Optional[torch.Tensor] = None, backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ): - - @flashinfer_api(trace=bmm_fp8_trace) def bmm_fp8( A: torch.Tensor, @@ -1524,7 +1555,7 @@ out_dtype: Optional[torch.dtype] = None, backend: Literal["cutlass", "trtllm"] = "cutlass", ): -@flashinfer_api +@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace) def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -1712,8 +1743,17 @@ sf_dtype: str, c_dtype: str, sf_vec_size: int, + topk_weights: Optional[torch.Tensor] = None, + idx_src_info: Optional[torch.Tensor] = None, + rank_src_info: Optional[torch.Tensor] = None, + out_ptrs: Optional[torch.Tensor] = None, + num_ranks: int = 0, dst_signals: Optional[torch.Tensor] = None, sm_count: Optional[int] = None, + barrier_flag_local: Optional[torch.Tensor] = None, + barrier_flag_multicast: Optional[torch.Tensor] = None, + is_combine_fusion: bool = False, + is_swap_ab: bool = False, **kwargs, ): [Global Functions] @@ -1722,14 +1762,21 @@ mat_a: torch.Tensor, mat_b: torch.Tensor, out: torch.Tensor, - launch_with_pdl: bool = False, + launch_with_pdl: bool = True, ) -> None: @flashinfer_api(trace=mm_M1_16_K7168_N256_trace) def mm_M1_16_K7168_N256( mat_a: torch.Tensor, mat_b: torch.Tensor, out: torch.Tensor, - launch_with_pdl: bool = False, + launch_with_pdl: bool = True, +) -> None: +@flashinfer_api(trace=mm_M1_16_K6144_N256_trace) +def mm_M1_16_K6144_N256( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = True, ) -> None: @flashinfer_api(trace=tinygemm_bf16_trace) def tinygemm_bf16( @@ -1826,6 +1873,36 @@ tactic: int = -1, ) -> torch.Tensor: [Global Functions] +@flashinfer_api +def checkpointing_ssu( + state: torch.Tensor, + old_x: torch.Tensor, + old_B: torch.Tensor, + old_dt: torch.Tensor, + old_cumAdt: torch.Tensor, + cache_buf_idx: torch.Tensor, + prev_num_accepted_tokens: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + out: torch.Tensor, + D: Optional[torch.Tensor] = None, + z: Optional[torch.Tensor] = None, + dt_bias: Optional[torch.Tensor] = None, + dt_softplus: bool = False, + state_batch_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, + state_scale: Optional[torch.Tensor] = None, + rand_seed: Optional[torch.Tensor] = None, + philox_rounds: int = 10, + d_split: Optional[int] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + enable_pdl: bool = False, +) -> torch.Tensor: +[Global Functions] @flashinfer_api(trace=selective_state_update_trace) def selective_state_update( state: torch.Tensor, @@ -1966,6 +2043,7 @@ kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -1991,7 +2069,10 @@ backend: str = "auto", is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, + cute_dsl_impl: str = "auto", +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @flashinfer_api(trace=xqa_batch_decode_mla_trace) @@ -2252,6 +2333,44 @@ norm_out: Optional[torch.Tensor] = None, sf_out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + qkv, + q_weight, + k_weight, + **kwargs, +): + + +@flashinfer_api +def fused_qk_rmsnorm_rope( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + *, + ppf: int, + pph: int, + ppw: int, + num_frame_channels: int, + num_height_channels: int, + num_width_channels: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float = 1e-6, + base: float = 10000.0, + interleave: bool = True, + factor: float = 1.0, + low: float = 0.0, + high: float = 0.0, + attention_factor: float = 1.0, + is_qk_norm: bool = True, + output_fp8: bool = False, + output_quant_scale: float = 1.0, + v_quant_scale: float = 1.0, + q_out: Optional[torch.Tensor] = None, + k_out: Optional[torch.Tensor] = None, + v_out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: [Global Functions] @flashinfer_api def get_batch_indices_positions( @@ -2730,7 +2849,11 @@ skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, causal: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: @flashinfer_api(trace=fmha_v2_prefill_deepseek_trace) @@ -2942,6 +3065,7 @@ is_sf_swizzled_layout: bool = True, alignment: int = 32, enable_pdl: bool | None = None, + is_sf_8x4_layout: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: ``` API changes since v0.6.11.post3 (old approach) ```diff $ git diff v0.6.11.post3..main -- "*.py" | grep -B5 -A20 "@flashinfer_api" -def _reconstruct_value(value: Any) -> Any: +def flush_graph_dumps(synchronize: bool = True) -> int: + """Write CUDA-graph-deferred level-10 dumps to disk. + + When ``FLASHINFER_LOGLEVEL=10`` is active inside ``torch.cuda.graph(...)``, + each ``@flashinfer_api`` call records input/output tensor references instead + of writing immediately or inserting D2H copies into the captured graph. + After ``g.replay()`` completes, calling this function materializes current + tensor values to CPU and serializes them to two places: + + 1. ``inputs.pt``/``outputs.pt`` (or the safetensors equivalents) in the + original dump directory, for backwards compatibility. These files + always reflect the most recent flush. + 2. ``graph_flushes/flush_XXXX/`` under the original dump directory. These + immutable snapshots preserve every explicit flush, so callers can keep + every replay by calling ``flush_graph_dumps()`` after every replay. + + Parameters + ---------- + synchronize : bool, default True + Synchronize the current stream first to ensure the most recent + ``g.replay()`` has completed before materializing tensors. Set to + ``False`` only if you've already synchronized externally. + + Returns + ------- -- routing_logits, None, None, @@ -3199,7 +3362,7 @@ def trtllm_fp4_block_scale_moe( @flashinfer_api(trace=trtllm_fp4_block_scale_routed_moe_trace) def trtllm_fp4_block_scale_routed_moe( - topk_ids: torch.Tensor, + topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], @@ -3231,13 +3394,20 @@ def trtllm_fp4_block_scale_routed_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: - """FP4 block scale MoE operation. + """FP4 block scale MoE operation with pre-computed routing. + + This function supports two pre-computed routing formats: + 1. Packed format: topk_ids is a single tensor with packed (score << 16 | expert_id) + 2. Unpacked format: topk_ids is a tuple of (topk_ids, topk_weights) tensors Args: - topk_ids (torch.Tensor): shape [seq_len, top_k] - Tensor of top-k indices and expert weights. Dtype must be int32. -- norm_topk_prob, routing_replay_out, ) + + +@flashinfer_api +def trtllm_mxint4_block_scale_routed_moe( + topk_ids: torch.Tensor, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + do_finalize: bool = True, -- - except Exception: - return False - - @supported_compute_capability([120, 121]) @flashinfer_api(trace=b12x_fused_moe_trace) def b12x_fused_moe( @@ -74,13 +67,11 @@ def b12x_fused_moe( output_dtype: torch.dtype = torch.bfloat16, activation: str = "silu", activation_precision: str = "fp4", - quant_mode: Optional[str] = None, - source_format: str = "modelopt", ) -> torch.Tensor: """Run fused MoE on SM120/SM121 using b12x CuTe DSL kernels. - The kernel takes bf16 input and runs routing, FC1, activation, FC2, - and scatter through the selected backend. + The kernel takes bf16 input and fuses quantization + routing + + FC1 + activation + FC2 + scatter in a single launch. Automatically selects micro (decode), static, or dynamic backend based on routed row count. @@ -99,19 +90,16 @@ def b12x_fused_moe( w1_alpha: Per-expert global scale for FC1. w2_alpha: Per-expert global scale for FC2. -- @@ -6387,7 +6276,7 @@ def _check_gemm_fp8_nt_groupwise_problem_size( }, common_check=_check_gemm_fp8_nt_groupwise_problem_size, ) -@flashinfer_api +@flashinfer_api(trace=gemm_fp8_nt_groupwise_trace) def gemm_fp8_nt_groupwise( a: torch.Tensor, b: torch.Tensor, @@ -8031,7 +7920,7 @@ def _calculate_block_scale_dims( @functools.lru_cache(maxsize=1024) -def create_cudnn_execution_plans_mxfp8_gemm( +def build_cudnn_gemm_mxfp8_graph( a_shape, a_stride, a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 @@ -8041,7 +7930,11 @@ def create_cudnn_execution_plans_mxfp8_gemm( block_size, o_type, # cudnn.data_type, BF16 or FP16 device, + policy=None, ): + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE -- @@ -229,6 +264,54 @@ def mm_M1_16_K7168_N256( ) +@backend_requirement({}, common_check=_mm_M1_16_K6144_N256_shape_checks) +@flashinfer_api(trace=mm_M1_16_K6144_N256_trace) +def mm_M1_16_K6144_N256( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + out: torch.Tensor, + launch_with_pdl: bool = True, +) -> None: + """Optimized GEMM for the router operation in GLM-MoE-DSA. + + This function performs a highly optimized matrix multiplication specifically tailored + for the expert routing GEMM in GLM-MoE-DSA's Mixture of Experts (MoE) architecture. + It computes out = mat_a @ mat_b where mat_a contains token embeddings and mat_b + contains expert routing weights. + + The implementation is optimized for the specific problem dimensions used in GLM-MoE-DSA: + - Hidden dimension (K): 6144 + - Number of experts (N): 256 + - Number of tokens (M): 1-16 + + Args: + mat_a (torch.Tensor): Input token embeddings of shape (M, K) where M is the number -- +) -> None: + """Fake implementation for torch.compile() meta tensor propagation.""" + pass + + +@flashinfer_api +def checkpointing_ssu( + state: torch.Tensor, + old_x: torch.Tensor, + old_B: torch.Tensor, + old_dt: torch.Tensor, + old_cumAdt: torch.Tensor, + cache_buf_idx: torch.Tensor, + prev_num_accepted_tokens: torch.Tensor, + x: torch.Tensor, + dt: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + out: torch.Tensor, + D: Optional[torch.Tensor] = None, + z: Optional[torch.Tensor] = None, + dt_bias: Optional[torch.Tensor] = None, + dt_softplus: bool = False, + state_batch_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, -- page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @flashinfer_api(trace=mla_paged_decode_trace) @@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper: kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Run the MLA attention computation. @@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper: ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. + When ``o_scale`` is provided, this should be an FP8 tensor. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool, optional @@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper: The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``. page_table : Optional[torch.Tensor] The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``. -- + ) + + return True + + +@flashinfer_api +@backend_requirement(backend_checks={}, common_check=_check_fused_qk_rmsnorm_rope) +def fused_qk_rmsnorm_rope( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + *, + ppf: int, + pph: int, + ppw: int, + num_frame_channels: int, + num_height_channels: int, + num_width_channels: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float = 1e-6, + base: float = 10000.0, + interleave: bool = True, + factor: float = 1.0,``` **Supplemental: class-wrapper overload stub changes (BatchMLAPagedAttentionWrapper.run gained `o_scale`)** ```diff $ git diff v0.6.11.post3..main -- "flashinfer/mla/_core.py" | grep -B5 -A10 "o_scale" mod = gen_trtllm_gen_fmha_module() @@ -457,6 +881,7 @@ class BatchMLAPagedAttentionWrapper: kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> torch.Tensor: ... @overload @@ -473,6 +898,7 @@ class BatchMLAPagedAttentionWrapper: kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @flashinfer_api(trace=mla_paged_decode_trace) @@ -489,6 +915,7 @@ class BatchMLAPagedAttentionWrapper: kv_len: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, return_lse_base_on_e: bool = False, + o_scale: Optional[float] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Run the MLA attention computation. @@ -506,6 +933,7 @@ class BatchMLAPagedAttentionWrapper: ``head_dim_kpe`` is 64 in DeepSeek v2/v3 models. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. + When ``o_scale`` is provided, this should be an FP8 tensor. lse : Optional[torch.Tensor] The log-sum-exp of attention logits, if not provided, will be allocated internally. return_lse : bool, optional @@ -516,6 +944,10 @@ class BatchMLAPagedAttentionWrapper: The query length of each request, shape: ``[batch_size]``. Required when ``backend`` is ``cutlass``. page_table : Optional[torch.Tensor] The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``. + o_scale : Optional[float] + FP8 output dequantization scale (``real = quantized * o_scale``). + When provided, ``out`` must be an FP8 tensor. Only supported with + the ``cutlass`` backend. """ if self._backend == "cutlass": if return_lse: @@ -525,7 +957,26 @@ class BatchMLAPagedAttentionWrapper: "profiler_buffer does not support cutlass backend for now." ) self._cached_module = get_mla_module() - if out is None: + output_scale = 1.0 + if o_scale is not None: + output_scale = float(o_scale) + if not math.isfinite(output_scale) or output_scale <= 0.0: + raise ValueError( + f"o_scale must be a finite positive value, got {o_scale}" + ) + if out is None: + raise ValueError( + "out tensor must be provided when o_scale is used for FP8 output." + ) + if out.dtype not in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ): + raise ValueError( + f"out must be an FP8 tensor when o_scale is provided, got {out.dtype}" + ) + check_shape_dtype_device(out, q_nope.shape, None, q_nope.device, "out") + elif out is None: out = torch.empty_like(q_nope) else: check_shape_dtype_device( @@ -543,9 +994,14 @@ class BatchMLAPagedAttentionWrapper: ckv_kpe_cache, kv_len, page_table, + output_scale, ) return out + if o_scale is not None: + raise ValueError( + "o_scale is only supported with the cutlass backend for now." + ) if profiler_buffer is None: if self._use_profiler: raise ValueError( @@ -615,7 +1071,10 @@ def trtllm_batch_decode_with_kv_cache_mla( backend: str = "auto", is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, ``` **Supplemental: `trtllm_batch_decode_with_kv_cache` / `trtllm_batch_context_with_kv_cache` gained `lse` and `return_lse` parameters (signature widening — BC)** ```diff $ git diff v0.6.11.post3..main -- "flashinfer/decode.py" "flashinfer/prefill.py" | grep -B3 -A6 "return_lse: bool = False" uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- -- causal: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Version bumped to 0.6.12. <!-- review_stack_entry_start --> [](https://app.coderabbit.ai/change-stack/flashinfer-ai/flashinfer/pull/3388?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
A utility script that can be used for API review/QA purposes
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Examples
List all
@flashinfer_api-decorated APIs (module-level and class methods), grouped by class (or[Global Functions]per file), with full multi-line signatures.List current API surface
Signatures only (no paths, no line numbers)
Class methods only (skip module-level functions)
Inspect a single file
Run against a git tag / branch / SHA
Auto-fetches from
upstreamororiginif not present locally; uses a throwaway worktree.Diff the API surface between two revisions
Use
-dfor stable, byte-identical output so the diff only reflects real API changes.Flags
-n,--no-lines-p,--no-paths-M,--methods-only-d,--deterministic-r,--ref REF-h,--helpSummary by CodeRabbit