Skip to content

[CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels#2805

Open
pgera wants to merge 41 commits intoflashinfer-ai:mainfrom
pgera:cutedsl-fmha-prefill
Open

[CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels#2805
pgera wants to merge 41 commits intoflashinfer-ai:mainfrom
pgera:cutedsl-fmha-prefill

Conversation

@pgera
Copy link
Copy Markdown

@pgera pgera commented Mar 17, 2026

Summary

Modular rewrite of the CuTe DSL attention kernels, refactored into composable building blocks. Includes FMHA prefill (from #1549) and MLA decode (replacing the monolithic flashinfer/mla/cute_dsl/ kernels), with bug fixes, new attention variants, and comprehensive tests.

Architecture

  • Composable "roles" (Loader, MMA, Softmax, Correction, Epilogue) connected by declarative pipeline topologies
  • Shared infrastructure across FMHA prefill and MLA decode (pipeline topology, collective builder, mainloop spec)
  • MLA-specific roles for paged KV, page-table loading, and split-KV reduction

FMHA prefill features

  • bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks
  • Custom logits transforms (sigmoid, sigmoid-tanh via MUFU.TANH), output transforms
  • Attention variants via score_mod hook: ALiBi (1-D params), RPE (2-D params), SoftCapping (compile-time only)
  • Attention sink with custom update_statistics
  • Variable-length sequences (ragged layout)
  • Wired into BatchPrefillWithRaggedKVCacheWrapper via backend="cute-dsl"
  • TVM-FFI compile path: make_fake_compact_tensor for compile-time tracing, PyTorch tensors at runtime

MLA decode features

  • FP16/BF16 and FP8 (separate kernel class + roles)
  • Paged KV cache with configurable page sizes
  • Persistent tile scheduler with split-KV reduction
  • Replaces ~9,700 lines of monolithic code with modular roles

Bug fixes vs #1549

  • Causal mask boundary + PV accumulate: fixed off-by-one in causal mask and accumulation bug
  • Attention sink dtype mismatch: wrapper hardcoded fp16 for sink tensor regardless of input dtype
  • Attention sink M_D_update domain: corrected domain conversion and exp2 scaling in online softmax
  • Sliding window mask (4 issues): missing left-bound check, incorrect trip count/start block for symmetric windows, KV coordinate offset mismatch, and Q/K offset for unequal sequence lengths
  • SigmoidAttention bias: bias not converted to log-base-2 domain

Key files

Path Description
flashinfer/cute_dsl/attention/prefill.py FMHA prefill kernel
flashinfer/cute_dsl/attention/mla_decode.py MLA decode kernel (FP16/BF16)
flashinfer/cute_dsl/attention/mla_decode_fp8.py MLA decode kernel (FP8)
flashinfer/cute_dsl/attention/roles/ Warp role implementations (FMHA + MLA)
flashinfer/cute_dsl/attention/fusion/ Mask, logits transform, attention variants
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py FMHA prefill PyTorch wrapper (TVM-FFI)
flashinfer/cute_dsl/attention/wrappers/batch_mla.py MLA decode PyTorch wrapper
flashinfer/cute_dsl/attention/config.py FMHA kernel configuration
flashinfer/cute_dsl/attention/mla_config.py MLA kernel configuration
flashinfer/cute_dsl/attention/pipeline_topology.py Declarative pipeline graph
flashinfer/cute_dsl/attention/collective_builder.py MMA atoms, TMA, SharedStorage factory
tests/attention/test_modular_fmha_prefill.py FMHA test suite (152 cases)
tests/attention/test_cute_dsl_mla_decode.py MLA decode test suite (252 cases)
benchmarks/bench_blackwell_attention_cutedsl.py FMHA benchmark (CUPTI timing)

Test plan

  • FMHA prefill: 152 tests (bf16/fp16, causal, sliding window, head_dim 64/128, GQA/MHA, sigmoid/sigmoid-tanh, ALiBi, RPE, attention sink, varlen)
  • MLA decode FP16: 232 tests (bf16/fp16, page_size 32/128, batch 1-32, seq_len 128-8192, variable-seq)
  • MLA decode FP8: 12 tests
  • MLA decode vs trtllm-gen: 8 cross-validation tests
  • Sliding window with qo_len != kv_len: 4 tests
  • Benchmark with CUPTI timing on SM100
  • Pre-commit: mypy, ruff check, ruff format all pass

cc: @yzh119

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a full CuTe-DSL Blackwell attention implementation: new FMHA prefill and modular MLA decode kernels, many CuTe role primitives, pipeline/topology and scheduler abstractions, configuration/fusion/masking APIs, PyTorch wrappers (prefill & MLA), FP8 variants, benchmarking script, and extensive CUDA tests.

Changes

Cohort / File(s) Summary
Benchmark
benchmarks/bench_blackwell_attention_cutedsl.py
New CUDA benchmarking script for batched prefill attention using FlashInfer backends (cutlass/CuteDSL), reporting median latency, TFLOPs/s, and GB/s.
Public package entry & re-exports
flashinfer/cute_dsl/attention/__init__.py, flashinfer/cute_dsl/attention/wrappers/__init__.py, flashinfer/cute_dsl/attention/roles/__init__.py, flashinfer/cute_dsl/attention/scheduler/__init__.py
New top-level initializers that re-export kernels, pipeline/scheduler types, fusion utilities, role classes, and wrapper APIs.
Configuration / layouts / compat
flashinfer/cute_dsl/attention/config.py, .../tmem_layout.py, .../warp_schedule.py, .../compat.py, .../mainloop_spec.py
Add AttentionConfig/AttentionFusion, TileBounds/HeadMapping, deterministic TmemLayout, WarpSchedule presets, mainloop spec builders/resolvers, and cutlass-dsl compatibility shims.
Fusion / masking / variants
flashinfer/cute_dsl/attention/fusion/... (__init__.py, mask.py, variant.py)
Introduce MaskType enum and JIT masking helpers, AttentionVariant base and multiple concrete variants (Standard, Sink, Sigmoid, SigmoidTanh, ALiBi, RPE, SoftCapping) with JIT hooks.
Pipeline topology & launch param builder
flashinfer/cute_dsl/attention/pipeline_topology.py, .../collective_builder.py
Declarative pipeline-graph types (PipelineType/Edge/Topology) and factories; builder functions deriving MMA/TMA descriptors, staged SMEM/TMA atoms, SharedStorage structs and launch params for FMHA and MLA.
Kernel implementations (FMHA & MLA, FP8)
flashinfer/cute_dsl/attention/prefill.py, .../mla_decode.py, .../mla_decode_fp8.py
BlackwellFusedMultiHeadAttentionForward and BlackwellMultiLatentAttentionForward(/FP8): mainloop kernels, grid/SMEM computation, role orchestration, split-KV reduction, can_implement and workspace helpers.
Roles (per-stage kernels)
flashinfer/cute_dsl/attention/roles/... (many files)
New role modules and JIT entrypoints: LoaderRole (TMA loads), MmaRole (GEMMs, TMEM alloc/dealloc), SoftmaxRole (+ math), CorrectionRole, EpilogueRole, and MLA-specific roles (page-table loader, mla_loader, mla_mma, mla_compute, mla_correction) plus FP8 loader/mma roles.
Softmax math helpers
flashinfer/cute_dsl/attention/roles/softmax_math.py
Packed exp2 and packed row-sum JIT helpers used by SoftmaxRole.
Schedulers
flashinfer/cute_dsl/attention/scheduler/persistent.py, .../scheduler/mla_persistent.py
FmhaStaticTileScheduler and MLAStaticTileScheduler implementations with params, MLIR (de)serialization hooks, WorkTileInfo, and factory constructors.
MLA config & schedules
flashinfer/cute_dsl/attention/mla_config.py, .../mla_warp_schedule.py
Immutable MLAConfig dataclass with derived tiler properties and FP8 variant checks; MLA warp schedules and named-barrier helpers.
Wrappers / PyTorch integration
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py, .../wrappers/batch_mla.py
BatchPrefillCuteDSLWrapper (plan/run, qkv padding/helpers) and BatchMLADecodeCuteDSLWrapper + cute_dsl_mla_decode with split-KV/workspace handling and compilation caching.
Tests
tests/test_blackwell_fmha_attention.py, tests/attention/test_modular_mla_decode.py
Extensive CUDA pytest suites and reference implementations covering FMHA variants, masking/modes, variable-length sequences, MLA modular decode correctness and comparisons vs monolithic kernels.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host (PyTorch)
    participant Wrapper as BatchPrefillWrapper
    participant Kernel as Compiled Kernel
    participant Loader as LoaderRole (TMA)
    participant MMA as MmaRole
    participant Softmax as SoftmaxRole
    participant Corr as Correction/Epilogue
    participant GMEM as Global Memory / TMEM

    Host->>Wrapper: plan()/run(q, k, v, indptr...)
    Wrapper->>Kernel: launch with DLPack tensors & params
    Kernel->>Loader: request Q/K/V tiles (TMA producers)
    Loader->>GMEM: TMA load Q/K/V -> TMEM/SMEM
    Loader-->>MMA: hand off Q/K/P/V fragments
    MMA->>Softmax: produce QK scores / write logits to TMEM
    Softmax->>GMEM: apply mask/variant, write exponentials / row stats
    Softmax->>MMA: provide P fragments for PV GEMM
    MMA->>Corr: commit partial outputs (handles)
    Corr->>GMEM: rescale and store final outputs (epilogue)
    Kernel-->>Host: return out tensor
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

cute-dsl, ready

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119
  • yongwww
  • cyx-6
  • jimmyzho
  • bkryu
  • kahyunnam
  • nv-yunzheq
  • samuellees
  • saltyminty
  • yyihuang
  • jiahanc

Poem

🐰 In CuTe fields where kernels hop and play,
I stitch softmax threads through night and day.
Warps leap in order, tiles hum in tune,
From loader to epilogue under Blackwell’s moon.
A rabbit applauds — benchmarks bloom, hooray!

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refactored and improved FMHA prefill kernel using a modular design. The new implementation addresses several bugs present in the previous version and offers a more flexible and testable architecture. By composing roles and using declarative pipeline topologies, the kernel supports a wide range of configurations and customizations, making it suitable for various attention mechanisms.

Highlights

  • Modular FMHA Prefill Kernel: This PR introduces a modular rewrite of the FMHA prefill kernel, enhancing composability and flexibility.
  • Bug Fixes: Addresses several bugs in the previous implementation, including causal mask boundary issues, dtype mismatches, and sliding window mask errors.
  • Comprehensive Testing: Includes a comprehensive test suite with 8 test cases covering various configurations (bf16/fp16, causal, sliding window, head_dim 64/128, GQA/MHA, sigmoid logits, attention sink, varlen).
  • Composable Architecture: The architecture is designed with composable "roles" connected by declarative pipeline topologies, facilitating easy swapping of warp schedules and fusion strategies.
  • Extensive Support: Supports bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks, custom logits transforms, output transforms, attention sink with custom M_D_update, and variable-length sequences.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • The PR introduces a new modular FMHA prefill kernel.
  • It includes bug fixes for causal mask boundary, attention sink dtype mismatch, attention sink M_D_update domain, and sliding window mask issues.
  • The PR adds a comprehensive test suite with 8 test cases.
  • It refactors the kernel into composable building blocks with declarative pipeline topologies.
  • The PR supports various configurations such as bf16/fp16, head_dim 64/128, GQA/MHA, causal/sliding-window masks, custom logits transforms, output transforms, attention sink, and variable-length sequences.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-structured modular rewrite of the FMHA prefill kernel using CuTe DSL. The new architecture, based on composable roles and declarative pipeline topologies, is a major improvement for maintainability and extensibility. The comprehensive test suite covers a wide range of configurations, ensuring the robustness of the new implementation. My review includes a few suggestions for code cleanup, improving maintainability by adding documentation, and addressing potential issues like unused code and a missing JIT decorator.

Comment on lines +1 to +101
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""Shared TMEM utilities for compute roles.

Provides tmem_load_partition() — partitions TMEM output accumulator for
load/store by the rescale and epilogue roles.
"""

from types import SimpleNamespace

import cutlass
import cutlass.cute as cute
import cutlass.cute.nvgpu.tcgen05 as tcgen05


@cute.jit
def tmem_load_partition(
tmem_ptr: cutlass.Int32,
tmem_o_offset: int,
acc_dtype: cutlass.Constexpr,
mma_pv_tiler: cutlass.Constexpr,
cluster_shape_mnk: cutlass.Constexpr,
warps_in_n: int,
num_compute_warps: int,
threads_per_warp: int,
common_params: SimpleNamespace,
tiled_mma_pv: cute.TiledMma,
iter_n: int,
) -> tuple[
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
cute.TiledMma,
]:
tOtO_shape = tiled_mma_pv.partition_shape_C(
cute.select(mma_pv_tiler, mode=[0, 1])
)
tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape)
tOtO_layout = cute.append(
tOtO.layout,
cute.make_layout(
common_params.L // mma_pv_tiler[1],
stride=mma_pv_tiler[1] // warps_in_n,
),
)
tOtO = cute.make_tensor(tmem_ptr + tmem_o_offset, tOtO_layout)
tOtO = tOtO[None, None, None, iter_n]

tAcc = tOtO[(None, None), 0, 0]

tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), acc_dtype
)
tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc)
tmem_load_thr_copy = tmem_load_tiled_copy.get_slice(
common_params.tidx % (num_compute_warps * threads_per_warp)
)

cta_pv_tiler = (
mma_pv_tiler[0] // cluster_shape_mnk[0],
mma_pv_tiler[1],
mma_pv_tiler[2],
)
cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1])

gO = None
if cutlass.const_expr(common_params.mAccO is not None):
gO = cute.local_tile(
common_params.mAccO[None, common_params.blk_coord[3], None, None],
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
cO = cute.local_tile(
cute.make_identity_tensor(
common_params.mAccO[
None, common_params.blk_coord[3], None, None
].shape
),
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
else:
gO = cute.local_tile(
common_params.mO,
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)
cO = cute.local_tile(
cute.make_identity_tensor(common_params.mO.shape),
cta_pv_tiler_mn,
(common_params.blk_coord[0], iter_n, common_params.blk_coord[2]),
)

tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc)
tTR_gO = tmem_load_thr_copy.partition_D(gO)
tTR_cO = tmem_load_thr_copy.partition_D(cO)
tTR_rAcc = cute.make_fragment_like(tTR_gO, acc_dtype)
return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file tmem_utils.py and the function tmem_load_partition within it do not appear to be used anywhere in the codebase. Unused code can lead to confusion and maintenance overhead. Please either integrate it into the kernel or remove it if it's obsolete.

Comment on lines +171 to +177
if s_k.shape[0] > 1:
for i in range(len(s_k)):
if s_k[i] % self._mma_tiler_mn[1] != 0:
self._mask_type = MaskType.RESIDUAL_MASK
else:
if s_k % self._mma_tiler_mn[1] != 0:
self._mask_type = MaskType.RESIDUAL_MASK
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine if RESIDUAL_MASK is needed can be simplified. The current implementation iterates over the s_k tensor and has a branch for s_k.shape[0] > 1 which is always taken since s_k is derived from kv_indptr and will be a 1D tensor. You can use torch.any for a more concise and efficient check.

            if torch.any(s_k % self._mma_tiler_mn[1] != 0):
                self._mask_type = MaskType.RESIDUAL_MASK

@@ -0,0 +1,419 @@
from typing import Optional, Type, Tuple
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file provides a custom implementation of pipeline participants, which appears to be a patch on top of cutlass.pipeline. However, it's missing a file-level docstring explaining why this custom implementation is necessary and what it changes compared to the original. Adding a docstring would greatly improve maintainability and make it easier for other developers to understand the purpose of this module.

def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale):
# m is in the RAW (unscaled) domain; convert sink from scaled-logit to RAW
log2_e = math.log2(math.exp(1.0))
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.

Suggested change
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf

@cute.jit
def sink_M_D_update(params, kv_tile_idx, qo_head_idx, m, d, scale):
log2_e = math.log2(math.exp(1.0))
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition qo_head_idx < NUM_QO_HEADS is redundant because qo_head_idx is a grid coordinate over the heads dimension, which is sized to NUM_QO_HEADS. Therefore, qo_head_idx will always be less than NUM_QO_HEADS. You can simplify the expression.

Suggested change
sink_raw = params.sink[qo_head_idx] * log2_e / scale if (kv_tile_idx == 0 and qo_head_idx < NUM_QO_HEADS) else -math.inf
sink_raw = params.sink[qo_head_idx] * log2_e / scale if kv_tile_idx == 0 else -math.inf

@pgera pgera marked this pull request as ready for review March 20, 2026 20:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (13)
flashinfer/cute_dsl/attention/tmem_layout.py (1)

35-49: Consider extracting SM100_TMEM_CAPACITY_COLUMNS as a module-level constant.

The SM100 TMEM capacity is a hardware characteristic that may be referenced elsewhere. Extracting it improves discoverability and avoids magic numbers.

Proposed refactor
+SM100_TMEM_CAPACITY_COLUMNS = 512
+
+
 `@dataclass`(frozen=True)
 class TmemLayout:
     ...
 
     `@staticmethod`
     def from_config(config: AttentionConfig) -> TmemLayout:
         tile_m = config.mma_tiler[0]
-        SM100_TMEM_CAPACITY_COLUMNS = 512
         return TmemLayout(
             ...
             alloc_cols=SM100_TMEM_CAPACITY_COLUMNS,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/tmem_layout.py` around lines 35 - 49, Extract
the literal 512 used for TMEM capacity into a module-level constant (e.g.,
SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the local variable in
TmemLayout.from_config so the function uses that constant instead of a magic
number; update the top of the module with the constant and ensure
TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.
flashinfer/cute_dsl/attention/scheduler/persistent.py (2)

38-45: Add strict=True to zip() for safer MLIR value reconstruction.

In __new_from_mlir_values__, the zip() call iterates over [self.is_persistent, self.problem_shape_mbh] and self._values_pos. If these lists have mismatched lengths (e.g., due to a maintenance error), zip() will silently truncate, potentially causing subtle bugs during MLIR reconstruction.

Also, the ip parameter is not forwarded to the new FmhaStaticTileSchedulerParams instance on line 45.

Proposed fix
     def __new_from_mlir_values__(self, values):
         obj_list = []
         for obj, n_items in zip(
-            [self.is_persistent, self.problem_shape_mbh], self._values_pos
+            [self.is_persistent, self.problem_shape_mbh], self._values_pos, strict=True
         ):
             obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
             values = values[n_items:]
-        return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
+        return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc, ip=self._ip)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 38 - 45,
In __new_from_mlir_values__ update the zip over [self.is_persistent,
self.problem_shape_mbh] and self._values_pos to use zip(..., strict=True) to
fail loudly on length mismatches, and when returning the
FmhaStaticTileSchedulerParams instance forward the current object's ip parameter
(pass loc=self._loc, ip=self.ip) so the new instance receives ip as well; this
touches the __new_from_mlir_values__ method, the attributes self.is_persistent,
self.problem_shape_mbh, self._values_pos, and the FmhaStaticTileSchedulerParams
constructor call.

148-158: Hardcoded MLIR value count is fragile.

The assertion assert len(values) == 10 couples the implementation to a specific MLIR representation. If any constituent object's MLIR value count changes, this will fail without a clear message.

Consider deriving the expected count dynamically or providing a descriptive error message.

Proposed improvement
     def __new_from_mlir_values__(self, values):
-        assert len(values) == 10
+        expected = 3 + 1 + 3 + 3  # params(3) + work_idx(1) + blk_coord(3) + grid_shape(3)
+        assert len(values) == expected, f"Expected {expected} MLIR values, got {len(values)}"
         new_params = cutlass.new_from_mlir_values(self._params, values[0:3])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py` around lines 148 -
158, The hardcoded assertion in __new_from_mlir_values__ (assert len(values) ==
10) is fragile; change it to compute the expected MLIR value count by summing
the MLIR-value counts of the constituent objects (self._params,
self._current_work_linear_idx, self._blk_coord, self._grid_shape) using whatever
helper/attribute your cutlass layer exposes (e.g., a mlir value count helper or
by querying each object's MLIR representation), then compare len(values) to that
computed total and raise a ValueError with a descriptive message if mismatched;
update the slicing logic that builds new_params, new_current_work_linear_idx,
new_blk_coord, and new_grid_shape to use those computed per-object counts
instead of fixed indices so FmhaStaticTileScheduler construction remains
correct.
flashinfer/cute_dsl/attention/collective_builder.py (1)

163-186: Consider using a typed dataclass instead of SimpleNamespace for better IDE support.

The returned SimpleNamespace contains 20+ fields. A typed dataclass or NamedTuple would provide autocompletion and type checking for consumers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 163 - 186,
Replace the anonymous SimpleNamespace return with a typed dataclass (e.g.,
define `@dataclass` class AttentionCollective or AttentionCollectiveConfig) that
declares typed fields for each symbol currently passed (qk_tiled_mma,
pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k, tma_tensor_k, tma_atom_v,
tma_tensor_v, tma_atom_o, tma_tensor_o, q_smem_layout_staged,
k_smem_layout_staged, p_tmem_layout_staged, v_smem_layout_staged,
o_smem_layout_staged, SharedStorage, tma_copy_q_bytes, tma_copy_kv_bytes,
cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout), add appropriate
type hints (use typing.Any or more specific types if known), import dataclasses
and typing, instantiate and return that dataclass instead of SimpleNamespace,
and update any consumers to accept the new dataclass type for improved IDE
autocompletion and type checking.
benchmarks/bench_blackwell_attention_cutedsl.py (1)

7-8: Use the public flashinfer.testing benchmark helper.

This benchmark already relies on the standard timing helper, but it pulls it from flashinfer.testing.utils, which couples the script to a private module path.

♻️ Suggested change
-from flashinfer.testing.utils import bench_gpu_time
+from flashinfer.testing import bench_gpu_time

Based on learnings Use flashinfer.testing.bench_gpu_time() for benchmarking kernels, preferring CUPTI timing with auto-fallback to CUDA events.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_attention_cutedsl.py` around lines 7 - 8, The
benchmark imports bench_gpu_time from a private path (flashinfer.testing.utils);
update the import to use the public helper by replacing references to
flashinfer.testing.utils with the public module flashinfer.testing and import
bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.
tests/test_blackwell_fmha_attention.py (1)

1-13: Please move this suite under a feature-specific tests subdirectory.

This is kernel-specific CuTe DSL attention coverage, but the new module sits at tests/ root. Putting it under a matching subdirectory keeps the test surface organized with the rest of the kernel-category suites.

As per coding guidelines tests/**/*.py: Prefix test functions with test_ and structure tests by feature in tests/ subdirectories matching kernel categories.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_blackwell_fmha_attention.py` around lines 1 - 13, The test module
test_blackwell_fmha_attention.py is at the tests/ root but belongs in the
attention-specific kernel tests; move this suite into a feature-specific tests
subdirectory matching the kernel category (e.g., an attention/ or
blackwell_fmha/ tests folder), update any relative imports inside the module to
the new location, and ensure all test callables in the file are properly
prefixed with test_ so pytest discovers them (check function names and any
parametrized fixtures used by functions in this module).
flashinfer/cute_dsl/attention/wrappers/batch_prefill.py (3)

393-396: Add strict=True to zip() for early shape-mismatch detection.

Using strict=True catches mismatched lengths between padding and shape_ early, improving debuggability.

Suggested fix
-        slices = tuple(slice(s, e) for s, e in zip(padding, shape_))
+        slices = tuple(slice(s, e) for s, e in zip(padding, shape_, strict=True))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 393 -
396, The slice construction using zip(padding, shape_) can silently ignore
length mismatches; update the tuple comprehension that defines slices (used to
create torch_tensor from torch_tensor_full and assigned to torch_tensor) to call
zip with strict=True (i.e., zip(padding, shape_, strict=True)) so any mismatch
between padding and shape_ raises immediately and makes debugging easier.

129-157: Prefix unused unpacked variables with underscore.

The variables q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, and o_torch from create_and_pad_tensor() are intentionally unused (they're dummy tensors for CuTe JIT tracing). Prefix them with _ to indicate intent and silence linter warnings.

Suggested fix
-        q_ref, q_cute, q_torch = create_and_pad_tensor(
+        _q_ref, q_cute, _q_torch = create_and_pad_tensor(
             qo_shape,
             ...
         )
-        k_ref, k_cute, k_torch = create_and_pad_tensor(
+        _k_ref, k_cute, _k_torch = create_and_pad_tensor(
             kv_shape,
             ...
         )
-        v_ref, v_cute, v_torch = create_and_pad_tensor(
+        _v_ref, v_cute, _v_torch = create_and_pad_tensor(
             kv_shape,
             ...
         )

-        _, o_cute, o_torch = create_and_pad_tensor(
+        _, o_cute, _o_torch = create_and_pad_tensor(
             qo_shape,
             ...
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 129 -
157, The unpacked dummy tensors returned by create_and_pad_tensor (q_ref,
q_torch, k_ref, k_torch, v_ref, v_torch, o_torch) are unused and should be
prefixed with an underscore to indicate intentional unused variables and silence
linters; update the unpacking lines where create_and_pad_tensor is called (for
q_, k_, v_, and o_) to rename those specific variables to _q_ref/_q_torch,
_k_ref/_k_torch, _v_ref/_v_torch, and _o_torch (or similar underscore-prefixed
names) while keeping the used names q_cute/k_cute/v_cute/o_cute unchanged.

318-319: Minor: device=q.device is redundant with torch.empty_like.

torch.empty_like(q, ...) already inherits q's device by default.

Suggested fix
-            out = torch.empty_like(q, device=q.device)
+            out = torch.empty_like(q)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py` around lines 318 -
319, In batch_prefill.py replace the redundant explicit device argument when
creating the empty tensor so that out is created with torch.empty_like(q)
instead of torch.empty_like(q, device=q.device); locate the assignment that sets
out when out is None (the one referencing variables out and q) and remove the
device=q.device parameter to rely on torch.empty_like inheriting q's device.
flashinfer/cute_dsl/attention/roles/softmax.py (1)

336-344: Redundant thread_idx computation.

thread_idx is computed identically at lines 337-344 and again at lines 366-373. The second computation overwrites the first with the same value.

Remove duplicate computation
         thread_idx = tidx % (
             self.threads_per_warp
             * (
                 len(self.softmax0_warp_ids)
                 if stage == 0
                 else len(self.softmax1_warp_ids)
             )
         )
         ...
         tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi)
-        thread_idx = tidx % (
-            self.threads_per_warp
-            * (
-                len(self.softmax0_warp_ids)
-                if stage == 0
-                else len(self.softmax1_warp_ids)
-            )
-        )
         thr_tmem_load = tiled_tmem_load.get_slice(thread_idx)

Also applies to: 366-373

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/roles/softmax.py` around lines 336 - 344, The
duplicated computation of thread_idx (calling cute.arch.thread_idx(), taking
tidx and computing tidx % (self.threads_per_warp * (len(self.softmax0_warp_ids)
if stage == 0 else len(self.softmax1_warp_ids)))) appears twice; remove the
redundant second block (the one at lines 366-373) so thread_idx remains computed
once and subsequent code uses the already-computed thread_idx from the first
occurrence; ensure any references after the removed block still rely on the
existing thread_idx variable and that no logic dependent on re-calling
cute.arch.thread_idx() is lost.
flashinfer/cute_dsl/attention/prefill.py (3)

155-156: Prefix unused s_k with underscore.

s_k is unpacked but never used. Prefix with _ to indicate intent.

-        b, s_q, s_k, h_q, h_k, d = problem_size
+        b, s_q, _s_k, h_q, h_k, d = problem_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 155 - 156, The tuple
unpacking of problem_size currently binds an unused variable s_k; change the
unpacking to use _s_k (or simply _ ) instead of s_k to signal it's intentionally
unused (e.g., replace "b, s_q, s_k, h_q, h_k, d = problem_size" with an
unpacking that prefixes s_k with an underscore) in the prefill logic where
variables b, s_q, h_q, h_k, d are used and h_r is computed from h_q and h_k.

45-51: Overly broad warning suppression may hide legitimate issues.

Suppressing all UserWarning messages (line 51) could mask important warnings from other parts of the codebase or dependencies. Consider scoping the suppression more narrowly, or applying it only within the specific context where the unrolling warning occurs.

Alternative: use a context manager at call sites
# Remove the global filter at module level
# warnings.filterwarnings("ignore", category=UserWarning)

# Instead, wrap specific calls that generate the warning:
import contextlib

`@contextlib.contextmanager`
def suppress_loop_unroll_warning():
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="This loop is no longer unrolled and may cause performance regression",
        )
        yield
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 45 - 51, The module
currently suppresses all UserWarning globally by calling
warnings.filterwarnings("ignore", category=UserWarning); remove that broad
module-level filter and instead scope suppression to only the specific unroll
warning by introducing a context manager (e.g., suppress_loop_unroll_warning
using warnings.catch_warnings and warnings.filterwarnings with message="This
loop is no longer unrolled and may cause performance regression") and use that
context manager at the specific call sites in prefill.py where the unrolling
warning is raised so other UserWarnings remain visible.

385-386: Prefix unused tidx with underscore.

tidx from thread_idx() is unpacked but unused in the kernel entry. The variable is only used by roles that call thread_idx() themselves.

-        tidx, _, _ = cute.arch.thread_idx()
+        _tidx, _, _ = cute.arch.thread_idx()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/prefill.py` around lines 385 - 386, The
unpacked thread index variable tidx from cute.arch.thread_idx() is unused in the
kernel entry; change its name to _tidx to mark it as intentionally unused (i.e.,
replace "tidx, _, _ = cute.arch.thread_idx()" with "_tidx, _, _ =
cute.arch.thread_idx()") so linters/readers know it's unused while keeping the
other unpacked values and the existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 153-161: The script currently unconditionally runs an SM100-only
kernel in the __main__ block (calls to bench_fmha_cutedsl), which will
JIT/launch-fail on non-SM100 GPUs; add a GPU capability check before running the
default sweep: use torch.cuda.is_available() and
torch.cuda.get_device_capability() or
torch.cuda.get_device_properties(device).major/minor (or device name) to detect
whether the current GPU supports SM100, and if not, skip the default
bench_fmha_cutedsl(...) calls and exit or print a clear message; update the
__main__ section so the SM100-only sweep only runs when the capability check
passes.

In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 96-98: The p_tmem_layout_staged is being created with the wrong
dtype (q_dtype) causing a mismatch with pv_tiled_mma which was created for V;
update the call to sm100_utils.make_smem_layout_a in collective_builder so
p_tmem_layout_staged uses v_dtype instead of q_dtype (the call that takes
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass
v_dtype) to align the P buffer TMEM layout with the V buffer.

In `@flashinfer/cute_dsl/attention/fusion/mask.py`:
- Around line 45-53: Sliding-window masking currently centers the KV window on
raw Q indices (see MaskType.SLIDING_WINDOW_MASK handling using blk_coord,
tile_shape, window_left, seqlen_k) and ignores the Q/K length offset used by the
causal path; compute q_k_offset = seqlen_k - seqlen_q and add it to first_q and
last_q (or otherwise shift Q indices into KV space) before calculating min_kv,
max_kv, start_block, end_block, and any element masks; apply the same fix to the
other sliding-window blocks noted (around the other occurrences at the given
ranges) so all sliding-window computations use shifted Q indices into KV
coordinate space.

In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 551-554: SoftCappingAttention.score_mod calls non-existent
cute.arch.tanh; replace it with a local tanh approximation implemented using
available primitives (e.g., cute.arch.exp2 and cute.arch.rcp_approx) or a cheap
rational polynomial and call that from score_mod. Add a helper function (e.g.,
_tanh_approx(x)) in the same class or module and use it in
SoftCappingAttention.score_mod (referencing self.cap and self.rcp_cap as
before), implementing tanh(x) via exp2 by computing exp(-2*abs(x)) with
exp2(-2*abs(x)/ln2) plus sign handling or by a stable rational approximation
(polynomial numerator/denominator) and ensure the helper uses
cute.jit-compatible operations only.
- Around line 367-378: Update the class and relevant parameter docstrings to
state that sink values are expected in the logit domain (raw Q·K dot-product
units, unnormalized), not pre-scaled to log2; specifically mention this near the
documentation for the sink parameter(s) used by update_statistics and the
self.params/sink_raw conversion (which divides by scale/log2_e), and add a
cross-reference to sink_softmax in sink_attention_reference.py so callers know
sinks are concatenated to logits before any log2 scaling.

In `@flashinfer/cute_dsl/attention/pipeline_topology.py`:
- Around line 68-79: The Pipeline.dataclass field cluster_scale is ignored by
create_pipelines(), causing incorrect participant and barrier arrive counts;
either (preferred) honor it by multiplying the all-thread side's participant
counts when constructing producer/consumer groups and computing barrier arrive
counts for PipelineType values UMMA_ASYNC and ASYNC_UMMA (but leave TMA_UMMA
unchanged), i.e., when building groups from producer_warp_ids/consumer_warp_ids
in create_pipelines() multiply the thread counts by pipeline.cluster_scale and
use that scaled value when setting arrive counts for barriers/tx_count_key, or
fail fast by adding a check in create_pipelines() that raises a clear exception
if pipeline.cluster_scale != 1 so callers must handle scaling explicitly.

In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 41-66: partition_output is incorrectly decorated with `@cute.jit`
while returning tensor objects (tOsO, tOgO) which violates the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs as a normal Python method, or refactor it to avoid returning tensors by (a)
accepting preallocated output containers/handles and writing into them, or (b)
moving the cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function
into a non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).

In `@flashinfer/cute_dsl/attention/warp_schedule.py`:
- Around line 17-71: Add a fail-fast validation in WarpSchedule (implement in a
__post_init__ method) that verifies: 1) all_warp_ids (built from
softmax0_warp_ids, softmax1_warp_ids, correction_warp_ids, mma_warp_id,
load_warp_id, epilogue_warp_id, empty_warp_id) contain unique values and form a
contiguous range starting at 0 up to len(all_warp_ids)-1, and 2) the total
number of softmax warps (len(softmax0_warp_ids)+len(softmax1_warp_ids)) is
divisible by num_warps_per_warpgroup; on violation raise ValueError with a clear
message referencing the failing condition so consumers of num_warps,
threads_per_cta, and softmax_warpgroup_count cannot silently compute incorrect
sizes.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 159-169: The NameError risk comes from params_cute being defined
only inside the if self._has_params block yet referenced later; fix by defining
params_cute = None before the if and only assigning it inside the block (where
you call from_dlpack) so later code can safely use the conditional expression
(params_cute.iterator if self._has_params else None); update references
involving self._has_params, _params_torch, and from_dlpack accordingly to rely
on the initialized params_cute variable.

---

Nitpick comments:
In `@benchmarks/bench_blackwell_attention_cutedsl.py`:
- Around line 7-8: The benchmark imports bench_gpu_time from a private path
(flashinfer.testing.utils); update the import to use the public helper by
replacing references to flashinfer.testing.utils with the public module
flashinfer.testing and import bench_gpu_time from flashinfer.testing (i.e., use
flashinfer.testing.bench_gpu_time) so the benchmark relies on the supported
public API rather than a private module.

In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 163-186: Replace the anonymous SimpleNamespace return with a typed
dataclass (e.g., define `@dataclass` class AttentionCollective or
AttentionCollectiveConfig) that declares typed fields for each symbol currently
passed (qk_tiled_mma, pv_tiled_mma, tma_atom_q, tma_tensor_q, tma_atom_k,
tma_tensor_k, tma_atom_v, tma_tensor_v, tma_atom_o, tma_tensor_o,
q_smem_layout_staged, k_smem_layout_staged, p_tmem_layout_staged,
v_smem_layout_staged, o_smem_layout_staged, SharedStorage, tma_copy_q_bytes,
tma_copy_kv_bytes, cluster_shape_mnk, cluster_layout_vmnk, epi_tile, o_layout),
add appropriate type hints (use typing.Any or more specific types if known),
import dataclasses and typing, instantiate and return that dataclass instead of
SimpleNamespace, and update any consumers to accept the new dataclass type for
improved IDE autocompletion and type checking.

In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 155-156: The tuple unpacking of problem_size currently binds an
unused variable s_k; change the unpacking to use _s_k (or simply _ ) instead of
s_k to signal it's intentionally unused (e.g., replace "b, s_q, s_k, h_q, h_k, d
= problem_size" with an unpacking that prefixes s_k with an underscore) in the
prefill logic where variables b, s_q, h_q, h_k, d are used and h_r is computed
from h_q and h_k.
- Around line 45-51: The module currently suppresses all UserWarning globally by
calling warnings.filterwarnings("ignore", category=UserWarning); remove that
broad module-level filter and instead scope suppression to only the specific
unroll warning by introducing a context manager (e.g.,
suppress_loop_unroll_warning using warnings.catch_warnings and
warnings.filterwarnings with message="This loop is no longer unrolled and may
cause performance regression") and use that context manager at the specific call
sites in prefill.py where the unrolling warning is raised so other UserWarnings
remain visible.
- Around line 385-386: The unpacked thread index variable tidx from
cute.arch.thread_idx() is unused in the kernel entry; change its name to _tidx
to mark it as intentionally unused (i.e., replace "tidx, _, _ =
cute.arch.thread_idx()" with "_tidx, _, _ = cute.arch.thread_idx()") so
linters/readers know it's unused while keeping the other unpacked values and the
existing warp_idx assignment (warp_idx =
cute.arch.make_warp_uniform(cute.arch.warp_idx())) intact.

In `@flashinfer/cute_dsl/attention/roles/softmax.py`:
- Around line 336-344: The duplicated computation of thread_idx (calling
cute.arch.thread_idx(), taking tidx and computing tidx % (self.threads_per_warp
* (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids))))
appears twice; remove the redundant second block (the one at lines 366-373) so
thread_idx remains computed once and subsequent code uses the already-computed
thread_idx from the first occurrence; ensure any references after the removed
block still rely on the existing thread_idx variable and that no logic dependent
on re-calling cute.arch.thread_idx() is lost.

In `@flashinfer/cute_dsl/attention/scheduler/persistent.py`:
- Around line 38-45: In __new_from_mlir_values__ update the zip over
[self.is_persistent, self.problem_shape_mbh] and self._values_pos to use
zip(..., strict=True) to fail loudly on length mismatches, and when returning
the FmhaStaticTileSchedulerParams instance forward the current object's ip
parameter (pass loc=self._loc, ip=self.ip) so the new instance receives ip as
well; this touches the __new_from_mlir_values__ method, the attributes
self.is_persistent, self.problem_shape_mbh, self._values_pos, and the
FmhaStaticTileSchedulerParams constructor call.
- Around line 148-158: The hardcoded assertion in __new_from_mlir_values__
(assert len(values) == 10) is fragile; change it to compute the expected MLIR
value count by summing the MLIR-value counts of the constituent objects
(self._params, self._current_work_linear_idx, self._blk_coord, self._grid_shape)
using whatever helper/attribute your cutlass layer exposes (e.g., a mlir value
count helper or by querying each object's MLIR representation), then compare
len(values) to that computed total and raise a ValueError with a descriptive
message if mismatched; update the slicing logic that builds new_params,
new_current_work_linear_idx, new_blk_coord, and new_grid_shape to use those
computed per-object counts instead of fixed indices so FmhaStaticTileScheduler
construction remains correct.

In `@flashinfer/cute_dsl/attention/tmem_layout.py`:
- Around line 35-49: Extract the literal 512 used for TMEM capacity into a
module-level constant (e.g., SM100_TMEM_CAPACITY_COLUMNS = 512) and replace the
local variable in TmemLayout.from_config so the function uses that constant
instead of a magic number; update the top of the module with the constant and
ensure TmemLayout.from_config (which takes AttentionConfig and reads
config.mma_tiler[0]) references the new constant for alloc_cols so other code
can reuse the hardware-capacity value.

In `@flashinfer/cute_dsl/attention/wrappers/batch_prefill.py`:
- Around line 393-396: The slice construction using zip(padding, shape_) can
silently ignore length mismatches; update the tuple comprehension that defines
slices (used to create torch_tensor from torch_tensor_full and assigned to
torch_tensor) to call zip with strict=True (i.e., zip(padding, shape_,
strict=True)) so any mismatch between padding and shape_ raises immediately and
makes debugging easier.
- Around line 129-157: The unpacked dummy tensors returned by
create_and_pad_tensor (q_ref, q_torch, k_ref, k_torch, v_ref, v_torch, o_torch)
are unused and should be prefixed with an underscore to indicate intentional
unused variables and silence linters; update the unpacking lines where
create_and_pad_tensor is called (for q_, k_, v_, and o_) to rename those
specific variables to _q_ref/_q_torch, _k_ref/_k_torch, _v_ref/_v_torch, and
_o_torch (or similar underscore-prefixed names) while keeping the used names
q_cute/k_cute/v_cute/o_cute unchanged.
- Around line 318-319: In batch_prefill.py replace the redundant explicit device
argument when creating the empty tensor so that out is created with
torch.empty_like(q) instead of torch.empty_like(q, device=q.device); locate the
assignment that sets out when out is None (the one referencing variables out and
q) and remove the device=q.device parameter to rely on torch.empty_like
inheriting q's device.

In `@tests/test_blackwell_fmha_attention.py`:
- Around line 1-13: The test module test_blackwell_fmha_attention.py is at the
tests/ root but belongs in the attention-specific kernel tests; move this suite
into a feature-specific tests subdirectory matching the kernel category (e.g.,
an attention/ or blackwell_fmha/ tests folder), update any relative imports
inside the module to the new location, and ensure all test callables in the file
are properly prefixed with test_ so pytest discovers them (check function names
and any parametrized fixtures used by functions in this module).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f4d60aef-cbff-41af-8cdd-f4914f9854d7

📥 Commits

Reviewing files that changed from the base of the PR and between e4dc66f and 9f0ba5e.

📒 Files selected for processing (24)
  • benchmarks/bench_blackwell_attention_cutedsl.py
  • flashinfer/cute_dsl/attention/__init__.py
  • flashinfer/cute_dsl/attention/collective_builder.py
  • flashinfer/cute_dsl/attention/config.py
  • flashinfer/cute_dsl/attention/fusion/__init__.py
  • flashinfer/cute_dsl/attention/fusion/mask.py
  • flashinfer/cute_dsl/attention/fusion/variant.py
  • flashinfer/cute_dsl/attention/mainloop_spec.py
  • flashinfer/cute_dsl/attention/pipeline_topology.py
  • flashinfer/cute_dsl/attention/prefill.py
  • flashinfer/cute_dsl/attention/roles/__init__.py
  • flashinfer/cute_dsl/attention/roles/correction.py
  • flashinfer/cute_dsl/attention/roles/epilogue.py
  • flashinfer/cute_dsl/attention/roles/loader_tma.py
  • flashinfer/cute_dsl/attention/roles/mma.py
  • flashinfer/cute_dsl/attention/roles/softmax.py
  • flashinfer/cute_dsl/attention/roles/softmax_math.py
  • flashinfer/cute_dsl/attention/scheduler/__init__.py
  • flashinfer/cute_dsl/attention/scheduler/persistent.py
  • flashinfer/cute_dsl/attention/tmem_layout.py
  • flashinfer/cute_dsl/attention/warp_schedule.py
  • flashinfer/cute_dsl/attention/wrappers/__init__.py
  • flashinfer/cute_dsl/attention/wrappers/batch_prefill.py
  • tests/test_blackwell_fmha_attention.py

Comment on lines +96 to +98
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how p_tmem_layout_staged is used and if q_dtype is the correct type
rg -n "p_tmem_layout" --type py flashinfer/cute_dsl/attention/

Repository: flashinfer-ai/flashinfer

Length of output: 824


🏁 Script executed:

#!/bin/bash
# Get the full function signature and beginning to see all dtype parameters
sed -n '50,120p' flashinfer/cute_dsl/attention/collective_builder.py

Repository: flashinfer-ai/flashinfer

Length of output: 2667


🏁 Script executed:

#!/bin/bash
# Check sm100_utils.make_smem_layout_a to see what dtype parameter does
rg -n "def make_smem_layout_a" --type py flashinfer/cute_dsl/attention/

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check what dtypes are typically used - look for q_dtype, k_dtype, v_dtype, p_dtype
rg -n "q_dtype|k_dtype|v_dtype|p_dtype|intermediate_dtype" --type py flashinfer/cute_dsl/attention/collective_builder.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 900


P buffer TMEM layout should use v_dtype, not q_dtype.

The P buffer (attention weights) shares the same pv_tiled_mma with V, which is created with v_dtype. However, p_tmem_layout_staged is created with q_dtype instead, causing a dtype mismatch. Change line 97 to use v_dtype:

Before:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
    pv_tiled_mma, config.pv_mma_tiler, q_dtype, mainloop.acc_stage,
)
After:
p_tmem_layout_staged = sm100_utils.make_smem_layout_a(
    pv_tiled_mma, config.pv_mma_tiler, v_dtype, mainloop.acc_stage,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 96 - 98,
The p_tmem_layout_staged is being created with the wrong dtype (q_dtype) causing
a mismatch with pv_tiled_mma which was created for V; update the call to
sm100_utils.make_smem_layout_a in collective_builder so p_tmem_layout_staged
uses v_dtype instead of q_dtype (the call that takes pv_tiled_mma,
config.pv_mma_tiler, q_dtype, mainloop.acc_stage should pass v_dtype) to align
the P buffer TMEM layout with the V buffer.

Comment on lines +367 to +378
@cute.jit
def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale):
log2_e = math.log2(math.exp(1.0))
sink_raw = (
self.params[qo_head_idx] * log2_e / scale
if kv_tile_idx == 0
else -math.inf
)
m_new = sink_raw if sink_raw > m else m
rescale = cute.arch.exp2((m - m_new) * scale)
d_new = cute.arch.exp2((sink_raw - m_new) * scale) + d * rescale
return m_new, d_new
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

CuTe DSL attention sink implementation softmax denominator

💡 Result:

No specific "attention sink implementation" in CuTe DSL was found. CuTe DSL (from NVIDIA CUTLASS) is used for high-performance FlashAttention kernels, including softmax computation via online softmax for numerical stability. The softmax denominator (row sums of exp(scores - row_max)) is computed using tiled reductions in shared memory across blocks/tiles of the attention matrix, leveraging CuTe's tensor layouts, TensorSSA for vectorized ops, and arch-specific reductions like cute.arch.warp_reduction_sum (as in RMSNorm examples, analogous for softmax rows). Official examples like CUTLASS's flash_attention_v2.py feature softmax_rescale_O for online softmax fusion with GEMM. Repos like luliyucoordinate/cute-flash-attention and HuyNguyen-hust/flash-attn-101 implement FlashAttention in CuTe, using standard online softmax without special "sink" handling—sinks are a training/model phenomenon, not a kernel feature. FlashAttention-3/4 in Dao-AILab/flash-attention use CuTe DSL for Hopper/Blackwell with optimized exp2/emulation and reductions for the denominator. Attention sinks refer to probability mass concentrating on fixed positions in softmax attention (e.g., arXiv:2603.11487), but kernel implementations compute general softmax denominators via parallel reductions, not sink-specific logic.

Citations:


🏁 Script executed:

find . -type f -name "variant.py" | grep -E "cute_dsl|attention"

Repository: flashinfer-ai/flashinfer

Length of output: 117


🏁 Script executed:

if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
  wc -l flashinfer/cute_dsl/attention/fusion/variant.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

if [ -f "flashinfer/cute_dsl/attention/fusion/variant.py" ]; then
  sed -n '350,395p' flashinfer/cute_dsl/attention/fusion/variant.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1402


🏁 Script executed:

# Search for scale parameter definition and sink-related documentation
rg -n "scale\s*=" flashinfer/cute_dsl/attention/fusion/variant.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 335


🏁 Script executed:

# Look for documentation about sink semantics
rg -n -B3 -A3 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1858


🏁 Script executed:

# Search for params initialization and documentation
rg -n -B5 "self.params" flashinfer/cute_dsl/attention/fusion/variant.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2271


🏁 Script executed:

sed -n '340,410p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 2239


🏁 Script executed:

# Search for AttentionWithSink usage and test cases
rg -n "AttentionWithSink" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 839


🏁 Script executed:

# Look for documentation around the sink parameter
sed -n '340,365p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 790


🏁 Script executed:

# Look at test cases to understand sink value semantics
sed -n '420,450p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1041


🏁 Script executed:

# Check more of the test context
sed -n '715,780p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 2614


🏁 Script executed:

# Look for any documentation about sink value domain/units
rg -n -B5 -A5 "sink" flashinfer/cute_dsl/attention/fusion/variant.py | tail -80

Repository: flashinfer-ai/flashinfer

Length of output: 2864


🏁 Script executed:

# Check if there's a reference paper or documentation file
find . -type f -name "*.md" -o -name "*.rst" | xargs grep -l "sink" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

# Find the reference implementation to understand sink semantics
rg -n "def attention_ref" tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

# Look at the reference implementation
sed -n '1,120p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 3643


🏁 Script executed:

# Search for attention_ref implementation across the repo
find . -type f -name "*.py" -exec grep -l "def attention_ref" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 230


🏁 Script executed:

# Find the sink_softmax implementation
find . -type f -name "sink_attention_reference.py"

Repository: flashinfer-ai/flashinfer

Length of output: 116


🏁 Script executed:

# Search for sink_softmax definition
rg -n "def sink_softmax" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

cat tests/test_helpers/sink_attention_reference.py

Repository: flashinfer-ai/flashinfer

Length of output: 16156


🏁 Script executed:

# Look at how sink values are used in tests - what domain are they in?
sed -n '420,435p' tests/test_blackwell_fmha_attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 710


🏁 Script executed:

# Check the documentation docstring more carefully
sed -n '337,356p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 696


🏁 Script executed:

# Look for any comments about sink semantics in scale context
rg -n -B10 -A10 "log2_e\|log2.math" flashinfer/cute_dsl/attention/fusion/variant.py | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for how scale is set/passed to update_statistics
rg -n "update_statistics.*scale" --type py -B3 -A3 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 2605


🏁 Script executed:

# Look for where scale is initialized/computed in the wrapper
rg -n "scale.*=" flashinfer/cute_dsl/attention/ --type py | grep -v test | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 3044


🏁 Script executed:

# Check if there's any comment about sink value domain in tests
rg -n -B5 -A5 "torch.randn.*sink" tests/ --type py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the softmax_math.py to understand how m and d work
sed -n '1,50p' flashinfer/cute_dsl/attention/roles/softmax_math.py

Repository: flashinfer-ai/flashinfer

Length of output: 1430


🏁 Script executed:

# Check what domain m and d are actually in
rg -n "m_new\|d_new" flashinfer/cute_dsl/attention/fusion/variant.py -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for documentation about expected sink value domain/units
rg -n "sink" flashinfer/cute_dsl/attention/fusion/variant.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1018


🏁 Script executed:

# Let me check the documentation more carefully for scale parameter
sed -n '50,100p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 1926


🏁 Script executed:

# Look at the complete context of the update_statistics docstring
sed -n '275,315p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 1503


🏁 Script executed:

# Let me verify the mathematical simplification one more time
python3 << 'EOF'
import math

# Verify: if scale = log2(e) * sm_scale, then params * log2_e / scale = params / sm_scale

log2_e = math.log2(math.exp(1.0))
print(f"log2(e) = {log2_e}")
print(f"log2(e) ≈ 1.4426950408889634")

# For any sm_scale, verify the simplification
sm_scale = 1.0 / math.sqrt(128)  # example: head_dim = 128
scale = log2_e * sm_scale

print(f"\nFor head_dim=128: sm_scale = {sm_scale}")
print(f"scale = log2(e) * sm_scale = {scale}")

# Test the simplification with a sample sink value
params = 2.5
sink_raw_original = params * log2_e / scale
sink_raw_simplified = params / sm_scale

print(f"\nFor params = {params}:")
print(f"sink_raw (original formula) = params * log2_e / scale = {sink_raw_original}")
print(f"sink_raw (simplified formula) = params / sm_scale = {sink_raw_simplified}")
print(f"Match: {abs(sink_raw_original - sink_raw_simplified) < 1e-10}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 401


🏁 Script executed:

# Check if sink value domain is documented in docstring or README
find . -type f \( -name "README.md" -o -name "*.rst" \) -exec grep -l "sink" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Final check: Look at the complete example in the docstring
sed -n '131,150p' flashinfer/cute_dsl/attention/fusion/variant.py

Repository: flashinfer-ai/flashinfer

Length of output: 751


Clarify the expected domain of sink tensor values in the docstring.

The mathematical conversion sink_raw = params * log2_e / scale is correct and simplifies to params / sm_scale as stated. However, the class docstring does not document the domain in which sink values are expected. The reference implementation (sink_softmax in sink_attention_reference.py) concatenates sink values directly to logits, confirming they should be in the logit domain (raw QK dot-product units, unnormalized). Add this to the sink parameter documentation to clarify that users must provide sink values in the logit domain, not pre-converted to log2 scale.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 367 - 378,
Update the class and relevant parameter docstrings to state that sink values are
expected in the logit domain (raw Q·K dot-product units, unnormalized), not
pre-scaled to log2; specifically mention this near the documentation for the
sink parameter(s) used by update_statistics and the self.params/sink_raw
conversion (which divides by scale/log2_e), and add a cross-reference to
sink_softmax in sink_attention_reference.py so callers know sinks are
concatenated to logits before any log2 scaling.

Comment on lines +41 to +66
# NOTE on CuTe DSL JIT limitations:
# - partition_output(): Returns tensor tuples — CuTe DSL JIT does not
# reliably handle returning tensors from @cute.jit methods.
# - store_tile(): SAFE — takes pre-sliced tensors as arguments, no
# runtime indexing or return values. Used in run() successfully.
# =========================================================================

@cute.jit
def partition_output(
self,
tma_atom_o: cute.CopyAtom,
mO_qdl: cute.Tensor,
sO: cute.Tensor,
block_coord: tuple,
):
"""Partition output global tensor for TMA stores. Returns (tOsO, tOgO)."""
gO_qdl = cute.flat_divide(
mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1])
)
gO = gO_qdl[None, None, None, 0, block_coord[2]]
tOsO, tOgO = cute.nvgpu.cpasync.tma_partition(
tma_atom_o, 0, cute.make_layout(1),
cute.group_modes(sO, 0, 2),
cute.group_modes(gO, 0, 2),
)
return tOsO, tOgO
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Does NVIDIA CUTLASS CuTe DSL reliably support returning tensor objects from @cute.jit functions or methods, or should helpers that return tensors stay non-jitted/inlined?

💡 Result:

No, NVIDIA CUTLASS CuTe DSL does not reliably support returning tensor objects from @cute.jit functions or methods. Helpers that return tensors should stay non-jitted/inlined.

Citations:


Remove @cute.jit from partition_output() or refactor to avoid returning tensors.

This method is decorated with @cute.jit but returns a tensor tuple, contradicting the documented limitation directly above it: "CuTe DSL JIT does not reliably handle returning tensors from @cute.jit methods." Since this is exported as a reusable primitive, it violates its own safety contract. Either remove the @cute.jit decorator or restructure to keep tensor returns outside JIT compilation per NVIDIA documentation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 41 - 66,
partition_output is incorrectly decorated with `@cute.jit` while returning tensor
objects (tOsO, tOgO) which violates the CuTe JIT limitation; either remove the
`@cute.jit` decorator from partition_output so it runs as a normal Python method,
or refactor it to avoid returning tensors by (a) accepting preallocated output
containers/handles and writing into them, or (b) moving the
cute.nvgpu.cpasync.tma_partition call out of the `@cute.jit` function into a
non-jit wrapper (e.g., create partition_output_nonjit that calls
cute.nvgpu.cpasync.tma_partition and returns tensors or change partition_output
to populate passed-in tensor references); update references to partition_output
accordingly so no `@cute.jit` function returns tensors (symbols: partition_output,
tOsO, tOgO, tma_partition, tma_atom_o).

Comment on lines +17 to +71
@dataclass(frozen=True)
class WarpSchedule:
"""Defines warp role assignment and register budgets for attention kernels.

Each field maps directly to C++ CUTLASS's KernelSchedule:
- Warp ID ranges for each role
- Register allocation per role (controls spill/occupancy tradeoff)
- Barrier IDs for CTA sync and TMEM allocation
"""

softmax0_warp_ids: Tuple[int, ...] = (0, 1, 2, 3)
softmax1_warp_ids: Tuple[int, ...] = (4, 5, 6, 7)
correction_warp_ids: Tuple[int, ...] = (8, 9, 10, 11)
mma_warp_id: int = 12
load_warp_id: int = 13
epilogue_warp_id: int = 14
empty_warp_id: int = 15

num_regs_softmax: int = 192
num_regs_correction: int = 96
num_regs_other: int = 32
num_regs_empty: int = 24

threads_per_warp: int = 32
cta_sync_bar_id: int = 0
tmem_alloc_sync_bar_id: int = 1

@property
def all_warp_ids(self) -> Tuple[int, ...]:
return (
*self.softmax0_warp_ids,
*self.softmax1_warp_ids,
*self.correction_warp_ids,
self.mma_warp_id,
self.load_warp_id,
self.epilogue_warp_id,
self.empty_warp_id,
)

@property
def num_warps(self) -> int:
return len(self.all_warp_ids)

@property
def threads_per_cta(self) -> int:
return self.threads_per_warp * self.num_warps

@property
def num_warps_per_warpgroup(self) -> int:
return 4

@property
def softmax_warpgroup_count(self) -> int:
total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
return total_softmax_warps // self.num_warps_per_warpgroup
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate custom schedules before deriving CTA sizes.

num_warps, threads_per_cta, and softmax_warpgroup_count all assume the warp ids are unique, contiguous from 0, and that the softmax warps fill whole warpgroups. With a custom WarpSchedule, duplicate/gapped ids or a non-multiple-of-4 softmax set will silently produce the wrong CTA/barrier sizing.

🛠️ Suggested fail-fast validation
 `@dataclass`(frozen=True)
 class WarpSchedule:
@@
     threads_per_warp: int = 32
     cta_sync_bar_id: int = 0
     tmem_alloc_sync_bar_id: int = 1
+
+    def __post_init__(self):
+        all_warp_ids = self.all_warp_ids
+        if len(set(all_warp_ids)) != len(all_warp_ids):
+            raise ValueError("warp ids must be unique across roles")
+        if tuple(sorted(all_warp_ids)) != tuple(range(len(all_warp_ids))):
+            raise ValueError("warp ids must form a contiguous range starting at 0")
+        total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids)
+        if total_softmax_warps % self.num_warps_per_warpgroup != 0:
+            raise ValueError("softmax warps must fill whole warpgroups")
+        if self.cta_sync_bar_id == self.tmem_alloc_sync_bar_id:
+            raise ValueError("barrier ids must be distinct")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/warp_schedule.py` around lines 17 - 71, Add a
fail-fast validation in WarpSchedule (implement in a __post_init__ method) that
verifies: 1) all_warp_ids (built from softmax0_warp_ids, softmax1_warp_ids,
correction_warp_ids, mma_warp_id, load_warp_id, epilogue_warp_id, empty_warp_id)
contain unique values and form a contiguous range starting at 0 up to
len(all_warp_ids)-1, and 2) the total number of softmax warps
(len(softmax0_warp_ids)+len(softmax1_warp_ids)) is divisible by
num_warps_per_warpgroup; on violation raise ValueError with a clear message
referencing the failing condition so consumers of num_warps, threads_per_cta,
and softmax_warpgroup_count cannot silently compute incorrect sizes.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 24, 2026

/bot run

@yzh119 yzh119 added the run-ci label Mar 24, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !457 has been created, and the CI pipeline #46896453 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46896453: 10/20 passed

@pgera pgera requested a review from samuellees as a code owner April 1, 2026 23:12
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
flashinfer/cute_dsl/attention/roles/epilogue.py (1)

48-66: ⚠️ Potential issue | 🟠 Major

Remove @cute.jit from partition_output() or stop returning tensors.

This helper still returns (tOsO, tOgO), so it reintroduces the exact CuTe JIT limitation called out in the note immediately above it. run() already had to inline the same partitioning logic to avoid that path, which is a strong sign this exported primitive is still unsafe for reuse.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/roles/epilogue.py` around lines 48 - 66, The
partition_output helper currently decorated with `@cute.jit` returns tensors
(tOsO, tOgO) which reintroduces the CuTe JIT limitation; either remove the
`@cute.jit` decorator from partition_output so it runs in Python (and call it from
run()), or keep it JIT'd but change its API to not return cute.Tensor/CuTe
objects (e.g., perform the tma_partition side-effects inside the function or
write results into provided buffers/atoms), updating callers (notably run()) to
use the new behavior; locate partition_output, its use of
cute.nvgpu.cpasync.tma_partition and the tma_atom_o argument when making the
change.
🧹 Nitpick comments (3)
flashinfer/cute_dsl/attention/fusion/variant.py (1)

357-382: Make the base transform_output() honor scale.

The docstring says this hook replaces output *= scale_output / d, but the fallback implementation drops scale and returns output * rcp_d. That makes the base API misleading for any custom variant that enables has_output_transform=True and relies on the inherited behavior.

♻️ Proposed fix
-        return output * rcp_d
+        return output * scale * rcp_d
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/fusion/variant.py` around lines 357 - 382, The
base transform_output method (transform_output) ignores the scale parameter,
returning output * rcp_d which breaks the documented contract and any subclass
relying on inherited behavior when has_output_transform=True; modify
transform_output to apply the scale as the docstring describes by returning
output multiplied by rcp_d and scale (i.e., incorporate the scale/output scaling
factor), and ensure callers referring to scale_output/scale still get the
correct behavior.
tests/test_blackwell_fmha_attention.py (1)

4-13: Move this suite under a kernel-category test subdirectory.

This is a new root-level test module, but the CuTe-DSL attention coverage will be easier to discover and maintain if it lives under a feature-specific tests/ subdirectory with the rest of the kernel-category tests.

As per coding guidelines, tests/**/*.py: Prefix test functions with test_ and structure tests by feature in tests/ subdirectories matching kernel categories.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_blackwell_fmha_attention.py` around lines 4 - 13, The test module
tests/test_blackwell_fmha_attention.py should be moved into the kernel-category
tests subdirectory and follow the test naming convention; relocate the file into
the appropriate feature-specific folder under tests/ (e.g.,
tests/kernel_category/blackwell_fmha/) and ensure all test functions inside the
module are prefixed with test_ (rename any non-prefixed functions), keeping the
module name descriptive (blackwell_fmha_attention) and updating any imports or
test discovery paths accordingly.
flashinfer/cute_dsl/attention/collective_builder.py (1)

137-155: Use MainloopSpec.barrier_stage_counts() as the source of truth for SharedStorage.

This block still hard-codes every barrier array size even though the spec object now exposes those counts explicitly for shared-storage sizing. Keeping topology and storage slots in two places makes the next stage-count/topology tweak easy to desynchronize.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/collective_builder.py` around lines 137 - 155,
The SharedStorage struct currently hard-codes barrier array sizes and
conditional logic; instead call mainloop.barrier_stage_counts() once and use its
returned counts to size every MemRange so storage and topology come from a
single source of truth. Update the s0_corr_stages, mma_corr_stages,
s0_epi_stages locals (or remove them) and replace the numeric expressions used
in SharedStorage fields (load_q_mbar_ptr, load_kv_mbar_ptr, mma_s0_mbar_ptr,
mma_s1_mbar_ptr, s0_corr_mbar_ptr, s1_corr_mbar_ptr, s0_s1_sequence_mbar_ptr,
corr_epi_mbar_ptr, mma_corr_mbar_ptr, s0_epi_mbar_ptr, s1_epi_mbar_ptr,
tmem_dealloc_mbar_ptr) with the appropriate entries from counts (e.g.,
counts['q'], counts['kv'], counts['mma_softmax'], counts['s0_corr'],
counts['mma_corr'], counts['s0_epi'], sched.softmax_warpgroup_count, etc.),
removing the has_logits_transform conditionals so all sizes derive from
mainloop.barrier_stage_counts().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/cute_dsl/attention/mainloop_spec.py`:
- Around line 21-22: The helper that builds a prefill topology must use the
transform warp schedule when has_logits_transform=True: update the logic in the
factory that constructs PipelineTopology (the code paths using
make_prefill_topology and make_prefill_topology_transform and selecting a
WarpSchedule) to choose WarpSchedule.PREFILL_TRANSFORM_SCHEDULE (or the
module-level PREFILL_TRANSFORM_SCHEDULE constant) instead of PREFILL_SCHEDULE
when has_logits_transform is true; apply the same change to the other occurrence
that mirrors this logic so both prefill/topology-transform code paths use
PREFILL_TRANSFORM_SCHEDULE for transform variants.

In `@flashinfer/cute_dsl/attention/prefill.py`:
- Around line 42-51: Remove the process-wide suppression that calls
warnings.filterwarnings("ignore", category=UserWarning) in prefill.py; keep only
the narrow message-based filter for the CUTLASS loop-unroll warning (the
existing warnings.filterwarnings(..., message="This loop is no longer
unrolled...")) and, if extra suppression is needed around a specific compile
call, wrap that call in a local warnings.catch_warnings() context or apply a
temporary filter immediately around the compile helper rather than mutating the
module/global filter. Target the two warnings.filterwarnings calls in this file
(the message-based and the category-based invocations) and delete or refactor
the category-based one.

---

Duplicate comments:
In `@flashinfer/cute_dsl/attention/roles/epilogue.py`:
- Around line 48-66: The partition_output helper currently decorated with
`@cute.jit` returns tensors (tOsO, tOgO) which reintroduces the CuTe JIT
limitation; either remove the `@cute.jit` decorator from partition_output so it
runs in Python (and call it from run()), or keep it JIT'd but change its API to
not return cute.Tensor/CuTe objects (e.g., perform the tma_partition
side-effects inside the function or write results into provided buffers/atoms),
updating callers (notably run()) to use the new behavior; locate
partition_output, its use of cute.nvgpu.cpasync.tma_partition and the tma_atom_o
argument when making the change.

---

Nitpick comments:
In `@flashinfer/cute_dsl/attention/collective_builder.py`:
- Around line 137-155: The SharedStorage struct currently hard-codes barrier
array sizes and conditional logic; instead call mainloop.barrier_stage_counts()
once and use its returned counts to size every MemRange so storage and topology
come from a single source of truth. Update the s0_corr_stages, mma_corr_stages,
s0_epi_stages locals (or remove them) and replace the numeric expressions used
in SharedStorage fields (load_q_mbar_ptr, load_kv_mbar_ptr, mma_s0_mbar_ptr,
mma_s1_mbar_ptr, s0_corr_mbar_ptr, s1_corr_mbar_ptr, s0_s1_sequence_mbar_ptr,
corr_epi_mbar_ptr, mma_corr_mbar_ptr, s0_epi_mbar_ptr, s1_epi_mbar_ptr,
tmem_dealloc_mbar_ptr) with the appropriate entries from counts (e.g.,
counts['q'], counts['kv'], counts['mma_softmax'], counts['s0_corr'],
counts['mma_corr'], counts['s0_epi'], sched.softmax_warpgroup_count, etc.),
removing the has_logits_transform conditionals so all sizes derive from
mainloop.barrier_stage_counts().

In `@flashinfer/cute_dsl/attention/fusion/variant.py`:
- Around line 357-382: The base transform_output method (transform_output)
ignores the scale parameter, returning output * rcp_d which breaks the
documented contract and any subclass relying on inherited behavior when
has_output_transform=True; modify transform_output to apply the scale as the
docstring describes by returning output multiplied by rcp_d and scale (i.e.,
incorporate the scale/output scaling factor), and ensure callers referring to
scale_output/scale still get the correct behavior.

In `@tests/test_blackwell_fmha_attention.py`:
- Around line 4-13: The test module tests/test_blackwell_fmha_attention.py
should be moved into the kernel-category tests subdirectory and follow the test
naming convention; relocate the file into the appropriate feature-specific
folder under tests/ (e.g., tests/kernel_category/blackwell_fmha/) and ensure all
test functions inside the module are prefixed with test_ (rename any
non-prefixed functions), keeping the module name descriptive
(blackwell_fmha_attention) and updating any imports or test discovery paths
accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3433f39f-702a-4c43-999e-b1895a32a960

📥 Commits

Reviewing files that changed from the base of the PR and between 9f0ba5e and ada3c0f.

📒 Files selected for processing (12)
  • flashinfer/cute_dsl/attention/__init__.py
  • flashinfer/cute_dsl/attention/collective_builder.py
  • flashinfer/cute_dsl/attention/fusion/__init__.py
  • flashinfer/cute_dsl/attention/fusion/variant.py
  • flashinfer/cute_dsl/attention/mainloop_spec.py
  • flashinfer/cute_dsl/attention/pipeline_topology.py
  • flashinfer/cute_dsl/attention/prefill.py
  • flashinfer/cute_dsl/attention/roles/epilogue.py
  • flashinfer/cute_dsl/attention/roles/mma.py
  • flashinfer/cute_dsl/attention/roles/softmax.py
  • flashinfer/cute_dsl/attention/warp_schedule.py
  • tests/test_blackwell_fmha_attention.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • flashinfer/cute_dsl/attention/fusion/init.py
  • flashinfer/cute_dsl/attention/init.py
  • flashinfer/cute_dsl/attention/warp_schedule.py
  • flashinfer/cute_dsl/attention/roles/mma.py
  • flashinfer/cute_dsl/attention/pipeline_topology.py

Comment on lines +21 to +22
from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE
from .pipeline_topology import PipelineTopology, make_prefill_topology, make_prefill_topology_transform
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the transform schedule when has_logits_transform=True.

This helper still falls back to PREFILL_SCHEDULE for transform variants. That disagrees with BlackwellFusedMultiHeadAttentionForward.__init__, which already switches to PREFILL_TRANSFORM_SCHEDULE, so direct callers of this public factory can build a transform topology with the wrong warp schedule.

♻️ Proposed fix
-from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE
+from .warp_schedule import (
+    WarpSchedule,
+    PREFILL_SCHEDULE,
+    PREFILL_TRANSFORM_SCHEDULE,
+)
@@
-    sched = warp_schedule if warp_schedule is not None else PREFILL_SCHEDULE
+    if warp_schedule is None:
+        sched = (
+            PREFILL_TRANSFORM_SCHEDULE
+            if has_logits_transform
+            else PREFILL_SCHEDULE
+        )
+    else:
+        sched = warp_schedule

Also applies to: 105-110

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/attention/mainloop_spec.py` around lines 21 - 22, The
helper that builds a prefill topology must use the transform warp schedule when
has_logits_transform=True: update the logic in the factory that constructs
PipelineTopology (the code paths using make_prefill_topology and
make_prefill_topology_transform and selecting a WarpSchedule) to choose
WarpSchedule.PREFILL_TRANSFORM_SCHEDULE (or the module-level
PREFILL_TRANSFORM_SCHEDULE constant) instead of PREFILL_SCHEDULE when
has_logits_transform is true; apply the same change to the other occurrence that
mirrors this logic so both prefill/topology-transform code paths use
PREFILL_TRANSFORM_SCHEDULE for transform variants.

pgera added 8 commits April 1, 2026 16:39
Add CuTe DSL-based attention implementation:
- flashinfer/cute_dsl/attention/ - Modular attention package with composable
  roles (loader, softmax, MMA, epilogue), fusion points (logits transform,
  mask, output transform), and schedulers
- flashinfer/cute_dsl/prefill.py - Batch prefill wrapper
- flashinfer/cute_dsl/mla.py - MLA decode wrapper
- flashinfer/cute_dsl/patch/pipeline.py - Pipeline patching utilities

Tests and benchmarks (named to avoid conflicts with existing cutlass tests):
- tests/test_blackwell_fmha_cutedsl.py - FMHA tests (prefill)
- tests/test_blackwell_fmha_attention.py - Modular attention package tests
- tests/test_blackwell_mla_attention.py - MLA attention tests
- tests/test_deepseek_mla_cutedsl.py - DeepSeek MLA tests
- benchmarks/bench_blackwell_attention_cutedsl.py - Attention benchmarks
- docs/cutedsl_fmha_architecture.md - Architecture documentation

Made-with: Cursor
- Delete flashinfer/cute_dsl/prefill.py and mla.py (replaced by
  the modular flashinfer/cute_dsl/attention/ package)
- Delete tests/test_blackwell_fmha_cutedsl.py and
  tests/test_deepseek_mla_cutedsl.py (replaced by
  test_blackwell_fmha_attention.py and test_blackwell_mla_attention.py)
- Revert benchmarks/bench_deepseek_mla.py to upstream version
- Split benchmarks into prefill and decode:
  bench_blackwell_attention_cutedsl.py (FMHA prefill)
  bench_blackwell_mla_cutedsl.py (MLA decode)

Made-with: Cursor
…rnels

Two kernel correctness bugs fixed:

1. PV1(end) accumulate flag: The final PV GEMM for stage 1 used hardcoded
   accumulate=True, causing stale TMEM data corruption when the KV loop
   didn't execute (kv_len <= tile_size with multi-Q-tile batches).
   Fix: use pv_whether_acc instead of True.

2. Causal mask trip count: get_masked_trip_count used ceil_div(M, N) which
   doesn't account for non-zero causal_offset shifting the diagonal across
   extra KV tiles. When kv_len != qo_len, some tiles needing masking were
   processed as unmasked, leaking unmasked scores into softmax.
   Fix: compute masked tile count from actual diagonal boundary positions.

Both fixes required threading seqlen_q through the mask functions and
passing causal_offset to apply_mask.

Test suite pruned to ~112 curated cases covering tile boundaries, GQA,
varlen, causal, output/logits transforms, and attention sink.

AI-assisted (Claude)

Made-with: Cursor
…ate domain conversion

The plan() method created a dummy sink tensor with hardcoded float16 dtype
for JIT compilation regardless of input dtype. When bfloat16 inputs were
used at runtime, the compiled kernel misinterpreted bf16 bits as fp16,
garbling sink values (causal row-0 error: 1.75 -> 0.004).

Also fix the sink_M_D_update test helper to properly convert the sink value
from scaled-logit space to raw-logit space by dividing by scale, and tighten
the sink test tolerance from atol=2.0 to atol=0.01.

AI-assisted (Claude)

Made-with: Cursor
…ve tests (AI-assisted)

Kernel fixes:
- Sliding window apply_mask: add missing left-bound check (|kv-q| > window)
  and seqlen_k bounds check
- Sliding window get_trip_count/get_kv_start_block_idx: compute correct
  symmetric window tile range instead of right-only approximation
- Softmax run(): add kv_start_offset to coordinate identity tensor so mask
  coordinates match actual KV positions loaded by the TMA loader

Test fixes:
- sink_M_D_update: add * scale to exp2 rescale terms for correctness
  (m is in RAW domain, exp2 needs domain conversion via * scale)
- Sink test: use SM_SCALE=1/sqrt(head_dim) instead of 1.0, which made the
  sink contribution negligible (~0) and the test vacuous

New test coverage:
- float16 dtype (3 shapes x 2 causal)
- Sliding window mask (4 window/shape combos)
- head_dim=64 (3 shapes x 2 causal)
- Variable-length + sigmoid logits transform (2 indptr patterns)
- Variable-length + attention sink (2 indptr patterns)
- Attention sink with MHA / num_kv_heads=32 (2 shapes x 2 causal)

All 118 tests pass, 18 skipped (qo>kv+causal), ~10 min runtime.

Made-with: Cursor
Strip out MLA decode kernel, config, warp schedule, roles, scheduler,
wrapper, benchmark, test, and architecture doc to keep this PR focused
on FMHA prefill only. Clean up MLA references in shared modules.

AI-assisted

Made-with: Cursor
The PipelineProducer/PipelineConsumer wrappers are now available in
cutlass.pipeline (nvidia-cutlass-dsl >= 4.3). Use them directly
instead of maintaining a local copy. Pipeline creation uses
defer_sync=True since the kernel handles barrier init/sync separately.

Verified: no perf regression (< 1% noise), all 118 tests pass.

AI-assisted

Made-with: Cursor
…ed tmem_utils.py

- Add missing @cute.jit decorator to get_trip_count for consistency
  with all other functions in mask.py
- Remove tmem_utils.py which was MLA-specific dead code after MLA removal

AI-assisted

Made-with: Cursor
pgera and others added 2 commits April 7, 2026 15:01
…uards (AI-assisted)

Review feedback from yzh119:
- Switch FMHA prefill compile path to TVM-FFI: use make_fake_compact_tensor
  for compile-time tracing and pass PyTorch tensors directly at runtime,
  matching the pattern used by the MLA decode kernel. This replaces the
  from_dlpack + .iterator approach which is incompatible with --enable-tvm-ffi.
- Change prefill.py __call__ from cute.Pointer to cute.Tensor parameters;
  the kernel extracts .iterator internally.
- Add --enable-tvm-ffi --opt-level 2 to cute.compile() and use
  make_fake_stream(use_tvm_ffi_env_stream=True) for stream handling.
- Add @flashinfer_api decorator to BatchPrefillCuteDSLWrapper __init__,
  plan, and run for crash-safe API logging.
- Add is_sm100a_supported() guard and enable_cupti=True in benchmark.
- Remove broad warnings.filterwarnings("ignore", category=UserWarning)
  from prefill.py, mla_decode.py, mla_decode_fp8.py; keep only the
  targeted cutlass loop-unroll message filter.
- Fix NamedBarrier.wait() -> arrive_and_wait() in MLA roles to use the
  correct API (they are identical at the hardware level, but wait()
  triggers a deprecation warning from cutlass-dsl).

Made-with: Cursor
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !457 has been updated with latest changes, and the CI pipeline #47962773 is currently running. I'll report back once the pipeline job completes.

pgera added 5 commits April 8, 2026 02:21
- Auto-fix 33 unused import warnings across attention package (ruff --fix)
- Auto-format 23 files to match ruff style
- Remove unused variable tOcO_custom_i in softmax.py
- Remove dead code qo_idx_offset in softmax.py
- Add strict=True to zip() in scheduler/persistent.py
- Add assert guards before .arrive_and_wait() calls on Optional barriers
  in mla_correction.py and mla_compute.py to satisfy mypy

Made-with: Cursor
Add Type[cutlass.Numeric] annotations to dtype attributes in mla_decode,
mla_decode_fp8, mla_compute, softmax, and correction roles — matching
the pattern from gemm_allreduce_two_shot.py. Add assert guards for
Optional attributes accessed in type-checked code paths. Add params
attribute to AttentionVariant base class. Use Union type for
MLAMainloopSpec warp_schedule to accept both FP16 and FP8 schedules.

Made-with: Cursor
…(AI-assisted)

- batch_mla.py: validate 4D kv_cache has shape[1]==1 before squeeze(1),
  and reject non-3D/4D inputs. Prevents silent wrong page_size detection.
- test_modular_fmha_prefill.py: document that sliding window with
  qo_len != kv_len is not yet supported (mask offset + coordinate
  identity tensor fixes needed).

Made-with: Cursor
The sliding window mask functions computed KV tile ranges and element
masks using raw Q tile indices, ignoring the Q/K sequence length offset.
When qo_len < kv_len (e.g. suffix-prefill, append-with-cache), the
window was centered on the wrong KV region.

Fix: add qk_offset = seqlen_k - seqlen_q to Q positions in
get_trip_count, get_kv_start_block_idx, and apply_mask, matching the
causal mask path which already uses this offset.

Add 4 test cases with qo_len != kv_len to SLIDING_WINDOW_PARAMS and
update the reference implementation to right-align Q positions to KV.

Made-with: Cursor
@pgera
Copy link
Copy Markdown
Author

pgera commented Apr 8, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@pgera is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

pgera added 4 commits April 8, 2026 05:08
Add crash-safe API logging decorator to __init__, plan, and run methods,
matching the pattern used by BatchPrefillCuteDSLWrapper and all other
FlashInfer wrapper classes.

Made-with: Cursor
- Add _validate_run_inputs() to BatchMLADecodeCuteDSLWrapper, checking
  dtype/device/shape consistency (matching prefill wrapper pattern)
- Replace bare assert with raise ValueError for D_qk and workspace size
  checks in MLA run()
- Add docstring to prefill plan() method
- Add comment explaining float_workspace_buffer naming convention

Made-with: Cursor
Replace full 25-line BSD-3-Clause license text with the abbreviated
2-line SPDX header in MLA files ported from the monolithic codebase.
Update copyright year from 2025-2026 to 2026 across all 11 files.

Made-with: Cursor
…sted)

Skip gracefully when nvidia-cutlass-dsl package is not installed,
matching the pattern used by test_cute_dsl_mla_decode.py.

Made-with: Cursor
@pgera pgera changed the title [CuTe DSL] Add modular FMHA prefill attention kernel [CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels Apr 8, 2026
@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Apr 9, 2026

cc @leejnau about cutedsl mla update

pgera added 3 commits April 9, 2026 13:50
…isted)

Wire AttentionVariant hooks (score_mod, update_statistics, transform_output)
through both FP16 and FP8 MLA decode kernels, unify compilation caching
under @functools.cache for prefill and MLA decode, and switch prefill to
symbolic tensor dimensions for cross-batch kernel reuse.

Key changes:
- mla_decode.py / mla_decode_fp8.py: accept fusion + params_in, thread
  params and cta_m_offset through compute and correction roles
- mla_compute.py: integrate score_mod and update_statistics hooks with
  post-score_mod re-masking to prevent -inf→finite leakage (SoftCapping)
- mla_correction.py: integrate transform_output hook (AttentionWithSink)
- softmax.py: re-apply mask after score_mod for prefill (same fix)
- batch_mla.py: merge _get_compiled_mla_kernel into single _compile_mla_kernel
  with variant/params_shape cache keys; add logits_transform validation
- batch_prefill.py: extract _get_compiled_prefill_kernel with symbolic dims
  and @functools.cache for batch-independent compilation
- variant.py: fix NaN in AttentionWithSink.update_statistics when split-KV
  CTAs don't own tile 0 (-inf - (-inf) = NaN)

Tests: ALiBi, SoftCapping, AttentionWithSink, RPE for BF16 + FP8 MLA decode,
plus SoftCapping regression tests for non-tile-aligned sequences.

Made-with: Cursor
…-assisted)

Pre-allocate a padded output scratch buffer in plan() instead of
allocating per run() call.  The kernel's TMA varlen addressing requires
front-padding on the output tensor (negative pointer offset trick shared
with the CUTLASS backend), so run() writes into the scratch buffer and
copies back to the caller's out tensor when provided.  This honors the
out= in-place contract that serving frameworks like vLLM rely on.

Other changes:
- Add jit_args guard for cute-dsl backend (skip unnecessary JIT compilation)
- Document why return_lse/FP8 scale checks must live at run() time
- Replace outdated standalone-script docstring in prefill.py
- Add test verifying the out= in-place contract

Made-with: Cursor
@pgera
Copy link
Copy Markdown
Author

pgera commented Apr 9, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !457 has been updated with latest changes, and the CI pipeline #48159085 is currently running. I'll report back once the pipeline job completes.

Add early NotImplementedError in BatchPrefillWithPagedKVCacheWrapper
when backend="cute-dsl" is passed, since paged KV cache is not yet
supported by the cute-dsl kernel. Also add the jit_args guard
(backend != "cute-dsl") for future-proofing when paged support is added.

Made-with: Cursor
Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved – internal CI fails on test_trtllm_fused_moe_autotuner_integration.py, which should be unrelated.

Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(removing approval, pending further internal discussion)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants