Skip to content

bump version to 0.6.7 & fix api breaking changes#2832

Merged
aleozlx merged 5 commits intoflashinfer-ai:mainfrom
aleozlx:fix_0.6.7
Mar 24, 2026
Merged

bump version to 0.6.7 & fix api breaking changes#2832
aleozlx merged 5 commits intoflashinfer-ai:mainfrom
aleozlx:fix_0.6.7

Conversation

@aleozlx
Copy link
Collaborator

@aleozlx aleozlx commented Mar 20, 2026

📌 Description

fix api breaking changes for 0.6.7 release

🔍 Related Issues (Gated-by PRs)

https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

API changes review

API changes since v0.6.6

PR #2520 + commit e35c19e (fixed to be compatible)

Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params (after *). Backward-compatible.

PR #2618 (has PR #2730 to fix it)

Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still defaults to True at runtime but emits a deprecation
warning; will flip to False in 0.7.0.

PR #2775 (expected — cute DSL MoE cleanup)

Function: blockscaled_contiguous_grouped_gemm_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.

Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.

Function: blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param.

Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param.

Function: CuteDslMoEWrapper.init()
Change: Added enable_pdl: bool = True param. Backward-compatible.

Function: cute_dsl_fused_moe_nvfp4()
Change: Added enable_pdl: bool = True param. Backward-compatible.

PR #2428

Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type torch.Tensor → None.

Function: fused_add_rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor].

Quantization functions (relocated, not removed)

All quantization APIs (fp4_quantize, block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and init.py exports are preserved. No breakage.

$ git diff v0.6.6 | grep -A20 "@flashinfer_api"                                               
     @flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
         sinks: Optional[torch.Tensor] = None,
         q_len_per_req: Optional[int] = 1,
         skip_softmax_threshold_scale_factor: Optional[float] = None,
+        kv_block_scales: Optional[
+            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+        ] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Compute batch decode attention between query and paged kv cache.

@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
             enable_pdl = device_support_pdl(q.device)
         k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

+        # Unpack kv_block_scales
+        key_block_scales = None
+        value_block_scales = None
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
-    input: torch.Tensor,
-    global_scale: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    sf_use_ue8m0: bool = False,
-    is_sf_swizzled_layout: bool = True,
-    is_sf_8x4_layout: bool = False,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to FP4 format.
-
-    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
-    """Swizzle block scale tensor for FP4 format.
-
-    This function swizzles the block scale tensor to optimize memory access patterns
-    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
-    Args:
-        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
-    Returns:
-        torch.Tensor: Swizzled tensor with the same shape as input.
-
-    Raises:
-        AssertionError: If input dtype is not uint8 or bfloat16.
-    """
-    # TODO(shuw): check input dtype is uint8
-    assert (
-        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
-    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
-    e2m1_tensor: torch.Tensor,
-    ufp8_scale_tensor: torch.Tensor,
-    global_scale_tensor: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    ufp8_type: int = 1,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
-    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
-    back to float values using the associated UFP8 scale factors and global scale.
-
-    Args:
-        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
-        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
-        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
-    """
-    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
-    """
-    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
-    return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
-    input_tensor: torch.Tensor,
-    epilogue_tile_m: int,
-    num_elts_per_sf: int = 16,
-):
-    """
-    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
-    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
-    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
-    layout.
-    This function expects the input to be in linear layout. It's done this
-    way because the scaling factors in the NVFP4 checkpoints are quantized
-    and are in linear layout.
-    This function doesn't add padding.
-    """
-
-    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
-    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
-    a,
-    a_global_sf,
-    sfLayout=SfLayout.layout_128x4,
-    do_shuffle=False,
-    sf_vec_size=16,
-    enable_pdl=None,
-):
-    """
-    Quantize input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
-        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
-    """
-    Quantize input tensor to MXFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-            - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-    """
-    a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
-    a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
-    return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
-    """
-    Dequantize input tensor from MXFP4 format.
-
-    Parameters:
-        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    return e2m1_and_ufp8sf_scale_to_float(
-        a_fp4.cpu().view(torch.uint8),
-        a_sf.cpu().view(torch.uint8).reshape(-1),
-        torch.tensor([1.0], device=a_fp4.device),
-        32,
-        0,
-        True,
-    )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
-    weight: torch.Tensor,
-    scale: torch.Tensor,
-    group_size: int = 32,
-) -> torch.Tensor:
-    """
-    Dequantize input tensor from MXFP4 format on host.
-
-    Parameters:
-        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-        group_size (int, optional): Group size for dequantization. Defaults to 32.
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
-    major, minor = get_compute_capability(
-        torch.device("cuda:0")
-    )  # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
-    a,
-    a_global_sf,
-    sf_vec_size=16,
-):
-    """
-    Quantize batched input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
-    a,
-    mask,
-    a_global_sf,
-):
-    """
-    quantize batched input tensor to NVFP4 format with mask.
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
-    a_fp4, a_sf = get_fp4_quantization_module(
-        device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
-    input: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-    alignment: int = 32,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to MxFP8 format.
-
-    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
-        alignment (int, optional): sfVecSize. Defaults to 32.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
-    input: torch.Tensor,
-    scale_tensor: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Dequantize input tensor from MxFP8 format.
-
-    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
-    back to float values using the associated scale factors.
-
-    Args:
-        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
-        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
-    Returns:
-        torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
-    """
-
--
-@flashinfer_api
 def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     vectorized_f32: bool = True,
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.

@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     major, minor = get_compute_capability(a.device)
     if major != 10:
         raise ValueError(
-            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
             f"Got SM{major}{minor}."
         )

--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (128, 128),
-    cluster_shape_mn: Tuple[int, int] = (1, 1),
-    sm_count: Optional[int] = None,
-) -> torch.Tensor:
-    """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
 def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     cluster_shape_mn: Tuple[int, int] = (2, 1),
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.

@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
             expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
         token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
         out: Optional output tensor, shape (seq_len, n). Created if None.
-             This tensor is used for atomic accumulation, so it should be zero-initialized.
+             This tensor is used for atomic accumulation. If `out` is
+             provided, it must already be zero-initialized by the caller.
+             If `out` is None, this function allocates a zero-initialized
+             output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    out_scale: Optional[torch.Tensor] = None,
-    global_scale: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (256, 128),
-    cluster_shape_mn: Tuple[int, int] = (2, 1),
-    vectorized_f32: bool = True,
-    sm_count: Optional[int] = None,
--
     @flashinfer_api
     def __init__(
         self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
         sf_vec_size: int = 16,
         output_dtype: torch.dtype = torch.bfloat16,
         device: str = "cuda",
+        enable_pdl: bool = True,
     ):
         """Initialize the MoE wrapper.

@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
             sf_vec_size: Scale factor vector size. Default: 16.
             output_dtype: Output data type. Default: torch.bfloat16.
             device: Device for buffer allocation. Default: "cuda".
+            enable_pdl: Enable Programmatic Dependent Launch. Default: True.
         """
         self.num_experts = num_experts
         self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
         self.sf_vec_size = sf_vec_size
--
     @flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
                 f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
             )

-        # Allocate output buffer if not using pre-allocated one
+        # Slice the pre-allocated buffer to the active batch so that
+        # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
         if self.use_cuda_graph:
-            moe_output = self._moe_output
+            moe_output = self._moe_output[:num_tokens]
         else:
             moe_output = torch.empty(
                 (num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Internal implementation called by auto-tuner for functional API."""
--
 @flashinfer_api
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Run fused MoE computation using CuteDSL NVFP4 kernels.

+    Supported architectures: SM100, SM103.
+
     This is the simple functional API. For CUDA graph support, use
     `CuteDslMoEWrapper` instead.

@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
         local_expert_offset=local_expert_offset,
         use_fused_finalize=use_fused_finalize,
         output_dtype=output_dtype,
