[CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels#2805
[CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels#2805pgera wants to merge 41 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a full CuTe-DSL Blackwell attention implementation: new FMHA prefill and modular MLA decode kernels, many CuTe role primitives, pipeline/topology and scheduler abstractions, configuration/fusion/masking APIs, PyTorch wrappers (prefill & MLA), FP8 variants, benchmarking script, and extensive CUDA tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (PyTorch)
participant Wrapper as BatchPrefillWrapper
participant Kernel as Compiled Kernel
participant Loader as LoaderRole (TMA)
participant MMA as MmaRole
participant Softmax as SoftmaxRole
participant Corr as Correction/Epilogue
participant GMEM as Global Memory / TMEM
Host->>Wrapper: plan()/run(q, k, v, indptr...)
Wrapper->>Kernel: launch with DLPack tensors & params
Kernel->>Loader: request Q/K/V tiles (TMA producers)
Loader->>GMEM: TMA load Q/K/V -> TMEM/SMEM
Loader-->>MMA: hand off Q/K/P/V fragments
MMA->>Softmax: produce QK scores / write logits to TMEM
Softmax->>GMEM: apply mask/variant, write exponentials / row stats
Softmax->>MMA: provide P fragments for PV GEMM
MMA->>Corr: commit partial outputs (handles)
Corr->>GMEM: rescale and store final outputs (epilogue)
Kernel-->>Host: return out tensor
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
✨ Finishing Touches🧪 Generate unit tests (beta)
|
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a refactored and improved FMHA prefill kernel using a modular design. The new implementation addresses several bugs present in the previous version and offers a more flexible and testable architecture. By composing roles and using declarative pipeline topologies, the kernel supports a wide range of configurations and customizations, making it suitable for various attention mechanisms. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-structured modular rewrite of the FMHA prefill kernel using CuTe DSL. The new architecture, based on composable roles and declarative pipeline topologies, is a major improvement for maintainability and extensibility. The comprehensive test suite covers a wide range of configurations, ensuring the robustness of the new implementation. My review includes a few suggestions for code cleanup, improving maintainability by adding documentation, and addressing potential issues like unused code and a missing JIT decorator.
| # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Shared TMEM utilities for compute roles. | ||
|
|
||
| Provides tmem_load_partition() — partitions TMEM output accumulator for | ||
| load/store by the rescale and epilogue roles. | ||
| """ | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import cutlass | ||
| import cutlass.cute as cute | ||
| import cutlass.cute.nvgpu.tcgen05 as tcgen05 | ||
|
|
||
|
|
||
| @cute.jit | ||
| def tmem_load_partition( | ||
| tmem_ptr: cutlass.Int32, | ||
| tmem_o_offset: int, | ||
| acc_dtype: cutlass.Constexpr, | ||
| mma_pv_tiler: cutlass.Constexpr, | ||
| cluster_shape_mnk: cutlass.Constexpr, | ||
| warps_in_n: int, | ||
| num_compute_warps: int, | ||
| threads_per_warp: int, | ||
| common_params: SimpleNamespace, | ||
| tiled_mma_pv: cute.TiledMma, | ||
| iter_n: int, | ||
| ) -> tuple[ | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| cute.TiledMma, | ||
| ]: | ||
| tOtO_shape = tiled_mma_pv.partition_shape_C( | ||
| cute.select(mma_pv_tiler, mode=[0, 1]) | ||
| ) | ||
| tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) | ||
| tOtO_layout = cute.append( | ||
| tOtO.layout, | ||
| cute.make_layout( | ||
| common_params.L // mma_pv_tiler[1], | ||
| stride=mma_pv_tiler[1] // warps_in_n, | ||
| ), | ||
| ) | ||
| tOtO = cute.make_tensor(tmem_ptr + tmem_o_offset, tOtO_layout) | ||
| tOtO = tOtO[None, None, None, iter_n] | ||
|
|
||
| tAcc = tOtO[(None, None), 0, 0] | ||
|
|
||
| tmem_load_atom = cute.make_copy_atom( | ||
| tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), acc_dtype | ||
| ) | ||
| tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) | ||
| tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( | ||
| common_params.tidx % (num_compute_warps * threads_per_warp) | ||
| ) | ||
|
|
||
| cta_pv_tiler = ( | ||
| mma_pv_tiler[0] // cluster_shape_mnk[0], | ||
| mma_pv_tiler[1], | ||
| mma_pv_tiler[2], | ||
| ) | ||
| cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) | ||
|
|
||
| gO = None | ||
| if cutlass.const_expr(common_params.mAccO is not None): | ||
| gO = cute.local_tile( | ||
| common_params.mAccO[None, common_params.blk_coord[3], None, None], | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| cO = cute.local_tile( | ||
| cute.make_identity_tensor( | ||
| common_params.mAccO[ | ||
| None, common_params.blk_coord[3], None, None | ||
| ].shape | ||
| ), | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| else: | ||
| gO = cute.local_tile( | ||
| common_params.mO, | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
| cO = cute.local_tile( | ||
| cute.make_identity_tensor(common_params.mO.shape), | ||
| cta_pv_tiler_mn, | ||
| (common_params.blk_coord[0], iter_n, common_params.blk_coord[2]), | ||
| ) | ||
|
|
||
| tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) | ||
| tTR_gO = tmem_load_thr_copy.partition_D(gO) | ||
| tTR_cO = tmem_load_thr_copy.partition_D(cO) | ||
| tTR_rAcc = cute.make_fragment_like(tTR_gO, acc_dtype) | ||
| return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc |
| if s_k.shape[0] > 1: | ||
| for i in range(len(s_k)): | ||
| if s_k[i] % self._mma_tiler_mn[1] != 0: | ||
| self._mask_type = MaskType.RESIDUAL_MASK | ||
| else: | ||
| if s_k % self._mma_tiler_mn[1] != 0: | ||
| self._mask_type = MaskType.RESIDUAL_MASK |
There was a problem hiding this comment.
The logic to determine if RESIDUAL_MASK is needed can be simplified. The current implementation iterates over the s_k tensor and has a branch for s_k.shape[0] > 1 which is always taken since s_k is derived from kv_indptr and will be a 1D tensor. You can use torch.any for a more concise and efficient check.
if torch.any(s_k % self._mma_tiler_mn[1] != 0):
self._mask_type = MaskType.RESIDUAL_MASK| @@ -0,0 +1,419 @@ | |||
| from typing import Optional, Type, Tuple | |||
There was a problem hiding this comment.
This file provides a custom implementation of pipeline participants, which appears to be a patch on top of cutlass.pipeline. However, it's missing a file-level docstring explaining why this custom implementation is necessary and what it changes compared to the original. Adding a docstring would greatly improve maintainability and make it easier for other developers to understand the purpose of this module.
| def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| # m is in the RAW (unscaled) domain; convert sink from scaled-logit to RAW | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf |
There was a problem hiding this comment.
The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf | |
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf |
| @cute.jit | ||
| def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf |
There was a problem hiding this comment.
The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf | |
| sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf |
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (13)
flashinfer/cute_dsl/attention/tmem_layout.py (1)
35-49: Consider extractingSM100_TMEM_CAPACITY_COLUMNSas a module-level constant.The SM100 TMEM capacity is a hardware characteristic that may be referenced elsewhere. Extracting it improves discoverability and avoids magic numbers.
Proposed refactor
+SM100_TMEM_CAPACITY_COLUMNS = 512 + + `@dataclass`(frozen=True) class TmemLayout: ... `@staticmethod` def from_config(config: AttentionConfig) -> TmemLayout: tile_m = config.mma_tiler[0] - SM100_TMEM_CAPACITY_COLUMNS = 512 return TmemLayout( ... alloc_cols=SM100_TMEM_CAPACITY_COLUMNS, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/tmem_layout.py` around lines 35 - 49, Extract the literal 512 used for TMEM capacity into a module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the local variable in TmemLayout.from_config so the function uses that constant instead of a magic number; update the top of the module with the constant and ensure TmemLayout.from_config (which takes AttentionConfig and reads config.mma_tiler[0]) references the new constant for alloc_cols so other code can reuse the hardware-capacity value.flashinfer/cute_dsl/attention/scheduler/persistent.py (2)
38-45: Addstrict=Truetozip()for safer MLIR value reconstruction.In
__new_from_mlir_values__, thezip()call iterates over[self.is_persistent, self.problem_shape_mbh]andself._values_pos. If these lists have mismatched lengths (e.g., due to a maintenance error),zip()will silently truncate, potentially causing subtle bugs during MLIR reconstruction.Also, the
ipparameter is not forwarded to the newFmhaStaticTileSchedulerParamsinstance on line 45.Proposed fix
def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( - [self.is_persistent, self.problem_shape_mbh], self._values_pos + [self.is_persistent, self.problem_shape_mbh], self._values_pos, strict=True ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc, ip=self._ip)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 38 - 45, In __new_from_mlir_values__ update the zip over [self.is_persistent, self.problem_shape_mbh] and self._values_pos to use zip(..., strict=True) to fail loudly on length mismatches, and when returning the FmhaStaticTileSchedulerParams instance forward the current object's ip parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as well; this touches the __new_from_mlir_values__ method, the attributes self.is_persistent, self.problem_shape_mbh, self._values_pos, and the FmhaStaticTileSchedulerParams constructor call.
148-158: Hardcoded MLIR value count is fragile.The assertion
assert len(values) == 10couples the implementation to a specific MLIR representation. If any constituent object's MLIR value count changes, this will fail without a clear message.Consider deriving the expected count dynamically or providing a descriptive error message.
Proposed improvement
def __new_from_mlir_values__(self, values): - assert len(values) == 10 + expected = 3 + 1 + 3 + 3 # params(3) + work_idx(1) + blk_coord(3) + grid_shape(3) + assert len(values) == expected, f"Expected {expected} MLIR values, got {len(values)}" new_params = cutlass.new_from_mlir_values(self._params, values[0:3])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 148 - 158, The hardcoded assertion in __new_from_mlir_values__ (assert len(values) == 10) is fragile; change it to compute the expected MLIR value count by summing the MLIR-value counts of the constituent objects (self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape) using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value count helper or by querying each object's MLIR representation), then compare len(values) to that computed total and raise a ValueError with a descriptive message if mismatched; update the slicing logic that builds new_params, new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those computed per-object counts instead of fixed indices so FmhaStaticTileScheduler construction remains correct.flashinfer/cute_dsl/attention/collective_builder.py (1)
163-186: Consider using a typed dataclass instead ofSimpleNamespacefor better IDE support.The returned
SimpleNamespacecontains 20+ fields. A typed dataclass orNamedTuplewould provide autocompletion and type checking for consumers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 163 - 186, Replace the anonymous SimpleNamespace return with a typed dataclass (e.g., define `@dataclass` class AttentionCollective or AttentionCollectiveConfig) that declares typed fields for each symbol currently passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k, tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o, q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged, v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes, tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout), add appropriate type hints (use typing.Any or more specific types if known), import dataclasses and typing, instantiate and return that dataclass instead of SimpleNamespace, and update any consumers to accept the new dataclass type for improved IDE autocompletion and type checking.benchmarks/bench_blackwell_attention_cutedsl.py (1)
7-8: Use the publicflashinfer.testingbenchmark helper.This benchmark already relies on the standard timing helper, but it pulls it from
flashinfer.testing.utils, which couples the script to a private module path.♻️ Suggested change
-from flashinfer.testing.utils import bench_gpu_time +from flashinfer.testing import bench_gpu_timeBased on learnings
Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 7 - 8, The benchmark imports bench_gpu_time from a private path (flashinfer.testing.utils); update the import to use the public helper by replacing references to flashinfer.testing.utils with the public module flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported public API rather than a private module.tests/test_blackwell_fmha_attention.py (1)
1-13: Please move this suite under a feature-specific tests subdirectory.This is kernel-specific CuTe DSL attention coverage, but the new module sits at
tests/root. Putting it under a matching subdirectory keeps the test surface organized with the rest of the kernel-category suites.As per coding guidelines
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_blackwell_fmha_attention.py` around lines 1 - 13, The test module test_blackwell_fmha_attention.py is at the tests/ root but belongs in the attention-specific kernel tests; move this suite into a feature-specific tests subdirectory matching the kernel category (e.g., an attention/ or blackwell_fmha/ tests folder), update any relative imports inside the module to the new location, and ensure all test callables in the file are properly prefixed with test_ so pytest discovers them (check function names and any parametrized fixtures used by functions in this module).flashinfer/cute_dsl/attention/wrappers/batch_prefill.py (3)
393-396: Addstrict=Truetozip()for early shape-mismatch detection.Using
strict=Truecatches mismatched lengths betweenpaddingandshape_early, improving debuggability.Suggested fix
- slices = tuple(slice(s, e) for s, e in zip(padding, shape_)) + slices = tuple(slice(s, e) for s, e in zip(padding, shape_, strict=True))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 393 - 396, The slice construction using zip(padding, shape_) can silently ignore length mismatches; update the tuple comprehension that defines slices (used to create torch_tensor from torch_tensor_full and assigned to torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_, strict=True)) so any mismatch between padding and shape_ raises immediately and makes debugging easier.
129-157: Prefix unused unpacked variables with underscore.The variables
q_ref,q_torch,k_ref,k_torch,v_ref,v_torch, ando_torchfromcreate_and_pad_tensor()are intentionally unused (they're dummy tensors for CuTe JIT tracing). Prefix them with_to indicate intent and silence linter warnings.Suggested fix
- q_ref, q_cute, q_torch = create_and_pad_tensor( + _q_ref, q_cute, _q_torch = create_and_pad_tensor( qo_shape, ... ) - k_ref, k_cute, k_torch = create_and_pad_tensor( + _k_ref, k_cute, _k_torch = create_and_pad_tensor( kv_shape, ... ) - v_ref, v_cute, v_torch = create_and_pad_tensor( + _v_ref, v_cute, _v_torch = create_and_pad_tensor( kv_shape, ... ) - _, o_cute, o_torch = create_and_pad_tensor( + _, o_cute, _o_torch = create_and_pad_tensor( qo_shape, ... )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 129 - 157, The unpacked dummy tensors returned by create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch) are unused and should be prefixed with an underscore to indicate intentional unused variables and silence linters; update the unpacking lines where create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and _o_torch (or similar underscore-prefixed names) while keeping the used names q_cute/k_cute/v_cute/o_cute unchanged.
318-319: Minor:device=q.deviceis redundant withtorch.empty_like.
torch.empty_like(q, ...)already inheritsq's device by default.Suggested fix
- out = torch.empty_like(q, device=q.device) + out = torch.empty_like(q)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 318 - 319, In batch_prefill.py replace the redundant explicit device argument when creating the empty tensor so that out is created with torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the assignment that sets out when out is None (the one referencing variables out and q) and remove the device=q.device parameter to rely on torch.empty_like inheriting q's device.flashinfer/cute_dsl/attention/roles/softmax.py (1)
336-344: Redundantthread_idxcomputation.
thread_idxis computed identically at lines 337-344 and again at lines 366-373. The second computation overwrites the first with the same value.Remove duplicate computation
thread_idx = tidx % ( self.threads_per_warp * ( len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) ) ) ... tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) - thread_idx = tidx % ( - self.threads_per_warp - * ( - len(self.softmax0_warp_ids) - if stage == 0 - else len(self.softmax1_warp_ids) - ) - ) thr_tmem_load = tiled_tmem_load.get_slice(thread_idx)Also applies to: 366-373
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/roles/softmax.py` around lines 336 - 344, The duplicated computation of thread_idx (calling cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids)))) appears twice; remove the redundant second block (the one at lines 366-373) so thread_idx remains computed once and subsequent code uses the already-computed thread_idx from the first occurrence; ensure any references after the removed block still rely on the existing thread_idx variable and that no logic dependent on re-calling cute.arch.thread_idx() is lost.flashinfer/cute_dsl/attention/prefill.py (3)
155-156: Prefix unuseds_kwith underscore.
s_kis unpacked but never used. Prefix with_to indicate intent.- b, s_q, s_k, h_q, h_k, d = problem_size + b, s_q, _s_k, h_q, h_k, d = problem_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 155 - 156, The tuple unpacking of problem_size currently binds an unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d = problem_size" with an unpacking that prefixes s_k with an underscore) in the prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed from h_q and h_k.
45-51: Overly broad warning suppression may hide legitimate issues.Suppressing all
UserWarningmessages (line 51) could mask important warnings from other parts of the codebase or dependencies. Consider scoping the suppression more narrowly, or applying it only within the specific context where the unrolling warning occurs.Alternative: use a context manager at call sites
# Remove the global filter at module level # warnings.filterwarnings("ignore", category=UserWarning) # Instead, wrap specific calls that generate the warning: import contextlib `@contextlib.contextmanager` def suppress_loop_unroll_warning(): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="This loop is no longer unrolled and may cause performance regression", ) yield🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 45 - 51, The module currently suppresses all UserWarning globally by calling warnings.filterwarnings("ignore", category=UserWarning); remove that broad module-level filter and instead scope suppression to only the specific unroll warning by introducing a context manager (e.g., suppress_loop_unroll_warning using warnings.catch_warnings and warnings.filterwarnings with message="This loop is no longer unrolled and may cause performance regression") and use that context manager at the specific call sites in prefill.py where the unrolling warning is raised so other UserWarnings remain visible.
385-386: Prefix unusedtidxwith underscore.
tidxfromthread_idx()is unpacked but unused in the kernel entry. The variable is only used by roles that callthread_idx()themselves.- tidx, _, _ = cute.arch.thread_idx() + _tidx, _, _ = cute.arch.thread_idx()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/prefill.py` around lines 385 - 386, The unpacked thread index variable tidx from cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx to mark it as intentionally unused (i.e., replace "tidx, _, _ = cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so linters/readers know it's unused while keeping the other unpacked values and the existing warp_idx assignment (warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 153-161: The script currently unconditionally runs an SM100-only
kernel in the __main__ block (calls to bench_fmha_cutedsl), which will
JIT/launch-fail on non-SM100 GPUs; add a GPU capability check before running the
default sweep: use torch.cuda.is_available() and
torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 96-98: The p_tmem_layout_staged is being created with the wrong
dtype (q_dtype) causing a mismatch with pv_tiled_mma which was created for V;
update the call to sm100_utils.make_smem_layout_a in collective_builder so
p_tmem_layout_staged uses v_dtype instead of q_dtype (the call that takes
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass
v_dtype) to align the P buffer TMEM layout with the V buffer.
In `@flashinfer/cute_dsl/attention/fusion/mask.py`:
- Around line 45-53: Sliding-window masking currently centers the KV window on
raw Q indices (see MaskType.SLIDING_WINDOW_MASK handling using blk_coord,
tile_shape, window_left, seqlen_k) and ignores the Q/K length offset used by the
causal path; compute q_k_offset = seqlen_k - seqlen_q and add it to first_q and
last_q (or otherwise shift Q indices into KV space) before calculating min_kv,
max_kv, start_block, end_block, and any element masks; apply the same fix to the
other sliding-window blocks noted (around the other occurrences at the given
ranges) so all sliding-window computations use shifted Q indices into KV
coordinate space.
In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 551-554: SoftCappingAttention.score_mod calls non-existent
cute.arch.tanh; replace it with a local tanh approximation implemented using
available primitives (e.g., cute.arch.exp2 and cute.arch.rcp_approx) or a cheap
rational polynomial and call that from score_mod. Add a helper function (e.g.,
_tanh_approx(x)) in the same class or module and use it in
SoftCappingAttention.score_mod (referencing self.cap and self.rcp_cap as
before), implementing tanh(x) via exp2 by computing exp(-2*abs(x)) with
exp2(-2*abs(x)/ln2) plus sign handling or by a stable rational approximation
(polynomial numerator/denominator) and ensure the helper uses
cute.jit-compatible operations only.
- Around line 367-378: Update the class and relevant parameter docstrings to
state that sink values are expected in the logit domain (raw Q·K dot-product
units, unnormalized), not pre-scaled to log2; specifically mention this near the
documentation for the sink parameter(s) used by update_statistics and the
self.params/sink_raw conversion (which divides by scale/log2_e), and add a
cross-reference to sink_softmax in sink_attention_reference.py so callers know
sinks are concatenated to logits before any log2 scaling.
In `@flashinfer/cute_dsl/attention/pipeline_topology.py`:
- Around line 68-79: The Pipeline.dataclass field cluster_scale is ignored by
create_pipelines(), causing incorrect participant and barrier arrive counts;
either (preferred) honor it by multiplying the all-thread side's participant
counts when constructing producer/consumer groups and computing barrier arrive
counts for PipelineType values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA
unchanged), i.e., when building groups from producer_warp_ids/consumer_warp_ids
in create_pipelines() multiply the thread counts by pipeline.cluster_scale and
use that scaled value when setting arrive counts for barriers/tx_count_key, or
fail fast by adding a check in create_pipelines() that raises a clear exception
if pipeline.cluster_scale != 1 so callers must handle scaling explicitly.
In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 41-66: partition_output is incorrectly decorated with `@cute.jit`
while returning tensor objects (tOsO, tOgO) which violates the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs as a normal Python method, or refactor it to avoid returning tensors by (a)
accepting preallocated output containers/handles and writing into them, or (b)
moving the cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function
into a non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).
In `@flashinfer/cute_dsl/attention/warp_schedule.py`:
- Around line 17-71: Add a fail-fast validation in WarpSchedule (implement in a
__post_init__ method) that verifies: 1) all_warp_ids (built from
softmax0_warp_ids, softmax1_warp_ids, correction_warp_ids, mma_warp_id,
load_warp_id, epilogue_warp_id, empty_warp_id) contain unique values and form a
contiguous range starting at 0 up to len(all_warp_ids)-1, and 2) the total
number of softmax warps (len(softmax0_warp_ids)+len(softmax1_warp_ids)) is
divisible by num_warps_per_warpgroup; on violation raise ValueError with a clear
message referencing the failing condition so consumers of num_warps,
threads_per_cta, and softmax_warpgroup_count cannot silently compute incorrect
sizes.
In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 159-169: The NameError risk comes from params_cute being defined
only inside the if self._has_params block yet referenced later; fix by defining
params_cute = None before the if and only assigning it inside the block (where
you call from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.
---
Nitpick comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 7-8: The benchmark imports bench_gpu_time from a private path
(flashinfer.testing.utils); update the import to use the public helper by
replacing references to flashinfer.testing.utils with the public module
flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 163-186: Replace the anonymous SimpleNamespace return with a typed
dataclass (e.g., define `@dataclass` class AttentionCollective or
AttentionCollectiveConfig) that declares typed fields for each symbol currently
passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k,
tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o,
q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged,
v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes,
tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout),
add appropriate type hints (use typing.Any or more specific types if known),
import dataclasses and typing, instantiate and return that dataclass instead of
SimpleNamespace, and update any consumers to accept the new dataclass type for
improved IDE autocompletion and type checking.
In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 155-156: The tuple unpacking of problem_size currently binds an
unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of
s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d
= problem_size" with an unpacking that prefixes s_k with an underscore) in the
prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed
from h_q and h_k.
- Around line 45-51: The module currently suppresses all UserWarning globally by
calling warnings.filterwarnings("ignore", category=UserWarning); remove that
broad module-level filter and instead scope suppression to only the specific
unroll warning by introducing a context manager (e.g.,
suppress_loop_unroll_warning using warnings.catch_warnings and
warnings.filterwarnings with message="This loop is no longer unrolled and may
cause performance regression") and use that context manager at the specific call
sites in prefill.py where the unrolling warning is raised so other UserWarnings
remain visible.
- Around line 385-386: The unpacked thread index variable tidx from
cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx
to mark it as intentionally unused (i.e., replace "tidx, _, _ =
cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so
linters/readers know it's unused while keeping the other unpacked values and the
existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
In `@flashinfer/cute_dsl/attention/roles/softmax.py`:
- Around line 336-344: The duplicated computation of thread_idx (calling
cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp
* (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))))
appears twice; remove the redundant second block (the one at lines 366-373) so
thread_idx remains computed once and subsequent code uses the already-computed
thread_idx from the first occurrence; ensure any references after the removed
block still rely on the existing thread_idx variable and that no logic dependent
on re-calling cute.arch.thread_idx() is lost.
In `@flashinfer/cute_dsl/attention/scheduler/persistent.py`:
- Around line 38-45: In __new_from_mlir_values__ update the zip over
[self.is_persistent, self.problem_shape_mbh] and self._values_pos to use
zip(..., strict=True) to fail loudly on length mismatches, and when returning
the FmhaStaticTileSchedulerParams instance forward the current object's ip
parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as
well; this touches the __new_from_mlir_values__ method, the attributes
self.is_persistent, self.problem_shape_mbh, self._values_pos, and the
FmhaStaticTileSchedulerParams constructor call.
- Around line 148-158: The hardcoded assertion in __new_from_mlir_values__
(assert len(values) == 10) is fragile; change it to compute the expected MLIR
value count by summing the MLIR-value counts of the constituent objects
(self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape)
using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value
count helper or by querying each object's MLIR representation), then compare
len(values) to that computed total and raise a ValueError with a descriptive
message if mismatched; update the slicing logic that builds new_params,
new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those
computed per-object counts instead of fixed indices so FmhaStaticTileScheduler
construction remains correct.
In `@flashinfer/cute_dsl/attention/tmem_layout.py`:
- Around line 35-49: Extract the literal 512 used for TMEM capacity into a
module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the
local variable in TmemLayout.from_config so the function uses that constant
instead of a magic number; update the top of the module with the constant and
ensure TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.
In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 393-396: The slice construction using zip(padding, shape_) can
silently ignore length mismatches; update the tuple comprehension that defines
slices (used to create torch_tensor from torch_tensor_full and assigned to
torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_,
strict=True)) so any mismatch between padding and shape_ raises immediately and
makes debugging easier.
- Around line 129-157: The unpacked dummy tensors returned by
create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch)
are unused and should be prefixed with an underscore to indicate intentional
unused variables and silence linters; update the unpacking lines where
create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those
specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and
_o_torch (or similar underscore-prefixed names) while keeping the used names
q_cute/k_cute/v_cute/o_cute unchanged.
- Around line 318-319: In batch_prefill.py replace the redundant explicit device
argument when creating the empty tensor so that out is created with
torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the
assignment that sets out when out is None (the one referencing variables out and
q) and remove the device=q.device parameter to rely on torch.empty_like
inheriting q's device.
In `@tests/test_blackwell_fmha_attention.py`:
- Around line 1-13: The test module test_blackwell_fmha_attention.py is at the
tests/ root but belongs in the attention-specific kernel tests; move this suite
into a feature-specific tests subdirectory matching the kernel category (e.g.,
an attention/ or blackwell_fmha/ tests folder), update any relative imports
inside the module to the new location, and ensure all test callables in the file
are properly prefixed with test_ so pytest discovers them (check function names
and any parametrized fixtures used by functions in this module).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f4d60aef-cbff-41af-8cdd-f4914f9854d7
📒 Files selected for processing (24)
benchmarks/bench_blackwell_attention_cutedsl.pyflashinfer/cute_dsl/attention/__init__.pyflashinfer/cute_dsl/attention/collective_builder.pyflashinfer/cute_dsl/attention/config.pyflashinfer/cute_dsl/attention/fusion/__init__.pyflashinfer/cute_dsl/attention/fusion/mask.pyflashinfer/cute_dsl/attention/fusion/variant.pyflashinfer/cute_dsl/attention/mainloop_spec.pyflashinfer/cute_dsl/attention/pipeline_topology.pyflashinfer/cute_dsl/attention/prefill.pyflashinfer/cute_dsl/attention/roles/__init__.pyflashinfer/cute_dsl/attention/roles/correction.pyflashinfer/cute_dsl/attention/roles/epilogue.pyflashinfer/cute_dsl/attention/roles/loader_tma.pyflashinfer/cute_dsl/attention/roles/mma.pyflashinfer/cute_dsl/attention/roles/softmax.pyflashinfer/cute_dsl/attention/roles/softmax_math.pyflashinfer/cute_dsl/attention/scheduler/__init__.pyflashinfer/cute_dsl/attention/scheduler/persistent.pyflashinfer/cute_dsl/attention/tmem_layout.pyflashinfer/cute_dsl/attention/warp_schedule.pyflashinfer/cute_dsl/attention/wrappers/__init__.pyflashinfer/cute_dsl/attention/wrappers/batch_prefill.pytests/test_blackwell_fmha_attention.py
| p_tmem_layout_staged = sm100_utils.make_smem_layout_a( | ||
| pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how p_tmem_layout_staged is used and if q_dtype is the correct type
rg -n "p_tmem_layout" --type py flashinfer/cute_dsl/attention/Repository: flashinfer-ai/flashinfer
Length of output: 824
🏁 Script executed:
#!/bin/bash
# Get the full function signature and beginning to see all dtype parameters
sed -n '50,120p' flashinfer/cute_dsl/attention/collective_builder.pyRepository: flashinfer-ai/flashinfer
Length of output: 2667
🏁 Script executed:
#!/bin/bash
# Check sm100_utils.make_smem_layout_a to see what dtype parameter does
rg -n "def make_smem_layout_a" --type py flashinfer/cute_dsl/attention/Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check what dtypes are typically used - look for q_dtype, k_dtype, v_dtype, p_dtype
rg -n "q_dtype|k_dtype|v_dtype|p_dtype|intermediate_dtype" --type py flashinfer/cute_dsl/attention/collective_builder.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 900
P buffer TMEM layout should use v_dtype, not q_dtype.
The P buffer (attention weights) shares the same pv_tiled_mma with V, which is created with v_dtype. However, p_tmem_layout_staged is created with q_dtype instead, causing a dtype mismatch. Change line 97 to use v_dtype:
Before:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)After:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, v_dtype, mainloop.acc_stage,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 96 - 98,
The p_tmem_layout_staged is being created with the wrong dtype (q_dtype) causing
a mismatch with pv_tiled_mma which was created for V; update the call to
sm100_utils.make_smem_layout_a in collective_builder so p_tmem_layout_staged
uses v_dtype instead of q_dtype (the call that takes pv_tiled_mma,
config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass v_dtype) to align
the P buffer TMEM layout with the V buffer.
| @cute.jit | ||
| def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale): | ||
| log2_e = math.log2(math.exp(1.0)) | ||
| sink_raw = ( | ||
| self.params[qo_head_idx] * log2_e / scale | ||
| if kv_tile_idx == 0 | ||
| else -math.inf | ||
| ) | ||
| m_new = sink_raw if sink_raw > m else m | ||
| rescale = cute.arch.exp2((m - m_new) * scale) | ||
| d_new = cute.arch.exp2((sink_raw - m_new) * scale) + d * rescale | ||
| return m_new, d_new |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
CuTe DSL attention sink implementation softmax denominator
💡 Result:
No specific "attention sink implementation" in CuTe DSL was found. CuTe DSL (from NVIDIA CUTLASS) is used for high-performance FlashAttention kernels, including softmax computation via online softmax for numerical stability. The softmax denominator (row sums of exp(scores - row_max)) is computed using tiled reductions in shared memory across blocks/tiles of the attention matrix, leveraging CuTe's tensor layouts, TensorSSA for vectorized ops, and arch-specific reductions like cute.arch.warp_reduction_sum (as in RMSNorm examples, analogous for softmax rows). Official examples like CUTLASS's flash_attention_v2.py feature softmax_rescale_O for online softmax fusion with GEMM. Repos like luliyucoordinate/cute-flash-attention and HuyNguyen-hust/flash-attn-101 implement FlashAttention in CuTe, using standard online softmax without special "sink" handling—sinks are a training/model phenomenon, not a kernel feature. FlashAttention-3/4 in Dao-AILab/flash-attention use CuTe DSL for Hopper/Blackwell with optimized exp2/emulation and reductions for the denominator. Attention sinks refer to probability mass concentrating on fixed positions in softmax attention (e.g., arXiv:2603.11487), but kernel implementations compute general softmax denominators via parallel reductions, not sink-specific logic.
Citations:
- 1: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py
- 2: https://veitner.bearblog.dev/simple-reduction-in-cutedsl/
- 3: https://github.com/luliyucoordinate/cute-flash-attention
- 4: https://github.com/HuyNguyen-hust/flash-attn-101
- 5: Rationale behind using ex2_emulation in FA4 Softmax implementation vs standard exp2 Dao-AILab/flash-attention#2081
- 6: https://arxiv.org/pdf/2312.11918
- 7: https://research.colfax-intl.com/a-users-guide-to-flexattention-in-flash-attention-cute-dsl/
🏁 Script executed:
find . -type f -name "variant.py" | grep -E "cute_dsl|attention"Repository: flashinfer-ai/flashinfer
Length of output: 117
🏁 Script executed:
if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
wc -l flashinfer/cute_dsl/attention/fusion/variant.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
sed -n '350,395p' flashinfer/cute_dsl/attention/fusion/variant.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 1402
🏁 Script executed:
# Search for scale parameter definition and sink-related documentation
rg -n "scale\s*=" flashinfer/cute_dsl/attention/fusion/variant.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 335
🏁 Script executed:
# Look for documentation about sink semantics
rg -n -B3 -A3 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1858
🏁 Script executed:
# Search for params initialization and documentation
rg -n -B5 "self.params" flashinfer/cute_dsl/attention/fusion/variant.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2271
🏁 Script executed:
sed -n '340,410p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 2239
🏁 Script executed:
# Search for AttentionWithSink usage and test cases
rg -n "AttentionWithSink" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 839
🏁 Script executed:
# Look for documentation around the sink parameter
sed -n '340,365p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 790
🏁 Script executed:
# Look at test cases to understand sink value semantics
sed -n '420,450p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1041
🏁 Script executed:
# Check more of the test context
sed -n '715,780p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 2614
🏁 Script executed:
# Look for any documentation about sink value domain/units
rg -n -B5 -A5 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | tail -80Repository: flashinfer-ai/flashinfer
Length of output: 2864
🏁 Script executed:
# Check if there's a reference paper or documentation file
find . -type f -name "*.md" -o -name "*.rst" | xargs grep -l "sink" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 52
🏁 Script executed:
# Find the reference implementation to understand sink semantics
rg -n "def attention_ref" tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 89
🏁 Script executed:
# Look at the reference implementation
sed -n '1,120p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 3643
🏁 Script executed:
# Search for attention_ref implementation across the repo
find . -type f -name "*.py" -exec grep -l "def attention_ref" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 230
🏁 Script executed:
# Find the sink_softmax implementation
find . -type f -name "sink_attention_reference.py"Repository: flashinfer-ai/flashinfer
Length of output: 116
🏁 Script executed:
# Search for sink_softmax definition
rg -n "def sink_softmax" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
cat tests/test_helpers/sink_attention_reference.pyRepository: flashinfer-ai/flashinfer
Length of output: 16156
🏁 Script executed:
# Look at how sink values are used in tests - what domain are they in?
sed -n '420,435p' tests/test_blackwell_fmha_attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 710
🏁 Script executed:
# Check the documentation docstring more carefully
sed -n '337,356p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 696
🏁 Script executed:
# Look for any comments about sink semantics in scale context
rg -n -B10 -A10 "log2_e\|log2.math" flashinfer/cute_dsl/attention/fusion/variant.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for how scale is set/passed to update_statistics
rg -n "update_statistics.*scale" --type py -B3 -A3 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 2605
🏁 Script executed:
# Look for where scale is initialized/computed in the wrapper
rg -n "scale.*=" flashinfer/cute_dsl/attention/ --type py | grep -v test | head -30Repository: flashinfer-ai/flashinfer
Length of output: 3044
🏁 Script executed:
# Check if there's any comment about sink value domain in tests
rg -n -B5 -A5 "torch.randn.*sink" tests/ --type py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the softmax_math.py to understand how m and d work
sed -n '1,50p' flashinfer/cute_dsl/attention/roles/softmax_math.pyRepository: flashinfer-ai/flashinfer
Length of output: 1430
🏁 Script executed:
# Check what domain m and d are actually in
rg -n "m_new\|d_new" flashinfer/cute_dsl/attention/fusion/variant.py -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for documentation about expected sink value domain/units
rg -n "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1018
🏁 Script executed:
# Let me check the documentation more carefully for scale parameter
sed -n '50,100p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 1926
🏁 Script executed:
# Look at the complete context of the update_statistics docstring
sed -n '275,315p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 1503
🏁 Script executed:
# Let me verify the mathematical simplification one more time
python3 << 'EOF'
import math
# Verify: if scale = log2(e) * sm_scale, then params * log2_e / scale = params / sm_scale
log2_e = math.log2(math.exp(1.0))
print(f"log2(e) = {log2_e}")
print(f"log2(e) ≈ 1.4426950408889634")
# For any sm_scale, verify the simplification
sm_scale = 1.0 / math.sqrt(128) # example: head_dim = 128
scale = log2_e * sm_scale
print(f"\nFor head_dim=128: sm_scale = {sm_scale}")
print(f"scale = log2(e) * sm_scale = {scale}")
# Test the simplification with a sample sink value
params = 2.5
sink_raw_original = params * log2_e / scale
sink_raw_simplified = params / sm_scale
print(f"\nFor params = {params}:")
print(f"sink_raw (original formula) = params * log2_e / scale = {sink_raw_original}")
print(f"sink_raw (simplified formula) = params / sm_scale = {sink_raw_simplified}")
print(f"Match: {abs(sink_raw_original - sink_raw_simplified) < 1e-10}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 401
🏁 Script executed:
# Check if sink value domain is documented in docstring or README
find . -type f \( -name "README.md" -o -name "*.rst" \) -exec grep -l "sink" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Final check: Look at the complete example in the docstring
sed -n '131,150p' flashinfer/cute_dsl/attention/fusion/variant.pyRepository: flashinfer-ai/flashinfer
Length of output: 751
Clarify the expected domain of sink tensor values in the docstring.
The mathematical conversion sink_raw = params * log2_e / scale is correct and simplifies to params / sm_scale as stated. However, the class docstring does not document the domain in which sink values are expected. The reference implementation (sink_softmax in sink_attention_reference.py) concatenates sink values directly to logits, confirming they should be in the logit domain (raw QK dot-product units, unnormalized). Add this to the sink parameter documentation to clarify that users must provide sink values in the logit domain, not pre-converted to log2 scale.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 367 - 378,
Update the class and relevant parameter docstrings to state that sink values are
expected in the logit domain (raw Q·K dot-product units, unnormalized), not
pre-scaled to log2; specifically mention this near the documentation for the
sink parameter(s) used by update_statistics and the self.params/sink_raw
conversion (which divides by scale/log2_e), and add a cross-reference to
sink_softmax in sink_attention_reference.py so callers know sinks are
concatenated to logits before any log2 scaling.
| # NOTE on CuTe DSL JIT limitations: | ||
| # - partition_output(): Returns tensor tuples — CuTe DSL JIT does not | ||
| # reliably handle returning tensors from @cute.jit methods. | ||
| # - store_tile(): SAFE — takes pre-sliced tensors as arguments, no | ||
| # runtime indexing or return values. Used in run() successfully. | ||
| # ========================================================================= | ||
|
|
||
| @cute.jit | ||
| def partition_output( | ||
| self, | ||
| tma_atom_o: cute.CopyAtom, | ||
| mO_qdl: cute.Tensor, | ||
| sO: cute.Tensor, | ||
| block_coord: tuple, | ||
| ): | ||
| """Partition output global tensor for TMA stores. Returns (tOsO, tOgO).""" | ||
| gO_qdl = cute.flat_divide( | ||
| mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1]) | ||
| ) | ||
| gO = gO_qdl[None, None, None, 0, block_coord[2]] | ||
| tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( | ||
| tma_atom_o, 0, cute.make_layout(1), | ||
| cute.group_modes(sO, 0, 2), | ||
| cute.group_modes(gO, 0, 2), | ||
| ) | ||
| return tOsO, tOgO |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does NVIDIA CUTLASS CuTe DSL reliably support returning tensor objects from @cute.jit functions or methods, or should helpers that return tensors stay non-jitted/inlined?
💡 Result:
No, NVIDIA CUTLASS CuTe DSL does not reliably support returning tensor objects from @cute.jit functions or methods. Helpers that return tensors should stay non-jitted/inlined.
Citations:
Remove @cute.jit from partition_output() or refactor to avoid returning tensors.
This method is decorated with @cute.jit but returns a tensor tuple, contradicting the documented limitation directly above it: "CuTe DSL JIT does not reliably handle returning tensors from @cute.jit methods." Since this is exported as a reusable primitive, it violates its own safety contract. Either remove the @cute.jit decorator or restructure to keep tensor returns outside JIT compilation per NVIDIA documentation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 41 - 66,
partition_output is incorrectly decorated with `@cute.jit` while returning tensor
objects (tOsO, tOgO) which violates the CuTe JIT limitation; either remove the
`@cute.jit` decorator from partition_output so it runs as a normal Python method,
or refactor it to avoid returning tensors by (a) accepting preallocated output
containers/handles and writing into them, or (b) moving the
cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function into a
non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).
| @dataclass(frozen=True) | ||
| class WarpSchedule: | ||
| """Defines warp role assignment and register budgets for attention kernels. | ||
|
|
||
| Each field maps directly to C++ CUTLASS's KernelSchedule: | ||
| - Warp ID ranges for each role | ||
| - Register allocation per role (controls spill/occupancy tradeoff) | ||
| - Barrier IDs for CTA sync and TMEM allocation | ||
| """ | ||
|
|
||
| softmax0_warp_ids: Tuple[int, ...] = (0, 1, 2, 3) | ||
| softmax1_warp_ids: Tuple[int, ...] = (4, 5, 6, 7) | ||
| correction_warp_ids: Tuple[int, ...] = (8, 9, 10, 11) | ||
| mma_warp_id: int = 12 | ||
| load_warp_id: int = 13 | ||
| epilogue_warp_id: int = 14 | ||
| empty_warp_id: int = 15 | ||
|
|
||
| num_regs_softmax: int = 192 | ||
| num_regs_correction: int = 96 | ||
| num_regs_other: int = 32 | ||
| num_regs_empty: int = 24 | ||
|
|
||
| threads_per_warp: int = 32 | ||
| cta_sync_bar_id: int = 0 | ||
| tmem_alloc_sync_bar_id: int = 1 | ||
|
|
||
| @property | ||
| def all_warp_ids(self) -> Tuple[int, ...]: | ||
| return ( | ||
| *self.softmax0_warp_ids, | ||
| *self.softmax1_warp_ids, | ||
| *self.correction_warp_ids, | ||
| self.mma_warp_id, | ||
| self.load_warp_id, | ||
| self.epilogue_warp_id, | ||
| self.empty_warp_id, | ||
| ) | ||
|
|
||
| @property | ||
| def num_warps(self) -> int: | ||
| return len(self.all_warp_ids) | ||
|
|
||
| @property | ||
| def threads_per_cta(self) -> int: | ||
| return self.threads_per_warp * self.num_warps | ||
|
|
||
| @property | ||
| def num_warps_per_warpgroup(self) -> int: | ||
| return 4 | ||
|
|
||
| @property | ||
| def softmax_warpgroup_count(self) -> int: | ||
| total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids) | ||
| return total_softmax_warps // self.num_warps_per_warpgroup |
There was a problem hiding this comment.
Validate custom schedules before deriving CTA sizes.
num_warps, threads_per_cta, and softmax_warpgroup_count all assume the warp ids are unique, contiguous from 0, and that the softmax warps fill whole warpgroups. With a custom WarpSchedule, duplicate/gapped ids or a non-multiple-of-4 softmax set will silently produce the wrong CTA/barrier sizing.
🛠️ Suggested fail-fast validation
`@dataclass`(frozen=True)
class WarpSchedule:
@@
threads_per_warp: int = 32
cta_sync_bar_id: int = 0
tmem_alloc_sync_bar_id: int = 1
+
+ def __post_init__(self):
+ all_warp_ids = self.all_warp_ids
+ if len(set(all_warp_ids)) != len(all_warp_ids):
+ raise ValueError("warp ids must be unique across roles")
+ if tuple(sorted(all_warp_ids)) != tuple(range(len(all_warp_ids))):
+ raise ValueError("warp ids must form a contiguous range starting at 0")
+ total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
+ if total_softmax_warps % self.num_warps_per_warpgroup != 0:
+ raise ValueError("softmax warps must fill whole warpgroups")
+ if self.cta_sync_bar_id == self.tmem_alloc_sync_bar_id:
+ raise ValueError("barrier ids must be distinct")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/warp_schedule.py` around lines 17 - 71, Add a
fail-fast validation in WarpSchedule (implement in a __post_init__ method) that
verifies: 1) all_warp_ids (built from softmax0_warp_ids, softmax1_warp_ids,
correction_warp_ids, mma_warp_id, load_warp_id, epilogue_warp_id, empty_warp_id)
contain unique values and form a contiguous range starting at 0 up to
len(all_warp_ids)-1, and 2) the total number of softmax warps
(len(softmax0_warp_ids)+len(softmax1_warp_ids)) is divisible by
num_warps_per_warpgroup; on violation raise ValueError with a clear message
referencing the failing condition so consumers of num_warps, threads_per_cta,
and softmax_warpgroup_count cannot silently compute incorrect sizes.
|
/bot run |
|
[FAILED] Pipeline #46896453: 10/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
flashinfer/cute_dsl/attention/roles/epilogue.py (1)
48-66:⚠️ Potential issue | 🟠 MajorRemove
@cute.jitfrompartition_output()or stop returning tensors.This helper still returns
(tOsO, tOgO), so it reintroduces the exact CuTe JIT limitation called out in the note immediately above it.run()already had to inline the same partitioning logic to avoid that path, which is a strong sign this exported primitive is still unsafe for reuse.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 48 - 66, The partition_output helper currently decorated with `@cute.jit` returns tensors (tOsO, tOgO) which reintroduces the CuTe JIT limitation; either remove the `@cute.jit` decorator from partition_output so it runs in Python (and call it from run()), or keep it JIT'd but change its API to not return cute.Tensor/CuTe objects (e.g., perform the tma_partition side-effects inside the function or write results into provided buffers/atoms), updating callers (notably run()) to use the new behavior; locate partition_output, its use of cute.nvgpu.cpasync.tma_partition and the tma_atom_o argument when making the change.
🧹 Nitpick comments (3)
flashinfer/cute_dsl/attention/fusion/variant.py (1)
357-382: Make the basetransform_output()honorscale.The docstring says this hook replaces
output *= scale_output / d, but the fallback implementation dropsscaleand returnsoutput * rcp_d. That makes the base API misleading for any custom variant that enableshas_output_transform=Trueand relies on the inherited behavior.♻️ Proposed fix
- return output * rcp_d + return output * scale * rcp_d🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 357 - 382, The base transform_output method (transform_output) ignores the scale parameter, returning output * rcp_d which breaks the documented contract and any subclass relying on inherited behavior when has_output_transform=True; modify transform_output to apply the scale as the docstring describes by returning output multiplied by rcp_d and scale (i.e., incorporate the scale/output scaling factor), and ensure callers referring to scale_output/scale still get the correct behavior.tests/test_blackwell_fmha_attention.py (1)
4-13: Move this suite under a kernel-category test subdirectory.This is a new root-level test module, but the CuTe-DSL attention coverage will be easier to discover and maintain if it lives under a feature-specific
tests/subdirectory with the rest of the kernel-category tests.As per coding guidelines,
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_blackwell_fmha_attention.py` around lines 4 - 13, The test module tests/test_blackwell_fmha_attention.py should be moved into the kernel-category tests subdirectory and follow the test naming convention; relocate the file into the appropriate feature-specific folder under tests/ (e.g., tests/kernel_category/blackwell_fmha/) and ensure all test functions inside the module are prefixed with test_ (rename any non-prefixed functions), keeping the module name descriptive (blackwell_fmha_attention) and updating any imports or test discovery paths accordingly.flashinfer/cute_dsl/attention/collective_builder.py (1)
137-155: UseMainloopSpec.barrier_stage_counts()as the source of truth forSharedStorage.This block still hard-codes every barrier array size even though the spec object now exposes those counts explicitly for shared-storage sizing. Keeping topology and storage slots in two places makes the next stage-count/topology tweak easy to desynchronize.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 137 - 155, The SharedStorage struct currently hard-codes barrier array sizes and conditional logic; instead call mainloop.barrier_stage_counts() once and use its returned counts to size every MemRange so storage and topology come from a single source of truth. Update the s0_corr_stages, mma_corr_stages, s0_epi_stages locals (or remove them) and replace the numeric expressions used in SharedStorage fields (load_q_mbar_ptr, load_kv_mbar_ptr, mma_s0_mbar_ptr, mma_s1_mbar_ptr, s0_corr_mbar_ptr, s1_corr_mbar_ptr, s0_s1_sequence_mbar_ptr, corr_epi_mbar_ptr, mma_corr_mbar_ptr, s0_epi_mbar_ptr, s1_epi_mbar_ptr, tmem_dealloc_mbar_ptr) with the appropriate entries from counts (e.g., counts['q'], counts['kv'], counts['mma_softmax'], counts['s0_corr'], counts['mma_corr'], counts['s0_epi'], sched.softmax_warpgroup_count, etc.), removing the has_logits_transform conditionals so all sizes derive from mainloop.barrier_stage_counts().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/cute_dsl/attention/mainloop_spec.py`:
- Around line 21-22: The helper that builds a prefill topology must use the
transform warp schedule when has_logits_transform=True: update the logic in the
factory that constructs PipelineTopology (the code paths using
make_prefill_topology and make_prefill_topology_transform and selecting a
WarpSchedule) to choose WarpSchedule.PREFILL_TRANSFORM_SCHEDULE (or the
module-level PREFILL_TRANSFORM_SCHEDULE constant) instead of PREFILL_SCHEDULE
when has_logits_transform is true; apply the same change to the other occurrence
that mirrors this logic so both prefill/topology-transform code paths use
PREFILL_TRANSFORM_SCHEDULE for transform variants.
In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 42-51: Remove the process-wide suppression that calls
warnings.filterwarnings("ignore", category=UserWarning) in prefill.py; keep only
the narrow message-based filter for the CUTLASS loop-unroll warning (the
existing warnings.filterwarnings(..., message="This loop is no longer
unrolled...")) and, if extra suppression is needed around a specific compile
call, wrap that call in a local warnings.catch_warnings() context or apply a
temporary filter immediately around the compile helper rather than mutating the
module/global filter. Target the two warnings.filterwarnings calls in this file
(the message-based and the category-based invocations) and delete or refactor
the category-based one.
---
Duplicate comments:
In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 48-66: The partition_output helper currently decorated with
`@cute.jit` returns tensors (tOsO, tOgO) which reintroduces the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs in Python (and call it from run()), or keep it JIT'd but change its API to
not return cute.Tensor/CuTe objects (e.g., perform the tma_partition
side-effects inside the function or write results into provided buffers/atoms),
updating callers (notably run()) to use the new behavior; locate
partition_output, its use of cute.nvgpu.cpasync.tma_partition and the tma_atom_o
argument when making the change.
---
Nitpick comments:
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 137-155: The SharedStorage struct currently hard-codes barrier
array sizes and conditional logic; instead call mainloop.barrier_stage_counts()
once and use its returned counts to size every MemRange so storage and topology
come from a single source of truth. Update the s0_corr_stages, mma_corr_stages,
s0_epi_stages locals (or remove them) and replace the numeric expressions used
in SharedStorage fields (load_q_mbar_ptr, load_kv_mbar_ptr, mma_s0_mbar_ptr,
mma_s1_mbar_ptr, s0_corr_mbar_ptr, s1_corr_mbar_ptr, s0_s1_sequence_mbar_ptr,
corr_epi_mbar_ptr, mma_corr_mbar_ptr, s0_epi_mbar_ptr, s1_epi_mbar_ptr,
tmem_dealloc_mbar_ptr) with the appropriate entries from counts (e.g.,
counts['q'], counts['kv'], counts['mma_softmax'], counts['s0_corr'],
counts['mma_corr'], counts['s0_epi'], sched.softmax_warpgroup_count, etc.),
removing the has_logits_transform conditionals so all sizes derive from
mainloop.barrier_stage_counts().
In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 357-382: The base transform_output method (transform_output)
ignores the scale parameter, returning output * rcp_d which breaks the
documented contract and any subclass relying on inherited behavior when
has_output_transform=True; modify transform_output to apply the scale as the
docstring describes by returning output multiplied by rcp_d and scale (i.e.,
incorporate the scale/output scaling factor), and ensure callers referring to
scale_output/scale still get the correct behavior.
In `@tests/test_blackwell_fmha_attention.py`:
- Around line 4-13: The test module tests/test_blackwell_fmha_attention.py
should be moved into the kernel-category tests subdirectory and follow the test
naming convention; relocate the file into the appropriate feature-specific
folder under tests/ (e.g., tests/kernel_category/blackwell_fmha/) and ensure all
test functions inside the module are prefixed with test_ (rename any
non-prefixed functions), keeping the module name descriptive
(blackwell_fmha_attention) and updating any imports or test discovery paths
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3433f39f-702a-4c43-999e-b1895a32a960
📒 Files selected for processing (12)
flashinfer/cute_dsl/attention/__init__.pyflashinfer/cute_dsl/attention/collective_builder.pyflashinfer/cute_dsl/attention/fusion/__init__.pyflashinfer/cute_dsl/attention/fusion/variant.pyflashinfer/cute_dsl/attention/mainloop_spec.pyflashinfer/cute_dsl/attention/pipeline_topology.pyflashinfer/cute_dsl/attention/prefill.pyflashinfer/cute_dsl/attention/roles/epilogue.pyflashinfer/cute_dsl/attention/roles/mma.pyflashinfer/cute_dsl/attention/roles/softmax.pyflashinfer/cute_dsl/attention/warp_schedule.pytests/test_blackwell_fmha_attention.py
🚧 Files skipped from review as they are similar to previous changes (5)
- flashinfer/cute_dsl/attention/fusion/init.py
- flashinfer/cute_dsl/attention/init.py
- flashinfer/cute_dsl/attention/warp_schedule.py
- flashinfer/cute_dsl/attention/roles/mma.py
- flashinfer/cute_dsl/attention/pipeline_topology.py
| from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE | ||
| from .pipeline_topology import PipelineTopology, make_prefill_topology, make_prefill_topology_transform |
There was a problem hiding this comment.
Use the transform schedule when has_logits_transform=True.
This helper still falls back to PREFILL_SCHEDULE for transform variants. That disagrees with BlackwellFusedMultiHeadAttentionForward.__init__, which already switches to PREFILL_TRANSFORM_SCHEDULE, so direct callers of this public factory can build a transform topology with the wrong warp schedule.
♻️ Proposed fix
-from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE
+from .warp_schedule import (
+ WarpSchedule,
+ PREFILL_SCHEDULE,
+ PREFILL_TRANSFORM_SCHEDULE,
+)
@@
- sched = warp_schedule if warp_schedule is not None else PREFILL_SCHEDULE
+ if warp_schedule is None:
+ sched = (
+ PREFILL_TRANSFORM_SCHEDULE
+ if has_logits_transform
+ else PREFILL_SCHEDULE
+ )
+ else:
+ sched = warp_scheduleAlso applies to: 105-110
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/cute_dsl/attention/mainloop_spec.py` around lines 21 - 22, The
helper that builds a prefill topology must use the transform warp schedule when
has_logits_transform=True: update the logic in the factory that constructs
PipelineTopology (the code paths using make_prefill_topology and
make_prefill_topology_transform and selecting a WarpSchedule) to choose
WarpSchedule.PREFILL_TRANSFORM_SCHEDULE (or the module-level
PREFILL_TRANSFORM_SCHEDULE constant) instead of PREFILL_SCHEDULE when
has_logits_transform is true; apply the same change to the other occurrence that
mirrors this logic so both prefill/topology-transform code paths use
PREFILL_TRANSFORM_SCHEDULE for transform variants.
Add CuTe DSL-based attention implementation: - flashinfer/cute_dsl/attention/ - Modular attention package with composable roles (loader, softmax, MMA, epilogue), fusion points (logits transform, mask, output transform), and schedulers - flashinfer/cute_dsl/prefill.py - Batch prefill wrapper - flashinfer/cute_dsl/mla.py - MLA decode wrapper - flashinfer/cute_dsl/patch/pipeline.py - Pipeline patching utilities Tests and benchmarks (named to avoid conflicts with existing cutlass tests): - tests/test_blackwell_fmha_cutedsl.py - FMHA tests (prefill) - tests/test_blackwell_fmha_attention.py - Modular attention package tests - tests/test_blackwell_mla_attention.py - MLA attention tests - tests/test_deepseek_mla_cutedsl.py - DeepSeek MLA tests - benchmarks/bench_blackwell_attention_cutedsl.py - Attention benchmarks - docs/cutedsl_fmha_architecture.md - Architecture documentation Made-with: Cursor
- Delete flashinfer/cute_dsl/prefill.py and mla.py (replaced by the modular flashinfer/cute_dsl/attention/ package) - Delete tests/test_blackwell_fmha_cutedsl.py and tests/test_deepseek_mla_cutedsl.py (replaced by test_blackwell_fmha_attention.py and test_blackwell_mla_attention.py) - Revert benchmarks/bench_deepseek_mla.py to upstream version - Split benchmarks into prefill and decode: bench_blackwell_attention_cutedsl.py (FMHA prefill) bench_blackwell_mla_cutedsl.py (MLA decode) Made-with: Cursor
…rnels Two kernel correctness bugs fixed: 1. PV1(end) accumulate flag: The final PV GEMM for stage 1 used hardcoded accumulate=True, causing stale TMEM data corruption when the KV loop didn't execute (kv_len <= tile_size with multi-Q-tile batches). Fix: use pv_whether_acc instead of True. 2. Causal mask trip count: get_masked_trip_count used ceil_div(M, N) which doesn't account for non-zero causal_offset shifting the diagonal across extra KV tiles. When kv_len != qo_len, some tiles needing masking were processed as unmasked, leaking unmasked scores into softmax. Fix: compute masked tile count from actual diagonal boundary positions. Both fixes required threading seqlen_q through the mask functions and passing causal_offset to apply_mask. Test suite pruned to ~112 curated cases covering tile boundaries, GQA, varlen, causal, output/logits transforms, and attention sink. AI-assisted (Claude) Made-with: Cursor
…ate domain conversion The plan() method created a dummy sink tensor with hardcoded float16 dtype for JIT compilation regardless of input dtype. When bfloat16 inputs were used at runtime, the compiled kernel misinterpreted bf16 bits as fp16, garbling sink values (causal row-0 error: 1.75 -> 0.004). Also fix the sink_M_D_update test helper to properly convert the sink value from scaled-logit space to raw-logit space by dividing by scale, and tighten the sink test tolerance from atol=2.0 to atol=0.01. AI-assisted (Claude) Made-with: Cursor
…ve tests (AI-assisted) Kernel fixes: - Sliding window apply_mask: add missing left-bound check (|kv-q| > window) and seqlen_k bounds check - Sliding window get_trip_count/get_kv_start_block_idx: compute correct symmetric window tile range instead of right-only approximation - Softmax run(): add kv_start_offset to coordinate identity tensor so mask coordinates match actual KV positions loaded by the TMA loader Test fixes: - sink_M_D_update: add * scale to exp2 rescale terms for correctness (m is in RAW domain, exp2 needs domain conversion via * scale) - Sink test: use SM_SCALE=1/sqrt(head_dim) instead of 1.0, which made the sink contribution negligible (~0) and the test vacuous New test coverage: - float16 dtype (3 shapes x 2 causal) - Sliding window mask (4 window/shape combos) - head_dim=64 (3 shapes x 2 causal) - Variable-length + sigmoid logits transform (2 indptr patterns) - Variable-length + attention sink (2 indptr patterns) - Attention sink with MHA / num_kv_heads=32 (2 shapes x 2 causal) All 118 tests pass, 18 skipped (qo>kv+causal), ~10 min runtime. Made-with: Cursor
Strip out MLA decode kernel, config, warp schedule, roles, scheduler, wrapper, benchmark, test, and architecture doc to keep this PR focused on FMHA prefill only. Clean up MLA references in shared modules. AI-assisted Made-with: Cursor
The PipelineProducer/PipelineConsumer wrappers are now available in cutlass.pipeline (nvidia-cutlass-dsl >= 4.3). Use them directly instead of maintaining a local copy. Pipeline creation uses defer_sync=True since the kernel handles barrier init/sync separately. Verified: no perf regression (< 1% noise), all 118 tests pass. AI-assisted Made-with: Cursor
…ed tmem_utils.py - Add missing @cute.jit decorator to get_trip_count for consistency with all other functions in mask.py - Remove tmem_utils.py which was MLA-specific dead code after MLA removal AI-assisted Made-with: Cursor
…uards (AI-assisted)
Review feedback from yzh119:
- Switch FMHA prefill compile path to TVM-FFI: use make_fake_compact_tensor
for compile-time tracing and pass PyTorch tensors directly at runtime,
matching the pattern used by the MLA decode kernel. This replaces the
from_dlpack + .iterator approach which is incompatible with --enable-tvm-ffi.
- Change prefill.py __call__ from cute.Pointer to cute.Tensor parameters;
the kernel extracts .iterator internally.
- Add --enable-tvm-ffi --opt-level 2 to cute.compile() and use
make_fake_stream(use_tvm_ffi_env_stream=True) for stream handling.
- Add @flashinfer_api decorator to BatchPrefillCuteDSLWrapper __init__,
plan, and run for crash-safe API logging.
- Add is_sm100a_supported() guard and enable_cupti=True in benchmark.
- Remove broad warnings.filterwarnings("ignore", category=UserWarning)
from prefill.py, mla_decode.py, mla_decode_fp8.py; keep only the
targeted cutlass loop-unroll message filter.
- Fix NamedBarrier.wait() -> arrive_and_wait() in MLA roles to use the
correct API (they are identical at the hardware level, but wait()
triggers a deprecation warning from cutlass-dsl).
Made-with: Cursor
|
/bot run |
- Auto-fix 33 unused import warnings across attention package (ruff --fix) - Auto-format 23 files to match ruff style - Remove unused variable tOcO_custom_i in softmax.py - Remove dead code qo_idx_offset in softmax.py - Add strict=True to zip() in scheduler/persistent.py - Add assert guards before .arrive_and_wait() calls on Optional barriers in mla_correction.py and mla_compute.py to satisfy mypy Made-with: Cursor
Add Type[cutlass.Numeric] annotations to dtype attributes in mla_decode, mla_decode_fp8, mla_compute, softmax, and correction roles — matching the pattern from gemm_allreduce_two_shot.py. Add assert guards for Optional attributes accessed in type-checked code paths. Add params attribute to AttentionVariant base class. Use Union type for MLAMainloopSpec warp_schedule to accept both FP16 and FP8 schedules. Made-with: Cursor
…(AI-assisted) - batch_mla.py: validate 4D kv_cache has shape[1]==1 before squeeze(1), and reject non-3D/4D inputs. Prevents silent wrong page_size detection. - test_modular_fmha_prefill.py: document that sliding window with qo_len != kv_len is not yet supported (mask offset + coordinate identity tensor fixes needed). Made-with: Cursor
The sliding window mask functions computed KV tile ranges and element masks using raw Q tile indices, ignoring the Q/K sequence length offset. When qo_len < kv_len (e.g. suffix-prefill, append-with-cache), the window was centered on the wrong KV region. Fix: add qk_offset = seqlen_k - seqlen_q to Q positions in get_trip_count, get_kv_start_block_idx, and apply_mask, matching the causal mask path which already uses this offset. Add 4 test cases with qo_len != kv_len to SLIDING_WINDOW_PARAMS and update the reference implementation to right-align Q positions to KV. Made-with: Cursor
|
/bot run |
|
@pgera is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
Add crash-safe API logging decorator to __init__, plan, and run methods, matching the pattern used by BatchPrefillCuteDSLWrapper and all other FlashInfer wrapper classes. Made-with: Cursor
- Add _validate_run_inputs() to BatchMLADecodeCuteDSLWrapper, checking dtype/device/shape consistency (matching prefill wrapper pattern) - Replace bare assert with raise ValueError for D_qk and workspace size checks in MLA run() - Add docstring to prefill plan() method - Add comment explaining float_workspace_buffer naming convention Made-with: Cursor
Replace full 25-line BSD-3-Clause license text with the abbreviated 2-line SPDX header in MLA files ported from the monolithic codebase. Update copyright year from 2025-2026 to 2026 across all 11 files. Made-with: Cursor
…sted) Skip gracefully when nvidia-cutlass-dsl package is not installed, matching the pattern used by test_cute_dsl_mla_decode.py. Made-with: Cursor
|
cc @leejnau about cutedsl mla update |
…isted) Wire AttentionVariant hooks (score_mod, update_statistics, transform_output) through both FP16 and FP8 MLA decode kernels, unify compilation caching under @functools.cache for prefill and MLA decode, and switch prefill to symbolic tensor dimensions for cross-batch kernel reuse. Key changes: - mla_decode.py / mla_decode_fp8.py: accept fusion + params_in, thread params and cta_m_offset through compute and correction roles - mla_compute.py: integrate score_mod and update_statistics hooks with post-score_mod re-masking to prevent -inf→finite leakage (SoftCapping) - mla_correction.py: integrate transform_output hook (AttentionWithSink) - softmax.py: re-apply mask after score_mod for prefill (same fix) - batch_mla.py: merge _get_compiled_mla_kernel into single _compile_mla_kernel with variant/params_shape cache keys; add logits_transform validation - batch_prefill.py: extract _get_compiled_prefill_kernel with symbolic dims and @functools.cache for batch-independent compilation - variant.py: fix NaN in AttentionWithSink.update_statistics when split-KV CTAs don't own tile 0 (-inf - (-inf) = NaN) Tests: ALiBi, SoftCapping, AttentionWithSink, RPE for BF16 + FP8 MLA decode, plus SoftCapping regression tests for non-tile-aligned sequences. Made-with: Cursor
…-assisted) Pre-allocate a padded output scratch buffer in plan() instead of allocating per run() call. The kernel's TMA varlen addressing requires front-padding on the output tensor (negative pointer offset trick shared with the CUTLASS backend), so run() writes into the scratch buffer and copies back to the caller's out tensor when provided. This honors the out= in-place contract that serving frameworks like vLLM rely on. Other changes: - Add jit_args guard for cute-dsl backend (skip unnecessary JIT compilation) - Document why return_lse/FP8 scale checks must live at run() time - Replace outdated standalone-script docstring in prefill.py - Add test verifying the out= in-place contract Made-with: Cursor
|
/bot run |
Add early NotImplementedError in BatchPrefillWithPagedKVCacheWrapper when backend="cute-dsl" is passed, since paged KV cache is not yet supported by the cute-dsl kernel. Also add the jit_args guard (backend != "cute-dsl") for future-proofing when paged support is added. Made-with: Cursor
saltyminty
left a comment
There was a problem hiding this comment.
Approved – internal CI fails on test_trtllm_fused_moe_autotuner_integration.py, which should be unrelated.
Summary
Modular rewrite of the CuTe DSL attention kernels, refactored into composable building blocks. Includes FMHA prefill (from #1549) and MLA decode (replacing the monolithic
flashinfer/mla/cute_dsl/kernels), with bug fixes, new attention variants, and comprehensive tests.Architecture
FMHA prefill features
score_modhook: ALiBi (1-D params), RPE (2-D params), SoftCapping (compile-time only)update_statisticsBatchPrefillWithRaggedKVCacheWrapperviabackend="cute-dsl"make_fake_compact_tensorfor compile-time tracing, PyTorch tensors at runtimeMLA decode features
Bug fixes vs #1549
Key files
flashinfer/cute_dsl/attention/prefill.pyflashinfer/cute_dsl/attention/mla_decode.pyflashinfer/cute_dsl/attention/mla_decode_fp8.pyflashinfer/cute_dsl/attention/roles/flashinfer/cute_dsl/attention/fusion/flashinfer/cute_dsl/attention/wrappers/batch_prefill.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pyflashinfer/cute_dsl/attention/config.pyflashinfer/cute_dsl/attention/mla_config.pyflashinfer/cute_dsl/attention/pipeline_topology.pyflashinfer/cute_dsl/attention/collective_builder.pytests/attention/test_modular_fmha_prefill.pytests/attention/test_cute_dsl_mla_decode.pybenchmarks/bench_blackwell_attention_cutedsl.pyTest plan
cc: @yzh119