bump version to 0.6.7 & fix api breaking changes#2832
bump version to 0.6.7 & fix api breaking changes#2832aleozlx merged 5 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses API breaking changes for the upcoming 0.6.7 release by refining function signatures across several modules. It primarily focuses on making certain cache-related parameters keyword-only to improve API stability and clarity. Additionally, it enhances the flexibility of quantization scale handling in normalization functions by temporarily allowing float inputs while guiding users towards a more robust Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
📝 WalkthroughWalkthroughThe PR makes KV-cache scale tensors keyword-only in the Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces several API changes to improve backward compatibility and future-proofing. The changes in flashinfer/decode.py and flashinfer/xqa.py correctly adapt to making k_sf_cache and v_sf_cache keyword-only arguments, which is a good API design practice. The modifications in flashinfer/norm/__init__.py add backward compatibility for the scale parameter by allowing a float value, while issuing a helpful FutureWarning. The implementation is sound. I have one minor suggestion regarding code style in flashinfer/norm/__init__.py to improve adherence to PEP 8.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/norm/__init__.py (1)
186-187:⚠️ Potential issue | 🟡 MinorUpdate
scaleparameter docs to match the new API contract.The docstrings still say
scale: torch.Tensor, but the function now accepts float (deprecated) as well. This mismatch will confuse users.Proposed doc update
- scale: torch.Tensor - Scale factor for quantization, shape (1,). + scale: Union[float, torch.Tensor] + Quantization scale. `torch.Tensor` of shape (1,) is preferred. + Passing `float` is deprecated and kept temporarily for compatibility.Also applies to: 301-302
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 186 - 187, Update the docstring for the parameter named "scale" to reflect the new API contract: change the type from "torch.Tensor" to indicate it accepts either a torch.Tensor or a float (with a note that float usage is deprecated), and clarify expected shape/semantics (e.g., torch.Tensor shape (1,) or scalar float). Locate the two occurrences in this module where "scale: torch.Tensor" is documented (the block around the earlier occurrence and the second occurrence near the later docstring) and update both to "scale: Union[torch.Tensor, float] — Scale factor for quantization; preferred as torch.Tensor of shape (1,), float is accepted but deprecated."
🧹 Nitpick comments (1)
flashinfer/norm/__init__.py (1)
65-77: Tighten non-tensor input validation forscale.Current logic warns for any non-tensor value, but only float input is intended here. Add an explicit type gate so invalid inputs fail with a clear
TypeErrorinstead of implicit tensor-construction errors.Proposed patch
def _normalize_scale_tensor( scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor ) -> torch.Tensor: """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" - if not isinstance(scale, torch.Tensor): + if not isinstance(scale, torch.Tensor): + if not isinstance(scale, float): + raise TypeError( + f"scale must be float or torch.Tensor, got {type(scale).__name__}" + ) import warnings warnings.warn( "Passing scale as a float is deprecated and will be removed in a future " "release. Use a torch.Tensor of shape (1,) instead.", FutureWarning, stacklevel=3, ) scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 65 - 77, The current normalization helper accepts any non-torch.Tensor and attempts to convert it, which masks invalid types; update the input validation so that if scale is a torch.Tensor proceed as before, if it's a float emit the existing FutureWarning and convert via torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if scale is neither float nor torch.Tensor raise a TypeError with a clear message; adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce this type gate and avoid implicit tensor-construction errors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/norm/__init__.py`:
- Around line 186-187: Update the docstring for the parameter named "scale" to
reflect the new API contract: change the type from "torch.Tensor" to indicate it
accepts either a torch.Tensor or a float (with a note that float usage is
deprecated), and clarify expected shape/semantics (e.g., torch.Tensor shape (1,)
or scalar float). Locate the two occurrences in this module where "scale:
torch.Tensor" is documented (the block around the earlier occurrence and the
second occurrence near the later docstring) and update both to "scale:
Union[torch.Tensor, float] — Scale factor for quantization; preferred as
torch.Tensor of shape (1,), float is accepted but deprecated."
---
Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 65-77: The current normalization helper accepts any
non-torch.Tensor and attempts to convert it, which masks invalid types; update
the input validation so that if scale is a torch.Tensor proceed as before, if
it's a float emit the existing FutureWarning and convert via
torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device), but if
scale is neither float nor torch.Tensor raise a TypeError with a clear message;
adjust the branch around scale/ref_tensor and the FutureWarning usage to enforce
this type gate and avoid implicit tensor-construction errors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 87962e9b-6d94-4598-9a79-016ec25b4bbd
📒 Files selected for processing (3)
flashinfer/decode.pyflashinfer/norm/__init__.pyflashinfer/xqa.py
|
/bot run |
jimmyzho
left a comment
There was a problem hiding this comment.
lgtm for decode, just left question for clarity
|
[FAILED] Pipeline #46621950: 6/20 passed |
|
ugh internal CI has caught errors on xqa.. i'll fix them later today |
|
/bot run |
|
api changes are clean now |
|
waiting on pipeline run #46807950 |
|
B200/300 issues on cuda 12.9 is on main branch |
| """Normalize quantization scale to 1D tensor of shape (1,) on target device.""" | ||
| if not isinstance(scale, torch.Tensor): | ||
| raise TypeError(f"scale must be torch.Tensor, got {type(scale)}") | ||
| warnings.warn( |
There was a problem hiding this comment.
Why does this interface have to be stable version over version? Naively I would think that _-prefixed functions are not part of the public interface and therefore have no external stability guarantees.
There was a problem hiding this comment.
rmsnorm_quant and fused_add_rmsnorm_quant were the breaking changes (type changed). the compatibility fix landed in this helper function so both old and new signature are supported
📌 Description
fix api breaking changes for 0.6.7 release
🔍 Related Issues (Gated-by PRs)
https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
API changes review
API changes since v0.6.6
PR #2520 + commit e35c19e (fixed to be compatible)
Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params (after *). Backward-compatible.
PR #2618 (has PR #2730 to fix it)
Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still defaults to True at runtime but emits a deprecation
warning; will flip to False in 0.7.0.
PR #2775 (expected — cute DSL MoE cleanup)
Function: blockscaled_contiguous_grouped_gemm_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.
Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
Change: Entire @flashinfer_api decorated function deleted.
Function: blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param.
Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param.
Function: CuteDslMoEWrapper.init()
Change: Added enable_pdl: bool = True param. Backward-compatible.
Function: cute_dsl_fused_moe_nvfp4()
Change: Added enable_pdl: bool = True param. Backward-compatible.
PR #2428
Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type torch.Tensor → None.
Function: fused_add_rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor].
Quantization functions (relocated, not removed)
All quantization APIs (fp4_quantize, block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and init.py exports are preserved. No breakage.
$ git diff v0.6.6 | grep -A20 "@flashinfer_api" @flashinfer_api @@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper: sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, + kv_block_scales: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + # Unpack kv_block_scales + key_block_scales = None + value_block_scales = None + if kv_block_scales is not None: + if isinstance(kv_block_scales, tuple): + key_block_scales, value_block_scales = kv_block_scales -- -@flashinfer_api -def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - This function implements FP4 quantization that converts input tensors to a compressed FP4 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- -@flashinfer_api -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - This function swizzles the block scale tensor to optimize memory access patterns - for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. - - Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: Swizzled tensor with the same shape as input. - - Raises: - AssertionError: If input dtype is not uint8 or bfloat16. - """ - # TODO(shuw): check input dtype is uint8 - assert ( - unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 - ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" - -- -@flashinfer_api -def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- -@flashinfer_api -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: - """ - PyTorch equivalent of trtllm-gen `shuffleMatrixA` - """ - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] - - -@flashinfer_api -def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, -): - """ - Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. - `shuffleMatrixSfA` expects the input to be in 128x4 layout and then - apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 - layout. - This function expects the input to be in linear layout. It's done this - way because the scaling factors in the NVFP4 checkpoints are quantized - and are in linear layout. - This function doesn't add padding. - """ - - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - -- -@flashinfer_api -def nvfp4_quantize( - a, - a_global_sf, - sfLayout=SfLayout.layout_128x4, - do_shuffle=False, - sf_vec_size=16, - enable_pdl=None, -): - """ - Quantize input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. - do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - -- -@flashinfer_api -def mxfp4_quantize(a): - """ - Quantize input tensor to MXFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() - a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_dequantize(a_fp4, a_sf): - """ - Dequantize input tensor from MXFP4 format. - - Parameters: - a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), - 32, - 0, - True, - ) - -- -@flashinfer_api -def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Dequantize input tensor from MXFP4 format on host. - - Parameters: - weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - group_size (int, optional): Group size for dequantization. Defaults to 32. - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # use any cuda device to get a compute capability -- -@flashinfer_api -def nvfp4_batched_quantize( - a, - a_global_sf, - sf_vec_size=16, -): - """ - Quantize batched input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" -- -@flashinfer_api -def scaled_fp4_grouped_quantize( - a, - mask, - a_global_sf, -): - """ - quantize batched input tensor to NVFP4 format with mask. - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module( - device_arch -- -@flashinfer_api -def mxfp8_quantize( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 -- -@flashinfer_api -def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - This function performs dequantization by converting a packed FP8 tensor in MxFP8 format - back to float values using the associated scale factors. - - Args: - input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - -- -@flashinfer_api def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( vectorized_f32: bool = True, raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads. @@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( major, minor = get_compute_capability(a.device) if major != 10: raise ValueError( - f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). " + f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). " f"Got SM{major}{minor}." ) -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (128, 128), - cluster_shape_mn: Tuple[int, int] = (1, 1), - sm_count: Optional[int] = None, -) -> torch.Tensor: - """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization. - -- -@flashinfer_api def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( cluster_shape_mn: Tuple[int, int] = (2, 1), raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads. @@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1. token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16 out: Optional output tensor, shape (seq_len, n). Created if None. - This tensor is used for atomic accumulation, so it should be zero-initialized. + This tensor is used for atomic accumulation. If `out` is + provided, it must already be zero-initialized by the caller. + If `out` is None, this function allocates a zero-initialized + output tensor. Passing a non-zeroed `out` buffer will silently -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - out_scale: Optional[torch.Tensor] = None, - global_scale: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (256, 128), - cluster_shape_mn: Tuple[int, int] = (2, 1), - vectorized_f32: bool = True, - sm_count: Optional[int] = None, -- @flashinfer_api def __init__( self, @@ -347,6 +355,7 @@ class CuteDslMoEWrapper: sf_vec_size: int = 16, output_dtype: torch.dtype = torch.bfloat16, device: str = "cuda", + enable_pdl: bool = True, ): """Initialize the MoE wrapper. @@ -363,6 +372,7 @@ class CuteDslMoEWrapper: sf_vec_size: Scale factor vector size. Default: 16. output_dtype: Output data type. Default: torch.bfloat16. device: Device for buffer allocation. Default: "cuda". + enable_pdl: Enable Programmatic Dependent Launch. Default: True. """ self.num_experts = num_experts self.top_k = top_k @@ -376,6 +386,7 @@ class CuteDslMoEWrapper: self.sf_vec_size = sf_vec_size -- @flashinfer_api @@ -550,9 +570,10 @@ class CuteDslMoEWrapper: f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" ) - # Allocate output buffer if not using pre-allocated one + # Slice the pre-allocated buffer to the active batch so that + # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. if self.use_cuda_graph: - moe_output = self._moe_output + moe_output = self._moe_output[:num_tokens] else: moe_output = torch.empty( (num_tokens, self.hidden_size), @@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Internal implementation called by auto-tuner for functional API.""" -- @flashinfer_api def cute_dsl_fused_moe_nvfp4( x: torch.Tensor, @@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Run fused MoE computation using CuteDSL NVFP4 kernels. + Supported architectures: SM100, SM103. + This is the simple functional API. For CUDA graph support, use `CuteDslMoEWrapper` instead. @@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4( local_expert_offset=local_expert_offset, use_fused_finalize=use_fused_finalize, output_dtype=output_dtype, + enable_pdl=enable_pdl, -- @flashinfer_api def gated_delta_rule_decode_pretranspose( q: torch.Tensor, @@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose( - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used (supports both the direct ``state`` path and the pool+indices path). - - pool+indices (``initial_state``/``initial_state_indices``) only supported - via the bf16 fast path; float32 state raises an error. + - pool+indices (``initial_state``/``initial_state_indices``) supported on + both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes @@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( -- @flashinfer_api def gated_delta_rule_mtp( q: torch.Tensor, @@ -2427,7 +489,7 @@ def gated_delta_rule_mtp( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, intermediate_states_buffer: Optional[torch.Tensor] = None, - disable_state_update: bool = True, + disable_state_update: Optional[bool] = None, use_qk_l2norm: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -2463,8 +525,15 @@ def gated_delta_rule_mtp( intermediate_states_buffer (Optional[torch.Tensor]): Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``. If None, intermediate states are not cached. - disable_state_update (bool): - If True, the initial state is not updated. Default: ``True``. + disable_state_update (Optional[bool]): + If True, the initial state is not updated. Currently defaults to ``True``. + Please pass this argument explicitly — the default will change to ``False`` -- @flashinfer_api @@ -60,16 +120,14 @@ def rmsnorm( output: torch.Tensor Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, enable_pdl) + _rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) -def _rmsnorm( +def _rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -78,11 +136,21 @@ def _rmsnorm( -- @flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, @@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek( If return_lse is False, the output will be a single tensor. """ if not is_sm12x_supported(query.device): - major, minor = get_compute_capability(query.device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} " - f"for SM12{minor}x GPUs." - ) raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) - module = get_trtllm_fmha_v2_module() + module = get_trtllm_fmha_v2_sm120_module() is_e4m3 = query.dtype == torch.float8_e4m3fn -- +@flashinfer_api +def trtllm_fmha_v2_prefill( + qkv: Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ], + input_layout: str, + workspace_buffer: torch.Tensor, + seq_lens: torch.Tensor, + max_q_len: int, + max_kv_len: int, + bmm1_scale: float, + bmm2_scale: float, + batch_size: int, + cum_seq_lens_q: torch.Tensor, + cum_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[Union[torch.dtype, str]] = None, + sinks: Optional[List[torch.Tensor]] = None, -- +@flashinfer_api +def fp4_quantize( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + This function implements FP4 quantization that converts input tensors to a compressed FP4 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- +@flashinfer_api +def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + This function swizzles the block scale tensor to optimize memory access patterns + for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. + + Args: + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: Swizzled tensor with the same shape as input. + + Raises: + AssertionError: If input dtype is not uint8 or bfloat16. + """ + # TODO(shuw): check input dtype is uint8 + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" + -- +@flashinfer_api +def e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- +@flashinfer_api +def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: + """ + PyTorch equivalent of trtllm-gen `shuffleMatrixA` + """ + row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) + + return input_tensor[row_indices.to(input_tensor.device)] + + +@flashinfer_api +def shuffle_matrix_sf_a( + input_tensor: torch.Tensor, + epilogue_tile_m: int, + num_elts_per_sf: int = 16, +): + """ + Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. + `shuffleMatrixSfA` expects the input to be in 128x4 layout and then + apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 + layout. + This function expects the input to be in linear layout. It's done this + way because the scaling factors in the NVFP4 checkpoints are quantized + and are in linear layout. + This function doesn't add padding. + """ + + row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) + + w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + -- +@flashinfer_api +def nvfp4_quantize( + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, +): + """ + Quantize input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. + do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + -- +@flashinfer_api +def mxfp4_quantize( + a: torch.Tensor, + backend: str = "cuda", + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic + Dependent Launch). Only used when backend="cute-dsl". + If None, automatically detects based on device capability. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) -- +@flashinfer_api +def mxfp4_dequantize(a_fp4, a_sf): + """ + Dequantize input tensor from MXFP4 format. + + Parameters: + a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + return e2m1_and_ufp8sf_scale_to_float( + a_fp4.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0], device=a_fp4.device), + 32, + 0, + True, + ) + -- +@flashinfer_api +def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Dequantize input tensor from MXFP4 format on host. + + Parameters: + weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + group_size (int, optional): Group size for dequantization. Defaults to 32. + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # use any cuda device to get a compute capability -- +@flashinfer_api +def nvfp4_batched_quantize( + a, + a_global_sf, + sf_vec_size=16, +): + """ + Quantize batched input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" -- +@flashinfer_api +def nvfp4_quantize_paged_kv_cache( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_layout: str = "HND", + k_global_sf: Optional[torch.Tensor] = None, + v_global_sf: Optional[torch.Tensor] = None, +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + float, + float, +]: + """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA. + + Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling + (global FP32 + per-block FP8), and swizzles scale factors + for the SM100 trtllm-gen MHA kernel layout. + + Args: + k_cache: Key cache tensor. -- +@flashinfer_api +def scaled_fp4_grouped_quantize( + a, + mask, + a_global_sf, +): + """ + quantize batched input tensor to NVFP4 format with mask. + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + mask (torch.Tensor): Mask tensor to apply before quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module( + device_arch -- +@flashinfer_api +def nvfp4_kv_dequantize( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """GPU dequantization of NVFP4 KV cache data with linear block scale layout. + + Requires SM80+. + + Args: + fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` + with dtype uint8. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as fp4_data. + output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. + + Returns: + torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. -- +@flashinfer_api +def nvfp4_kv_quantize( + input: torch.Tensor, + global_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """GPU quantization to NVFP4 KV cache format with linear block scale layout. + + Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. + K must be divisible by 16. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as input. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. + """ + M, K = input.shape -- +@flashinfer_api +def mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + backend: Literal["cuda", "cute-dsl"] = "cuda", + sf_swizzle_layout: Optional[SfLayout] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. + backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are: -- +@flashinfer_api +def mxfp8_dequantize_host( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: Optional[SfLayout] = None, +) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + This function performs dequantization by converting a packed FP8 tensor in MxFP8 format + back to float values using the associated scale factors. + + Args: + input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear. + + Returns: -- +@flashinfer_api +def mxfp4_quantize_cute_dsl( + input: torch.Tensor, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: + - Global scale computed as (448 * 6) / max(|input|) + - UE8M0 scale factors + - E2M1 output format (4-bit, 2 values per byte) + - Swizzled (128x4) scale factor layout + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). -- +@flashinfer_api +def mxfp8_quantize_cute_dsl( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP8 format using CuTe-DSL kernel. + + This is a GPU implementation with dual-path optimization: + - LINEAR layout: SF-block based iteration (fast) + - SWIZZLED layout: Row-based iteration with padding fast path (optimized) + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False) + alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)Summary by CodeRabbit