+        enable_pdl=enable_pdl,
--
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
         - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
           and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
           (supports both the direct ``state`` path and the pool+indices path).
-        - pool+indices (``initial_state``/``initial_state_indices``) only supported
-          via the bf16 fast path; float32 state raises an error.
+        - pool+indices (``initial_state``/``initial_state_indices``) supported on
+          both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+          (T=1). The float32 path also supports negative indices for padding.
         - Legacy path (float32 state, T=1): K and V must be multiples of 4.
     """
     # Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
         return_state = initial_state if use_pool else state
         return output, return_state

-    # Legacy path: T=1 only, float32 state (no pool+indices support)
-    assert not use_pool, (
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: Optional[bool] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
         intermediate_states_buffer (Optional[torch.Tensor]):
             Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
             If None, intermediate states are not cached.
-        disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+        disable_state_update (Optional[bool]):
+            If True, the initial state is not updated. Currently defaults to ``True``.
+            Please pass this argument explicitly — the default will change to ``False``
--
 @flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
     output: torch.Tensor
         Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
     """
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(input.device)
     if out is None:
         out = torch.empty_like(input)
-    _rmsnorm(out, input, weight, eps, enable_pdl)
+    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
     return out


 @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
     out: torch.Tensor,
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
 @flashinfer_api
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
         If return_lse is False, the output will be a single tensor.
     """
     if not is_sm12x_supported(query.device):
-        major, minor = get_compute_capability(query.device)
-        if major == 12:
-            min_cuda = "13.0" if minor >= 1 else "12.8"
-            raise ValueError(
-                f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
-                f"for SM12{minor}x GPUs."
-            )
         raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
     assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
         "currently only support deepseek r1 192 query and 128 value"
     )
-    module = get_trtllm_fmha_v2_module()
+    module = get_trtllm_fmha_v2_sm120_module()
     is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+    qkv: Union[
+        torch.Tensor,
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ],
+    input_layout: str,
+    workspace_buffer: torch.Tensor,
+    seq_lens: torch.Tensor,
+    max_q_len: int,
+    max_kv_len: int,
+    bmm1_scale: float,
+    bmm2_scale: float,
+    batch_size: int,
+    cum_seq_lens_q: torch.Tensor,
+    cum_seq_lens_kv: torch.Tensor,
+    block_tables: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    out_dtype: Optional[Union[torch.dtype, str]] = None,
+    sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+    input: torch.Tensor,
+    global_scale: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    sf_use_ue8m0: bool = False,
+    is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to FP4 format.
+
+    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+    """Swizzle block scale tensor for FP4 format.
+
+    This function swizzles the block scale tensor to optimize memory access patterns
+    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+    Args:
+        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+    Returns:
+        torch.Tensor: Swizzled tensor with the same shape as input.
+
+    Raises:
+        AssertionError: If input dtype is not uint8 or bfloat16.
+    """
+    # TODO(shuw): check input dtype is uint8
+    assert (
+        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+    e2m1_tensor: torch.Tensor,
+    ufp8_scale_tensor: torch.Tensor,
+    global_scale_tensor: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    ufp8_type: int = 1,
+    is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+    back to float values using the associated UFP8 scale factors and global scale.
+
+    Args:
+        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+    """
+    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+    """
+    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+    return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+    input_tensor: torch.Tensor,
+    epilogue_tile_m: int,
+    num_elts_per_sf: int = 16,
+):
+    """
+    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+    layout.
+    This function expects the input to be in linear layout. It's done this
+    way because the scaling factors in the NVFP4 checkpoints are quantized
+    and are in linear layout.
+    This function doesn't add padding.
+    """
+
+    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+    a,
+    a_global_sf,
+    sfLayout=SfLayout.layout_128x4,
+    do_shuffle=False,
+    sf_vec_size=16,
+    enable_pdl=None,
+):
+    """
+    Quantize input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+    a: torch.Tensor,
+    backend: str = "cuda",
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        backend (str, optional): Backend to use for quantization.
+            - "cuda": Use CUDA kernel (default, stable)
+            - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+            Dependent Launch). Only used when backend="cute-dsl".
+            If None, automatically detects based on device capability.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+    """
+    Dequantize input tensor from MXFP4 format.
+
+    Parameters:
+        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    return e2m1_and_ufp8sf_scale_to_float(
+        a_fp4.cpu().view(torch.uint8),
+        a_sf.cpu().view(torch.uint8).reshape(-1),
+        torch.tensor([1.0], device=a_fp4.device),
+        32,
+        0,
+        True,
+    )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+    weight: torch.Tensor,
+    scale: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """
+    Dequantize input tensor from MXFP4 format on host.
+
+    Parameters:
+        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+        group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+    major, minor = get_compute_capability(
+        torch.device("cuda:0")
+    )  # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+    a,
+    a_global_sf,
+    sf_vec_size=16,
+):
+    """
+    Quantize batched input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    kv_layout: str = "HND",
+    k_global_sf: Optional[torch.Tensor] = None,
+    v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[torch.Tensor, torch.Tensor],
+    float,
+    float,
+]:
+    """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+    Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+    (global FP32 + per-block FP8), and swizzles scale factors
+    for the SM100 trtllm-gen MHA kernel layout.
+
+    Args:
+        k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+    a,
+    mask,
+    a_global_sf,
+):
+    """
+    quantize batched input tensor to NVFP4 format with mask.
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        mask (torch.Tensor): Mask tensor to apply before quantization.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
+    a_fp4, a_sf = get_fp4_quantization_module(
+        device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+    fp4_data: torch.Tensor,
+    block_scales: torch.Tensor,
+    global_scale: torch.Tensor,
+    output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+    """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+    Requires SM80+.
+
+    Args:
+        fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+        block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+            with dtype uint8.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as fp4_data.
+        output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+    input: torch.Tensor,
+    global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+    Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+            K must be divisible by 16.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as input.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+            - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+    """
+    M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: Optional[bool] = None,
+    backend: Literal["cuda", "cute-dsl"] = "cuda",
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to MxFP8 format.
+
+    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        alignment (int, optional): sfVecSize. Defaults to 32.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+        backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+    input: torch.Tensor,
+    scale_tensor: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+    """Dequantize input tensor from MxFP8 format.
+
+    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+    back to float values using the associated scale factors.
+
+    Args:
+        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+            If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+            Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+    Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+    input: torch.Tensor,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+    This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+    - Global scale computed as (448 * 6) / max(|input|)
+    - UE8M0 scale factors
+    - E2M1 output format (4-bit, 2 values per byte)
+    - Swizzled (128x4) scale factor layout
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+    This is a GPU implementation with dual-path optimization:
+    - LINEAR layout: SF-block based iteration (fast)
+    - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+        alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)

Summary by CodeRabbit

  • Enhancements
    • Normalization now accepts scale as either a float or tensor; passing a float emits a deprecation warning and is auto-converted for compatibility.
    • Attention/decoding API: cache-scale parameters are now optional keyword-only arguments with sensible defaults, simplifying common call patterns.
  • Tests
    • Tests updated to match the adjusted attention/decoding call signature.
  • Chores
    • Release version bumped to 0.6.7.

@aleozlx aleozlx added the v0.6.7 release blocker label for 0.6.7 label Mar 20, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses API breaking changes for the upcoming 0.6.7 release by refining function signatures across several modules. It primarily focuses on making certain cache-related parameters keyword-only to improve API stability and clarity. Additionally, it enhances the flexibility of quantization scale handling in normalization functions by temporarily allowing float inputs while guiding users towards a more robust torch.Tensor approach with a deprecation warning.

Highlights

  • API Signature Changes: The xqa_batch_decode_with_kv_cache function in flashinfer/decode.py and the xqa function in flashinfer/xqa.py had their k_cache_sf and v_cache_sf (or k_sf_cache and v_sf_cache) parameters moved to keyword-only arguments to prevent breaking changes.
  • Quantization Scale Handling: The _normalize_scale_tensor, rmsnorm_quant, and fused_add_rmsnorm_quant functions in flashinfer/norm/__init__.py were updated to accept scale as either a float or torch.Tensor, with a deprecation warning issued when a float is provided to encourage future use of torch.Tensor.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

The PR makes KV-cache scale tensors keyword-only in the xqa API and updates callers; relaxes normalization APIs to accept scale as float or torch.Tensor (with a deprecation warning and conversion); and bumps the package version.

Changes

Cohort / File(s) Summary
XQA API & Callers
flashinfer/xqa.py, flashinfer/decode.py, tests/attention/test_xqa.py
xqa signature changed: k_sf_cache and v_sf_cache removed from positional args and added as keyword-only optional params (*, k_sf_cache=None, v_sf_cache=None). Call sites updated to pass k_sf_cache=/v_sf_cache= (tests adjusted).
Normalization API Flexibility
flashinfer/norm/__init__.py
_normalize_scale_tensor now accepts scale: Union[float, torch.Tensor]; non-tensor scale emits a FutureWarning and is converted to a torch.tensor(...). Public annotations for rmsnorm_quant and fused_add_rmsnorm_quant updated accordingly.
Misc
version.txt
Package version bumped from 0.6.60.6.7.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • yzh119
  • nv-yunzheq
  • cyx-6
  • yyihuang

Poem

🐇 I hopped through lines both new and old,
KV-scales whispered, now keyword-bold,
Floats turned tensors with a gentle chime,
A tiny bump in version-time,
I nibble bugs and celebrate the fold.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title accurately captures the main purpose of the PR: bumping the version to 0.6.7 and fixing API breaking changes, both of which are clearly reflected in the changeset.
Description check ✅ Passed The PR description provides a clear overview of the changes and includes proper documentation of API changes, but uses reviewer notes instead of the template's description section.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several API changes to improve backward compatibility and future-proofing. The changes in flashinfer/decode.py and flashinfer/xqa.py correctly adapt to making k_sf_cache and v_sf_cache keyword-only arguments, which is a good API design practice. The modifications in flashinfer/norm/__init__.py add backward compatibility for the scale parameter by allowing a float value, while issuing a helpful FutureWarning. The implementation is sound. I have one minor suggestion regarding code style in flashinfer/norm/__init__.py to improve adherence to PEP 8.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/norm/__init__.py (1)

186-187: ⚠️ Potential issue | 🟡 Minor

Update scale parameter docs to match the new API contract.

The docstrings still say scale: torch.Tensor, but the function now accepts float (deprecated) as well. This mismatch will confuse users.

Proposed doc update
-    scale: torch.Tensor
-        Scale factor for quantization, shape (1,).
+    scale: Union[float, torch.Tensor]
+        Quantization scale. `torch.Tensor` of shape (1,) is preferred.
+        Passing `float` is deprecated and kept temporarily for compatibility.

Also applies to: 301-302

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/__init__.py` around lines 186 - 187, Update the docstring for
the parameter named "scale" to reflect the new API contract: change the type
from "torch.Tensor" to indicate it accepts either a torch.Tensor or a float
(with a note that float usage is deprecated), and clarify expected
shape/semantics (e.g., torch.Tensor shape (1,) or scalar float). Locate the two
occurrences in this module where "scale: torch.Tensor" is documented (the block
around the earlier occurrence and the second occurrence near the later
docstring) and update both to "scale: Union[torch.Tensor, float] — Scale factor
for quantization; preferred as torch.Tensor of shape (1,), float is accepted but
deprecated."
🧹 Nitpick comments (1)
flashinfer/norm/__init__.py (1)

65-77: Tighten non-tensor input validation for scale.

Current logic warns for any non-tensor value, but only float input is intended here. Add an explicit type gate so invalid inputs fail with a clear TypeError instead of implicit tensor-construction errors.

Proposed patch
 def _normalize_scale_tensor(
     scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor
 ) -> torch.Tensor:
     """Normalize quantization scale to 1D tensor of shape (1,) on target device."""
-    if not isinstance(scale, torch.Tensor):
+    if not isinstance(scale, torch.Tensor):
+        if not isinstance(scale, float):
+            raise TypeError(
+                f"scale must be float or torch.Tensor, got {type(scale).__name__}"
+            )
         import warnings
 
         warnings.warn(
             "Passing scale as a float is deprecated and will be removed in a future "
             "release. Use a torch.Tensor of shape (1,) instead.",
             FutureWarning,
             stacklevel=3,
         )
         scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/__init__.py` around lines 65 - 77, The current normalization
helper accepts any non-torch.Tensor and attempts to convert it, which masks
invalid types; update the input validation so that if scale is a torch.Tensor
proceed as before, if it's a float emit the existing FutureWarning and convert
via torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if
scale is neither float nor torch.Tensor raise a TypeError with a clear message;
adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce
this type gate and avoid implicit tensor-construction errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/norm/__init__.py`:
- Around line 186-187: Update the docstring for the parameter named "scale" to
reflect the new API contract: change the type from "torch.Tensor" to indicate it
accepts either a torch.Tensor or a float (with a note that float usage is
deprecated), and clarify expected shape/semantics (e.g., torch.Tensor shape (1,)
or scalar float). Locate the two occurrences in this module where "scale:
torch.Tensor" is documented (the block around the earlier occurrence and the
second occurrence near the later docstring) and update both to "scale:
Union[torch.Tensor, float] — Scale factor for quantization; preferred as
torch.Tensor of shape (1,), float is accepted but deprecated."

---

Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 65-77: The current normalization helper accepts any
non-torch.Tensor and attempts to convert it, which masks invalid types; update
the input validation so that if scale is a torch.Tensor proceed as before, if
it's a float emit the existing FutureWarning and convert via
torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if
scale is neither float nor torch.Tensor raise a TypeError with a clear message;
adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce
this type gate and avoid implicit tensor-construction errors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 87962e9b-6d94-4598-9a79-016ec25b4bbd

📥 Commits

Reviewing files that changed from the base of the PR and between 6f0928c and e35c19e.

📒 Files selected for processing (3)
  • flashinfer/decode.py
  • flashinfer/norm/__init__.py
  • flashinfer/xqa.py

@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !440 has been created, and the CI pipeline #46621950 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm for decode, just left question for clarity

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving norm changes. Thanks @aleozlx

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46621950: 6/20 passed

@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 20, 2026

ugh internal CI has caught errors on xqa.. i'll fix them later today

@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 23, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !440 has been updated with latest changes, and the CI pipeline #46807950 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 23, 2026

api changes are clean now

@aleozlx aleozlx changed the title fix api breaking changes bump version to 0.6.7 & fix api breaking changes Mar 23, 2026
@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 23, 2026

waiting on pipeline run #46807950

@aleozlx
Copy link
Collaborator Author

aleozlx commented Mar 23, 2026

B200/300 issues on cuda 12.9 is on main branch

"""Normalize quantization scale to 1D tensor of shape (1,) on target device."""
if not isinstance(scale, torch.Tensor):
raise TypeError(f"scale must be torch.Tensor, got {type(scale)}")
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this interface have to be stable version over version? Naively I would think that _-prefixed functions are not part of the public interface and therefore have no external stability guarantees.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rmsnorm_quant and fused_add_rmsnorm_quant were the breaking changes (type changed). the compatibility fix landed in this helper function so both old and new signature are supported

@aleozlx aleozlx enabled auto-merge (squash) March 23, 2026 23:28
@aleozlx aleozlx merged commit 1de1b97 into flashinfer-ai:main Mar 24, 2026
32 of 37 checks passed
@aleozlx aleozlx mentioned this pull request Mar 24, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

v0.6.7 release blocker label for 0.6.7

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants