diff --git a/flash_attn/cute/AUTHORS b/flash_attn/cute/AUTHORS index bc3991c676d..055e75b6670 100644 --- a/flash_attn/cute/AUTHORS +++ b/flash_attn/cute/AUTHORS @@ -1,5 +1,8 @@ -Tri Dao, tri@tridao.me +Tri Dao Jay Shah Ted Zadouri Markus Hoehnerbach -Vijay Thakkar \ No newline at end of file +Vijay Thakkar +Timmy Liu +Driss Guessous +Reuben Stern \ No newline at end of file diff --git a/flash_attn/cute/MANIFEST.in b/flash_attn/cute/MANIFEST.in new file mode 100644 index 00000000000..329d71b317a --- /dev/null +++ b/flash_attn/cute/MANIFEST.in @@ -0,0 +1,5 @@ +global-exclude *.egg-info/* +prune flash_attn_4.egg-info +prune flash_attn.egg-info +prune build +prune dist diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md index 03f48654b51..653f7b1cee2 100644 --- a/flash_attn/cute/README.md +++ b/flash_attn/cute/README.md @@ -1,26 +1,32 @@ -# Flash Attention CUTE +# FlashAttention-4 (CuTeDSL) -## Development Installation +FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs. -1. Clone the repository (if you haven't already): - ```bash - git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/cute - ``` +## Installation -2. Install in editable mode with dev dependencies: - ```bash - pip install -e "./cute[dev]" - ``` +```sh +pip install flash-attn-4 +``` -## Running Tests +If you're on CUDA 13, install with the `cu13` extra for best performance: -```bash -pytest tests/cute/ +```sh +pip install "flash-attn-4[cu13]" ``` -## Linting +## Usage -```bash -ruff check flash_attn/cute/ +```python +from flash_attn.cute import flash_attn_func, flash_attn_varlen_func + +out = flash_attn_func(q, k, v, causal=True) +``` + +## Development + +```sh +git clone https://github.com/Dao-AILab/flash-attention.git +cd flash-attention +pip install -e "flash_attn/cute[dev]" +pytest tests/cute/ ``` diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py index fbbfc14050e..1b84363b63d 100644 --- a/flash_attn/cute/__init__.py +++ b/flash_attn/cute/__init__.py @@ -1,6 +1,11 @@ """Flash Attention CUTE (CUDA Template Engine) implementation.""" -__version__ = "0.1.0" +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("fa4") +except PackageNotFoundError: + __version__ = "0.0.0" import cutlass.cute as cute diff --git a/flash_attn/cute/bench_utils.py b/flash_attn/cute/bench_utils.py new file mode 100644 index 00000000000..45cbcf1af36 --- /dev/null +++ b/flash_attn/cute/bench_utils.py @@ -0,0 +1,196 @@ +"""Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation.""" + +import math +import torch + +try: + import cudnn +except ImportError: + cudnn = None + + +# ── FLOPS calculation ──────────────────────────────────────────────────────── + + +def flops( + batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) +): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device="cuda") + col_left = ( + torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) + if window_size[0] is not None + else torch.zeros_like(row_idx) + ) + col_right = ( + torch.minimum( + row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1) + ) + if window_size[1] is not None + else torch.full_like(row_idx, seqlen_k - 1) + ) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +# ── Reference attention ───────────────────────────────────────────────────── + +_attention_ref_mask_cache = {} + + +def attention_ref(q, k, v, causal=False): + """Standard attention reference implementation. + + Args: + q, k, v: (batch, seqlen, nheads, headdim) tensors. + causal: whether to apply causal mask. + """ + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + if causal: + if scores.shape[-2] not in _attention_ref_mask_cache: + mask = torch.tril( + torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0 + ) + _attention_ref_mask_cache[scores.shape[-2]] = mask + else: + mask = _attention_ref_mask_cache[scores.shape[-2]] + scores = scores.masked_fill(mask, float("-inf")) + attn = torch.softmax(scores, dim=-1) + return torch.einsum("bhts,bshd->bthd", attn, v) + + +# ── cuDNN graph helpers ───────────────────────────────────────────────────── + +_TORCH_TO_CUDNN_DTYPE = { + torch.float16: "HALF", + torch.bfloat16: "BFLOAT16", + torch.float32: "FLOAT", + torch.int32: "INT32", + torch.int64: "INT64", +} + + +def _build_cudnn_graph(io_dtype, tensors, build_fn): + """Build a cuDNN graph. Returns (graph, variant_pack, workspace).""" + assert cudnn is not None, "cuDNN is not available" + cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype]) + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()} + variant_pack = build_fn(graph, graph_tensors) + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + return graph, variant_pack, workspace + + +def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None): + """Build a cuDNN forward SDPA graph. + + Args: + q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout). + causal: whether to apply causal mask. + window_size_left: sliding window size (None for no window). + + Returns: + (fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable. + """ + b, nheads, seqlen_q, headdim = q.shape + headdim_v = v.shape[-1] + o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + + def build(graph, gt): + o, stats = graph.sdpa( + name="sdpa", + q=gt["q"], + k=gt["k"], + v=gt["v"], + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left + if window_size_left is not None and not causal + else None, + ) + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu} + + graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build) + + def fwd_fn(): + graph.execute(variant_pack, workspace) + return o_gpu + + return fwd_fn, o_gpu, stats_gpu + + +def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + """Build a cuDNN backward SDPA graph. + + Args: + q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout). + causal: whether to apply causal mask. + window_size_left: sliding window size (None for no window). + + Returns: + bwd_fn: zero-arg callable that returns (dq, dk, dv). + """ + headdim = q.shape[-1] + dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + + def build(graph, gt): + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=gt["q"], + k=gt["k"], + v=gt["v"], + o=gt["o"], + dO=gt["g"], + stats=gt["lse"], + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left + if window_size_left is not None and not causal + else None, + use_deterministic_algorithm=False, + ) + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + return { + gt["q"]: q, + gt["k"]: k, + gt["v"]: v, + gt["o"]: o, + gt["g"]: g, + gt["lse"]: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + graph, variant_pack, workspace = _build_cudnn_graph( + q.dtype, + {"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse}, + build, + ) + + def bwd_fn(): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return bwd_fn diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py index e2ff2ccc9ae..720778027b2 100644 --- a/flash_attn/cute/blackwell_helpers.py +++ b/flash_attn/cute/blackwell_helpers.py @@ -8,7 +8,6 @@ from cutlass._mlir.dialects import llvm import flash_attn.cute.mma_sm100_desc as sm100_desc -from flash_attn.cute.utils import parse_swizzle_from_pointer @cute.jit @@ -21,6 +20,7 @@ def gemm_w_idx( B_idx: Optional[Int32] = None, zero_init: bool | Boolean = False, swap_AB: bool = False, + num_unroll_groups: int = 1, ) -> None: if const_expr(swap_AB): return gemm_w_idx( @@ -29,8 +29,11 @@ def gemm_w_idx( else: rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) - for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + for k in cutlass.range( + cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups + ): mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) @@ -46,6 +49,7 @@ def gemm_ptx_w_idx( A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, zero_init: bool | Boolean = False, + cta_group: int = 1, **kwargs, ) -> None: rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] @@ -57,7 +61,15 @@ def gemm_ptx_w_idx( mma_atom = cute.make_mma_atom(tiled_mma.op) acc_tmem_addr = acc.iterator.toint() gemm_ptx_partial( - mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs + mma_atom.op, + acc_tmem_addr, + rA, + rB, + sA_cur, + sB_cur, + zero_init=zero_init, + cta_group=cta_group, + **kwargs, ) @@ -68,11 +80,11 @@ def gemm( tCrA: cute.Tensor, tCrB: cute.Tensor, zero_init: bool | Boolean = False, -) -> cute.TiledMma: +) -> None: + mma_atom = cute.make_mma_atom(tiled_mma.op) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): - tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) - cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) - return tiled_mma + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) def i64_to_i32x2(i: int) -> Tuple[int, int]: @@ -97,7 +109,7 @@ def gemm_ptx( sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): - sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), @@ -113,7 +125,7 @@ def gemm_ptx( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None - sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), @@ -212,7 +224,7 @@ def gemm_ptx_loop( sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): - sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), @@ -228,7 +240,7 @@ def gemm_ptx_loop( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None - sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), @@ -368,10 +380,12 @@ def gemm_ptx_partial( sB: cute.Tensor, mbar_ptr: Optional[cutlass.Pointer] = None, mbar_phase: Optional[Int32] = None, + split_arrive: Optional[int] = None, zero_init: bool | Boolean = False, # sA_offset: Int32 = 0, # acc_offset: Int32 = 0, tA_addr: Optional[Int32] = None, + cta_group: int = 1, ) -> None: # acc_tmem_addr += acc_offset is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM @@ -381,7 +395,7 @@ def gemm_ptx_partial( sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): - sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), @@ -397,7 +411,7 @@ def gemm_ptx_partial( else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None - sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), @@ -463,7 +477,7 @@ def gemm_ptx_partial( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -472,7 +486,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -496,6 +510,10 @@ def gemm_ptx_partial( ] if const_expr(mbar_ptr is not None): assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + assert split_arrive is not None, ( + "split_arrive must be provided when mbar_ptr is not None" + ) + split_arrive_idx = split_arrive // op.shape_mnk[2] input_args.append(mbar_ptr.toint().ir_value()) input_args.append(Int32(mbar_phase).ir_value()) mbar_wait_str = ( @@ -536,7 +554,7 @@ def gemm_ptx_partial( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -544,13 +562,11 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, - cute.size(tCrA.shape[2]) - if const_expr(mbar_ptr is None) - else cute.size(tCrA.shape[2]) // 4 * 3, + cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, ) ) + mbar_wait_str @@ -559,9 +575,9 @@ def gemm_ptx_partial( ( f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) - for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) + for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) ) if const_expr(mbar_ptr is not None) else "" @@ -751,3 +767,323 @@ def gemm_ptx_partial1( is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) + + +@cute.jit +def gemm_ptx_precomputed( + acc_tmem_addr: Int32, + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_start_b: Int32, + idesc: int, + smem_desc_base_a: Optional[int], + smem_desc_base_b: int, + tCrA_layout: cute.Layout, + tCrB_layout: cute.Layout, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + else: + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] + + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + # smem_desc_start_a_lo = smem_desc_start_a + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(num_k_tile // 4 * 3, num_k_tile) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_smem_desc( + smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A + smem_desc_base_a: Optional[int], + tCrA_layout: cute.Layout, + var_name_prefix: str = "smem_desc", +) -> None: + is_ts = const_expr(smem_desc_base_a is None) + num_k_tile = cute.size(tCrA_layout.shape[2]) + smem_desc_base_a_lo, smem_desc_a_hi = None, None + if const_expr(not is_ts): + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + tCrA_layout = ( + tCrA_layout + if const_expr(not is_ts) + # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) + # currently hard-coding the width to 16 + else cute.recast_layout(32, 16, tCrA_layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] + smem_desc_start_a_lo = None + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) + if const_expr(not is_ts): + llvm.inline_asm( + None, + [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], + f".reg .b32 {var_name_prefix}_lo;\n\t" + f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" + f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + + "".join( + ( + f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" + f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" + ) + for k in range(1, num_k_tile) + ), + "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: + idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) + llvm.inline_asm( + None, + [], + f".reg .b32 {var_name};\n\t" # noqa + f"mov.b32 {var_name}, {hex(idesc)};\n\t", + constraints="", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_precomputed_varname( + acc_tmem_addr: Int32, + smem_desc_start_b: Int32, + # idesc: int, + smem_desc_base_b: int, + tCrB_layout: cute.Layout, + smem_var_name_prefix: str, + idesc_var_name: str, + smem_offset: int, + zero_init: bool | Boolean = False, + cta_group: int = 1, +) -> None: + is_ts = False + num_k_tile = cute.size(tCrB_layout.shape[2]) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] + + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + # ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + # ".reg .b64 smem_desc_b;\n\t" + f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + # f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $2;\n\t" + "mov.b32 smem_desc_b_lo_start, $0;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + + "".join( + ( + f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + ) + for k in range(1, num_k_tile) + ) + + "setp.ne.b32 p, $1, 0;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + + "".join( + ( + # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" + # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" + # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" + # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + ) + for k in range(1, num_k_tile) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index be13e70f892..f21013891b4 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) @@ -25,8 +25,8 @@ def get_n_block_min_max( self, seqlen_info: SeqlenInfoQK, m_block: Int32, - split_idx: cutlass.Int32 = 0, - num_splits: cutlass.Int32 = 1, + split_idx: Int32 = 0, + num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): @@ -46,7 +46,7 @@ def get_n_block_min_max( n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) if cutlass.const_expr(self.is_split_kv): num_n_blocks_per_split = ( - cutlass.Int32(0) + Int32(0) if n_block_max <= n_block_min else (n_block_max - n_block_min + num_splits - 1) // num_splits ) @@ -70,6 +70,37 @@ def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tupl m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) return m_block_min, m_block_max + @cute.jit + def get_n_block_k_new_min_max( + self, + seqlen_info: SeqlenInfoQKNewK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + """Get the block range for new K tokens (append KV). + + First computes the full n_block range via get_n_block_min_max, then maps + those blocks into the new-K index space by subtracting seqlen_k_og. + """ + n_block_min, n_block_max = self.get_n_block_min_max( + seqlen_info, + m_block, + split_idx, + num_splits, + ) + idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0) + idx_k_new_max = cutlass.min( + n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new + ) + n_block_new_min = idx_k_new_min // self.tile_n + n_block_new_max = ( + cute.ceil_div(idx_k_new_max, self.tile_n) + if idx_k_new_max > idx_k_new_min + else n_block_new_min + ) + return n_block_new_min, n_block_new_max + @cute.jit def get_n_block_min_causal_local_mask( self, diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 396aa5e1f70..52cb7e06044 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -12,34 +12,82 @@ import cutlass.cute as cute from cutlass import Float32, Int32, const_expr +from quack import copy_utils + # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute import copy_utils from flash_attn.cute.named_barrier import NamedBarrierBwd +# NOTE [SM100 block-sparse empty tiles: mbarrier contract] +# +# For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active +# KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so +# the softmax warp-group has no row stats to publish. +# +# The correction warp-group seeds fully-masked-row stats and runs the usual correction +# epilogue so output/LSE have well-defined values. Both warp-groups must still perform +# the softmax<->correction mbarrier handshake so phases advance correctly across +# empty->empty and empty->non-empty tile sequences. +# +# In the no-sink case, this corresponds to the usual fully-masked-row convention: +# output is zero and LSE is -inf. +# +# Barrier contract (each is `mbar_ptr + + stage`): +# +# Producer/consumer pairs: +# - `mbar_softmax_corr_full` : softmax arrive -> correction wait +# - `mbar_softmax_corr_empty` : correction arrive -> softmax wait +# - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait +# - `mbar_P_full_2` : softmax arrive -> MMA wait +# - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate) +# +# Empty tile (`total_block_cnt == 0`): +# - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`). +# It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal. +# At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty` +# before each tile (when block-sparse) to drain a prior correction arrival and keep +# phases aligned across non-empty -> empty transitions. +# - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`, +# and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable). +# - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction +# (and correction<->epilogue) handshakes advance phases. +# +# Non-empty tile: +# - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to +# publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases; +# arrives `mbar_P_full_*` when P is stored. +# - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty` +# to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed. +# +# Backward (SM100): +# - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute. +# - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles +# skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward). +# - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros +# even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`). + + @cute.jit def load_block_list( block_indices: cute.Tensor, block_count, - load_q_with_first: cutlass.Constexpr, first_block_preloaded: cutlass.Constexpr, kv_producer_state, - load_Q, load_K, load_V, pipeline_k, pipeline_v, - use_tma_q: cutlass.Constexpr, - tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, ): - """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. - for the intra_wg_overlap case, we overlap the loads of K and V. And this + """Iterate over the sparse blocks and load K, V into the pipeline. + For the intra_wg_overlap case, we overlap the loads of K and V. And this means we need to pipeline the last V load from the partial block case, with the loads for the full blocks. Set first_block_preloaded when the caller has already issued the first K load for the list. + Q is loaded separately on its own mbarrier before this function is called. + Note: we iterate along the block_n indices in reverse. @@ -49,21 +97,7 @@ def load_block_list( """ if block_count > 0: if const_expr(not intra_wg_overlap): - # Peel first iteration: the first block may need to load Q alongside K, - # Parameters are already Constexpr, so no need to wrap in const_expr() - n_block_first = block_indices[block_count - 1] - extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 - pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) - - if const_expr(load_q_with_first and use_tma_q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - - load_K(src_idx=n_block_first, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_first, producer_state=kv_producer_state) - kv_producer_state.advance() - - for offset in cutlass.range(1, block_count): + for offset in cutlass.range(block_count): n_block = block_indices[block_count - 1 - offset] pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) @@ -73,14 +107,7 @@ def load_block_list( else: n_block_first = block_indices[block_count - 1] if const_expr(not first_block_preloaded): - extra_tx = ( - tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 - ) - pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) - - if const_expr(load_q_with_first and use_tma_q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - + pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_first, producer_state=kv_producer_state) for idx in cutlass.range(block_count - 1, unroll=1): @@ -136,19 +163,18 @@ def produce_block_sparse_loads( head_idx, m_block, kv_producer_state, - load_Q, load_K, load_V, pipeline_k, pipeline_v, - use_tma_q: cutlass.Constexpr, - tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Iterate over the mask and full block lists for a single tile. + Q is loaded separately on its own mbarrier before this function is called. + The masked (partial) list may leave the last V load pending when intra-warp-group overlap is enabled. The first full block must consume that pending V while issuing its own K load on the next pipeline stage. @@ -180,20 +206,16 @@ def produce_block_sparse_loads( full_empty = curr_full_block_cnt == 0 if mask_empty: - # No masked blocks: the full list owns the initial Q+K load. + # No masked blocks: the full list owns the initial K load. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -206,21 +228,16 @@ def produce_block_sparse_loads( kv_producer_state, ) else: - # Masked blocks present: load Q together with the first masked K so consumers can - # start immediately. When overlap is disabled this fully drains the list. + # Masked blocks present. When overlap is disabled this fully drains the list. kv_producer_state = load_block_list( curr_mask_block_idx, curr_mask_block_cnt, - load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -249,16 +266,12 @@ def produce_block_sparse_loads( kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=False, first_block_preloaded=True, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -270,21 +283,16 @@ def produce_block_sparse_loads( kv_producer_state, ) else: - # Non-overlap path with both lists: run the full list normally (skipping the Q - # reload because the masked list already issued it). + # Non-overlap path with both lists: run the full list normally. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=False, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -480,7 +488,6 @@ def load_block_list_sm100( block_indices: cute.Tensor, block_count, load_q_with_first: cutlass.Constexpr, - m_block, q_stage: cutlass.Constexpr, kv_producer_state, load_Q, @@ -495,9 +502,9 @@ def load_block_list_sm100( if const_expr(load_q_with_first): # SM100 loads Q0 and optionally Q1 - load_Q(block=q_stage * m_block + 0, stage=0) + load_Q(block=0, stage=0) if const_expr(q_stage == 2): - load_Q(block=q_stage * m_block + 1, stage=1) + load_Q(block=1, stage=1) # SM100 doesn't use producer_acquire for pipeline_kv in load path # The pipeline barriers are handled inside load_KV @@ -568,7 +575,6 @@ def produce_block_sparse_loads_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=True, - m_block=m_block, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, @@ -583,7 +589,6 @@ def produce_block_sparse_loads_sm100( curr_mask_block_idx, curr_mask_block_cnt, load_q_with_first=True, - m_block=m_block, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, @@ -599,7 +604,6 @@ def produce_block_sparse_loads_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=False, - m_block=m_block, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, @@ -654,16 +658,12 @@ def handle_block_sparse_empty_tile_correction_sm100( stats: list, correction_epilogue: Callable, thr_mma_pv: cute.core.ThrMma, - tOtOs: tuple[cute.Tensor], + tOtO: cute.Tensor, sO: cute.Tensor, - mbar_ptr, - mbar_softmax_corr_full_offset: Int32, - mbar_softmax_corr_empty_offset: Int32, - mbar_P_full_O_rescaled_offset: Int32, - mbar_P_full_2_offset: Int32, - mbar_corr_epi_full_offset: Int32, - mbar_corr_epi_empty_offset: Int32, - softmax_corr_consumer_phase: Int32, + pipeline_sm_stats: cutlass.pipeline.PipelineAsync, + sm_stats_barrier: cutlass.pipeline.NamedBarrier, + pipeline_o_epi: cutlass.pipeline.PipelineAsync, + sm_stats_consumer_phase: Int32, o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, @@ -671,12 +671,23 @@ def handle_block_sparse_empty_tile_correction_sm100( gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): - """Handle the block-sparse case where a tile is fully masked: - * zero staged results - * seed stats - * satisfy the usual barrier protocol so downstream warps continue to make progress. + """Handle SM100 forward block-sparse tiles with no active KV blocks. + + This path is taken when `total_block_cnt == 0`. The softmax warp-group still + arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction + warp-group can: + + - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE + - run `correction_epilogue` with `scale=0` so the output tile is written as zeros + (independent of any prior tmem contents) + - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty` + (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles + + This helper intentionally does not touch `mbar_P_full_*` since no P is produced. + See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. """ LOG2_E = Float32(math.log2(math.e)) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 for stage in cutlass.range_constexpr(q_stage): row_sum_value = Float32(1.0) @@ -708,20 +719,16 @@ def handle_block_sparse_empty_tile_correction_sm100( acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value stats[stage] = (row_sum_value, row_max_value, acc_flag) - cute.arch.mbarrier_wait( - mbar_ptr + mbar_softmax_corr_full_offset + stage, - softmax_corr_consumer_phase, - ) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) + # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. + # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) + sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) + pipeline_sm_stats.consumer_release_w_index(stage) if const_expr(gmem_tiled_copy_O is None): - cute.arch.mbarrier_wait( - mbar_ptr + mbar_corr_epi_empty_offset + stage, - corr_epi_producer_phase, - ) + pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) correction_epilogue( thr_mma_pv, - tOtOs[stage], + tOtO[None, None, None, stage], tidx, stage, m_block, @@ -729,20 +736,17 @@ def handle_block_sparse_empty_tile_correction_sm100( Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, - gO, + gO[None, None, stage], gmem_tiled_copy_O, ) if const_expr(gmem_tiled_copy_O is None): - cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) + pipeline_o_epi.producer_commit_w_index(stage) - softmax_corr_consumer_phase ^= 1 - o_corr_consumer_phase ^= 1 + sm_stats_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 return ( - softmax_corr_consumer_phase, + sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, ) @@ -760,17 +764,15 @@ def softmax_block_sparse_sm100( mma_si_consumer_phase: Int32, si_corr_producer_phase: Int32, s0_s1_sequence_phase: Int32, - mbar_ptr, - mbar_softmax_corr_full_offset: Int32, - mbar_softmax_corr_empty_offset: Int32, - mbar_P_full_O_rescaled_offset: Int32, - mbar_P_full_2_offset: Int32, + pipeline_sm_stats: cutlass.pipeline.PipelineAsync, + sm_stats_barrier: cutlass.pipeline.NamedBarrier, q_stage: cutlass.Constexpr, stage_idx: Int32, check_m_boundary: bool, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr[int] = 1, ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors @@ -788,10 +790,9 @@ def softmax_block_sparse_sm100( total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: - cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx) + # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. + # pipeline_sm_stats.producer_commit_w_index(stage_idx) + sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] @@ -1347,12 +1348,18 @@ def _store_one_dQaccum_sm90( m_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """Store dQaccum for a single m_block.""" - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): + cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, @@ -1360,16 +1367,10 @@ def _store_one_dQaccum_sm90( with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, - gdQaccum[None, warp_group_idx, m_block].iterator, + gdQaccum[(None, warp_group_idx), m_block].iterator, tma_copy_bytes_dQ, ) cute.arch.cp_async_bulk_commit_group() - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) @cute.jit @@ -1382,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90( gdQaccum: cute.Tensor, subtile_factor: cutlass.Constexpr, m_block_max: int, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): @@ -1411,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) @@ -1427,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index f19c8fb7f05..3fad8c9f491 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple): block_size: tuple[int, int] | None = None +def get_sparse_q_block_size( + tensors: BlockSparseTensorsTorch | None, + seqlen_q: int, +) -> int | None: + """Return the Q sparse block size, or None when sparsity is unset or ambiguous.""" + if tensors is None: + return None + if tensors.block_size is not None: + return tensors.block_size[0] + num_m_blocks = tensors.mask_block_idx.shape[2] + min_block_size = ceildiv(seqlen_q, num_m_blocks) + max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1) + if min_block_size != max_block_size: + return None + return min_block_size + + def _expand_sparsity_tensor( tensor: torch.Tensor, expected_shape: Tuple[int, ...], @@ -81,6 +98,12 @@ def _check_and_expand_block( expanded_cnt = _expand_sparsity_tensor( cnt, expected_count_shape, f"{name}_block_cnt", context, hint ) + # [Note] Allow Compact block sparse indices + # Allow the last dimension (n_blocks) of idx to be <= expected, since + # FA4 only accesses indices 0..cnt-1 per query tile. This enables compact + # index tensors that avoid O(N^2) memory at long sequence lengths. + if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]: + expected_index_shape = (*expected_index_shape[:3], idx.shape[3]) expanded_idx = _expand_sparsity_tensor( idx, expected_index_shape, f"{name}_block_idx", context, hint ) @@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes( num_m_blocks = tensors.mask_block_idx.shape[2] if sparse_block_size_q is None: - min_block_size = ceildiv(seqlen_q, num_m_blocks) - if num_m_blocks == 1: - max_block_size = seqlen_q - else: - max_block_size = (seqlen_q - 1) // (num_m_blocks - 1) - if max_block_size != min_block_size and base_m_block != 1: + sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q) + if sparse_block_size_q is None and base_m_block != 1: raise ValueError( f"Block sparse tensors{context} require explicit sparse_block_size[0] " f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}." ) - sparse_block_size_q = min_block_size + if sparse_block_size_q is None: + sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks) if sparse_block_size_q % base_m_block != 0: raise ValueError( @@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes( raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") if mask_block_cnt.shape[2] != mask_block_idx.shape[2]: raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.") - if mask_block_idx.shape[3] != expected_n_blocks: + # [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1 + # per query tile, so idx.shape[3] can be <= expected_n_blocks. + if mask_block_idx.shape[3] > expected_n_blocks: raise ValueError( - f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}." + f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}." ) if expected_m_blocks != num_m_blocks: raise ValueError( @@ -314,7 +336,7 @@ def normalize_block_sparse_config( ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: m_block_size, n_block_size = block_size if tensors.block_size is None: - sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size + sparse_block_size_q, sparse_block_size_kv = None, n_block_size else: sparse_block_size_q, sparse_block_size_kv = tensors.block_size if sparse_block_size_kv != n_block_size: @@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None + ( mask_block_cnt, mask_block_idx, diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py new file mode 100644 index 00000000000..3fca0579d98 --- /dev/null +++ b/flash_attn/cute/cache_utils.py @@ -0,0 +1,292 @@ +# Manage Ahead-of-Time (AOT) compiled kernels +import fcntl +import hashlib +import logging +import os +import pickle +import sys +import tempfile +import time +from functools import lru_cache +from getpass import getuser +from pathlib import Path +from typing import Hashable, TypeAlias + +import ctypes + +import cutlass +import cutlass.cute as cute +import tvm_ffi +from cutlass.cutlass_dsl import JitCompiledFunction + +# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols +# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. +# Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes +# "undefined symbol" errors when loading cached kernels from disk. +for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False): + if Path(_lib_path).exists(): + ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL) + +CompileKeyType: TypeAlias = tuple[Hashable, ...] +CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function + +logger = logging.getLogger(__name__) +_handler = logging.StreamHandler() +_handler.setFormatter( + logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) +) +logger.addHandler(_handler) +logger.setLevel(logging.DEBUG) + + +# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` +CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" + + +# Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is +# `/tmp/${USER}/flash_attention_cute_dsl_cache`` +CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None) + + +def get_cache_path() -> Path: + if CUTE_DSL_CACHE_DIR is not None: + cache_dir = Path(CUTE_DSL_CACHE_DIR) + else: + cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +@lru_cache(maxsize=1) +def _compute_source_fingerprint() -> str: + """ + Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint. + + The fingerprint changes whenever: + - Any .py file under flash_attn/cute is added, removed, renamed, or modified. + - The Python minor version changes (e.g. 3.13 -> 3.14). + - The cutlass or tvm_ffi package version changes. + + Computed once per process and cached. + """ + cute_root = Path(__file__).resolve().parent + h = hashlib.sha256() + + h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode()) + h.update(f"cutlass={cutlass.__version__}".encode()) + h.update(f"tvm_ffi={tvm_ffi.__version__}".encode()) + + for src in sorted(cute_root.rglob("*.py")): + if not src.is_file(): + continue + h.update(src.relative_to(cute_root).as_posix().encode()) + content = src.read_bytes() + h.update(len(content).to_bytes(8, "little")) + h.update(content) + + return h.hexdigest() + + +class FileLock: + """Context manager for advisory file locks using fcntl.flock. + + Supports exclusive (write) and shared (read) locks. + Always blocks with polling until the lock is acquired or timeout is reached. + + Usage: + with FileLock(lock_path, exclusive=True, timeout=15, label="abc"): + # do work under lock + """ + + def __init__( + self, + lock_path: Path, + exclusive: bool, + timeout: float = 15, + label: str = "", + ): + """ + Args: + lock_path: Path to the lock file on disk. + exclusive: True for exclusive (write) lock, False for shared (read) lock. + timeout: Max seconds to wait for lock acquisition before raising RuntimeError. + label: Optional human-readable label for error messages. + """ + self.lock_path: Path = lock_path + self.exclusive: bool = exclusive + self.timeout: float = timeout + self.label: str = label + self._fd: int = -1 + + @property + def _lock_label(self) -> str: + kind = "exclusive" if self.exclusive else "shared" + return f"{kind} {self.label}" if self.label else kind + + def __enter__(self) -> "FileLock": + open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT + lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH + + self._fd = os.open(str(self.lock_path), open_flags) + + deadline = time.monotonic() + self.timeout + acquired = False + while time.monotonic() < deadline: + try: + fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB) + acquired = True + break + except OSError: + time.sleep(0.1) + if not acquired: + os.close(self._fd) + self._fd = None + raise RuntimeError( + f"Timed out after {self.timeout}s waiting for " + f"{self._lock_label} lock: {self.lock_path}" + ) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self._fd is not None: + fcntl.flock(self._fd, fcntl.LOCK_UN) + os.close(self._fd) + self._fd = None + + +class JITCache: + """ + In-memory cache for compiled functions. + """ + + def __init__(self): + self.cache: dict[CompileKeyType, CallableFunction] = {} + + def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: + self.cache[key] = fn + + def __getitem__(self, key: CompileKeyType) -> CallableFunction: + return self.cache[key] + + def __contains__(self, key: CompileKeyType) -> bool: + return key in self.cache + + def clear(self) -> None: + """ + Clear in-memory cache of compiled functions + """ + self.cache.clear() + + +class JITPersistentCache(JITCache): + """ + In-memory cache for compiled functions, which is also backed by persistent storage. + Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True + """ + + EXPORT_FUNCTION_PREFIX = "func" + LOCK_TIMEOUT_SECONDS = 15 + + def __init__(self, cache_path: Path): + super().__init__() + cache_path.mkdir(parents=True, exist_ok=True) + self.cache_path: Path = cache_path + + def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: + JITCache.__setitem__(self, key, fn) + self._try_export_to_storage(key, fn) + + def __getitem__(self, key: CompileKeyType) -> CallableFunction: + # Use __contains__ to try populating in-memory cache with persistent storage + self.__contains__(key) + return JITCache.__getitem__(self, key) + + def __contains__(self, key: CompileKeyType) -> bool: + # Checks in-memory cache first, then tries loading from storage. + # When returning True, guarantees the in-memory cache is populated. + if JITCache.__contains__(self, key): + return True + return self._try_load_from_storage(key) + + def _try_load_from_storage(self, key: CompileKeyType) -> bool: + """ + Try to load a function from persistent storage into in-memory cache. + Returns True if loaded successfully, False if not found on disk. + Holds a shared lock during loading to prevent concurrent writes. + """ + sha256_hex = self._key_to_hash(key) + obj_path = self.cache_path / f"{sha256_hex}.o" + with FileLock( + self._lock_path(sha256_hex), + exclusive=False, + timeout=self.LOCK_TIMEOUT_SECONDS, + label=sha256_hex, + ): + if obj_path.exists(): + logger.debug("Loading compiled function from disk: %s", obj_path) + m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) + fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) + JITCache.__setitem__(self, key, fn) + return True + else: + logger.debug("Cache miss on disk for key hash %s", sha256_hex) + return False + + def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: + """Export a compiled function to persistent storage under exclusive lock.""" + sha256_hex = self._key_to_hash(key) + with FileLock( + self._lock_path(sha256_hex), + exclusive=True, + timeout=self.LOCK_TIMEOUT_SECONDS, + label=sha256_hex, + ): + obj_path = self.cache_path / f"{sha256_hex}.o" + if obj_path.exists(): + # Another process already exported. + logger.debug("Skipping export, already on disk: %s", obj_path) + return + logger.debug("Exporting compiled function to disk: %s", obj_path) + fn.export_to_c( + object_file_path=str(obj_path), + function_name=self.EXPORT_FUNCTION_PREFIX, + ) + logger.debug("Successfully exported compiled function to disk: %s", obj_path) + + def _key_to_hash(self, key: CompileKeyType) -> str: + return hashlib.sha256(pickle.dumps(key)).hexdigest() + + def _lock_path(self, sha256_hex: str) -> Path: + return self.cache_path / f"{sha256_hex}.lock" + + def clear(self) -> None: + """ + Not only clear the in-memory cache. Also purge persistent compilation cache. + """ + logger.debug("Clearing persistent cache at %s", self.cache_path) + super().clear() + for child in self.cache_path.iterdir(): + child.unlink() + + +def get_jit_cache(name: str | None = None) -> JITCache: + """ + JIT cache factory. + `name` is an optional identifier to create subdirectories to manage cache. + + When persistent caching is enabled, artifacts are namespaced under a + source fingerprint directory so that code or dependency changes + automatically invalidate stale entries. + """ + if CUTE_DSL_CACHE_ENABLED: + path = get_cache_path() / _compute_source_fingerprint() + if name: + path = path / name + logger.debug("Creating persistent JIT cache at %s", path) + return JITPersistentCache(path) + else: + logger.debug("Persistent cache disabled, using in-memory JIT cache") + return JITCache() diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py index cfdcbdb80a0..d8c6083c8cc 100644 --- a/flash_attn/cute/copy_utils.py +++ b/flash_attn/cute/copy_utils.py @@ -207,6 +207,38 @@ def store_shared_remote_fp32x4( ) +@dsl_user_op +def cpasync_bulk_s2cluster( + smem_src_ptr: cute.Pointer, + smem_dst_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + size: int | Int32, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +): + smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() + smem_dst_ptr_i32 = set_block_rank( + smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [ + smem_dst_ptr_i32, + smem_src_ptr_i32, + mbar_ptr_i32, + Int32(size).ir_value(loc=loc, ip=ip), + ], + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + @dsl_user_op def cpasync_bulk_g2s( gmem_ptr: cute.Pointer, diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 9d2f7aa739b..79ebd9df6cf 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -4,7 +4,6 @@ import pathlib from typing import Tuple from functools import partial, lru_cache -from dataclasses import dataclass, fields import torch @@ -15,7 +14,6 @@ import cutlass import cutlass.cute as cute -from cutlass.base_dsl.typing import JitArgument from cutlass.cutlass_dsl import NumericMeta from cutlass.cute.runtime import from_dlpack @@ -43,66 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) -@dataclass -class ParamsBase: - def __extract_mlir_values__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - values, self._values_pos = [], [] - for obj in non_constexpr_fields: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values - - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) - - -@dataclass -class ArgumentsBase(JitArgument): - def __c_pointers__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - c_ptrs = [] - for obj in non_constexpr_fields: - if hasattr(obj, "__c_pointers__"): - c_ptrs.extend(obj.__c_pointers__()) - return c_ptrs - - def __get_mlir_types__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - types, self._values_pos = [], [] - for obj in non_constexpr_fields: - if hasattr(obj, "__get_mlir_types__"): - obj_types = obj.__get_mlir_types__() - types.extend(obj_types) - self._values_pos.append(len(obj_types)) - else: - self._values_pos.append(0) - return types - - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) - - def load_cubin_module_data_patched(cubin_data, filepath): pathlib.Path(filepath).write_bytes(cubin_data) return load_cubin_module_data_og(cubin_data) @@ -152,6 +90,35 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena return tensor.mark_layout_dynamic(leading_dim=leading_dim) +def to_cute_aux_tensor(t, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. + This allows the user to specify alignment and leading dimension for aux tensors used in + custom score_mod callables. + """ + assumed_align: int = getattr(t, "__assumed_align__", None) + leading_dim: int = getattr(t, "__leading_dim__", None) + fully_dynamic: bool = leading_dim is None + + return to_cute_tensor( + t, + assumed_align=assumed_align, + leading_dim=leading_dim, + fully_dynamic=fully_dynamic, + enable_tvm_ffi=enable_tvm_ffi, + ) + + +def get_aux_tensor_metadata(aux_tensors): + return tuple( + ( + getattr(t, "__assumed_align__", 0), + getattr(t, "__leading_dim__", -1), + hasattr(t, "__leading_dim__"), + ) + for t in aux_tensors + ) + + def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: """Return tuple of bools indicating which dims have stride=0 (broadcast). diff --git a/flash_attn/cute/fa_logging.py b/flash_attn/cute/fa_logging.py new file mode 100644 index 00000000000..63189cd5d65 --- /dev/null +++ b/flash_attn/cute/fa_logging.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Tri Dao. + +"""Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var. + +Host-side messages go through Python ``logging`` (logger name ``flash_attn``). +A default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1`` +so that standalone scripts get output without extra setup; applications that +configure their own logging can remove or replace it via the standard API. + +FA_LOG_LEVEL mapping:: + + 0 off nothing logged + 1 host host-side summaries only (no kernel printf) + 2 kernel host + curated kernel traces + 3 max host + all kernel traces (noisy, perf hit) + +Set via environment variable:: + + FA_LOG_LEVEL=1 python train.py + +Device-side ``cute.printf`` calls are compile-time eliminated via +``cutlass.const_expr`` when the log level is below the callsite threshold, +so there is zero performance cost when device logging is off. +Changing the log level after kernel compilation requires a recompile +(the level participates in the forward compile key). +""" + +import logging +import os +import sys + +import cutlass.cute as cute +from cutlass import const_expr + +_LOG_LEVEL_NAMES = {"off": 0, "host": 1, "kernel": 2, "max": 3} + + +def _parse_log_level(raw: str) -> int: + if raw in _LOG_LEVEL_NAMES: + return _LOG_LEVEL_NAMES[raw] + try: + level = int(raw) + except ValueError: + return 0 + return max(0, min(level, 3)) + + +_fa_log_level: int = _parse_log_level(os.environ.get("FA_LOG_LEVEL", "0")) + +_logger = logging.getLogger("flash_attn") +_logger.addHandler(logging.NullHandler()) +_default_handler: logging.Handler | None = None + + +def _configure_default_handler() -> None: + global _default_handler + if _fa_log_level >= 1: + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.setFormatter(logging.Formatter("[FA] %(message)s")) + _logger.addHandler(_default_handler) + _logger.setLevel(logging.DEBUG) + else: + if _default_handler is not None: + _logger.removeHandler(_default_handler) + _default_handler = None + _logger.setLevel(logging.WARNING) + + +_configure_default_handler() + + +def get_fa_log_level() -> int: + return _fa_log_level + + +def set_fa_log_level(level: int | str) -> None: + """Set the FA log level programmatically. + + Host logging takes effect immediately. Device logging changes only + affect kernels compiled after this call (new compile-key selection). + """ + global _fa_log_level + if isinstance(level, str): + level = _parse_log_level(level) + _fa_log_level = max(0, min(int(level), 3)) + _configure_default_handler() + + +def fa_log(level: int, msg: str): + if _fa_log_level >= level: + _logger.info(msg) + + +def fa_printf(level: int, fmt, *args): + if const_expr(_fa_log_level >= level): + cute.printf(fmt, *args) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 71f07e79edb..824abdda139 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -20,7 +20,9 @@ from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from quack.cute_dsl_utils import ParamsBase +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: @@ -371,7 +373,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: cutlass.Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -380,8 +381,16 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): - assert mdQ_semaphore is None, "semaphore not supported yet" + assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( + "determinism not supported yet for Sm80" + ) # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) @@ -511,7 +520,17 @@ def kernel( n_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: - seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + seqlen = SeqlenInfoQK.create( + batch_idx, + mQ.shape[1], + mK.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.m_block_size, + tile_n=self.n_block_size, + ) m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) m_block_min = 0 @@ -537,7 +556,7 @@ def kernel( mdPsum_cur = mdPsum[batch_idx, head_idx, None] mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] else: - padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + padded_offset_q = seqlen.padded_offset_q mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) @@ -793,9 +812,10 @@ def kernel( # Mainloop # /////////////////////////////////////////////////////////////////////////////// # Start processing of the first n-block. - mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen) mask_fn = partial( mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + batch_idx=batch_idx, head_idx=head_idx, mask_seqlen=True, mask_causal=self.is_causal ) smem_pipe_read_q = cutlass.Int32(0) @@ -967,7 +987,7 @@ def dQ_mma(hook_fn): # MMA dK if cutlass.const_expr(self.Mma_dKV_is_RS): - tdVrP = layout_utils.reshape_acc_to_frgA(rdS) + tdKrdS = layout_utils.reshape_acc_to_frgA(rdS) else: tdKrdS = mma_params.tdKrdS sm80_utils.gemm( diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 4567875519c..76c856221c5 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -2,7 +2,7 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math -from typing import Callable, Optional, Type, Literal +from typing import Callable, Optional, Type import cuda.bindings.driver as cuda @@ -14,17 +14,17 @@ from cutlass import Float32, const_expr from cutlass.utils import LayoutEnum +from quack import copy_utils from quack import layout_utils from quack import sm90_utils from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( - ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, @@ -36,11 +36,13 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, - arch: Literal[80, 90, 100], + arch: int, tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, dQ_swapAB: bool = False, + use_2cta_instrs: bool = False, + cluster_size: int = 1, # for varlen offsets ): """ :param head_dim: head dimension @@ -50,8 +52,8 @@ def __init__( """ self.dtype = dtype self.tile_m = tile_m - assert arch in [80, 90, 100], ( - "Only Ampere (80), Hopper (90), and Blackwell (100) are supported" + assert arch // 10 in [8, 9, 10, 11, 12], ( + "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported" ) self.arch = arch # padding head_dim to a multiple of 32 as k_block_size @@ -61,6 +63,8 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB + self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 + self.cluster_size = cluster_size @staticmethod def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: @@ -85,7 +89,7 @@ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: return True def _get_tiled_mma(self): - if const_expr(self.arch == 80): + if const_expr(self.arch // 10 in [8, 12]): num_mma_warps = self.num_threads // 32 atom_layout_dQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) @@ -97,9 +101,9 @@ def _get_tiled_mma(self): atom_layout_dQ, permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), ) - elif const_expr(self.arch == 90): - num_mma_warp_groups = self.num_threads // 128 - atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) + elif const_expr(self.arch // 10 == 9): + num_wg_mma = self.num_threads // 128 + atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -121,7 +125,7 @@ def _get_tiled_mma(self): cta_group, (self.tile_m, self.tile_hdim), ) - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): assert self.num_threads == tiled_mma.size return tiled_mma @@ -144,22 +148,22 @@ def _setup_attributes(self): cute.make_layout(self.num_threads), cute.make_layout(async_copy_elems_accum), ) - num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 - if const_expr(self.arch == 80): + num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4 + if const_expr(self.arch // 10 in [8, 12]): self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_s2r_copy_elems ) self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) - elif const_expr(self.arch == 90): + elif const_expr(self.arch // 10 == 9): num_threads_per_warp_group = 128 - num_mma_warp_groups = self.num_threads // 128 + num_wg_mma = self.num_threads // 128 self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), - cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout + cute.make_layout((num_threads_per_warp_group, num_wg_mma)), # thr_layout cute.make_layout(128 // Float32.width), # val_layout ) self.sdQaccum_layout = cute.make_layout( - (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + (self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma) ) else: self.dQ_reduce_ncol = 32 @@ -172,8 +176,10 @@ def _setup_attributes(self): (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage) ) + num_copy_elems = 128 // self.dtype.width + threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( - self.dtype, self.tile_hdim, self.num_threads + self.dtype, threads_per_row, self.num_threads, num_copy_elems ) # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQ @@ -182,14 +188,18 @@ def _setup_attributes(self): # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. mma_shape_n = self.tiled_mma.get_tile_size(1) - if const_expr(self.arch == 80): + if const_expr(self.arch // 10 in [8, 12]): sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) self.sdQ_layout = cute.tile_to_shape( sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) ) - elif const_expr(self.arch == 90): + elif const_expr(self.arch // 10 == 9): + wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ self.sdQ_layout = sm90_utils.make_smem_layout( - self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_m, self.tile_hdim), + major_mode_size=self.tile_hdim // wg_d_dQ, ) else: # TODO: this is hard-coded for hdim 128 @@ -205,7 +215,8 @@ def __call__( scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], - stream: cuda.CUstream, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): @@ -299,7 +310,7 @@ def kernel( smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) else: # extra stage dimension @@ -330,15 +341,14 @@ def kernel( mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None, + tile_m=self.tile_m * self.cluster_size, ) if const_expr(not seqlen.has_cu_seqlens_q): mdQ_cur = mdQ[batch_idx, None, head_idx, None] mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: - padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m - if cutlass.const_expr(self.arch >= 90): - padded_offset_q = padded_offset_q // self.tile_m * self.tile_m + padded_offset_q = seqlen.padded_offset_q mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] @@ -363,78 +373,197 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - # Step 1: load dQaccum from gmem to smem - g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) - tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) - tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) - cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) - cute.arch.cp_async_commit_group() - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() - - # Step 2: load dQ from smem to rmem - s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) - tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) - tile_shape = (self.tile_m, self.tile_hdim) - acc = None - tiled_copy_t2r = None - if const_expr(self.arch in [80, 90]): - acc_shape = tiled_mma.partition_shape_C( - tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] - ) - acc = cute.make_fragment(acc_shape, cutlass.Float32) - assert cute.size(acc) == cute.size(tdQsdQaccum) - else: - thr_mma = tiled_mma.get_slice(0) # 1-CTA - dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) - tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) - tdQcdQ = thr_mma.partition_C( - cute.make_identity_tensor((self.tile_m, self.tile_hdim)) - ) + if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): + # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ + num_reduce_threads = self.num_threads + thr_mma_dsk = tiled_mma.get_slice(tidx) + dQacc_shape = thr_mma_dsk.partition_shape_C((self.tile_m, self.tile_hdim)) + tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout) + tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 ) - tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) - thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) - tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape - acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) - tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) - cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) - # Convert tdQrdQaccum from fp32 to fp16/bf16 - rdQ = cute.make_fragment_like(acc, self.dtype) - rdQ.store((acc.load() * scale).to(self.dtype)) - - # Step 3: Copy dQ from register to smem - cute.arch.barrier() # make sure all threads have finished loading dQaccum - if const_expr(self.arch in [80, 90]): - copy_atom_r2s_dQ = utils.get_smem_store_atom( - self.arch, self.dtype, transpose=self.dQ_swapAB + tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) + thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) + + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + tdQcdQ = thr_mma_dsk.partition_C(cdQ) + tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) + tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) + + tiled_copy_accum = s2r_tiled_copy_dQaccum + g2s_thr_copy = tiled_copy_accum.get_slice(tidx) + + # S -> R + tdQrdQ_fp32 = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) + tdQrdQ_s2r = cute.make_tensor(tdQrdQ_fp32.iterator, tdQrdQ_fp32.shape) + + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld ) - tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) - else: - # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( - # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, - # ) - # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) - thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads - val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) - copy_atom_r2s_dQ = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.dtype, - num_bits_per_copy=128, + r2s_tiled_copy = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_ld.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld.tiler_mn, ) - tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( - copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ + tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) + tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) + + num_stages = cute.size(tdQrdQ_fp32, mode=[1]) + stage_stride = self.dQ_reduce_ncol + row_groups = 2 + assert num_stages % row_groups == 0 + assert num_reduce_threads % row_groups == 0 + stage_groups = num_stages // row_groups + threads_per_row_group = num_reduce_threads // row_groups + stage_loads = tuple((row_group, row_group) for row_group in range(row_groups)) + stage_iters = tuple( + (row_group, row_group * threads_per_row_group) + for row_group in range(row_groups) ) - thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) - cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) - if const_expr(self.arch in [80, 90]): - taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) + s2r_lane = tidx % threads_per_row_group + s2r_buf = tidx // threads_per_row_group + + gdQaccum_layout_g2s = cute.make_layout( + shape=(self.tile_m * self.dQ_reduce_ncol, 1), stride=(1, 0) + ) + sdQaccum_g2s = g2s_thr_copy.partition_D(sdQaccum) + + # G -> S + for stage_group in cutlass.range_constexpr(stage_groups): + for stage_offset, smem_buf in stage_loads: + stage_idx = stage_group + stage_offset * stage_groups + gdQaccum_stage = cute.local_tile( + gdQaccum, + (self.tile_m * self.dQ_reduce_ncol,), + (stage_idx,), + ) + gdQaccum_stage_g2s = cute.make_tensor( + gdQaccum_stage.iterator, + gdQaccum_layout_g2s, + ) + tdQgdQ = g2s_thr_copy.partition_S(gdQaccum_stage_g2s) + cute.copy( + g2s_thr_copy, + tdQgdQ[None, None, 0], + sdQaccum_g2s[None, None, smem_buf], + ) + + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) + + # S -> R + for stage_offset, lane_offset in stage_iters: + stage_idx = stage_group + stage_offset * stage_groups + s2r_src_tidx = s2r_lane + lane_offset + s2r_thr_copy = tiled_copy_accum.get_slice(s2r_src_tidx) + sdQaccum_src = s2r_thr_copy.partition_S(sdQaccum)[None, None, s2r_buf] + + tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage_idx, None, None] + tdQrdQ_r2s_cpy = cute.make_tensor( + tdQrdQ_s2r_cpy.iterator, cute.make_layout(sdQaccum_src.shape) + ) + cute.copy(s2r_thr_copy, sdQaccum_src, tdQrdQ_r2s_cpy) + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) + + # R -> S + stage_lo = stage_idx % stage_stride + stage_hi = stage_idx // stage_stride + tdQrdQ_r2s_cpy = cute.make_tensor( + cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), + tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].shape, + ) + dQ_vec = tdQrdQ_r2s_cpy.load() * scale + tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].store( + dQ_vec.to(self.dtype) + ) + + # R -> S + cute.copy( + r2s_tiled_copy, + tdQrdQ_r2s[None, None, None, 0], + tdQsdQ_r2s[None, None, None, 0], + ) + cute.arch.fence_view_async_shared() + cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) else: - taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape - taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) - taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt) - cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) + # Step 1: load dQaccum from gmem to smem + g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) + cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Step 2: load dQ from smem to rmem + s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + tile_shape = (self.tile_m, self.tile_hdim) + acc = None + tiled_copy_t2r = None + if const_expr(self.arch // 10 in [8, 9, 12]): + acc_shape = tiled_mma.partition_shape_C( + tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + else: + thr_mma = tiled_mma.get_slice(0) # 1-CTA + dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) + tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) + tdQcdQ = thr_mma.partition_C( + cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + ) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), + Float32, + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) + # Convert tdQrdQaccum from fp32 to fp16/bf16 + rdQ = cute.make_fragment_like(acc, self.dtype) + rdQ.store((acc.load() * scale).to(self.dtype)) + + # Step 3: Copy dQ from register to smem + cute.arch.barrier() # make sure all threads have finished loading dQaccum + if const_expr(self.arch // 10 in [8, 9, 12]): + copy_atom_r2s_dQ = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.dQ_swapAB + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) + else: + # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( + # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, + # ) + # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) + thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads + val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) + copy_atom_r2s_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( + copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ + ) + thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + if const_expr(self.arch // 10 in [8, 9, 12]): + taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) + else: + taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape + taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) + taccdQsdQ = thr_copy_r2s_dQ.partition_D( + sdQ if const_expr(not self.dQ_swapAB) else sdQt + ) + cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem cute.arch.barrier() # make sure all smem stores are done diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 794baebf4b4..d93ea5cc50b 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -1,22 +1,34 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h # from Cutlass C++ to Cute-DSL. +# +# Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient: +# D'_i = D_i - dLSE_i +# This works because in the backward pass: +# dS_ij = P_ij * (dP_ij - D_i) [standard] +# When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij +# (since d(LSE_i)/d(S_ij) = P_ij), giving: +# dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij +# = P_ij * (dP_ij - (D_i - dLSE_i)) +# So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here. import math import operator -from typing import Callable, Type, Optional, Literal +from functools import partial +from typing import Callable, Type, Optional import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import Arch, BaseDSL + +from quack import copy_utils, layout_utils from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfo +from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( - ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, @@ -28,9 +40,9 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, - arch: Literal[80, 90, 100], - m_block_size: int = 128, - num_threads: int = 128, + head_dim_v: int, + tile_m: int = 128, + num_threads: int = 256, ): """ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension @@ -38,30 +50,31 @@ def __init__( :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :param num_threads: number of threads :type num_threads: int """ + self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a self.dtype = dtype - self.m_block_size = m_block_size - self.arch = arch + self.tile_m = tile_m # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) - self.check_hdim_oob = head_dim != self.head_dim_padded + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.num_threads = num_threads @staticmethod - def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :param num_threads: number of threads :type num_threads: int @@ -74,7 +87,7 @@ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: return False if num_threads % 32 != 0: return False - if num_threads < m_block_size: # For multiplying lse with log2 + if num_threads < tile_m: # For multiplying lse with log2 return False return True @@ -87,20 +100,22 @@ def _setup_attributes(self): # it's just between threads in the same warp gmem_k_block_size = ( 128 - if self.head_dim_padded % 128 == 0 + if self.head_dim_v_padded % 128 == 0 else ( 64 - if self.head_dim_padded % 64 == 0 - else (32 if self.head_dim_padded % 32 == 0 else 16) + if self.head_dim_v_padded % 64 == 0 + else (32 if self.head_dim_v_padded % 32 == 0 else 16) ) ) + num_copy_elems = 128 // self.dtype.width + threads_per_row = gmem_k_block_size // num_copy_elems self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( - self.dtype, gmem_k_block_size, self.num_threads + self.dtype, threads_per_row, self.num_threads, num_copy_elems ) universal_copy_bits = 128 num_copy_elems_dQaccum = universal_copy_bits // Float32.width assert ( - self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum + self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum ) % self.num_threads == 0 self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_copy_elems_dQaccum @@ -109,38 +124,53 @@ def _setup_attributes(self): @cute.jit def __call__( self, - mO: cute.Tensor, - mdO: cute.Tensor, - mdPsum: cute.Tensor, - mLSE: Optional[cute.Tensor], - mLSElog2: Optional[cute.Tensor], + mO: cute.Tensor, # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v) + mdO: cute.Tensor, # same shape as mO + mPdPsum: cute.Tensor, # (batch, nheads, seqlen_padded) or (nheads, total_q_padded) + mLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + mLSElog2: Optional[cute.Tensor], # same shape as mPdPsum + # (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v) mdQaccum: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,) + mSeqUsedQ: Optional[cute.Tensor], # (batch,) + mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not (mO.element_type == mdO.element_type)): + if const_expr(not (mO.element_type == mdO.element_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mdPsum.element_type not in [Float32]): - raise TypeError("dPsum tensor must be Float32") - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(mdQaccum.element_type not in [Float32]): + if const_expr(mPdPsum.element_type not in [Float32]): + raise TypeError("PdPsum tensor must be Float32") + if const_expr(mdQaccum is not None): + if const_expr(mdQaccum.element_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") - if cutlass.const_expr(mLSE is not None): + if const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" - if cutlass.const_expr(mLSE.element_type not in [Float32]): + if const_expr(mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(mLSElog2.element_type not in [Float32]): + if const_expr(mLSElog2.element_type not in [Float32]): raise TypeError("LSElog2 tensor must be Float32") - - mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)] + if const_expr(mdLSE is not None): + if const_expr(mdLSE.element_type not in [Float32]): + raise TypeError("dLSE tensor must be Float32") self._setup_attributes() - if cutlass.const_expr(mCuSeqlensQ is not None): + # (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q) + transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mPdPsum = layout_utils.select(mPdPsum, transpose) + if const_expr(mLSE is not None): + mLSE = layout_utils.select(mLSE, transpose) + mLSElog2 = layout_utils.select(mLSElog2, transpose) + if const_expr(mdLSE is not None): + mdLSE = layout_utils.select(mdLSE, transpose) + if const_expr(mdQaccum is not None): + mdQaccum = layout_utils.select(mdQaccum, transpose) + + if const_expr(mCuSeqlensQ is not None): TileScheduler = SingleTileVarlenScheduler num_head = mO.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 @@ -150,7 +180,7 @@ def __call__( num_batch = mO.shape[0] tile_sched_args = TileSchedulerArguments( - num_block=cute.ceil_div(mO.shape[1], self.m_block_size), + num_block=cute.ceil_div(mO.shape[1], self.tile_m), num_head=num_head, num_batch=num_batch, num_splits=1, @@ -158,7 +188,7 @@ def __call__( headdim=0, headdim_v=mO.shape[2], total_q=mO.shape[0], - tile_shape_mn=(self.m_block_size, 1), + tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) @@ -169,12 +199,13 @@ def __call__( self.kernel( mO, mdO, - mdPsum, + mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSeqUsedQ, + mdLSE, self.gmem_tiled_copy_O, self.gmem_tiled_copy_dQaccum, tile_sched_params, @@ -183,6 +214,7 @@ def __call__( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, + use_pdl=self.use_pdl, ) @cute.kernel @@ -190,12 +222,13 @@ def kernel( self, mO: cute.Tensor, mdO: cute.Tensor, - mdPsum: cute.Tensor, + mPdPsum: cute.Tensor, mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], + mdLSE: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, tile_sched_params: ParamsBase, @@ -212,145 +245,106 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK.create( - batch_idx, - mO.shape[1], - 0, - mCuSeqlensQ=mCuSeqlensQ, - mCuSeqlensK=None, - mSeqUsedQ=mSeqUsedQ, - mSeqUsedK=None, + seqlen = SeqlenInfo.create( + batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m ) + mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None] + mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None] + mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1] + seqlen_q = seqlen.seqlen + seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) + seqlen_limit = seqlen_q - m_block * self.tile_m + + lse = None + if const_expr(mLSE is not None): + mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx] + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) + lse = Float32.inf + if tidx < seqlen_limit: + lse = gLSE[tidx] - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[batch_idx, None, head_idx, None] - mdO_cur = mdO[batch_idx, None, head_idx, None] - mdPsum_cur = mdPsum[batch_idx, head_idx, None] - headdim_v = mO.shape[3] - else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None]) - mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) - - padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size - if cutlass.const_expr(self.arch >= 90): - padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size - mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) - headdim_v = mO.shape[2] - - blkOdO_shape = (self.m_block_size, self.head_dim_padded) - # (m_block_size, head_dim) - gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) - gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0)) - + blk_shape = (self.tile_m, self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0)) + gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K) tOgO = gmem_thr_copy_O.partition_S(gO) tOgdO = gmem_thr_copy_O.partition_S(gdO) - - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + cO = cute.make_identity_tensor(blk_shape) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=headdim_v) - tOpdO = utils.predicate_k(tOcO, limit=headdim_v) - - seqlen_q = seqlen.seqlen_q - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) - - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[batch_idx, head_idx, None] - else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None]) - - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) - lse = Float32.inf - if tidx < seqlen_q - m_block * self.m_block_size: - lse = gLSE[tidx] - - tOrO = cute.make_fragment_like(tOgO) - tOrdO = cute.make_fragment_like(tOgdO) - assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) - assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) - assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + tOpO = None + if const_expr(self.check_hdim_v_oob): + tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v) + # Each copy will use the same predicate + copy = partial(copy_utils.copy, pred=tOpO) + + tOrO = cute.make_rmem_tensor_like(tOgO) + tOrdO = cute.make_rmem_tensor_like(tOgdO) + if const_expr(self.check_hdim_v_oob): + tOrO.fill(0.0) + tOrdO.fill(0.0) + assert tOgO.shape == tOgdO.shape for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): - # Instead of using tOcO, we using t0OcO and subtract the offset from the limit - # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. - if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: - cute.copy( - gmem_thr_copy_O, - tOgO[None, m, None], - tOrO[None, m, None], - pred=tOpO[None, m, None] - if cutlass.const_expr(self.check_hdim_oob) - else None, - ) - cute.copy( - gmem_thr_copy_O, - tOgdO[None, m, None], - tOrdO[None, m, None], - pred=tOpdO[None, m, None] - if cutlass.const_expr(self.check_hdim_oob) - else None, - ) + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit. + # This is bc the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]: + copy(tOgO[None, m, None], tOrO[None, m, None]) + copy(tOgdO[None, m, None], tOrdO[None, m, None]) + # O and dO loads are done; signal that the next kernel can start. + # Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs. + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_launch_dependents() # Sum across the "k" dimension - dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( + pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) ) threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] assert cute.arch.WARP_SIZE % threads_per_row == 0 - dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) - dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) - dP_sum.store(dpsum) - - # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,)) - # Only the thread corresponding to column 0 writes out the dPsum to gmem + pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row) + PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32) + PdP_sum.store(pdpsum) + + # If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation). + gdLSE = None + if const_expr(mdLSE is not None): + mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx] + gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,)) + + # Write PdPsum from rmem -> gmem + gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,)) + # Only the thread corresponding to column 0 writes out the PdPsum to gmem if tOcO[0, 0, 0][1] == 0: - for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + for m in cutlass.range(cute.size(PdP_sum), unroll_full=True): row = tOcO[0, m, 0][0] - gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0 + PdPsum_val = 0.0 + if row < seqlen_limit: + PdPsum_val = PdP_sum[m] + if const_expr(mdLSE is not None): + PdPsum_val -= gdLSE[row] + gPdPsum[row] = PdPsum_val # Clear dQaccum - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] - else: - mdQaccum_cur = cute.domain_offset( - (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None] - ) - - # HACK: Compiler doesn't seem to recognize that padding - # by padded_offset_q * self.head_dim_padded keeps alignment - # since statically divisible by 4 - - mdQaccum_cur_ptr = cute.make_ptr( - dtype=mdQaccum_cur.element_type, - value=mdQaccum_cur.iterator.toint(), - mem_space=mdQaccum_cur.iterator.memspace, - assumed_align=mdQaccum.iterator.alignment, - ) - mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) - - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + if const_expr(mdQaccum is not None): + mdQaccum_cur = seqlen.offset_batch( + mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded + )[None, head_idx] + blkdQaccum_shape = (self.tile_m * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - zero = cute.make_fragment_like(tdQgdQaccum) + zero = cute.make_rmem_tensor_like(tdQgdQaccum) zero.fill(0.0) cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSElog2_cur = mLSElog2[batch_idx, head_idx, None] - else: - mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None]) - - gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) + if const_expr(mLSE is not None): + mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) - if tidx < seqlen_q_rounded - m_block * self.m_block_size: + if tidx < seqlen_q_rounded - m_block * self.tile_m: gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 430bcf4f6c2..e06cd811fc6 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -8,11 +8,11 @@ import cutlass import cutlass.cute as cute from cutlass.cute import FastDivmodDivisor -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Int64, const_expr from cutlass.utils import LayoutEnum from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic -from cutlass.pipeline import PipelineAsync, PipelineConsumer +from cutlass.pipeline import PipelineAsync import quack.activation from quack import layout_utils @@ -24,12 +24,12 @@ from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo +from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, - ParamsBase, ) from flash_attn.cute import barrier @@ -59,6 +59,7 @@ def __init__( is_persistent: bool = False, deterministic: bool = False, cluster_size: int = 1, + use_2cta_instrs: bool = False, score_mod: cutlass.Constexpr | None = None, score_mod_bwd: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, @@ -70,29 +71,40 @@ def __init__( self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v - assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) - assert self.tile_hdim == self.tile_hdimv, ( - "tile_hdim and tile_hdimv must be the same for now" - ) self.check_hdim_oob = head_dim != self.tile_hdim self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.tile_m = tile_m self.tile_n = tile_n + assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128) + assert self.tile_hdimv <= 128 + + self.use_2cta_instrs = bool( + use_2cta_instrs + and cluster_size == 2 + and score_mod is None + and score_mod_bwd is None + and mask_mod is None + ) + self.cta_group_size = 2 if self.use_2cta_instrs else 1 + + assert self.tile_hdim != 192 or self.use_2cta_instrs, "Must use 2CTA for hdim 192" + # CTA tiler self.cta_tiler = (tile_n, tile_m, self.tile_hdim) # S = K @ Q.T - self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) + self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim) # dP = V @ dO.T - self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv) + self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv) # dV = P.T @ dO - self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m) - # dK = dS.T @ Q (N, M) (M, D) - self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m) + self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m) + # dK = dS.T @ Q + self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m) # dQ = dS @ K - self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) + # 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size). + self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size) self.acc_dtype = Float32 @@ -121,13 +133,14 @@ def __init__( # Speed optimizations, does not affect correctness self.shuffle_LSE = False self.shuffle_dPsum = False - self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal + # Generally slower to use store dS in smem for dK, and doesn't work for 2cta + self.use_smem_dS_for_mma_dK = False self.reduce_warp_ids = (0, 1, 2, 3) self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) self.mma_warp_id = 12 self.load_warp_id = 13 - self.epi_warp_id = 14 + self.relay_warp_id = 14 self.empty_warp_id = 15 # 16 warps -> 512 threads @@ -137,11 +150,10 @@ def __init__( *self.compute_warp_ids, self.mma_warp_id, self.load_warp_id, - self.epi_warp_id, + self.relay_warp_id, self.empty_warp_id, ) ) - # NamedBarrier self.compute_sync_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwdSm100.Compute), @@ -155,11 +167,8 @@ def __init__( barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE, ) - # TMEM setup - SM100_TMEM_CAPACITY_COLUMNS = 512 - self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS - + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") # self.tmem_dK_offset = 0 # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv @@ -169,67 +178,117 @@ def __init__( # self.tmem_total = self.tmem_S_offset + self.tile_n # assert self.tmem_total <= self.tmem_alloc_cols - self.tmem_S_offset = 0 - self.tmem_P_offset = 0 # overlap with S - self.tmem_dV_offset = self.tmem_S_offset + self.tile_n - self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv - self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP - self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m - self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP + if self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128: + assert self.tile_m == 128 + assert self.tile_n == 128 + self.tmem_dV_offset = 0 + self.tmem_dK_offset = self.tmem_dV_offset + self.tile_hdimv + self.tmem_S_offset = self.tmem_dK_offset + self.tile_hdim + self.tmem_P_offset = self.tmem_S_offset # overlap with S + self.tmem_dP_offset = 512 - self.tile_m + self.tmem_dS_offset = self.tmem_dP_offset # overlaps with dP + self.tmem_dQ_offset = 512 - self.tile_hdim // 2 + else: + self.tmem_S_offset = 0 + self.tmem_P_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_S_offset + self.tile_n + self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv + self.tmem_dQ_offset = ( + (self.tmem_S_offset + (self.tile_hdim // 2)) + if self.use_2cta_instrs + else self.tmem_dP_offset + ) + self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m + self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP if (not is_causal and not is_local) or deterministic: - self.num_regs_reduce = 152 + self.num_regs_reduce = 136 if self.use_2cta_instrs else 152 self.num_regs_compute = 136 + self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8 + self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load else: - self.num_regs_reduce = 136 - self.num_regs_compute = 144 - self.num_regs_other = 96 - 8 + self.num_regs_reduce = 136 if self.use_2cta_instrs else 136 + self.num_regs_compute = 136 if self.use_2cta_instrs else 144 + self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8 + self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load self.num_regs_empty = 24 - assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 + if const_expr(self.tile_hdim == 192): + if not is_causal and not is_local: + self.num_regs_reduce = 128 + 8 + self.num_regs_compute = 128 + 8 + self.num_regs_load = 128 - 24 + self.num_regs_mma = self.num_regs_load + else: + self.num_regs_reduce = 128 + 8 + self.num_regs_compute = 128 + 8 + self.num_regs_load = 128 - 24 + self.num_regs_mma = self.num_regs_load + + assert ( + self.num_regs_reduce + + self.num_regs_compute * 2 + + max(self.num_regs_load, self.num_regs_mma) + <= 512 + ) self.buffer_align_bytes = 1024 def _setup_attributes(self): - self.Q_stage = 2 + self.Q_stage = 1 if self.use_2cta_instrs else 2 self.dO_stage = 1 + self.single_stage = 1 # LSE_stage = Q_stage and dPsum_stage = dO_stage - # self.sdKVaccum_stage = 2 + self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma - self.dQ_reduce_ncol = 32 - self.sdQaccum_stage = 64 // self.dQ_reduce_ncol - assert self.tile_hdim % self.dQ_reduce_ncol == 0 + # todo: try 32/1 or 48/2 for 2cta d=192 dv=128 + if self.use_2cta_instrs and self.tile_hdim == 192: + self.dQ_reduce_ncol_t2r = 32 + self.dQ_reduce_ncol = 24 if not self.is_causal else 32 + self.sdQaccum_stage = 2 if not self.is_causal else 1 + else: + if self.use_2cta_instrs: + self.dQ_reduce_ncol = 16 if self.deterministic else 8 + self.sdQaccum_stage = 2 if self.deterministic else 4 + self.dQ_reduce_ncol_t2r = 32 + else: + self.dQ_reduce_ncol = 32 + self.sdQaccum_stage = 64 // self.dQ_reduce_ncol + self.dQ_reduce_ncol_t2r = self.dQ_reduce_ncol + assert (self.tile_hdim // self.cta_group_size) % self.dQ_reduce_ncol == 0 self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + self.dQaccum_reduce_stage_t2r = self.tile_hdim // self.dQ_reduce_ncol_t2r self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 - # number of tma reduce adds for dKacc and dVacc epilogue - self.dK_reduce_ncol = 32 + # number of tma reduce adds for dKacc and dVacc epilogue (must divide hdim_per_wg) + self.dK_reduce_ncol = math.gcd(32, self.tile_hdim // 2) + # CTA group for MMA operations + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE def _get_tiled_mma(self): - cta_group = tcgen05.CtaGroup.ONE - # S = K @ Q.T + # S.T = K @ Q.T tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma( self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, self.acc_dtype, - cta_group, + self.cta_group, self.mma_tiler_kq[:2], ) - # dP = V @ dO.T + # dP.T = V @ dO.T tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, self.acc_dtype, - cta_group, + self.cta_group, self.mma_tiler_vdo[:2], ) - # dV += P @ dO --> (K, MN) major + # dV += P.T @ dO --> (K, MN) major tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # P_major_mode tcgen05.OperandMajorMode.MN, # dO_major_mode self.acc_dtype, - cta_group, + self.cta_group, self.mma_tiler_pdo[:2], a_source=tcgen05.OperandSource.TMEM, ) @@ -243,7 +302,7 @@ def _get_tiled_mma(self): tcgen05.OperandMajorMode.K, # dS_major_mode tcgen05.OperandMajorMode.MN, # Q_major_mode self.acc_dtype, - cta_group, + self.cta_group, self.mma_tiler_dsq[:2], a_source=mma_dK_a_src, ) @@ -253,13 +312,13 @@ def _get_tiled_mma(self): tcgen05.OperandMajorMode.MN, # dS_major_mode tcgen05.OperandMajorMode.MN, # Kt_major_mode self.acc_dtype, - cta_group, + self.cta_group, self.mma_tiler_dsk[:2], ) return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _setup_smem_layout(self): - # S = K @ Q.T + # S.T = K @ Q.T sK_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_S, self.mma_tiler_kq, @@ -273,7 +332,7 @@ def _setup_smem_layout(self): self.q_dtype, self.Q_stage, ) - # dP = V @ dO.T + # dP.T = V @ dO.T sV_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dP, self.mma_tiler_vdo, @@ -287,7 +346,7 @@ def _setup_smem_layout(self): self.do_dtype, self.dO_stage, ) - # dV += P @ dO + # dV += P.T @ dO tP_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dV, self.mma_tiler_pdo, @@ -337,34 +396,48 @@ def _setup_smem_layout(self): 1, ) self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0)) + self.sdS_xchg_layout = cute.make_layout(shape=(self.tile_n, self.tile_m // 2)) + self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) ) self.sLSE_layout = cute.make_layout( - shape=(self.tile_m, self.Q_stage), - stride=(1, cute.round_up(self.tile_m, 64)), + shape=(self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64)) ) self.sdPsum_layout = cute.make_layout( shape=(self.tile_m, self.dO_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) - self.sdKV_epi_tile = ( + self.sdK_epi_tile = ( self.tile_n, - min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 + math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 + ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + self.sdV_epi_tile = ( + self.tile_n, + math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdimv // 2), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] # headdim_64 gets 1 stage - self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1]) - self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages - # TODO: dK and dV could have different shapes + self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdK_epi_tile[1]) + self.num_epi_stages_v = max(1, (self.tile_hdimv // 2) // self.sdV_epi_tile[1]) + self.sdK_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages + self.sdV_flat_epi_tile = self.tile_n * (self.tile_hdimv // 2) // self.num_epi_stages_v if const_expr(not self.dKV_postprocess): - self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( + self.sdK_layout = sm100_utils_basic.make_smem_layout_epi( self.dk_dtype, LayoutEnum.ROW_MAJOR, - self.sdKV_epi_tile, + self.sdK_epi_tile, + 2, # num compute wgs + ) + self.sdV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dv_dtype, + LayoutEnum.ROW_MAJOR, + self.sdV_epi_tile, 2, # num compute wgs ) else: - self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) + self.sdK_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) + # self.dK_reduce_ncol same for dV + self.sdV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) @cute.jit def __call__( @@ -379,7 +452,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -393,6 +465,8 @@ def __call__( aux_tensors: Optional[list] = None, # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): self.q_dtype = mQ.element_type self.k_dtype = mK.element_type @@ -408,6 +482,7 @@ def __call__( self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None) + # self.use_tma_store = not self.qhead_per_kvhead == 1 self.dKV_postprocess = self.qhead_per_kvhead > 1 if const_expr(self.dKV_postprocess): @@ -439,6 +514,10 @@ def __call__( dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] mdO = layout_utils.select(mdO, mode=dO_transpose) + # Transposes for 2-CTA K/Q paths (Q follows Q seqlens, K follows K seqlens) + transpose_sh_q = dO_transpose + transpose_sh_k = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + # (b, n, block, stage) -> (block, stage, n, b) semaphore_transpose = [2, 3, 1, 0] if const_expr(self.deterministic): @@ -466,8 +545,6 @@ def __call__( ) = self._get_tiled_mma() self._setup_smem_layout() - cta_group = tcgen05.CtaGroup.ONE - self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), @@ -491,15 +568,15 @@ def __call__( tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdK, - cute.select(self.sdKV_layout, mode=[0, 1]), - self.sdKV_epi_tile, + cute.select(self.sdK_layout, mode=[0, 1]), + self.sdK_epi_tile, 1, # no mcast ) tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdV, - cute.select(self.sdKV_layout, mode=[0, 1]), - self.sdKV_epi_tile, + cute.select(self.sdV_layout, mode=[0, 1]), + self.sdV_epi_tile, 1, # no mcast ) else: @@ -526,9 +603,7 @@ def __call__( Float32, 128, num_copy_elems=128 // Float32.width ) - tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) - tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) - + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) # S.T = K @ Q.T tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, @@ -542,7 +617,6 @@ def __call__( self.cluster_shape_mnk, self.tiled_mma_S.thr_id ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( - # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, Q_tma_op, mQ, cute.select(self.sQ_layout, mode=[0, 1, 2]), @@ -559,11 +633,11 @@ def __call__( self.tiled_mma_dP, self.cluster_layout_vmnk.shape, ) + # dV = P.T @ dO dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma_dV.thr_id ) tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( - # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, dO_tma_op, mdO, cute.select(self.sdO_layout, mode=[0, 1, 2]), @@ -571,9 +645,46 @@ def __call__( self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) + # ------------------------------------------------------------ + # 2-CTA + # ------------------------------------------------------------ + tma_atom_dOt = tma_tensor_dOt = None + if const_expr(self.use_2cta_instrs): + tma_atom_dOt, tma_tensor_dOt = cute.nvgpu.make_tiled_tma_atom_B( + dO_tma_op, + layout_utils.select(mdO, mode=transpose_sh_q), + cute.select(self.sdOt_layout, mode=[0, 1, 2]), + self.mma_tiler_vdo, + self.tiled_mma_dP, + self.cluster_layout_vmnk.shape, + ) + tma_atom_Qt = tma_tensor_Qt = None + if const_expr(self.use_2cta_instrs): + tma_atom_Qt, tma_tensor_Qt = cute.nvgpu.make_tiled_tma_atom_B( + Q_tma_op, + layout_utils.select(mQ, mode=transpose_sh_q), + cute.select(self.sQt_layout, mode=[0, 1, 2]), + self.mma_tiler_dsq, + self.tiled_mma_dK, + self.cluster_layout_vmnk.shape, + ) + tma_atom_Kt = tma_tensor_Kt = None + if const_expr(self.use_2cta_instrs): + Kt_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_dQ.thr_id + ) + tma_atom_Kt, tma_tensor_Kt = cute.nvgpu.make_tiled_tma_atom_B( + Kt_tma_op, + layout_utils.select(mK, mode=transpose_sh_k), + cute.select(self.sKt_layout, mode=[0, 1, 2]), + self.mma_tiler_dsk, + self.tiled_mma_dQ, + self.cluster_layout_vmnk.shape, + ) self.tma_copy_bytes = { - name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + name: self.cta_group_size + * cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) for name, mX, layout in [ ("Q", mQ, self.sQ_layout), ("K", mK, self.sK_layout), @@ -585,6 +696,8 @@ def __call__( self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 + self.tma_copy_bytes["dS"] = cute.size_in_bytes(self.ds_dtype, self.sdS_layout) + self.tma_copy_bytes["sdS_xchg"] = self.tma_copy_bytes["dS"] // 2 # Half of dS for exchange # TileScheduler = SingleTileScheduler if const_expr(self.is_varlen_k): @@ -593,7 +706,6 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - # reads n_blocks right-to-left self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks @@ -622,80 +734,167 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - # cute.printf("grid_dim = {}", grid_dim) # Compute allocation sizes for shared buffers that are reused # sQ is reused for sdK, sdO is reused for sdV sQ_alloc_bytes = max( cute.size_in_bytes(self.q_dtype, self.sQ_layout), - cute.size_in_bytes(self.dk_dtype, self.sdKV_layout), + cute.size_in_bytes(self.dk_dtype, self.sdK_layout), ) sdO_alloc_bytes = max( - cute.size_in_bytes(self.dv_dtype, self.sdKV_layout), + cute.size_in_bytes(self.dv_dtype, self.sdV_layout), cute.size_in_bytes(self.do_dtype, self.sdO_layout), ) - # Sanity check that layouts fit in allocation - sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout) - sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout) + + sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdK_layout) + sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdV_layout) assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation" assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation" + # 2-CTA: sdV reuses sV, sdK reuses sK + sV_bytes = cute.size_in_bytes(self.v_dtype, self.sV_layout) + sK_bytes = cute.size_in_bytes(self.k_dtype, self.sK_layout) + if const_expr(self.use_2cta_instrs): + assert sdV_bytes <= sV_bytes, "sdV doesn't fit in sV storage allocation (2-CTA)" + assert sdK_bytes <= sK_bytes, "sdK doesn't fit in sK storage allocation (2-CTA)" + + if const_expr(self.use_2cta_instrs): + sQt_size = cute.cosize(self.sQt_layout) if const_expr(self.tile_hdim <= 128) else 0 + sdOt_size = cute.cosize(self.sdOt_layout) if const_expr(self.tile_hdim <= 128) else 0 + sdS_xchg_size = ( + cute.cosize(self.sdS_xchg_layout) if const_expr(self.tile_hdim <= 128) else 0 + ) - @cute.struct - class SharedStorage: - Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] - dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] - LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] - dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] - S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] - dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] - dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] - dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] - dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] - dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.dQaccum_reduce_stage // 2 - ] - dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.dQaccum_reduce_stage // 2 - ] - tmem_holding_buf: Int32 - tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + @cute.struct + class SharedStorage: + Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage] + dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + tmem_holding_buf: Int32 + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # 2-CTA + Qt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + Kt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dS_cluster_empty_mbar_ptr: cutlass.Int64 + dS_cluster_full_mbar_ptr: cutlass.Int64 + dS_cluster_leader_mbar_ptr: cutlass.Int64 + dQaccum_empty_mbar_ptr: cutlass.Int64 + + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(self.sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[self.do_dtype, cute.cosize(self.sdO_layout)], + self.buffer_align_bytes, + ] + sQt: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, sQt_size], + self.buffer_align_bytes, + ] + sdOt: cute.struct.Align[ + cute.struct.MemRange[self.do_dtype, sdOt_size], + self.buffer_align_bytes, + ] + sdS_xchg: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, sdS_xchg_size], + self.buffer_align_bytes, + ] + sKt: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sKt_layout)], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], + self.buffer_align_bytes, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], + 128, + ] + sdPsum: cute.struct.Align[ + cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], + 128, + ] + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], + self.buffer_align_bytes if sdS_xchg_size == 0 else 128, + ] - # Smem tensors + else: - # sQ is reused for sdK which in the non-MHA case needs float32 - sQ: cute.struct.Align[ - cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], - self.buffer_align_bytes, - ] - sK: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], - self.buffer_align_bytes, - ] - sV: cute.struct.Align[ - cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], - self.buffer_align_bytes, - ] - # sdO is reused for sdV which in the non-MHA case needs float32 - sdO: cute.struct.Align[ - cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], - self.buffer_align_bytes, - ] - sdS: cute.struct.Align[ - cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], - 128, - ] - sLSE: cute.struct.Align[ - cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], - 128, - ] - sdPsum: cute.struct.Align[ - cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], - 128, - ] - sdQaccum: cute.struct.Align[ - cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], - self.buffer_align_bytes, - ] + @cute.struct + class SharedStorage: + Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] + dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage] + dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + tmem_holding_buf: Int32 + tmem_dealloc_mbar_ptr: Int64 + + sQ: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], + self.buffer_align_bytes, + ] + sdO: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], + 128, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], + 128, + ] + sdPsum: cute.struct.Align[ + cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], + 128, + ] + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], + self.buffer_align_bytes, + ] self.shared_storage = SharedStorage @@ -723,6 +922,13 @@ class SharedStorage: fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if const_expr(self.use_2cta_instrs): + assert blocksparse_tensors is None, ( + "2-CTA mode does not support block sparsity. " + "Please create kernel with use_2cta_instrs=False for block sparse attention." + ) + # 2-CTA: 231424 and 1-CTA: 232448 + # print("SMEM: ", self.shared_storage.size_in_bytes()) if const_expr(self.use_block_sparsity or aux_tensors is not None): assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd" @@ -730,11 +936,14 @@ class SharedStorage: self.kernel( tma_tensor_Q, + tma_tensor_Qt, tma_tensor_K, + tma_tensor_Kt, tma_tensor_V, mLSE, mdPsum, tma_tensor_dO, + tma_tensor_dOt, mdV, mdK, mdQaccum, @@ -748,14 +957,18 @@ class SharedStorage: mSeqUsedQ, mSeqUsedK, tma_atom_Q, + tma_atom_Qt, tma_atom_K, + tma_atom_Kt, tma_atom_V, tma_atom_dO, + tma_atom_dOt, tma_atom_dV, tma_atom_dK, self.sQ_layout, self.sQt_layout, self.sK_layout, + self.sKt_layout, self.sV_layout, self.sLSE_layout, self.sdPsum_layout, @@ -763,9 +976,10 @@ class SharedStorage: self.sdOt_layout, self.sdSt_layout, self.sdS_layout, - self.sKt_layout, + self.sdS_xchg_layout, self.sdQaccum_layout, - self.sdKV_layout, + self.sdK_layout, + self.sdV_layout, self.tP_layout, self.tdS_layout, self.tiled_mma_S, @@ -795,11 +1009,14 @@ class SharedStorage: def kernel( self, mQ: cute.Tensor, + mQt: Optional[cute.Tensor], mK: cute.Tensor, + mKt: Optional[cute.Tensor], mV: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdO: cute.Tensor, + mdOt: Optional[cute.Tensor], mdV: cute.Tensor, mdK: cute.Tensor, mdQaccum: cute.Tensor, @@ -813,14 +1030,18 @@ def kernel( mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, + tma_atom_Qt: Optional[cute.CopyAtom], tma_atom_K: cute.CopyAtom, + tma_atom_Kt: Optional[cute.CopyAtom], tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, + tma_atom_dOt: Optional[cute.CopyAtom], tma_atom_dV: Optional[cute.CopyAtom], tma_atom_dK: Optional[cute.CopyAtom], sQ_layout: cute.ComposedLayout, sQt_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, + sKt_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sLSE_layout: cute.Layout, sdPsum_layout: cute.Layout, @@ -828,9 +1049,10 @@ def kernel( sdOt_layout: cute.ComposedLayout, sdSt_layout: cute.ComposedLayout, sdS_layout: cute.ComposedLayout, - sKt_layout: cute.ComposedLayout, + sdS_xchg_layout: cute.Layout, sdQaccum_layout: cute.Layout, - sdKV_layout: cute.ComposedLayout | cute.Layout, + sdK_layout: cute.ComposedLayout | cute.Layout, + sdV_layout: cute.ComposedLayout | cute.Layout, tP_layout: cute.ComposedLayout, tdS_layout: cute.ComposedLayout, tiled_mma_S: cute.TiledMma, @@ -849,13 +1071,23 @@ def kernel( blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % self.cta_group_size + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) # Prefetch tma descriptor if warp_idx == self.load_warp_id: with cute.arch.elect_one(): cpasync.prefetch_descriptor(tma_atom_Q) + if const_expr(tma_atom_Qt is not None): + cpasync.prefetch_descriptor(tma_atom_Qt) cpasync.prefetch_descriptor(tma_atom_K) + if const_expr(tma_atom_Kt is not None): + cpasync.prefetch_descriptor(tma_atom_Kt) cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_dOt is not None): + cpasync.prefetch_descriptor(tma_atom_dOt) cpasync.prefetch_descriptor(tma_atom_dO) if const_expr(tma_atom_dV is not None): cpasync.prefetch_descriptor(tma_atom_dV) @@ -871,61 +1103,97 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr() dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr() - if warp_idx == 1: - cute.arch.mbarrier_init( - tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) - ) + if const_expr(self.use_2cta_instrs): + dS_cluster_full_mbar_ptr = storage.dS_cluster_full_mbar_ptr + dS_cluster_empty_mbar_ptr = storage.dS_cluster_empty_mbar_ptr + dS_cluster_leader_mbar_ptr = storage.dS_cluster_leader_mbar_ptr + dQaccum_empty_mbar_ptr = storage.dQaccum_empty_mbar_ptr + else: + dS_cluster_full_mbar_ptr = None + dS_cluster_empty_mbar_ptr = None + dS_cluster_leader_mbar_ptr = None + dQaccum_empty_mbar_ptr = None + + # Barrier initialization + if const_expr(self.use_2cta_instrs): + if const_expr(self.tile_hdim == 192): + if warp_idx == 2: + cute.arch.mbarrier_init( + dQaccum_empty_mbar_ptr, + len(self.reduce_warp_ids), + ) + if warp_idx == 4: + cute.arch.mbarrier_init(dS_cluster_full_mbar_ptr, 1) + cute.arch.mbarrier_init(dS_cluster_empty_mbar_ptr, 1) + cute.arch.mbarrier_init(dS_cluster_leader_mbar_ptr, 2) + if const_expr(self.cluster_reduce_dQ): if warp_idx == 4: for i in range(self.dQaccum_reduce_stage // 2): cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1) cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1) + tmem_alloc_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.TmemPtr), + num_threads=cute.arch.WARP_SIZE + * len((self.mma_warp_id, *self.compute_warp_ids, *self.reduce_warp_ids)), + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + # UMMA producers and AsyncThread consumers pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) - # Only 1 thread per warp will signal pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.cta_group_size ) pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.S_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dP_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=2, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dKV_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, - len(self.reduce_warp_ids), + len(self.reduce_warp_ids) * self.cta_group_size, ) # Compute pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, barrier_storage=storage.dQ_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, ) # AsyncThread producers and UMMA consumers # Only 1 thread per warp will signal pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + cutlass.pipeline.Agent.Thread, + len(self.compute_warp_ids) * self.cta_group_size, ) # Compute pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) @@ -935,6 +1203,7 @@ def kernel( producer_group=pipeline_PdS_producer_group, consumer_group=pipeline_PdS_consumer_group, barrier_storage=storage.dS_mbar_ptr.data_ptr(), + cta_layout_vmnk=cluster_layout_vmnk, ) # TMA producer and UMMA consumers @@ -946,7 +1215,6 @@ def kernel( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b ) pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup( - # cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * 1, ) @@ -977,6 +1245,32 @@ def kernel( cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) + + if const_expr(self.use_2cta_instrs): + if const_expr(self.tile_hdim == 192): + pipeline_Qt = pipeline_Q + else: + pipeline_Qt = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.Qt_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + pipeline_Kt = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.Kt_mbar_ptr.data_ptr(), + num_stages=self.single_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["K"], + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + else: + pipeline_Qt = pipeline_Kt = pipeline_Q + pipeline_dO = pipeline.PipelineTmaUmma.create( barrier_storage=storage.dO_mbar_ptr.data_ptr(), num_stages=self.dO_stage, @@ -988,36 +1282,69 @@ def kernel( ) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype) - sQt = cute.make_tensor( - cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer - ) + if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): + sQt = storage.sQt.get_tensor( + sQt_layout.outer, swizzle=sQt_layout.inner, dtype=self.q_dtype + ) + else: + sQt = cute.make_tensor( + cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer + ) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) + if const_expr(self.use_2cta_instrs): + sKt = storage.sKt.get_tensor(sKt_layout.outer, swizzle=sKt_layout.inner) + else: + sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) + if const_expr(self.use_2cta_instrs): + if const_expr(self.tile_hdim <= 128): + sdS_xchg = storage.sdS_xchg.get_tensor(sdS_xchg_layout) + else: + sdS_xchg = storage.sdQaccum.get_tensor(sdS_xchg_layout, dtype=self.ds_dtype) + else: + sdS_xchg = None + sdO = storage.sdO.get_tensor( sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype ) - sdOt = cute.make_tensor( - cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer - ) + if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): + sdOt = storage.sdOt.get_tensor( + sdOt_layout.outer, swizzle=sdOt_layout.inner, dtype=self.do_dtype + ) + else: + sdOt = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), + sdOt_layout.outer, + ) + sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) - if const_expr(not self.dKV_postprocess): + if const_expr(self.use_2cta_instrs): + if const_expr(not self.dKV_postprocess): + sdV = storage.sV.get_tensor( + sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype + ) + sdK = storage.sK.get_tensor( + sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype + ) + else: + sdV = storage.sV.get_tensor(sdV_layout, dtype=self.dv_dtype) + sdK = storage.sK.get_tensor(sdK_layout, dtype=self.dk_dtype) + elif const_expr(not self.dKV_postprocess): sdV = storage.sdO.get_tensor( - sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype + sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype ) sdK = storage.sQ.get_tensor( - sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype + sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype ) else: - sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype) - sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype) + sdV = storage.sdO.get_tensor(sdV_layout, dtype=self.dv_dtype) + sdK = storage.sQ.get_tensor(sdK_layout, dtype=self.dk_dtype) # Buffer sizing is guaranteed by max(...) in SharedStorage declarations # for both sQ (reused as sdK) and sdO (reused as sdV) - sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM @@ -1025,18 +1352,18 @@ def kernel( # request 512 columns of tmem, so we know that it starts at 0. tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) # S - thr_mma_S = tiled_mma_S.get_slice(0) + thr_mma_S = tiled_mma_S.get_slice(mma_tile_coord_v) Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) tStS = thr_mma_S.make_fragment_C(Sacc_shape) # (MMA, MMA_M, MMA_N) tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) # dP - thr_mma_dP = tiled_mma_dP.get_slice(0) + thr_mma_dP = tiled_mma_dP.get_slice(mma_tile_coord_v) dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2]) tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape) tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) # dV - thr_mma_dV = tiled_mma_dV.get_slice(0) + thr_mma_dV = tiled_mma_dV.get_slice(mma_tile_coord_v) dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) @@ -1044,7 +1371,7 @@ def kernel( cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer ) # dK - thr_mma_dK = tiled_mma_dK.get_slice(0) + thr_mma_dK = tiled_mma_dK.get_slice(mma_tile_coord_v) dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) @@ -1052,7 +1379,7 @@ def kernel( cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer ) # dQ - thr_mma_dQ = tiled_mma_dQ.get_slice(0) + thr_mma_dQ = tiled_mma_dQ.get_slice(mma_tile_coord_v) dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout) @@ -1077,55 +1404,78 @@ def kernel( mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, tile_m=self.tile_m, - tile_n=self.tile_n, + tile_n=self.tile_n * self.cluster_shape_mnk[0], ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) AttentionMaskCls = partial( AttentionMask, self.tile_m, - self.tile_n, + self.tile_n * self.cta_group_size, swap_AB=True, window_size_left=window_size_left, window_size_right=window_size_right, ) - # EMPTY # (15) if warp_idx == self.empty_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_empty) - # EPI + # RELAY # (14) - if warp_idx == self.epi_warp_id: - # currently no-op, could use for tma store/reduce - cute.arch.setmaxregister_decrease(self.num_regs_empty) + if warp_idx == self.relay_warp_id: + cute.arch.setmaxregister_decrease( + self.num_regs_mma if self.use_2cta_instrs else self.num_regs_empty + ) + if const_expr(self.use_2cta_instrs): + self.relay( + dS_cluster_full_mbar_ptr, + dS_cluster_empty_mbar_ptr, + dS_cluster_leader_mbar_ptr, + cluster_layout_vmnk, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) # LOAD # (13) if warp_idx == self.load_warp_id: - cute.arch.setmaxregister_decrease(self.num_regs_other) + cute.arch.setmaxregister_decrease(self.num_regs_load) self.load( thr_mma_S, thr_mma_dP, thr_mma_dV, + thr_mma_dK, + thr_mma_dQ, mQ, mK, + mKt, mV, + mdO, + mQt, + mdOt, mLSE, mdPsum, - mdO, sQ, sK, + sKt, sV, + sdO, + sQt, + sdOt, sLSE, sdPsum, - sdO, tma_atom_Q, tma_atom_K, + tma_atom_Kt, tma_atom_V, tma_atom_dO, + tma_atom_Qt, + tma_atom_dOt, pipeline_Q, + pipeline_Qt, + pipeline_Kt, pipeline_dO, pipeline_LSE, pipeline_dPsum, @@ -1141,12 +1491,12 @@ def kernel( # MMA # (12) if warp_idx == self.mma_warp_id: - cute.arch.setmaxregister_decrease(self.num_regs_other) + cute.arch.setmaxregister_decrease(self.num_regs_mma) # Alloc tmem buffer - tmem_alloc_cols = Int32(self.tmem_alloc_cols) - cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) - cute.arch.sync_warp() + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(Float32) self.mma( tiled_mma_S, @@ -1157,20 +1507,25 @@ def kernel( sQ, sQt, sK, + sKt, sV, sdO, sdOt, + tP, sdSt, sdS, - sKt, - tP, tdS, tStS, tdPtdP, tdVtdV, tdKtdK, tdQtdQ, - pipeline_Q.make_consumer(), + dS_cluster_full_mbar_ptr, + dS_cluster_empty_mbar_ptr, + dS_cluster_leader_mbar_ptr, + pipeline_Q, + pipeline_Qt, + pipeline_Kt, pipeline_dO, pipeline_S_P, pipeline_dS, @@ -1180,41 +1535,44 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + is_leader_cta, blocksparse_tensors, ) - cute.arch.relinquish_tmem_alloc_permit() - tmem_ptr = cute.arch.retrieve_tmem_ptr( - Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf - ) - - cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) - tmem_alloc_cols = Int32(self.tmem_alloc_cols) - cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False) + # Dealloc the tensor memory buffer + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr) # Compute # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: cute.arch.setmaxregister_increase(self.num_regs_compute) # 8 warps + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(Float32) self.compute_loop( thr_mma_S, thr_mma_dP, thr_mma_dV, thr_mma_dK, tStS, - sLSE, - sdPsum, + tdPtdP, tdVtdV, tdKtdK, + sLSE, + sdPsum, mdV, mdK, sdS, - tdPtdP, + sdS_xchg, pipeline_LSE, pipeline_dPsum, pipeline_S_P, pipeline_dS, pipeline_dKV, pipeline_dP, + dS_cluster_empty_mbar_ptr, + dS_cluster_full_mbar_ptr, + dQaccum_empty_mbar_ptr, softmax_scale, softmax_scale_log2, block_info, @@ -1234,50 +1592,111 @@ def kernel( fastdiv_mods, blocksparse_tensors, ) - cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + tmem_alloc_barrier.arrive() # Reduce # (0, 1, 2, 3) - dQ if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: cute.arch.setmaxregister_increase(self.num_regs_reduce) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(Float32) self.dQacc_reduce( mdQaccum, sdQaccum, thr_mma_dQ, tdQtdQ, pipeline_dQ, + dQaccum_empty_mbar_ptr, block_info, SeqlenInfoCls, TileSchedulerCls, mdQ_semaphore, blocksparse_tensors, ) + tmem_alloc_barrier.arrive() return + @cute.jit + def relay( + self, + dS_cluster_full_mbar_ptr: cute.Pointer, + dS_cluster_empty_mbar_ptr: cute.Pointer, + dS_cluster_leader_mbar_ptr: cute.Pointer, + cluster_layout_vmnk: cute.Layout, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + dS_cluster_phase = Int32(0) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + head_idx_kv = head_idx // self.qhead_per_kvhead + + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) + + if process_tile: + num_iters = m_block_max - m_block_min + for _ in cutlass.range(num_iters, unroll=1): + # Wait for dS_xchg from peer CTA + cute.arch.mbarrier_wait(dS_cluster_full_mbar_ptr, phase=dS_cluster_phase) + + # Arrive on MMA leader warp + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(dS_cluster_leader_mbar_ptr, Int32(0)) + + dS_cluster_phase ^= 1 + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + @cute.jit def load( self, thr_mma_S: cute.core.ThrMma, thr_mma_dP: cute.core.ThrMma, thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, + thr_mma_dQ: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, + mKt: Optional[cute.Tensor], mV: cute.Tensor, + mdO: cute.Tensor, + mQt: Optional[cute.Tensor], + mdOt: Optional[cute.Tensor], mLSE: cute.Tensor, mdPsum: cute.Tensor, - mdO: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, + sKt: cute.Tensor, sV: cute.Tensor, + sdO: cute.Tensor, + sQt: cute.Tensor, + sdOt: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, - sdO: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, + tma_atom_Kt: Optional[cute.CopyAtom], tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, + tma_atom_Qt: Optional[cute.CopyAtom], + tma_atom_dOt: Optional[cute.CopyAtom], # 2-CTA only pipeline_Q: PipelineAsync, + pipeline_Qt: PipelineAsync, + pipeline_Kt: PipelineAsync, pipeline_dO: PipelineAsync, pipeline_LSE: PipelineAsync, pipeline_dPsum: PipelineAsync, @@ -1292,9 +1711,27 @@ def load( producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) + producer_state_Qt = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_Kt = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.single_stage + ) producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) + producer_state_Q_Qt = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_O_Ot = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) + producer_state_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_dPsum = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) # Compute multicast mask for Q & dO buffer full cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) @@ -1314,6 +1751,9 @@ def load( seqlen, n_block // self.cluster_shape_mnk[0] ) head_idx_kv = head_idx // self.qhead_per_kvhead + n_block_cta_group = n_block // self.cta_group_size + + # GMEM tensors (varlen-aware) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] @@ -1326,10 +1766,28 @@ def load( None, head_idx ] - gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) + if const_expr(self.use_2cta_instrs): + if const_expr(not seqlen.has_cu_seqlens_q): + mQt_cur = mQt[None, None, head_idx, batch_idx] + mdOt_cur = mdOt[None, None, head_idx, batch_idx] + else: + mQt_cur = cute.domain_offset((0, seqlen.offset_q, 0), mQt)[None, None, head_idx] + mdOt_cur = cute.domain_offset((seqlen.offset_q, 0, 0), mdOt)[ + None, None, head_idx + ] + if const_expr(not seqlen.has_cu_seqlens_k): + mKt_cur = mKt[None, None, head_idx_kv, batch_idx] + else: + mKt_cur = cute.domain_offset((0, seqlen.offset_k, 0), mKt)[ + None, None, head_idx_kv + ] + + # (1) S.T = K @ Q.T + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block_cta_group, 0) + ) tSgK = thr_mma_S.partition_A(gK) - gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) - tdPgV = thr_mma_dP.partition_A(gV) + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_S.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) @@ -1337,17 +1795,16 @@ def load( gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdPgdO = thr_mma_dV.partition_B(gdO) + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) load_K, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True - ) - load_V, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_V, - 0, - cute.make_layout(1), - tdPgV, - sV, + tma_atom_K, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + tSgK, + sK, single_stage=True, ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) load_Q, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, @@ -1358,15 +1815,82 @@ def load( mcast_mask=q_do_mcast_mask, ) load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + + # (2) dP = V @ dO.T + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block_cta_group, 0) + ) + tdPgV = thr_mma_dP.partition_A(gV) + + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + tdPgV, + sV, + single_stage=True, + ) + + if const_expr(tma_atom_dOt is not None): + gdOt = cute.local_tile( + mdOt_cur, cute.select(self.mma_tiler_vdo, mode=[1, 2]), (None, 0) + ) + tdPgdO = thr_mma_dP.partition_B(gdOt) + load_dOt, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dOt, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tdPgdO, + dst_tensor=sdOt, + mcast_mask=q_do_mcast_mask, + ) + load_dOt = copy_utils.tma_producer_copy_fn(load_dOt, pipeline_dO) + + # (3) dV += P.T @ dO + gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) + tdVgdO = thr_mma_dV.partition_B(gdO) load_dO, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dO, cta_coord=block_in_cluster_coord_vmnk[1], cta_layout=b_cta_layout, - src_tensor=tdPgdO, + src_tensor=tdVgdO, dst_tensor=sdO, mcast_mask=q_do_mcast_mask, ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) + + # (4) dK += dS.T @ Q (2-CTA: needs separate Qt load) + if const_expr(tma_atom_Qt is not None): + gQt = cute.local_tile( + mQt_cur, cute.select(self.mma_tiler_dsq, mode=[1, 2]), (0, None) + ) + tdKgQt = thr_mma_dK.partition_B(gQt) + load_Qt, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Qt, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tdKgQt, + dst_tensor=sQt, + mcast_mask=q_do_mcast_mask, + ) + load_Qt = copy_utils.tma_producer_copy_fn(load_Qt, pipeline_Qt) + + # (5) dQ = dS @ K + if const_expr(self.use_2cta_instrs): + gKt = cute.local_tile( + mKt_cur, cute.select(self.mma_tiler_dsk, mode=[1, 2]), (0, n_block_cta_group) + ) + tdQgK = thr_mma_dQ.partition_B(gKt) + + load_Kt, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Kt, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + tdQgK, + sKt, + single_stage=True, + ) + copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) copy_stats = partial(cute.copy, copy_atom_stats) # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32) @@ -1426,67 +1950,159 @@ def load( ) else: first_m_block = m_block_min - - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - if const_expr(should_load_Q): + if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): + #### Prologue #### + assert should_load_Q and should_load_dO + # K & Q (for S) pipeline_Q.producer_acquire( - producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + producer_state_Q_Qt, + extra_tx_count=self.tma_copy_bytes["K"], ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) - load_Q(first_m_block, producer_state=producer_state_Q_LSE) - pipeline_Q.producer_commit(producer_state_Q_LSE) - pipeline_LSE.producer_acquire(producer_state_Q_LSE) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_Qt)) + load_Q(first_m_block, producer_state=producer_state_Q_Qt) + pipeline_Q.producer_commit(producer_state_Q_Qt) + producer_state_Q_Qt.advance() + # LSE + pipeline_LSE.producer_acquire(producer_state_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, first_m_block], - sLSE[None, producer_state_Q_LSE.index], - mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + sLSE[None, producer_state_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE), ) - producer_state_Q_LSE.advance() - if const_expr(should_load_dO): + producer_state_LSE.advance() + + # dOt + V, for dP.T = V @ dO.T pipeline_dO.producer_acquire( - producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] - ) - load_V( - tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum) + producer_state_O_Ot, + extra_tx_count=self.tma_copy_bytes["V"], ) - load_dO(first_m_block, producer_state=producer_state_dO_dPsum) - pipeline_dO.producer_commit(producer_state_dO_dPsum) - pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_O_Ot)) + load_dOt(first_m_block, producer_state=producer_state_O_Ot) + pipeline_dO.producer_commit(producer_state_O_Ot) + producer_state_O_Ot.advance() + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, first_m_block], - sdPsum[None, producer_state_dO_dPsum.index], - mbar_ptr=pipeline_dPsum.producer_get_barrier( - producer_state_dO_dPsum - ), + sdPsum[None, producer_state_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dPsum), ) - producer_state_dO_dPsum.advance() + producer_state_dPsum.advance() + + # Qt, for dK = dS.T @ Q + pipeline_Qt.producer_acquire( + producer_state_Q_Qt, + extra_tx_count=self.tma_copy_bytes["K"], + ) + load_Qt(first_m_block, producer_state=producer_state_Q_Qt) + load_Kt(tma_bar_ptr=pipeline_Qt.producer_get_barrier(producer_state_Q_Qt)) + pipeline_Qt.producer_commit(producer_state_Q_Qt) + producer_state_Q_Qt.advance() + + # dO, for dV = P.T @ dO + pipeline_dO.producer_acquire(producer_state_O_Ot) + load_dO(first_m_block, producer_state=producer_state_O_Ot) + pipeline_dO.producer_commit(producer_state_O_Ot) + producer_state_O_Ot.advance() + + #### Mainloop #### + # 2CTA: [lse | Q | dOt | dPsum | Qt | dO] + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # LSE + pipeline_LSE.producer_acquire(producer_state_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE), + ) + producer_state_LSE.advance() + + # Q + pipeline_Q.producer_acquire(producer_state_Q_Qt) + load_Q(m_block, producer_state=producer_state_Q_Qt) + pipeline_Q.producer_commit(producer_state_Q_Qt) + producer_state_Q_Qt.advance() + + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dPsum + ), + ) + producer_state_dPsum.advance() + + # dOt, for dP.T = V @ dO.T + pipeline_dO.producer_acquire(producer_state_O_Ot) + load_dOt(m_block, producer_state=producer_state_O_Ot) + pipeline_dO.producer_commit(producer_state_O_Ot) + producer_state_O_Ot.advance() + + # Qt, for dK = dS.T @ Q + pipeline_Qt.producer_acquire(producer_state_Q_Qt) + load_Qt(m_block, producer_state=producer_state_Q_Qt) + pipeline_Qt.producer_commit(producer_state_Q_Qt) + producer_state_Q_Qt.advance() + + # dO, for dV = P.T @ dO + pipeline_dO.producer_acquire(producer_state_O_Ot) + load_dO(m_block, producer_state=producer_state_O_Ot) + pipeline_dO.producer_commit(producer_state_O_Ot) + producer_state_O_Ot.advance() - # Dense path: iterate from m_block_min+1 to m_block_max - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + else: + #### Prologue #### if const_expr(should_load_Q): - pipeline_Q.producer_acquire(producer_state_Q_LSE) - load_Q(m_block, producer_state=producer_state_Q_LSE) + # K & Q (for S) + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K( + tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE) + ) + load_Q(first_m_block, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) + + # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block], + gLSE[None, first_m_block], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier( producer_state_Q_LSE ), ) producer_state_Q_LSE.advance() + if const_expr(should_load_dO): - pipeline_dO.producer_acquire(producer_state_dO_dPsum) - load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, + extra_tx_count=self.tma_copy_bytes["V"] + self.tma_copy_bytes["dO"] + if const_expr(tma_atom_dOt is not None) + else self.tma_copy_bytes["V"], + ) + load_V( + tma_bar_ptr=pipeline_dO.producer_get_barrier( + producer_state_dO_dPsum + ) + ) + load_dO(first_m_block, producer_state=producer_state_dO_dPsum) + if const_expr(tma_atom_dOt is not None): + load_dOt(first_m_block, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) + + # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block], + gdPsum[None, first_m_block], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier( producer_state_dO_dPsum @@ -1494,14 +2110,83 @@ def load( ) producer_state_dO_dPsum.advance() - if const_expr(should_load_Q): - pipeline_Q.producer_tail( - producer_state_Q_LSE.clone() - ) # will hang if we don't clone - pipeline_LSE.producer_tail(producer_state_Q_LSE) - if const_expr(should_load_dO): - pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) - pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + if const_expr(self.use_2cta_instrs): + pipeline_Kt.producer_acquire(producer_state_Kt) + load_Kt(tma_bar_ptr=pipeline_Kt.producer_get_barrier(producer_state_Kt)) + pipeline_Kt.producer_commit(producer_state_Kt) + producer_state_Kt.advance() + #### Main Loop #### + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + if const_expr(tma_atom_Qt is not None): + pipeline_Qt.producer_acquire(producer_state_Qt) + load_Qt(m_block - 1, producer_state=producer_state_Qt) + pipeline_Qt.producer_commit(producer_state_Qt) + producer_state_Qt.advance() + + # Q (for S) + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + + # LSE + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier( + producer_state_Q_LSE + ), + ) + producer_state_Q_LSE.advance() + + if const_expr(should_load_dO): + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, + extra_tx_count=self.tma_copy_bytes["dO"] + if const_expr(tma_atom_dOt is not None) + else 0, + ) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + if const_expr(tma_atom_dOt is not None): + load_dOt(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + + # dPsum + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + #### Tail #### + if const_expr(should_load_Q): + if const_expr(tma_atom_Qt is not None): + pipeline_Qt.producer_acquire(producer_state_Qt) + load_Qt(m_block_max - 1, producer_state=producer_state_Qt) + pipeline_Qt.producer_commit(producer_state_Qt) + producer_state_Qt.advance() + + if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): + pipeline_Q.producer_tail(producer_state_Q_Qt) + pipeline_LSE.producer_tail(producer_state_LSE) + pipeline_dO.producer_tail(producer_state_O_Ot) + pipeline_dPsum.producer_tail(producer_state_dPsum) + else: + if const_expr(should_load_Q): + pipeline_Q.producer_tail(producer_state_Q_LSE.clone()) + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(tma_atom_Qt is not None): + pipeline_Qt.producer_tail(producer_state_Qt) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1518,20 +2203,25 @@ def mma( sQ: cute.Tensor, sQt: cute.Tensor, sK: cute.Tensor, + sKt: cute.Tensor, sV: cute.Tensor, sdO: cute.Tensor, sdOt: cute.Tensor, + tP: cute.Tensor, sdSt: cute.Tensor, sdS: cute.Tensor, - sKt: cute.Tensor, - tP: cute.Tensor, tdS: cute.Tensor, tStS: cute.Tensor, tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, tdQtdQ: cute.Tensor, - pipeline_Q_consumer: PipelineConsumer, + dS_cluster_full_mbar_ptr: cute.Pointer, + dS_cluster_empty_mbar_ptr: cute.Pointer, + dS_cluster_leader_mbar_ptr: cute.Pointer, + pipeline_Q: PipelineAsync, + pipeline_Qt: PipelineAsync, + pipeline_Kt: PipelineAsync, pipeline_dO: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, @@ -1541,6 +2231,7 @@ def mma( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + is_leader_cta: cutlass.Boolean, blocksparse_tensors: Optional[BlockSparseTensors] = None, ): # [2025-10-21] For reasons I don't understand, putting these partitioning in the main @@ -1549,14 +2240,16 @@ def mma( # S = K @ Q.T tSrK = tiled_mma_S.make_fragment_A(sK) tSrQ = tiled_mma_S.make_fragment_B(sQ) - # dP = V @ dO.T + # dP = V @ dOt.T tdPrV = tiled_mma_dP.make_fragment_A(sV) tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) # dK = dS.T @ Q - if const_expr(self.use_smem_dS_for_mma_dK): - tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + # For 2-CTA, dS (dK mma) MUST come from TMEM (cannot use SMEM) + if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs): + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) # From SMEM else: - tdKrdS = tiled_mma_dK.make_fragment_A(tdS) + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) # From TMEM + tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) @@ -1567,7 +2260,15 @@ def mma( # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True) mma_qk_fn = partial( - gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True + gemm_ptx_w_idx, + tiled_mma_S, + tStS, + tSrK, + tSrQ, + sA=sK, + sB=sQ, + zero_init=True, + cta_group=self.cta_group_size, ) # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) mma_dov_fn = partial( @@ -1579,6 +2280,7 @@ def mma( sA=sV, sB=sdOt, zero_init=True, + cta_group=self.cta_group_size, ) # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) mma_pdo_fn = partial( @@ -1590,12 +2292,22 @@ def mma( sA=None, sB=sdO, tA_addr=self.tmem_P_offset, + cta_group=self.cta_group_size, + ) + num_unroll_groups = 2 if const_expr(self.use_2cta_instrs) else 1 + mma_dsk_fn = partial( + gemm_w_idx, + tiled_mma_dQ, + tdQtdQ, + tdQrdS, + tdQrK, + zero_init=True, + num_unroll_groups=num_unroll_groups, ) - mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True) # mma_dsk_fn = partial( # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) - if const_expr(self.use_smem_dS_for_mma_dK): + if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs): mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) else: # Need to explicitly pass in tA_addr for correctness @@ -1608,21 +2320,34 @@ def mma( sA=None, sB=sQt, tA_addr=self.tmem_dS_offset, + cta_group=self.cta_group_size, ) + pipeline_Q_consumer = pipeline_Q.make_consumer() + + consumer_state_Qt = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + consumer_state_Q = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + consumer_state_Kt = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.single_stage + ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) producer_phase_acc = Int32(1) # For S & P, dP, dQ + producer_phase_dQ = Int32(1) # 2-CTA: separate phase for dQ pipeline consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) - # producer_state_dKV = cutlass.pipeline.make_pipeline_state( - # cutlass.pipeline.PipelineUserType.Producer, 2 - # ) producer_phase_dKV = Int32(1) cta_group = pipeline_S_P.cta_group + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + dS_cluster_phase = Int32(0) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -1649,140 +2374,321 @@ def mma( or m_block_min < m_block_max ) - if process_tile: - accumulate_dK = False - # ----------------------------------------------------------- - ###### Prologue - # ----------------------------------------------------------- - # 1. S = Q0 @ K.T - # 2. dP = V @ dO.T - # 3. dV = P @ dO - # 1) S = Q0 @ K.T - handle_Q = pipeline_Q_consumer.wait_and_advance() - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_qk_fn(B_idx=handle_Q.index) - # Don't release Q yet - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # 2) dP = V @ dO.T - pipeline_dO.consumer_wait(consumer_state_dO) - pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) - # dQ uses the same tmem as dP - pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) - mma_dov_fn(B_idx=consumer_state_dO.index) - # Don't release dO yet - pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) - - producer_phase_acc ^= 1 - # 3) dV = P.T @ dO - # wait for P to be ready, which uses the same tmem as S - pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) - pipeline_dO.consumer_release(consumer_state_dO) - consumer_state_dO.advance() - # ----------------------------------------------------------- - ###### MAIN LOOP - # ----------------------------------------------------------- - # 1. S = K @ Q.T - # 2. dQ = dS @ K - # 3. dK = dS.T @ Q - # 4. dP = V @ dO.T - # 5. dV = P.T @ dO - - # For block sparsity, we use block_iter_count; for dense, use m_block range - # MMA doesn't need actual m_block indices, just the iteration count - main_loop_iters = ( - block_iter_count - 1 - if const_expr(self.use_block_sparsity) - else m_block_max - m_block_min - 1 - ) - for _ in cutlass.range(main_loop_iters, unroll=1): - # 1) S = K @ Q_i - handle_Q_next = pipeline_Q_consumer.wait_and_advance() - # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready - mma_qk_fn(B_idx=handle_Q_next.index) - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): + if is_leader_cta and process_tile: + accumulate_dK = False + accumulate_dV = False - # 2-3) - # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma - # Otherwise, reverse order - pipeline_dS.consumer_wait(consumer_state_dS) + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S.T = K @ Q.T + # 2. dP.T = V @ dO.T + # 3. dK = dS.T @ Q + # 4. dV = P.T @ dO + # 5. dQ = dS @ K + + main_loop_iters = m_block_max - m_block_min - if const_expr(self.use_smem_dS_for_mma_dK): + # empty waits + # pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + + for _ in cutlass.range(main_loop_iters, unroll=1): + # 1) S.T = K @ Q.T + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_dQ.sync_object_empty.wait( + 0, producer_phase_acc + ) # dQ tmem overlaps with S + mma_qk_fn(B_idx=consumer_state_Q.index) + pipeline_S_P.sync_object_full.arrive( + 0, pipeline_S_P.producer_mask, cta_group + ) + pipeline_Q.consumer_release(consumer_state_Q) + consumer_state_Q.advance() + + producer_phase_acc ^= 1 + + # 2) dP.T = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_S_P.sync_object_empty.wait( + 0, producer_phase_acc + ) # dP tmem overlaps with S + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + # 3) dK = dS.T @ Q + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS + mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) + pipeline_Q.consumer_release(consumer_state_Q) + consumer_state_Q.advance() + accumulate_dK = True + + # 4) dV = P.T @ dO + # Note: if dS is written to tmem, P must be written to tmem + pipeline_dO.consumer_wait(consumer_state_dO) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=not accumulate_dV) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + accumulate_dV = True + + # 5) dQ = dS @ K + pipeline_dS.consumer_wait(consumer_state_dS) + cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - accumulate_dK = True - handle_Q.release() - else: - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + dS_cluster_phase ^= 1 + + # signal to the epilogue that dV is ready + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # signal to the epilogue that dK is ready + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + producer_phase_dKV ^= 1 + elif const_expr(self.use_2cta_instrs): + if is_leader_cta and process_tile: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dOt.T + # 3. dV = P @ dO + + # 1) S = K @ Q + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=consumer_state_Q.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + pipeline_Q.consumer_release(consumer_state_Q) + consumer_state_Q.advance() + + # 2) dP = V @ dOt.T + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + # 3) dV = P.T @ dO + producer_phase_acc ^= 1 + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + pipeline_Kt.consumer_wait(consumer_state_Kt) + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S.T = K @ Q.T + # 2. dK = dS.T @ Q + # 3. dP.T = V @ dO.T + # 4. dQ = dS @ K + # 5. dV = P.T @ dO + + main_loop_iters = ( + block_iter_count - 1 + if const_expr(self.use_block_sparsity) + else m_block_max - m_block_min - 1 + ) + + for _ in cutlass.range(main_loop_iters, unroll=1): + # (1) S.T = K @ Q.T (next) + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) + mma_qk_fn(B_idx=consumer_state_Q.index) + pipeline_S_P.sync_object_full.arrive( + 0, pipeline_S_P.producer_mask, cta_group + ) + pipeline_Q.consumer_release(consumer_state_Q) + consumer_state_Q.advance() + + # pipeline_dS.consumer_wait(consumer_state_dS) + # (2) dK += dS.T @ Q (cur) + pipeline_Qt.consumer_wait(consumer_state_Qt) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS + mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK) accumulate_dK = True - handle_Q.release() + pipeline_Qt.consumer_release(consumer_state_Qt) + consumer_state_Qt.advance() + + # (3) dP.T = V @ dO.T (next) + pipeline_dO.consumer_wait(consumer_state_dO) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + # (5) dQ = dS @ K (cur) + pipeline_dS.consumer_wait(consumer_state_dS) + cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + dS_cluster_phase ^= 1 + producer_phase_dQ ^= 1 + + # (4) dV += P.T @ dO (next) + producer_phase_acc ^= 1 + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) # S -> P + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() - # dP uses the same tmem as dQ - # However, if dS is ready, then dP must have been ready, - # so we don't need this wait before mma_dsk_fn() - # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + # signal to the epilogue that dV is ready + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + # Tail: Remaining dK and dQ + # ----------------------------------------------------------- + # pipeline_dS.consumer_wait(consumer_state_dS) + # dK += dS.T @ Q + pipeline_Qt.consumer_wait(consumer_state_Qt) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS + mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK) + pipeline_Qt.consumer_release(consumer_state_Qt) + consumer_state_Qt.advance() + # signal to the epilogue that dK is ready + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + producer_phase_dKV ^= 1 + + # dQ = dS @ K + pipeline_dS.consumer_wait(consumer_state_dS) + cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) + pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) pipeline_dS.consumer_release(consumer_state_dS) + pipeline_Kt.consumer_release(consumer_state_Kt) consumer_state_dS.advance() + consumer_state_Kt.advance() + dS_cluster_phase ^= 1 + producer_phase_dQ ^= 1 - # 4) dP = V @ dO.T + producer_phase_acc ^= 1 + else: + if is_leader_cta and process_tile: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dOt.T + # 3. dV = P @ dO + + # 1) S = K @ Q + handle_Q = pipeline_Q_consumer.wait_and_advance() + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=handle_Q.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2) dP = V @ dOt.T pipeline_dO.consumer_wait(consumer_state_dO) - # dQ uses the same tmem as dP + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) producer_phase_acc ^= 1 - # 5) dV += P @ dO - # wait for P to be ready, which uses the same tmem as S + # 3) dV = P.T @ dO pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) - mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() - handle_Q = handle_Q_next - - pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) - - # signal to the epilogue that dV is ready - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - # pipeline_dKV.producer_acquire(producer_state_dKV) - pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) - - # ----------------------------------------------------------- - ###### Remaining 2 - # ----------------------------------------------------------- - # 1) dK += dS.T @ Q - pipeline_dS.consumer_wait(consumer_state_dS) - mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) - # signal to the epilogue that dK is ready - # pipeline_dKV.producer_commit(producer_state_dKV) - pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) - # producer_state_dKV.advance() - producer_phase_dKV ^= 1 - - # 2) dQ = dS @ K - # dS is done, so dP must have been ready, we don't need to wait - mma_dsk_fn() - pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) - # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier - handle_Q.release() - pipeline_dS.consumer_release(consumer_state_dS) - consumer_state_dS.advance() - - producer_phase_acc ^= 1 + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dOt.T + # 5. dV = P.T @ dO + + # For block sparsity, we use block_iter_count; for dense, use m_block range + # MMA doesn't need actual m_block indices, just the iteration count + main_loop_iters = ( + block_iter_count - 1 + if const_expr(self.use_block_sparsity) + else m_block_max - m_block_min - 1 + ) + + handle_Q_next = handle_Q + for _ in cutlass.range(main_loop_iters, unroll=1): + # (1) S.T = K @ Q.T + handle_Q_next = pipeline_Q_consumer.wait_and_advance() + mma_qk_fn(B_idx=handle_Q_next.index) + pipeline_S_P.sync_object_full.arrive( + 0, pipeline_S_P.producer_mask, cta_group + ) + + # (2) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + + # (3) dQ = dS @ K + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + # (4) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + # (5) dV += P.T @ dO + producer_phase_acc ^= 1 + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + handle_Q = handle_Q_next + + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # signal to the epilogue that dV is ready + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + # Tail: Remaining dK and dQ + # ----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + # signal to the epilogue that dK is ready + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + producer_phase_dKV ^= 1 + + # 2) dQ = dS @ K + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + handle_Q.release() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + producer_phase_acc ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - # Currently it hangs if we have this S_P.producer_tail, will need to understand why # pipeline_S_P.producer_tail(producer_state_S_P) # pipeline_dP.producer_tail(producer_state_dP) @@ -1910,20 +2816,24 @@ def compute_loop( thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tStS: cute.Tensor, - sLSE: cute.Tensor, - sdPsum: cute.Tensor, + tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, sdS: cute.Tensor, - tdPtdP: cute.Tensor, + sdS_xchg: cute.Tensor, pipeline_LSE: PipelineAsync, pipeline_dPsum: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, + dS_cluster_empty_mbar_ptr: cute.Pointer, + dS_cluster_full_mbar_ptr: cute.Pointer, + dQaccum_empty_mbar_ptr: cute.Pointer, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, block_info: BlockInfo, @@ -1972,7 +2882,7 @@ def compute_loop( # 0: [256...384] # 1: [128...256] - tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128 + tileP_f32_like = self.cta_tiler[1] // 32 * self.v_dtype.width # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) @@ -1984,6 +2894,7 @@ def compute_loop( tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2])) tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + # 2-CTA assumes: repetiton should always be 32 & 16 tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) @@ -2012,16 +2923,28 @@ def compute_loop( LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r ) thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx) + # We assume the swizzle (i.e. layout.inner) stays the same - sdS_layout = sm100_utils_basic.make_smem_layout_epi( + sdS_epi_layout = sm100_utils_basic.make_smem_layout_epi( self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1 - ).outer # ((8,16), (64,2), (1, 1)) - sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2)) + ) + sdS_layout = cute.slice_(sdS_epi_layout.outer, (None, None, 0)) # ((8,16), (64,2)) # Need to group into 1 mode to be compatible w thr_copy_r2s sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,)) sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout) tRS_sdS = thr_copy_r2s.partition_D(sdS_epi) + if const_expr(self.use_2cta_instrs): + sdS_xchg_epi = cute.make_tensor( + cute.recast_ptr(sdS_xchg.iterator, sdS_epi_layout.inner), sdS_layout + ) + tRS_sdS_xchg = thr_copy_r2s.partition_D(sdS_xchg_epi) + + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + dS_cluster_empty_phase = Int32(1) + # 2-CTA: CTA 0 exchanges stage 1 (bottom half), CTA 1 exchanges stage 0 (top half) + exchange_stage = cta_rank_in_cluster ^ 1 if const_expr(self.use_2cta_instrs) else Int32(0) + consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 cutlass.pipeline.PipelineUserType.Consumer, 1 ) @@ -2035,7 +2958,6 @@ def compute_loop( consumer_state_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) - # consumer_state_dPsum = cutlass.pipeline.make_pipeline_state( consumer_state_dPsum = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) @@ -2049,12 +2971,13 @@ def compute_loop( seqlen, n_block // self.cluster_shape_mnk[0] ) mask = AttentionMaskCls(seqlen) + n_block_for_cluster = n_block // self.cta_group_size # TODO: condition mask_seqlen mask_fn = partial( mask.apply_mask_sm100_transposed, tScS_t2r=tScS_t2r, t0ScS_t2r=t0ScS_t2r, - n_block=n_block, + n_block=n_block_for_cluster, mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, @@ -2067,7 +2990,6 @@ def compute_loop( # prefetch_LSE = not self.is_causal prefetch_LSE = False - # some tiles might be empty due to block sparsity if const_expr(self.use_block_sparsity): ( @@ -2122,6 +3044,22 @@ def compute_loop( #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) + + if const_expr(self.tile_hdim == 192): + # Signal S tmem load completion using pipeline_S_P when hdim 192 + # dP is overlapped with S + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + pipeline_S_P.consumer_release(consumer_state_S_P_dP) + elif const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): + # Signal S tmem load completion using pipeline_dS when 2cta hdim 128 + # dQ is overlapped with S + if iter_idx > 0: + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() + if const_expr(self.score_mod_bwd is not None): tSrS_pre = cute.make_fragment_like(tSrS_t2r) cute.autovec_copy(tSrS_t2r, tSrS_pre) @@ -2150,9 +3088,7 @@ def compute_loop( is_full_block=is_full_block, check_m_boundary=check_m_boundary, ) - num_stages = cute.size(tScS_t2r, mode=[1]) - # --------------------------------------------- #### P = exp(S - LSE) # --------------------------------------------- @@ -2196,23 +3132,25 @@ def compute_loop( ) cute.arch.fence_view_async_tmem_store() + cute.arch.fence_view_async_shared() self.compute_sync_barrier.arrive_and_wait() - - with cute.arch.elect_one(): - pipeline_S_P.consumer_release(consumer_state_S_P_dP) - # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) + if const_expr(not self.tile_hdim == 192): + # Signal tmem store P completion with pipeline_S_P + with cute.arch.elect_one(): + pipeline_S_P.consumer_release(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) + # Normally we'd need syncwarp here since only 1 thread will signal in + # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_LSE.consumer_release(consumer_state_LSE) - # consumer_state_S_P_dP.advance() consumer_state_LSE.advance() - # --------------------------------------------- # dS.T = P.T * (dP.T - D) # --------------------------------------------- pipeline_dPsum.consumer_wait(consumer_state_dPsum) - pipeline_dP.consumer_wait(consumer_state_S_P_dP) # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) - consumer_state_S_P_dP.advance() + ### Now delayed to after loop + # consumer_state_S_P_dP.advance() # consumer_phase_S_P_dP ^= 1 ##### dS.T = P.T * (dP.T - Psum) @@ -2276,24 +3214,81 @@ def compute_loop( utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt) if const_expr(stage == 0): pipeline_dS.producer_acquire(producer_state_dS) - cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) - if const_expr(not self.use_smem_dS_for_mma_dK): + if const_expr(self.use_2cta_instrs): + tdPrdS_xchg = cute.make_fragment_like(tdPrdS_cvt, self.ds_dtype) + + # RMEM->TMEM: always write to TMEM for MMA + if const_expr(not self.use_smem_dS_for_mma_dK or self.use_2cta_instrs): tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + # RMEM->SMEM: For 2-CTA, keep exchange stage in registers, write non-exchange to sdS + if const_expr(self.use_2cta_instrs): + if exchange_stage == stage: + cute.autovec_copy(tdPrdS_cvt, tdPrdS_xchg) + else: + cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) + else: + cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) + if const_expr(not self.use_smem_dS_for_mma_dK): cute.arch.fence_view_async_tmem_store() + + if const_expr(self.use_2cta_instrs): + # use pipeline_dP to signal tmem store of dS + with cute.arch.elect_one(): + pipeline_dP.consumer_release(consumer_state_S_P_dP) + consumer_state_S_P_dP.advance() + + # After the loop: copy exchange registers to sdS_xchg buffer + if const_expr(self.use_2cta_instrs): + # when hdim 192, sdQaccum overlapped with sdS_xchg + if const_expr(self.tile_hdim == 192): + cute.arch.mbarrier_wait( + dQaccum_empty_mbar_ptr, phase=producer_state_dS.phase + ) + cute.autovec_copy(tdPrdS_xchg, tRS_sdS_xchg[None, 0]) + cute.arch.fence_view_async_shared() self.compute_sync_barrier.arrive_and_wait() - - # with cute.arch.elect_one(): - # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive - # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) + # Normally we'd need syncwarp here since only 1 thread will signal in + # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_dPsum.consumer_release(consumer_state_dPsum) consumer_state_dPsum.advance() - with cute.arch.elect_one(): - pipeline_dS.producer_commit(producer_state_dS) - producer_state_dS.advance() + # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred + if const_expr(not (self.use_2cta_instrs and self.tile_hdim == 128)): + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() + + # 2-CTA: DSMEM copy from sdS_xchg to peer's sdS buffer + if const_expr(self.use_2cta_instrs): + stage_copy_bytes = const_expr(self.tma_copy_bytes["dS"] // 2) + stage_copy_elems = const_expr(stage_copy_bytes // (self.ds_dtype.width // 8)) + if tidx == 0: + peer_cta_rank_in_cluster = cta_rank_in_cluster ^ 1 + smem_src_ptr = sdS_xchg.iterator + # Destination is peer's sdS at our CTA's offset (exchange_stage position) + smem_dst_ptr = sdS.iterator + cta_rank_in_cluster * stage_copy_elems + cute.arch.mbarrier_arrive_and_expect_tx( + dS_cluster_full_mbar_ptr, + stage_copy_bytes, + peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, + ) + copy_utils.cpasync_bulk_s2cluster( + smem_src_ptr, + smem_dst_ptr, + dS_cluster_full_mbar_ptr, + stage_copy_bytes, + peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, + ) + + # Final signal for dS smem store completion + if const_expr(self.use_2cta_instrs and self.tile_hdim == 128): + if process_tile: + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() # Epilogue # Run epilogue if we processed any m_blocks for this n_block @@ -2336,6 +3331,7 @@ def compute_loop( None, # Don't scale int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdV_semaphore, + "V", ) #### STORE dK consumer_state_dKV = self.epilogue_dK_or_dV_tma( @@ -2355,6 +3351,7 @@ def compute_loop( softmax_scale if const_expr(not self.dKV_postprocess) else None, int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, + "K", ) # Zero dK/dV for empty tiles (local attention or block sparsity) # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile @@ -2368,35 +3365,50 @@ def compute_loop( should_zero_dKV = True if should_zero_dKV: - # like other epis, currently assumes hdim == hdimv - gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( + # For 2-CTA: use cluster-wide tile size (cta_group_size * tile_n) + cluster_tile_n = self.tile_n * self.cta_group_size + n_block_for_tile = n_block // self.cta_group_size + gmem_tiled_copy_zero_dK = copy_utils.tiled_copy_2d( self.dk_dtype, - self.tile_hdim, + math.gcd(64, self.tile_hdim), 128, # num_threads ) - gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) + gmem_tiled_copy_zero_dV = copy_utils.tiled_copy_2d( + self.dv_dtype, + math.gcd(64, self.tile_hdimv), + 128, # num_threads + ) + gmem_thr_copy_zero_dK = gmem_tiled_copy_zero_dK.get_slice(dp_idx) + gmem_thr_copy_zero_dV = gmem_tiled_copy_zero_dV.get_slice(dp_idx) mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] - gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) - gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) - tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) - tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV) - assert tdKgdK.shape[2] == 1 - assert tdVgdV.shape[2] == 1 - cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) - tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV) + gdK = cute.local_tile( + mdK_cur, (cluster_tile_n, self.tile_hdim), (n_block_for_tile, 0) + ) + gdV = cute.local_tile( + mdV_cur, (cluster_tile_n, self.tile_hdimv), (n_block_for_tile, 0) + ) + tdKgdK = gmem_thr_copy_zero_dK.partition_D(gdK) + tdVgdV = gmem_thr_copy_zero_dV.partition_D(gdV) + cdK = cute.make_identity_tensor((cluster_tile_n, self.tile_hdim)) + cdV = cute.make_identity_tensor((cluster_tile_n, self.tile_hdimv)) + tdKcdK = gmem_thr_copy_zero_dK.partition_D(cdK) + tdVcdV = gmem_thr_copy_zero_dV.partition_D(cdV) + assert cute.size(tdKgdK[None, 0, 0]) == cute.size(tdVgdV[None, 0, 0]) zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) zero.fill(0.0) if tidx < 128: for i in cutlass.range_constexpr(tdKgdK.shape[1]): - row_idx = tdKVcdKV[0, i, 0][0] - if row_idx < seqlen.seqlen_k - self.tile_n * n_block: - cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0]) + row_idx = tdKcdK[0, i, 0][0] + if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile: + for j in cutlass.range_constexpr(tdKgdK.shape[2]): + cute.copy(gmem_tiled_copy_zero_dK, zero, tdKgdK[None, i, j]) else: for i in cutlass.range_constexpr(tdVgdV.shape[1]): - row_idx = tdKVcdKV[0, i, 0][0] - if row_idx < seqlen.seqlen_k - self.tile_n * n_block: - cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0]) + row_idx = tdVcdV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile: + for j in cutlass.range_constexpr(tdVgdV.shape[2]): + cute.copy(gmem_tiled_copy_zero_dV, zero, tdVgdV[None, i, j]) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @@ -2409,6 +3421,7 @@ def dQacc_reduce( thr_mma_dQ: cute.core.ThrMma, tdQtdQ: cute.Tensor, pipeline_dQ: PipelineAsync, + dQaccum_empty_mbar_ptr: Optional[cute.Pointer], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, @@ -2419,16 +3432,24 @@ def dQacc_reduce( tidx = cute.arch.thread_idx()[0] % num_reduce_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) is_tma_warp = warp_idx == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) # TMEM -> RMEM tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol_t2r)), Float32 ) thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ) tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape - assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, ( - "dQaccum reduce stage mismatch" + # For 2-CTA: reduce_stage = dQaccum_reduce_stage_t2r / cta_group_size + expected_reduce_stages_t2r = self.dQaccum_reduce_stage_t2r // self.cta_group_size + assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == expected_reduce_stages_t2r, ( + "dQaccum t2r reduce stage mismatch" + ) + expected_reduce_stages = self.dQaccum_reduce_stage // self.cta_group_size + # 2-CTA: CTA 0 -> (M/2, D) (stage 0, 1) & CTA 1 -> (M/2, D) (stage 2, 3) + stage_offset = ( + expected_reduce_stages * cta_rank_in_cluster if const_expr(self.use_2cta_instrs) else 0 ) thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d( @@ -2448,10 +3469,9 @@ def dQacc_reduce( ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx + n_block_cta_group = n_block // self.cta_group_size # for 2cta seqlen = SeqlenInfoCls(batch_idx) - m_block_min, m_block_max = block_info.get_m_block_min_max( - seqlen, n_block // self.cluster_shape_mnk[0] - ) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block_cta_group) if const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] else: @@ -2467,8 +3487,8 @@ def dQacc_reduce( if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] - delay_semaphore_release = self.is_causal - n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) + # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192 + delay_semaphore_release = not self.tile_hdim == 192 # some tiles might be empty due to block sparsity if const_expr(self.use_block_sparsity): @@ -2524,37 +3544,33 @@ def dQacc_reduce( gdQaccum_cur = gdQaccum[None, None, m_block] - for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + tdQrdQ_shape = ( + self.dQ_reduce_ncol, + self.tile_hdim // self.cta_group_size // self.dQ_reduce_ncol, + ) + tdQrdQ = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_shape) + + for stage in cutlass.range_constexpr(cute.size(tdQrdQ, mode=[1])): smem_idx = dQ_tma_store_producer_state.index tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] - tdQrdQ_r2s = cute.make_tensor( - tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape - ) + tdQrdQ_r2s = cute.make_tensor(tdQrdQ[None, stage].iterator, tdQsdQ_r2s.shape) cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_view_async_shared() # semaphore acquire if const_expr(self.deterministic and stage == 0): if const_expr(self.spt): - if const_expr( - self.is_causal or block_info.window_size_right is not None - ): - n_idx_right = ( - (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q - ) - if const_expr(block_info.window_size_right is not None): - n_idx_right += block_info.window_size_right - n_block_max_for_m_block = min( - n_block_global_max, - cute.ceil_div(n_idx_right, self.tile_n), - ) - else: - n_block_max_for_m_block = n_block_global_max - lock_value = n_block_max_for_m_block - 1 - n_block + _, n_block_max_for_m_block = block_info.get_n_block_min_max( + seqlen, m_block + ) + lock_value = n_block_max_for_m_block - 1 - n_block_cta_group else: - lock_value = n_block + lock_value = n_block_cta_group barrier.wait_eq( - mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value + mdQ_semaphore_cur[(m_block, None)].iterator, + tidx, + cta_rank_in_cluster, + lock_value, ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory @@ -2562,7 +3578,7 @@ def dQacc_reduce( with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, smem_idx].iterator, - gdQaccum_cur[None, stage].iterator, + gdQaccum_cur[None, stage + stage_offset].iterator, self.tma_copy_bytes["dQ"] // 1, ) cute.arch.cp_async_bulk_commit_group() @@ -2584,25 +3600,42 @@ def dQacc_reduce( if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: barrier.arrive_inc( - mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1 + mdQ_semaphore_cur[(m_block - 1, None)].iterator, + tidx, + cta_rank_in_cluster, + 1, ) + if const_expr(self.tile_hdim == 192): + if const_expr(self.sdQaccum_stage > 1): + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(dQaccum_empty_mbar_ptr) + # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic and not delay_semaphore_release): - if is_tma_warp: - cute.arch.cp_async_bulk_wait_group(0, read=read_flag) - self.reduce_sync_barrier.arrive_and_wait() - barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) + if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192): + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + barrier.arrive_inc( + mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 + ) - if const_expr(not self.is_local) or m_block_min < m_block_max: + if process_tile: if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() # final semaphore release if const_expr(self.deterministic and delay_semaphore_release): barrier.arrive_inc( - mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1 + mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, + tidx, + cta_rank_in_cluster, + 1, ) if const_expr( @@ -2610,11 +3643,16 @@ def dQacc_reduce( ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): - barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block, None)].iterator, tidx, cta_rank_in_cluster, 1 + ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + if const_expr(not self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=True) + @cute.jit def epilogue_dKV( self, @@ -2646,7 +3684,6 @@ def epilogue_dKV( tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) - # dV pipeline_dKV.consumer_wait(consumer_state_dKV) @@ -2684,8 +3721,8 @@ def epilogue_dKV( dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) - gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) - gdV_tile = gdV[None, None, n_block] + gdV = cute.local_tile(mdV_cur, (self.mma_tiler_pdo[0], self.tile_hdimv), (None, 0)) + gdV_tile = gdV[None, None, n_block // self.cta_group_size] tdVgdV = thr_mma_dV.partition_C(gdV_tile) tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) @@ -2738,8 +3775,8 @@ def epilogue_dKV( dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) - gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0)) - gdK_tile = gdK[None, None, n_block] + gdK = cute.local_tile(mdK_cur, (self.mma_tiler_dsq[0], self.tile_hdim), (None, 0)) + gdK_tile = gdK[None, None, n_block // self.cta_group_size] tdKgdK = thr_mma_dK.partition_C(gdK_tile) tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) @@ -2751,7 +3788,6 @@ def epilogue_dKV( cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dKV.consumer_release(consumer_state_dKV) - consumer_state_dKV.advance() return consumer_state_dKV @cute.jit @@ -2773,14 +3809,22 @@ def epilogue_dK_or_dV_tma( scale: Optional[Float32], barrier_id: Int32, mdKV_semaphore: Optional[cute.Tensor], + K_or_V: cutlass.Constexpr[str], ) -> cutlass.pipeline.PipelineState: - # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) - # head_dim = head_dim_v, dk_dtype = dv_dtype + assert K_or_V in ("K", "V") + tile_hdim = self.tile_hdim if const_expr(K_or_V == "K") else self.tile_hdimv + dtype = self.dk_dtype if const_expr(K_or_V == "K") else self.dv_dtype + epi_tile = self.sdK_epi_tile if const_expr(K_or_V == "K") else self.sdV_epi_tile + flat_epi_tile = ( + self.sdK_flat_epi_tile if const_expr(K_or_V == "K") else self.sdV_flat_epi_tile + ) num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128 num_wg = num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + cta_group_tile_n = const_expr(self.tile_n * self.cta_group_size) + if const_expr(not self.dKV_postprocess): sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 else: @@ -2794,27 +3838,28 @@ def epilogue_dK_or_dV_tma( assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path" mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( - mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) - ) # (tile_n, hdim) + mdKV_cur, (self.tile_n, tile_hdim), (n_block, 0) + ) # (tile_n, hdim) - per CTA gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) gdKV_epi = cute.local_tile( - gdKV, self.sdKV_epi_tile, (0, None) + gdKV, epi_tile, (0, None) ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: + # n_block_group = n_block // self.cta_group_size if const_expr(not seqlen.has_cu_seqlens_k): mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) else: mdKV_cur = cute.domain_offset( - (seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] + (seqlen.padded_offset_k * tile_hdim,), mdKV[None, head_idx_kv] ) gdKV_p = cute.local_tile( - mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) + mdKV_cur, (self.tile_n * tile_hdim,), (n_block,) ) # (tile_n * hdim) - gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[ + gdKV = cute.logical_divide(gdKV_p, (self.tile_n * tile_hdim // num_wg,))[ ((None, wg_idx),) ] # (tile_n * hdim / 2) gdKV_epi = cute.flat_divide( - gdKV, (self.sdKV_flat_epi_tile,) + gdKV, (flat_epi_tile,) ) # (tile_n * hdim / 2 / epi_stage, epi_stage) deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 @@ -2832,12 +3877,17 @@ def epilogue_dK_or_dV_tma( assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" num_epi_stages = cute.size(tdKVgdKV.shape[1]) - assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong" + if const_expr(K_or_V == "K"): + assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong (K)" + else: + assert num_epi_stages == self.num_epi_stages_v, "Epi stage calculation is wrong (V)" else: - num_epi_stages = self.num_epi_stages + num_epi_stages = ( + self.num_epi_stages if const_expr(K_or_V == "K") else self.num_epi_stages_v + ) tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dK_reduce_ncol)), Float32 ) read_flag = const_expr(not deterministic_KV) @@ -2859,7 +3909,7 @@ def epilogue_dK_or_dV_tma( if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage] - cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + cdKV = cute.make_identity_tensor((cta_group_tile_n, tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] @@ -2882,8 +3932,8 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = cute.arch.mul_packed_f32x2( (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) ) - tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) - tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, dtype) # (32 columns) + tdKVrdKV.store(tdKVrdKV_t2r.load().to(dtype)) # RMEM -> SMEM -- copy, fence and barrier tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) diff --git a/flash_attn/cute/flash_bwd_sm120.py b/flash_attn/cute/flash_bwd_sm120.py new file mode 100644 index 00000000000..556c59e384a --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm120.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM120 (Blackwell GeForce / DGX Spark) backward pass. +# +# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has +# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses +# FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly. + +import cutlass +import cutlass.utils as utils_basic + +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 + + +class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_threads, + is_causal, + V_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented on SM120. + + Same logic as SM80 but uses SM120's shared memory capacity (99 KB). + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Shared memory usage: Q tile + dO tile + K tile + V tile + smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 + smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 + smem_usage_K = n_block_size * head_dim * 2 + smem_usage_V = n_block_size * head_dim_v * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K + # SM120 has 99 KB shared memory (vs 163 KB on SM80) + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") + if smem_usage > smem_capacity: + return False + return True diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index 7234296641a..f724b5a11e3 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -23,8 +23,15 @@ from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline -from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase -from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from quack.cute_dsl_utils import ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTBwdScheduler, + SingleTileVarlenScheduler, +) +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwd from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( @@ -45,6 +52,8 @@ def __init__( head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, + is_local: bool = False, + deterministic: bool = False, tile_m: int = 64, tile_n: int = 128, Q_stage: int = 2, @@ -63,6 +72,7 @@ def __init__( mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, + dQ_single_wg: bool = False, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -76,7 +86,8 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal - self.is_local = False + self.is_local = is_local + self.deterministic = deterministic self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads @@ -91,26 +102,28 @@ def __init__( self.AtomLayoutMSdP = AtomLayoutMSdP self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ - self.num_mma_warp_groups = (self.num_threads // 128) - 1 + self.num_wg_mma = (self.num_threads // 128) - 1 self.mma_dkv_is_rs = ( AtomLayoutMSdP == 1 - and AtomLayoutNdKV == self.num_mma_warp_groups + and AtomLayoutNdKV == self.num_wg_mma and SdP_swapAB and not dKV_swapAB ) self.V_in_regs = V_in_regs + # May be overridden in __call__ for varlen inputs. if qhead_per_kvhead > 1: assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v" - assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups" + assert self.num_wg_mma == 2, "GQA backward assumes 2 warp groups" # These are tuned for speed # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share # them and then shuffle to get the value whenever we need? This can reduce register # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. - # TODO: impl these for hdim 64 self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + self.buffer_align_bytes = 1024 + self.score_mod = score_mod self.score_mod_bwd = score_mod_bwd self.mask_mod = mask_mod @@ -121,6 +134,12 @@ def __init__( else: self.vec_size: cutlass.Constexpr = 4 self.qk_acc_dtype = Float32 + # dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it. + # Only valid for 2 MMA warp groups. + # Credit: Ben Spector + if dQ_single_wg: + assert self.num_wg_mma == 2, "dQ_single_wg only supports 2 warp groups" + self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma @staticmethod def can_implement( @@ -179,40 +198,58 @@ def _check_type( assert mQ_type == self.dtype def _setup_attributes(self): - self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ - sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) - for shape, stage in [ - ((self.tile_m, self.tile_hdim), self.Q_stage), - ((self.tile_n, self.tile_hdim), None), - ((self.tile_n, self.tile_hdimv), None), - ((self.tile_m, self.tile_hdimv), self.dO_stage), - ((self.tile_m, self.tile_n), self.PdS_stage), + # We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + # Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + # The M dimension (tile_m) doesn't matter for the layout, only the K dimension + wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV + self.sQ_layout, self.sdO_layout = [ + # Need to set major_mode_size (mms) to accommodate Q and Q.T + sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms) + for shape, stage, mms in [ + ((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV), + ((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV), ] ] + wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ + # Accomodate both K and K.T + self.sK_layout = sm90_utils.make_smem_layout( + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_n, self.tile_hdim), + stage=None, + major_mode_size=self.tile_hdim // wg_d_dQ, + ) + # There's only V, no V.T, so layout is normal + self.sV_layout = sm90_utils.make_smem_layout( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None + ) + # Accomodate both S and S.T + wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP + wg_n_dKV = self.AtomLayoutNdKV + self.sPdS_layout = sm90_utils.make_smem_layout( + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_m, self.tile_n), + stage=self.PdS_stage, + major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV), + ) self.sdQaccum_layout = cute.make_layout( - (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) ) # dQaccum R->S self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), # thr_layout - cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)), cute.make_layout(128 // Float32.width), # val_layout ) # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32 - self.sdKVaccum_layout = cute.make_layout( - (self.tile_n * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) - ) - # dKVaccum R->S (same pattern as dQaccum but sized for tile_n) - self.r2s_tiled_copy_dKVaccum = cute.make_tiled_copy_tv( - cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), - cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), - cute.make_layout(128 // Float32.width), - ) + # TODO: assert that sVaccum and sKaccum don't overflow smem def _get_tiled_mma(self): + maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape # S = Q @ K.T, dP = dO @ V.T - atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP) + atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1) tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -220,12 +257,11 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1]) - + (1,), - tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB), + tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]), ) # dV = P.T @ dO, dK = dS.T @ Q - atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) + atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1) tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) tiled_mma_dK, tiled_mma_dV = [ @@ -237,9 +273,8 @@ def _get_tiled_mma(self): else warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) - + (1,), - tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB), + tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]), a_source=warpgroup.OperandSource.RMEM if self.mma_dkv_is_rs else warpgroup.OperandSource.SMEM, @@ -247,7 +282,8 @@ def _get_tiled_mma(self): for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) ] # dQ = dS @ K - atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ) + assert self.num_wg_dQ % self.AtomLayoutMdQ == 0 + atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -255,22 +291,20 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), - tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB), + tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]), ) return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _get_shared_storage_cls(self): - sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 1024 - sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ - cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] - for (layout, type, alignment) in [ - (self.sQ_layout, self.dtype, sQ_alignment), - (self.sK_layout, self.dtype, sK_alignment), - (self.sV_layout, self.dtype, sV_alighment), - (self.sdO_layout, self.dtype, sdO_alignment), - (self.sdQaccum_layout, Float32, sdQaccum_alignment), + cute.struct.Align[cute.struct.MemRange[t, cute.cosize(layout)], self.buffer_align_bytes] + for (layout, t) in [ + (self.sQ_layout, self.dtype), + (self.sK_layout, self.dtype), + (self.sV_layout, self.dtype), + (self.sdO_layout, self.dtype), + (self.sdQaccum_layout, Float32), ] ] @@ -312,7 +346,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -325,10 +358,13 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): - assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( - "determinism not supported yet for Sm90" - ) + # For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV, + # so we need the float32 accum path + postprocess. + # For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors. + self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self._check_type( *( @@ -337,23 +373,36 @@ def __call__( ) ) + self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) ] - layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO = [layout_utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)] + # Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h). + # We convert both to a seqlen-major view with head-dim second. + # Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k). + def _qkv_transpose(t): + return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1]) + + mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)] if const_expr(self.qhead_per_kvhead == 1): - mdK, mdV = [layout_utils.select(t, layout_transpose) for t in (mdK, mdV)] + mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)] else: - accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b) + # Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen. + accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0] mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)] - LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + # Non-varlen stats are (b, n, s), varlen stats are (n, s). + LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0] mLSE, mdPsum, mdQaccum = [ layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() + # (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch) + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0]) self.num_mma_threads = tiled_mma_SdP.size assert self.num_mma_threads + 128 == self.num_threads @@ -361,10 +410,25 @@ def __call__( self.num_threads_per_warp_group = 128 self.num_producer_threads = 32 - self.num_mma_regs = 240 - self.num_producer_regs = 24 - # self.num_mma_regs = 232 - # self.num_producer_regs = 40 + REG_LIMIT = 504 if self.num_wg_mma == 2 else 512 + if const_expr(self.num_wg_mma == 2): + if const_expr(self.num_wg_dQ == 1): + self.num_mma_regs_wg0 = 256 + self.num_mma_regs_wg1 = 224 + else: + self.num_mma_regs_wg0 = 240 + self.num_mma_regs_wg1 = 240 + self.num_mma_regs = self.num_mma_regs_wg0 # for backward compat + self.num_producer_regs = 24 + assert ( + self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT + ) + else: # 3 warp groups + self.num_mma_regs_wg0 = 160 + self.num_mma_regs_wg1 = 160 + self.num_mma_regs = 160 + self.num_producer_regs = 32 + assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT self._setup_attributes() SharedStorage = self._get_shared_storage_cls() @@ -381,7 +445,7 @@ def __call__( self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = ( - self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups + self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ ) self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8 self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8 @@ -411,38 +475,59 @@ def __call__( (self.tile_m, self.tile_hdimv), ) if const_expr(self.qhead_per_kvhead == 1): + mdK_tma = ( + copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True) + if self.varlen_k + else mdK + ) + mdV_tma = ( + copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True) + if self.varlen_k + else mdV + ) tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), - mdK, + mdK_tma, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), ) tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), - mdV, + mdV_tma, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), ) else: tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None - TileScheduler = SingleTileScheduler + if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None): + TileScheduler = SingleTileVarlenScheduler + elif const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]), + cute.size(mK.shape[3]) + if const_expr(mCuSeqlensK is None) + else cute.size(mCuSeqlensK.shape[0] - 1), # num_batch 1, # num_splits - cute.size(mK.shape[0]), - mQ.shape[1], - mV.shape[1], - total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.tile_m, self.tile_n), - mCuSeqlensQ=None, - mSeqUsedQ=None, + cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k + mQ.shape[1], # headdim + mV.shape[1], # headdim_v + total_q=cute.size(mK.shape[0]) + if const_expr(mCuSeqlensK is not None) + else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), + tile_shape_mn=(self.tile_n, self.tile_m), # Swapping the role of Q & K + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, qhead_per_kvhead_packgqa=1, element_size=self.dtype.width // 8, is_persistent=False, - lpt=False, + lpt=self.spt, + head_swizzle=self.deterministic, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) @@ -468,6 +553,11 @@ def __call__( self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -484,15 +574,17 @@ def __call__( mLSE, mdPsum, mdQaccum, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, self.sQ_layout, self.sK_layout, self.sV_layout, self.sPdS_layout, self.sdO_layout, self.sdQaccum_layout, - self.sdKVaccum_layout, self.r2s_tiled_copy_dQaccum, - self.r2s_tiled_copy_dKVaccum, tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, @@ -506,12 +598,15 @@ def __call__( fastdiv_mods, blocksparse_tensors, qhead_per_kvhead_divmod, + mdQ_semaphore, + window_size_left, + window_size_right, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], - smem=SharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, + use_pdl=True, ) @cute.kernel @@ -532,15 +627,17 @@ def kernel( mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sPdS_layout: cute.ComposedLayout, sdO_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, - sdKVaccum_layout: cute.Layout, r2s_tiled_copy_dQaccum: cute.TiledCopy, - r2s_tiled_copy_dKVaccum: cute.TiledCopy, tiled_mma_SdP: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, @@ -554,15 +651,17 @@ def kernel( fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # prefetch TMA descriptors if warp_idx == 0: - cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) - cpasync.prefetch_descriptor(tma_atom_dO) + for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]: + if const_expr(atom is not None): + cpasync.prefetch_descriptor(atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -616,25 +715,27 @@ def kernel( self.is_causal, self.is_local, False, # is_split_kv - None, - None, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=None, - mCuSeqlensK=None, - mSeqUsedQ=None, - mSeqUsedK=None, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.tile_m, + tile_n=self.tile_n, ) AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, - window_size_left=None, - window_size_right=None, + window_size_left=window_size_left, + window_size_right=window_size_right, swap_AB=self.SdP_swapAB, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -668,11 +769,6 @@ def kernel( qhead_per_kvhead_divmod, ) if warp_idx == 1: - for warp_group_idx in cutlass.range(self.num_mma_warp_groups): - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) self.dQaccum_store( mdQaccum, sdQaccum, @@ -680,12 +776,12 @@ def kernel( TileSchedulerCls, SeqlenInfoCls, blocksparse_tensors, + mdQ_semaphore, ) else: - cute.arch.setmaxregister_increase(self.num_mma_regs) tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 - self.mma( + mma_args = ( tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, @@ -708,8 +804,6 @@ def kernel( tma_atom_dK, tma_atom_dV, r2s_tiled_copy_dQaccum, - r2s_tiled_copy_dKVaccum, - sdKVaccum_layout, softmax_scale_log2, softmax_scale, block_info, @@ -721,6 +815,19 @@ def kernel( blocksparse_tensors, qhead_per_kvhead_divmod, ) + if const_expr(self.num_wg_dQ == self.num_wg_mma): + # Both WGs compute dQ + cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) + self.mma(*mma_args, is_dQ_wg=True) + else: + # WG0 computes dQ, WG1 skips it + warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4 + if warp_idx_in_mma < 4: + cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) + self.mma(*mma_args, is_dQ_wg=True) + else: + cute.arch.setmaxregister_increase(self.num_mma_regs_wg1) + self.mma(*mma_args, is_dQ_wg=False) @cute.jit def load( @@ -768,18 +875,22 @@ def load( if const_expr(self.qhead_per_kvhead == 1) else head_idx // qhead_per_kvhead_divmod ) - mK_cur = mK[None, None, head_idx_kv, batch_idx] + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) - mV_cur = mV[None, None, head_idx_kv, batch_idx] gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) - mQ_cur = mQ[None, None, head_idx, batch_idx] + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx] + mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) - mdO_cur = mdO[None, None, head_idx, batch_idx] gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) - mLSE_cur = mLSE[None, head_idx, batch_idx] gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) - mdPsum_cur = mdPsum[None, head_idx, batch_idx] gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) load_K, _, _ = copy_utils.tma_get_copy_fn( @@ -805,7 +916,10 @@ def load( if const_expr(not self.use_block_sparsity): total_m_block_cnt = m_block_max - m_block_min - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, @@ -825,6 +939,8 @@ def load( ) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) load_Q(first_m_block, producer_state=producer_state_Q) + # Wait for bwd preprocess to finish writing LSE and dPsum + cute.arch.griddepcontrol_wait() load_LSE(first_m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO @@ -993,8 +1109,6 @@ def mma( tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, r2s_tiled_copy_dQaccum: cute.TiledCopy, - r2s_tiled_copy_dKVaccum: cute.TiledCopy, - sdKVaccum_layout: cute.Layout, softmax_scale_log2: Float32, softmax_scale: Float32, block_info: BlockInfo, @@ -1005,16 +1119,20 @@ def mma( fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + is_dQ_wg: cutlass.Constexpr[bool] = True, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + self.num_wg_mma, stride=self.num_threads_per_warp_group ) thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQ = None + if const_expr(is_dQ_wg): + wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0 + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ)) # S = Q @ K.T shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim) _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( @@ -1060,24 +1178,44 @@ def mma( # dQ = dS @ K sKt = layout_utils.transpose_view(sK) shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n) - _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( - wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB - ) - mma_dsk_fn = partial( - gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB - ) + mma_dsk_fn = None + if const_expr(is_dQ_wg): + _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( + wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB + ) + mma_dsk_fn = partial( + gemm_zero_init, + tiled_mma_dQ, + shape_mnk_dQ[:2], + tdQrdS, + tdQrKt, + swap_AB=self.dQ_swapAB, + ) - # Smem copy atom tiling - smem_copy_atom_PdS = copy_utils.get_smem_store_atom( - self.arch, self.dtype, transpose=self.SdP_swapAB - ) - smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( - tidx - ) - tPsP = None + # Smem copy atom tiling for P/dS R2S + copy_P_r2s = None + mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP) if const_expr(sP is not None): - tPsP = smem_thr_copy_PdS.partition_D(sP if const_expr(not self.SdP_swapAB) else sPt) - tdSsdS = smem_thr_copy_PdS.partition_D(sdS if const_expr(not self.SdP_swapAB) else sdSt) + sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt + copy_P_r2s, _, _ = copy_utils.get_smem_store_C( + tiled_mma_SdP, + sP_cpy, + tidx, + self.arch, + transpose=self.SdP_swapAB, + position_independent=True, + major_mode_size=mms_PdS, + ) + sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt + copy_dS_r2s, _, _ = copy_utils.get_smem_store_C( + tiled_mma_SdP, + sdS_cpy, + tidx, + self.arch, + transpose=self.SdP_swapAB, + position_independent=True, + major_mode_size=mms_PdS, + ) tLSEsLSE = layout_utils.mma_partition_C_vec( sLSE, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB @@ -1085,9 +1223,21 @@ def mma( tLSEsdPsum = layout_utils.mma_partition_C_vec( sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB ) - - smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) - tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + # When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp. + # Each thread loads only ceil(num_rows/8) values; + shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2) + if const_expr(self.shuffle_LSE): + tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE) + # ((2, 1), 1, 2) -> (((2, 1), 1), 2) + tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2) + if const_expr(self.shuffle_dPsum): + tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum) + tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2) + + tdQsdQaccum = None + if const_expr(is_dQ_wg): + smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) PdS_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads @@ -1115,19 +1265,18 @@ def mma( mma_pdo_fn=mma_pdo_fn, mma_dsq_fn=mma_dsq_fn, mma_dsk_fn=mma_dsk_fn, + copy_P_r2s=copy_P_r2s, + copy_dS_r2s=copy_dS_r2s, pipeline_Q=pipeline_Q, pipeline_dO=pipeline_dO, tLSEsLSE=tLSEsLSE, tLSEsdPsum=tLSEsdPsum, - tPsP=tPsP, - tdSsdS=tdSsdS, tdQsdQaccum=tdQsdQaccum, - smem_thr_copy_PdS=smem_thr_copy_PdS, - smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, softmax_scale_log2=softmax_scale_log2, PdS_barrier=PdS_barrier, # acc_dV=acc_dV, # acc_dK=acc_dK, + is_dQ_wg=is_dQ_wg, ) consumer_state_Q = cutlass.pipeline.make_pipeline_state( @@ -1159,7 +1308,10 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, @@ -1234,8 +1386,6 @@ def mma( tma_atom_dV, tiled_mma_dK, tiled_mma_dV, - r2s_tiled_copy_dKVaccum, - sdKVaccum_layout, tidx, n_block, head_idx, @@ -1243,8 +1393,8 @@ def mma( qhead_per_kvhead_divmod, ) else: - # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros. - if const_expr(self.use_block_sparsity): + # KV tile with zero Q blocks produces no dK/dV; write zeros. + if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q): acc_dK.fill(0.0) acc_dV.fill(0.0) self.epilogue_dKV( @@ -1259,8 +1409,6 @@ def mma( tma_atom_dV, tiled_mma_dK, tiled_mma_dV, - r2s_tiled_copy_dKVaccum, - sdKVaccum_layout, tidx, n_block, head_idx, @@ -1271,6 +1419,26 @@ def mma( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + @staticmethod + @cute.jit + def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32: + """Retrieve the statistic for a given accumulator row. + + When shuffle=False, direct register indexing. + When shuffle=True, warp shuffle from the thread group that holds the value. + """ + if const_expr(not shuffle): + return tSrS[row] + # tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp + vecsize = cute.size(tSrS, mode=[0, 0]) # 2 + idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1]))) + # register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ... + return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4)) + @cute.jit def mma_one_m_block( self, @@ -1283,24 +1451,23 @@ def mma_one_m_block( mma_pdo_fn: Callable, mma_dsq_fn: Callable, mma_dsk_fn: Callable, + copy_P_r2s: Optional[Callable], + copy_dS_r2s: Callable, pipeline_Q: cutlass.pipeline.PipelineAsync, pipeline_dO: cutlass.pipeline.PipelineAsync, tLSEsLSE: cute.Tensor, tLSEsdPsum: cute.Tensor, - tPsP: Optional[cute.Tensor], - tdSsdS: Optional[cute.Tensor], - tdQsdQaccum: cute.Tensor, - smem_thr_copy_PdS: cute.TiledCopy, - smem_thr_copy_dQaccum: cute.TiledCopy, + tdQsdQaccum: Optional[cute.Tensor], softmax_scale_log2: Float32, PdS_barrier: cutlass.pipeline.NamedBarrier, + is_dQ_wg: cutlass.Constexpr[bool] = True, mask_fn: Optional[Callable] = None, score_mod_fn: Optional[Callable] = None, score_mod_bwd_fn: Optional[Callable] = None, dKV_accumulate: Boolean = True, ): consumer_state_dO_cur = ( - consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q + consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO ) smem_idx_Q = consumer_state_Q.index smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 @@ -1308,6 +1475,7 @@ def mma_one_m_block( # (1) [GEMM 1] S = Q @ K^T pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) + # If shuffle_LSE, OOB reads are OK since sLSE is already padded tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) # (2) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait( @@ -1326,10 +1494,12 @@ def mma_one_m_block( if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB) + lane_idx = cute.arch.lane_idx() for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE) for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): acc_S_mn[r, c] = cute.math.exp2( - acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True + acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True ) tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) @@ -1340,15 +1510,15 @@ def mma_one_m_block( # sync to ensure P has already been used in the previous iteration before overwriting if const_expr(self.PdS_stage == 1): PdS_barrier.arrive_and_wait() - tPrP = smem_thr_copy_PdS.retile(tdVrP) - cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS]) + copy_P_r2s(tdVrP, dst_idx=smem_idx_PdS) # (4) [Pointwise 2] dS = P*(dP-dPsum) warpgroup.wait_group(0) acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum) for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): - acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val) if const_expr(self.score_mod_bwd is not None): score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block) @@ -1367,8 +1537,7 @@ def mma_one_m_block( PdS_barrier.arrive_and_wait() # R2S for dS - tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) - cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS]) + copy_dS_r2s(tdKrdS, dst_idx=smem_idx_PdS) # (5) [GEMM 3] dV += P.T @ dO if const_expr(not self.mma_dkv_is_rs): @@ -1381,36 +1550,50 @@ def mma_one_m_block( # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_view_async_shared() PdS_barrier.arrive_and_wait() - # (6) [GEMM 4] dQ = dS @ K - acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) - pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done - # (7) [GEMM 5] dK += dS.T @ Q - if const_expr(not self.mma_dkv_is_rs): - mma_dsq_fn( - A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 - ) - else: - mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) + if const_expr(is_dQ_wg): + # (6) [GEMM 4] dQ = dS @ K + acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) + pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) - tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) - cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) - cute.arch.fence_view_async_shared() - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) - warpgroup.wait_group(0) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) - pipeline_Q.consumer_release(consumer_state_Q) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) + # dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem + # When dQ_single_wg, only WG0 enters here so warp_group_idx == 0 + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + tdQrdQaccum_flat = cute.make_tensor( + acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape) + ) + cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) + cute.arch.fence_view_async_shared() + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + warpgroup.wait_group(0) + pipeline_Q.consumer_release(consumer_state_Q) + else: + # dQ_single_wg: WG1 skips dQ, only does dV wait + dK + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) + pipeline_dO.consumer_release(consumer_state_dO_cur) + warpgroup.wait_group(0) + pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() consumer_state_dO.advance() @@ -1430,37 +1613,24 @@ def epilogue_dKV( tma_atom_dV: cute.CopyAtom, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, - r2s_tiled_copy_dKVaccum: cute.TiledCopy, - sdKVaccum_layout: cute.Layout, tidx: Int32, n_block: Int32, head_idx: Int32, batch_idx: Int32, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, ): + epi_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads + ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if const_expr(self.qhead_per_kvhead == 1): - rdV = cute.make_fragment_like(acc_dV, self.dtype) - rdV.store(acc_dV.load().to(self.dtype)) - rdK = utils.cvt_f16(acc_dK, self.dtype) - - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) - - smem_copy_atom_dKV = cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=self.dKV_swapAB, num_matrices=4), - self.dtype, - ) - smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice( - tidx - ) - smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice( - tidx - ) - mdV_cur = mdV[None, None, head_idx, batch_idx] - mdK_cur = mdK[None, None, head_idx, batch_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[ + None, None, head_idx + ] + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[ + None, None, head_idx + ] gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) store_dK, _, _ = copy_utils.tma_get_copy_fn( @@ -1469,99 +1639,100 @@ def epilogue_dKV( store_dV, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True ) - - taccdVrdV = smem_thr_copy_dV.retile(rdV) sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV) - taccdVsdV = smem_thr_copy_dV.partition_D(sdV) - cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) - cute.arch.fence_view_async_shared() - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK) + copy_dV_r2s, _, _ = copy_utils.get_smem_store_C( + tiled_mma_dV, + sdV, + tidx, + self.arch, + transpose=self.dKV_swapAB, + position_independent=True, ) + copy_dK_r2s, _, _ = copy_utils.get_smem_store_C( + tiled_mma_dK, + sdK, + tidx, + self.arch, + transpose=self.dKV_swapAB, + position_independent=True, + ) + cute.arch.cp_async_bulk_wait_group(1, read=True) + epi_barrier.arrive_and_wait() + copy_dV_r2s(acc_dV, dst_idx=None) + cute.arch.fence_view_async_shared() + epi_barrier.arrive_and_wait() if warp_idx == 4: store_dV() - taccdKrdK = smem_thr_copy_dK.retile(rdK) - sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK) - taccdKsdK = smem_thr_copy_dK.partition_D(sdK) - cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(1, read=True) + epi_barrier.arrive_and_wait() + copy_dK_r2s(acc_dK, dst_idx=None) cute.arch.fence_view_async_shared() - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) + epi_barrier.arrive_and_wait() if warp_idx == 4: store_dK() cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) else: + sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma + sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma + sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma)) + sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma)) head_idx_kv = head_idx // qhead_per_kvhead_divmod - - mdKaccum_cur = mdK[None, head_idx_kv, batch_idx] + mdKaccum_cur = seqlen.offset_batch_K( + mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim + )[None, head_idx_kv] + mdVaccum_cur = seqlen.offset_batch_K( + mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv + )[None, head_idx_kv] gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,)) - gdKaccum = cute.flat_divide( - gdKaccum_, (self.tile_n * self.tile_hdim // self.num_mma_warp_groups,) - ) - - mdVaccum_cur = mdV[None, head_idx_kv, batch_idx] + gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,)) gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,)) - gdVaccum = cute.flat_divide( - gdVaccum_, (self.tile_n * self.tile_hdimv // self.num_mma_warp_groups,) + gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,)) + # These two overlap each other + sVaccum_ptr = cute.recast_ptr(sV.iterator, dtype=Float32) + sdKaccum = cute.make_tensor(sVaccum_ptr, sdKaccum_layout) + sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout) + tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)), + cute.make_layout(128 // Float32.width), ) - - sdKVaccum = cute.make_tensor( - cute.recast_ptr(sV.iterator, dtype=Float32), - sdKVaccum_layout, - ) - - smem_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_slice(tidx) - tdKsdKVaccum = smem_thr_copy_dKVaccum.partition_D(sdKVaccum) - - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) - - tdKrdKaccum_flat = cute.make_tensor( - acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape) - ) - cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum) + thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx) + tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum) + tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum) + + cute.arch.cp_async_bulk_wait_group(0, read=True) + epi_barrier.arrive_and_wait() + tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape) + cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum) cute.arch.fence_view_async_shared() - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) - + epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): - for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( - sdKVaccum[None, wg_idx].iterator, + sdKaccum[None, wg_idx].iterator, gdKaccum[None, wg_idx].iterator, - self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups, + self.tma_copy_bytes["dKacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) - - tdVrdVaccum_flat = cute.make_tensor( - acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape) - ) - cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum) + cute.arch.cp_async_bulk_wait_group(0, read=True) + epi_barrier.arrive_and_wait() + tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape) + cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum) cute.arch.fence_view_async_shared() - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads - ) - + epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): - for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( - sdKVaccum[None, wg_idx].iterator, + sdVaccum[None, wg_idx].iterator, gdVaccum[None, wg_idx].iterator, - self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups, + self.tma_copy_bytes["dVacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) @cute.jit def dQaccum_store( @@ -1572,21 +1743,45 @@ def dQaccum_store( TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], blocksparse_tensors: Optional[BlockSparseTensors] = None, + mdQ_semaphore: Optional[cute.Tensor] = None, ): + tidx, _, _ = cute.arch.thread_idx() + # warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63) + warp_local_tidx = tidx % cute.arch.WARP_SIZE + read_flag = const_expr(not self.deterministic) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) - # (M * K / WG, WG, _) - gdQaccum = cute.flat_divide( - gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) + if const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + else: + mdQaccum_cur = cute.domain_offset( + (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] + ) + # ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks) + gdQaccum = cute.local_tile( + mdQaccum_cur, + ( + cute.make_layout( + (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) + ), + ), + (None,), ) + + if const_expr(mdQ_semaphore is not None): + # mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose + mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) loop_count = m_block_max - m_block_min else: total_block_cnt = get_total_q_block_count_bwd( @@ -1605,7 +1800,36 @@ def dQaccum_store( m_block = m_block_min + iter_idx m_block_safe = m_block - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + num_dQ_chunks = self.num_wg_dQ + for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): + if const_expr(not self.deterministic): + # If deterministic, we already waited at the end of the prev iter + cute.arch.cp_async_bulk_wait_group( + num_dQ_chunks - 1 - warp_group_idx, read=read_flag + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + + cute.arch.WARP_SIZE, + ) + + # Semaphore acquire: wait for prior n_blocks to finish writing this m_block + if const_expr(self.deterministic): + if const_expr(self.spt): + _, n_block_max_for_m_block = block_info.get_n_block_min_max( + seqlen, m_block_safe + ) + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq( + mdQ_semaphore_cur[(m_block_safe, None)].iterator, + warp_local_tidx, + 0, # flag_offset + lock_value, + ) + + for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group @@ -1614,20 +1838,24 @@ def dQaccum_store( with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, - gdQaccum[None, warp_group_idx, m_block_safe].iterator, + gdQaccum[(None, warp_group_idx), m_block_safe].iterator, self.tma_copy_bytes["dQ"], ) cute.arch.cp_async_bulk_commit_group() - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group( - self.num_mma_warp_groups - 1 - warp_group_idx, read=True - ) - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group - + cute.arch.WARP_SIZE, + + # Semaphore release: signal that this n_block is done with this m_block + if const_expr(self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_safe, None)].iterator, + warp_local_tidx, + 0, # flag_offset + 1, ) else: + assert not self.deterministic, ( + "Deterministic not implemented for block-sparse backward" + ) dQaccum_store_block_sparse_bwd_sm90( blocksparse_tensors, batch_idx, @@ -1637,9 +1865,27 @@ def dQaccum_store( gdQaccum, subtile_factor=self.subtile_factor, m_block_max=m_block_max, - num_mma_warp_groups=self.num_mma_warp_groups, + num_dQ_warp_groups=self.num_wg_dQ, num_threads_per_warp_group=self.num_threads_per_warp_group, tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], ) + + # For local masking + deterministic (non-spt): signal remaining m_blocks + # that this n_block won't visit, so they don't deadlock waiting. + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block, None)].iterator, + warp_local_tidx, + 0, # flag_offset + 1, + ) + tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + + if const_expr(not self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 9eaccda41bc..4d47fab109f 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -15,42 +15,28 @@ import cutlass import cutlass.cute as cute from cutlass import Constexpr, Float32, Int32, const_expr, Boolean -from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu import cpasync, warp import cutlass.utils as utils_basic -from cutlass.utils import LayoutEnum -import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL from quack import copy_utils from quack import layout_utils -from quack import sm90_utils from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.softmax import Softmax from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( - produce_block_sparse_loads, - consume_block_sparse_loads, -) -from flash_attn.cute import pipeline from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.tile_scheduler import ( - TileSchedulerArguments, - SingleTileScheduler, - SingleTileLPTScheduler, - SingleTileVarlenScheduler, - ParamsBase, -) -from cutlass.cute import FastDivmodDivisor +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: - arch: int = 80 def __init__( self, @@ -113,10 +99,15 @@ def __init__( self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 - if const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) + if self.vec_size > 2: + raise ValueError( + f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 " + "due to accumulator thread ownership pattern." + ) + self.arch = BaseDSL._get_dsl().get_arch_enum() @staticmethod def can_implement( @@ -319,7 +310,8 @@ def __call__( mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale: Float32, - stream: cuda.CUstream, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. @@ -352,7 +344,7 @@ def epilogue( cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) - smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) @@ -367,11 +359,7 @@ def epilogue( # Write LSE from rmem -> gmem if const_expr(mLSE is not None): - if const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( @@ -385,7 +373,7 @@ def epilogue( t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: - for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): + for m in cutlass.range(cute.size(taccOgLSE.shape[1]), unroll_full=True): if ( t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] @@ -394,11 +382,8 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - if const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] # thr_mma = tiled_mma.get_slice(tidx) # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) @@ -635,12 +620,19 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - stream: cuda.CUstream, - softmax_scale: Optional[Float32] = None, + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors=None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. @@ -649,7 +641,7 @@ def __call__( """ assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type( - *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size @@ -657,41 +649,54 @@ def __call__( self.num_Q_load_threads = self.num_threads self.num_epilogue_threads = self.num_threads # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None - self.use_tma_O = self.arch >= 90 + self.use_tma_O = self.arch >= Arch.sm_90 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - mQ, mK, mV, mO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) - for t in (mQ, mK, mV, mO) + # Layout permutation: 4D non-varlen vs 3D varlen + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mQ, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) + for t in (mQ, mO) ] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mQ.shape[0], self.tile_m), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]), - ) - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = Float32(softmax_scale * LOG2_E) - softmax_scale = None + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + if const_expr(mLSE is not None): + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + # TileScheduler for varlen, simple grid for non-varlen + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler else: - # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = Float32(LOG2_E) - softmax_scale = Float32(softmax_scale) - - fastdiv_mods = None - if const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + TileScheduler = SingleTileScheduler + num_batch = ( + mCuSeqlensQ.shape[0] - 1 + if const_expr(mCuSeqlensQ is not None) + else mQ.shape[3] + ) + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mQ.shape[0], self.tile_m), + num_head=cute.size(mQ.shape[2]), + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mQ.shape[1], + headdim_v=mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) + fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors) self.kernel( mQ, @@ -699,6 +704,10 @@ def __call__( mV, mO, mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, softmax_scale_log2, softmax_scale, window_size_left, @@ -715,6 +724,8 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, + tile_sched_params, + TileScheduler, aux_tensors, fastdiv_mods, ).launch( @@ -732,6 +743,10 @@ def kernel( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], softmax_scale_log2: Float32, softmax_scale: Optional[Float32], window_size_left: Optional[Int32], @@ -748,12 +763,17 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, + tile_sched_params, + TileScheduler: cutlass.Constexpr[Callable], aux_tensors=None, fastdiv_mods=None, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + m_block, num_head, batch_size, _ = work_tile.tile_idx block_info = BlockInfo( self.tile_m, @@ -765,13 +785,21 @@ def kernel( window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + seqlen = SeqlenInfoQK.create( + batch_idx=batch_size, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # TODO: return early if n_block_max == 0 - # if self.is_causal: - # if n_block_max <= 0: - # return - n_block = n_block_max - 1 + # For varlen, wasted grid tiles (where batch_idx >= num_batch) will have + # seqlen_q=seqlen_k=0 and n_block_max=0. Clamp to 0 so we don't use a + # negative block index for K/V loads; the load/store predicates already + # guard all memory accesses when seqlen is 0. + n_block = cutlass.max(n_block_max - 1, 0) # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. @@ -779,10 +807,20 @@ def kernel( blkQ_shape = (self.tile_m, self.tile_hdim) blkK_shape = (self.tile_n, self.tile_hdim) blkV_shape = (self.tile_n, self.tile_hdimv) - gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) num_head_kv = num_head // self.qhead_per_kvhead - gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) - gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0)) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, num_head, batch_size] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head]) + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur = mK[None, None, num_head_kv, batch_size] + mV_cur = mV[None, None, num_head_kv, batch_size] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv]) + mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv]) + gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0)) + gK = cute.local_tile(mK_cur, blkK_shape, (None, 0)) + gV = cute.local_tile(mV_cur, blkV_shape, (None, 0)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -954,18 +992,20 @@ def preprocess_Q(): mask = AttentionMask( self.tile_m, self.tile_n, - seqlen.seqlen_q, - seqlen.seqlen_k, + seqlen, window_size_left, window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( mask.apply_mask, + batch_idx=batch_size, + head_idx=num_head, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, ) @@ -977,8 +1017,8 @@ def preprocess_Q(): smem_pipe_read, smem_pipe_write, is_first_n_block=True, - check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=True), + seqlen=seqlen, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -993,15 +1033,17 @@ def preprocess_Q(): n_block, smem_pipe_read, smem_pipe_write, - check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=False), + seqlen=seqlen, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range(n_block, unroll=1): compute_one_n_block( - n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, + seqlen=seqlen, is_first_n_block=False, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False) ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -1145,1282 +1187,9 @@ def load_K_next(): # load_K_next() -class FlashAttentionForwardSm90(FlashAttentionForwardBase): - arch = 90 - - def __init__( - self, - *args, - intra_wg_overlap: bool = True, - mma_pv_is_rs: bool = True, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.intra_wg_overlap = intra_wg_overlap - self.mma_pv_is_rs = mma_pv_is_rs - self.buffer_align_bytes = 1024 - - def _get_smem_layout_atom(self): - sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), - self.dtype, - ) - sK_layout_atom = sQ_layout_atom - sV_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv - ), - self.dtype, - ) - sO_layout_atom = sV_layout_atom - if not self.mma_pv_is_rs: - sP_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n - ), - self.dtype, - ) - else: - sP_layout_atom = None - return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom - - def _get_tiled_mma(self): - tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.K, - Float32, - atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.tile_n), - ) - tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.MN, - Float32, - atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM - if self.mma_pv_is_rs - else warpgroup.OperandSource.SMEM, - ) - return tiled_mma_qk, tiled_mma_pv - - def _get_shared_storage_cls(self): - sQ_struct, sK_struct, sV_struct = [ - cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes] - for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) - - ] - cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) - sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] - cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 - sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] - # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, - mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] - mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - - @cute.struct - class SharedStorageQKV: - mbar_ptr: mbar_ptr_QO_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - sV: sV_struct - sQ: sQ_struct - sK: sK_struct - sP: sP_struct - - @cute.struct - class SharedStorageSharedQV: - mbar_ptr: mbar_ptr_QO_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - sQ: sQV_struct - sK: sK_struct - sP: sP_struct - - return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV - - @cute.jit - def __call__( - self, - mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table - mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table - mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - mLSE: Optional[cute.Tensor], - softmax_scale: Float32, - stream: cuda.CUstream, - mCuSeqlensQ: Optional[cute.Tensor] = None, - mCuSeqlensK: Optional[cute.Tensor] = None, - mSeqUsedQ: Optional[cute.Tensor] = None, - mSeqUsedK: Optional[cute.Tensor] = None, - mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) - window_size_left: Int32 | int | None = None, - window_size_right: Int32 | int | None = None, - learnable_sink: Optional[cute.Tensor] = None, - blocksparse_tensors: Optional[BlockSparseTensors] = None, - aux_tensors: Optional[list] = None, - ): - """Configures and launches the flash attention kernel. - - mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: - (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) - """ - - self._check_type( - *( - t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) - ) - ) - - mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] - KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] - mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] - LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None - - tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() - self.num_mma_threads = tiled_mma_qk.size - self.num_threads_per_warp_group = 128 - self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group - self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) - self.num_producer_threads = 32 - self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q - self.num_epilogue_threads = self.num_mma_threads - self.num_mma_regs = ( - 256 - if self.num_mma_warp_groups == 1 - else (240 if self.num_mma_warp_groups == 2 else 160) - ) - self.num_producer_regs = ( - 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) - ) - # self.num_mma_regs = 232 - # self.num_producer_regs = 40 - self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) - - self.use_scheduler_barrier = ( - (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) - if const_expr(self.intra_wg_overlap) - else (self.num_mma_warp_groups == 2) - ) - self.use_tma_Q = self.arch >= 90 and not ( - self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 - ) - self.use_tma_O = ( - self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa - ) - # TODO: rescale_O_before_gemm - self._setup_attributes() - # TODO: we prob don't need most of what's in _setup_attributes - self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ - sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) - for mX, shape, stage in [ - (mQ, (self.tile_m, self.tile_hdim), None), - (mK, (self.tile_n, self.tile_hdim), self.num_stages), - (mV, (self.tile_n, self.tile_hdimv), self.num_stages), - (mO, (self.tile_m, self.tile_hdimv), None), - ] - ] - self.sP_layout = None - if const_expr(not self.mma_pv_is_rs): - self.sP_layout = sm90_utils.make_smem_layout( - mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) - ) - - SharedStorage = self._get_shared_storage_cls() - - if const_expr(self.pack_gqa): - shape_Q_packed = ( - (self.qhead_per_kvhead, mQ.shape[0]), - mQ.shape[1], - mK.shape[2], - *mQ.shape[3:], - ) - stride_Q_packed = ( - (mQ.stride[2], mQ.stride[0]), - mQ.stride[1], - mQ.stride[2] * self.qhead_per_kvhead, - *mQ.stride[3:], - ) - mQ = cute.make_tensor( - mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) - ) - shape_O_packed = ( - (self.qhead_per_kvhead, mO.shape[0]), - mK.shape[1], - mK.shape[2], - *mO.shape[3:], - ) - stride_O_packed = ( - (mO.stride[2], mO.stride[0]), - mO.stride[1], - mO.stride[2] * self.qhead_per_kvhead, - *mO.stride[3:], - ) - mO = cute.make_tensor( - mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) - ) - if const_expr(mLSE is not None): - shape_LSE_packed = ( - (self.qhead_per_kvhead, mLSE.shape[0]), - mK.shape[2], - *mLSE.shape[2:], - ) - stride_LSE_packed = ( - (mLSE.stride[1], mLSE.stride[0]), - mLSE.stride[1] * self.qhead_per_kvhead, - *mLSE.stride[2:], - ) - mLSE = cute.make_tensor( - mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) - ) - - # TMA - gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() - gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast - gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() - self.tma_copy_bytes = { - name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) - for name, mX, layout in [ - ("Q", mQ, self.sQ_layout), - ("K", mK, self.sK_layout), - ("V", mV, self.sV_layout), - ] - } - tma_atom_Q, tma_tensor_Q = None, None - if const_expr(self.use_tma_Q): - tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, - mQ, - self.sQ_layout, - (self.tile_m, self.tile_hdim), # No mcast - ) - tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_KV, - mK, - cute.select(self.sK_layout, mode=[0, 1]), - (self.tile_n, self.tile_hdim), - 1, # No mcast for now - ) - tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_KV, - mV, - cute.select(self.sV_layout, mode=[0, 1]), - (self.tile_n, self.tile_hdimv), - 1, # No mcast for now - ) - tma_atom_O, tma_tensor_O = None, None - if const_expr(self.use_tma_O): - tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_O, - mO, - self.sO_layout, - (self.tile_m, self.tile_hdimv), # No mcast - ) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_causal or self.is_local) - else SingleTileLPTScheduler - ) - tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) - if const_expr(mCuSeqlensQ is None) - else cute.size(mCuSeqlensQ.shape[0] - 1), - 1, # num_splits - cute.size(mK.shape[0]), - mQ.shape[1], - mV.shape[1], - total_q=cute.size(mQ.shape[0]) - if const_expr(mCuSeqlensQ is not None) - else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.tile_m, self.tile_n), - mCuSeqlensQ=mCuSeqlensQ, - mSeqUsedQ=mSeqUsedQ, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - element_size=self.dtype.width // 8, - is_persistent=False, - lpt=self.is_causal or self.is_local, - ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) - grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = softmax_scale * LOG2_E - softmax_scale = None - else: - # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = LOG2_E - softmax_scale = softmax_scale - if const_expr(window_size_left is not None): - window_size_left = Int32(window_size_left) - if const_expr(window_size_right is not None): - window_size_right = Int32(window_size_right) - - fastdiv_mods = None - if const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = ( - cute.size(mK.shape[0]) - if const_expr(mPageTable is None) - else mK.shape[0] * mPageTable.shape[1] - ) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) - - self.kernel( - tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, - tma_tensor_K, - tma_tensor_V, - tma_tensor_O if const_expr(self.use_tma_O) else mO, - mLSE, - mCuSeqlensQ, - mCuSeqlensK, - mSeqUsedQ, - mSeqUsedK, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - tma_atom_O, - softmax_scale_log2, - softmax_scale, - window_size_left, - window_size_right, - learnable_sink, - blocksparse_tensors, - self.sQ_layout, - self.sK_layout, - self.sV_layout, - self.sO_layout, - self.sP_layout, - self.gmem_tiled_copy_Q, - self.gmem_tiled_copy_K, - self.gmem_tiled_copy_V, - self.gmem_tiled_copy_O, - tiled_mma_qk, - tiled_mma_pv, - tile_sched_params, - TileScheduler, - SharedStorage, - aux_tensors, - fastdiv_mods, - ).launch( - grid=grid_dim, - block=[self.num_threads, 1, 1], - stream=stream, - min_blocks_per_mp=1, - ) - - @cute.kernel - def kernel( - self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mO: cute.Tensor, - mLSE: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mCuSeqlensK: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - mSeqUsedK: Optional[cute.Tensor], - tma_atom_Q: Optional[cute.CopyAtom], - tma_atom_K: Optional[cute.CopyAtom], - tma_atom_V: Optional[cute.CopyAtom], - tma_atom_O: Optional[cute.CopyAtom], - softmax_scale_log2: Float32, - softmax_scale: Optional[Float32], - window_size_left: Optional[Int32], - window_size_right: Optional[Int32], - learnable_sink: Optional[cute.Tensor], - blocksparse_tensors: Optional[BlockSparseTensors], - sQ_layout: cute.ComposedLayout, - sK_layout: cute.ComposedLayout, - sV_layout: cute.ComposedLayout, - sO_layout: cute.ComposedLayout, - sP_layout: cute.ComposedLayout | None, - gmem_tiled_copy_Q: cute.TiledCopy, - gmem_tiled_copy_K: cute.TiledCopy, - gmem_tiled_copy_V: cute.TiledCopy, - gmem_tiled_copy_O: cute.TiledCopy, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - tile_sched_params: ParamsBase, - TileScheduler: cutlass.Constexpr[Callable], - SharedStorage: cutlass.Constexpr[Callable], - aux_tensors=Optional[list[cute.Tensor]], - fastdiv_mods=None, - ): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # Prefetch tma descriptor - if warp_idx == 0: - for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): - if const_expr(tma_atom is not None): - cpasync.prefetch_descriptor(tma_atom) - - smem = cutlass.utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - # Mbarrier init - mbar_ptr_Q = storage.mbar_ptr.data_ptr() - if warp_idx == 1: - # if tidx < 2: - # # barrierO num threads should be self.num_mma_threads - # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - if const_expr(not self.use_tma_Q): - cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) - # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) - # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread - ) - pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE - ) - pipeline_k = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.mbar_ptr_K.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_kv_producer_group, - consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_bytes["K"], - defer_sync=True, - ) - pipeline_v = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.mbar_ptr_V.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_kv_producer_group, - consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_bytes["V"], - defer_sync=False - ) - - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - if const_expr(not self.Q_in_regs): - sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - else: - sV = storage.sQ.get_tensor( - sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type - ) - # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma - sVt = layout_utils.transpose_view(sV) - sP = None - if const_expr(sP_layout is not None): - sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) - # reuse sQ's data iterator - sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) - - block_info = BlockInfo( - self.tile_m, - self.tile_n, - self.is_causal, - self.is_local, - False, # is_split_kv - window_size_left, - window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - SeqlenInfoCls = partial( - SeqlenInfoQK.create, - seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], - seqlen_k_static=mK.shape[0], - mCuSeqlensQ=mCuSeqlensQ, - mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, - mSeqUsedK=mSeqUsedK, - ) - AttentionMaskCls = partial( - AttentionMask, - self.tile_m, - self.tile_n, - window_size_left=window_size_left, - window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) - - if warp_idx < 4: # Producer - cute.arch.setmaxregister_decrease(self.num_producer_regs) - self.load( - mQ, - mK, - mV, - sQ, - sK, - sV, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - blocksparse_tensors, - block_info, - SeqlenInfoCls, - TileSchedulerCls, - ) - - else: # Consumer - cute.arch.setmaxregister_increase(self.num_mma_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - 128 - self.mma( - tiled_mma_qk, - tiled_mma_pv, - mQ, - mO, - mLSE, - sQ, - sK, - sVt, - sP, - sO, - learnable_sink, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - gmem_tiled_copy_Q, - gmem_tiled_copy_O, - tma_atom_O, - tidx, - softmax_scale_log2, - softmax_scale, - block_info, - SeqlenInfoCls, - AttentionMaskCls, - TileSchedulerCls, - blocksparse_tensors, - aux_tensors, - fastdiv_mods, - ) - - @cute.jit - def load( - self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - sQ: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - mbar_ptr_Q: cutlass.Pointer, - blocksparse_tensors: Optional[BlockSparseTensors], - block_info: BlockInfo, - SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, - ): - warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - if warp_idx_in_wg == 0: - q_producer_phase = Int32(1) - kv_producer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.num_stages - ) - tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: - m_block, head_idx, batch_idx, _ = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) - mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - head_idx_kv = ( - head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx - ) - mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] - mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] - gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) - gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) - if const_expr(self.use_tma_Q): - gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - load_Q, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True - ) - # TODO: mcast - # TODO check warp_idx if we have 128 producer threads - load_K, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_K, 0, cute.make_layout(1), gK, sK - ) - load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) - load_V, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_V, 0, cute.make_layout(1), gV, sV - ) - load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - - if const_expr(not self.use_block_sparsity): - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - # First iteration: load both Q & K with the same mbarrier - n_block = n_block_max - 1 - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] - if const_expr(self.use_tma_Q) - else 0, - ) - if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - load_K(src_idx=n_block, producer_state=kv_producer_state) - - if const_expr(not self.intra_wg_overlap): - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 1 - i - 1 - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - else: - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block_prev = n_block_max - i - 1 - n_block = n_block_prev - 1 - kv_producer_state_prev = kv_producer_state.clone() - kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) - n_block = n_block_min - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - else: - kv_producer_state = produce_block_sparse_loads( - blocksparse_tensors, - batch_idx, - head_idx, - m_block, - kv_producer_state, - load_Q, - load_K, - load_V, - pipeline_k, - pipeline_v, - self.use_tma_Q, - self.tma_copy_bytes["Q"], - self.intra_wg_overlap, - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - self.q_subtile_factor if self.q_subtile_factor is not None else 1, - ) - - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() - # End of persistent scheduler loop - - @cute.jit - def mma( - self, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - # softmax: Softmax, - # acc_O: cute.Tensor, - mQ: cute.Tensor, - mO: cute.Tensor, - mLSE: Optional[cute.Tensor], - sQ: cute.Tensor, - sK: cute.Tensor, - sVt: cute.Tensor, - sP: Optional[cute.Tensor], - sO: cute.Tensor, - learnable_sink: Optional[cute.Tensor], - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - mbar_ptr_Q: cutlass.Pointer, - gmem_tiled_copy_Q: cute.TiledCopy, - gmem_tiled_copy_O: cute.TiledCopy, - tma_atom_O: Optional[cute.CopyAtom], - tidx: Int32, - softmax_scale_log2: Float32, - softmax_scale: Optional[Float32], - block_info: BlockInfo, - SeqlenInfoCls: Callable, - AttentionMaskCls: Callable, - TileSchedulerCls: Callable, - blocksparse_tensors: Optional[BlockSparseTensors], - aux_tensors: Optional[list], - fastdiv_mods=None, - ): - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group - ) - thr_mma_qk = tiled_mma_qk.get_slice(tidx) - wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) - _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( - wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK - ) - mma_qk_fn = partial( - sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK - ) - acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( - wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt - ) - mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) - - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) - tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None - smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - - self.mma_init() - - mma_one_n_block_all = partial( - self.mma_one_n_block_intrawg_overlap - if const_expr(self.intra_wg_overlap) - else self.mma_one_n_block, - mma_qk_fn=mma_qk_fn, - pipeline_k=pipeline_k, - pipeline_v=pipeline_v, - acc_O=acc_O, - tOrP=tOrP, - smem_copy_params=smem_copy_params, - check_inf=True, - ) - - q_consumer_phase = Int32(0) - kv_consumer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.num_stages - ) - - tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() - softmax = Softmax.create( - softmax_scale_log2, - num_rows=acc_O.shape[0][0] * acc_O.shape[1], - softmax_scale=softmax_scale, - ) - - process_first_half_block = partial( - self.first_half_block_overlap, - mma_qk_fn=mma_qk_fn, - pipeline_k=pipeline_k, - tOrP=tOrP, - smem_copy_params=smem_copy_params, - softmax=softmax, - ) - process_last_half_block = partial( - self.last_half_block_overlap, - pipeline_v=pipeline_v, - mma_pv_fn=mma_pv_fn, - ) - while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: - - # shape: (atom_v_m * rest_m) - m_block, head_idx, batch_idx, _ = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) - - # Recompute fastdiv_mods if necessary for varlen with aux_tensors - recompute_fastdiv_mods_q = cutlass.const_expr( - aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) - ) - recompute_fastdiv_mods_k = cutlass.const_expr( - aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) - ) - if cutlass.const_expr(fastdiv_mods is not None): - seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods - fastdiv_mods = ( - seqlen_q_divmod - if not recompute_fastdiv_mods_q - else FastDivmodDivisor(seqlen.seqlen_q), - seqlen_k_divmod - if not recompute_fastdiv_mods_k - else FastDivmodDivisor(seqlen.seqlen_k), - ) - - mask = AttentionMaskCls(seqlen) - mask_fn = partial( - mask.apply_mask, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=m_block, - thr_mma=thr_mma_qk, - mask_causal=self.is_causal, - mask_local=self.is_local, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, - ) - score_mod_fn = None - if const_expr(self.score_mod is not None): - score_mod_fn = partial( - self.apply_score_mod, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - softmax_scale=softmax_scale, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, - ) - mma_one_n_block = partial( - mma_one_n_block_all, - seqlen=seqlen, - softmax=softmax, - score_mod_fn=score_mod_fn, - ) - # Load Q if not TMA_Q - if const_expr(not self.use_tma_Q): - pack_gqa = PackGQA( - self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead - ) - mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, - # headdim=mQ.shape[1]) - pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) - - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - if const_expr(not self.use_tma_Q): - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) - q_consumer_phase ^= 1 - # For performance reason, we separate out two kinds of iterations: - # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of tile_n. - # We also need masking on S if it's causal, for the last several blocks. - # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True - O_should_accumulate = False - - # ========================================== - # MAINLOOP - # ========================================== - if const_expr(not self.use_block_sparsity): - # ========================================== - # No block-sparsity (original path) - # ========================================== - # First iteration with seqlen masking - if const_expr(self.intra_wg_overlap): - kv_consumer_state = process_first_half_block( - n_block=n_block_max - 1, - seqlen=seqlen, - kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=self.mask_mod), - score_mod_fn=score_mod_fn, - is_first_block=True, - ) - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - # acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=True), - is_first_n_block=True, - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), - ) - O_should_accumulate = True - # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range( - n_block_max - n_block_min_causal_local_mask, unroll=1 - ): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - # Last "half" iteration - if const_expr(self.intra_wg_overlap): - kv_consumer_state = process_last_half_block( - kv_consumer_state=kv_consumer_state, - zero_init=not O_should_accumulate, - ) - O_should_accumulate = True - else: - self.warp_scheduler_barrier_arrive() - - else: - # ========================================== - # Block sparsity - # ========================================== - kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( - blocksparse_tensors, - batch_idx, - head_idx, - m_block, - seqlen, - kv_consumer_state, - mma_pv_fn, - mma_one_n_block, - process_first_half_block, - process_last_half_block, - mask_fn, - score_mod_fn, - O_should_accumulate, - self.mask_mod, - fastdiv_mods, - self.intra_wg_overlap, - self.warp_scheduler_barrier_sync, - self.warp_scheduler_barrier_arrive, - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - self.q_subtile_factor if self.q_subtile_factor is not None else 1, - ) - - # Handle empty case (when no blocks to process) - if not processed_any: - softmax.reset() - acc_O.fill(0.0) - - sink_val = None - if const_expr(learnable_sink is not None): - if const_expr(not self.pack_gqa): - sink_val = Float32(learnable_sink[head_idx]) - else: # Each thread might have a different sink value due to different q_head - sink_val = cute.make_fragment_like(softmax.row_max, Float32) - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) - tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) - for r in cutlass.range(cute.size(sink_val), unroll_full=True): - row = m_block * self.tile_m + tScS_mn[r][0] - q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead - sink_val[r] = Float32(learnable_sink[q_head_idx]) - - # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize(sink_val=sink_val) - softmax.rescale_O(acc_O, row_scale) - - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - self.epilogue( - acc_O, - softmax.row_sum, - mO, - mLSE, - sO, - seqlen, - gmem_tiled_copy_O, - tma_atom_O, - tiled_mma_pv, - tidx, - m_block, - head_idx, - batch_idx, - ) - - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() - - - @cute.jit - def first_half_block_overlap( - self, - n_block: Int32, - mma_qk_fn: Callable, - kv_consumer_state, - pipeline_k, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - mask_fn: Callable = None, - score_mod_fn: Optional[Callable] = None, - is_first_block: bool = False, - ): - """Processes the first half block when using intra-warpgroup-overlap""" - - pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) - acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) - pipeline_k.consumer_release(kv_consumer_state) - - # Apply score modification if present - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - - # Apply mask; mask_seqlen always True for first block - # Caveat: if full block further right than mask block, seqlen masking is redundant; - # however, masking is being applied anyway, so essentially no perf hit - mask_fn(acc_S, n_block=n_block, mask_seqlen=True) - - softmax.online_softmax(acc_S, is_first=is_first_block) - - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - - # if pv gemm not rs - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - # Fence and barrier to make smem store visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() - - return kv_consumer_state - - @cute.jit - def last_half_block_overlap( - self, - kv_consumer_state, - pipeline_v, - mma_pv_fn: Callable, - zero_init: bool, - ): - """Processes the final PV GEMM when using intra-warpgroup-overlap""" - - pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) - pipeline_v.consumer_release(kv_consumer_state) - kv_consumer_state.advance() - return kv_consumer_state - - @cute.jit - def mma_one_n_block( - self, - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - n_block: Int32, - mma_qk_fn: Callable, - mma_pv_fn: Callable, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - acc_O: cute.Tensor, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - score_mod_fn: Optional[Callable] = None, - mask_fn: Optional[Callable] = None, - is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = True, - ): - pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - # S = Q @ K.T - acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) - self.warp_scheduler_barrier_arrive() - warpgroup.wait_group(0) - pipeline_k.consumer_release(smem_pipe_read) - - # handle score mods and masking - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - if const_expr(mask_fn is not None): - mask_fn(acc_S=acc_S, n_block=n_block) - - row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) - # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - # tOrP.store(tOrP_acc.load().to(self.dtype)) - # the "to(self.dtype)" conversion fails to vectorize for block sizes other - # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of - # 2 elements. So we just call ptx directly. - utils.cvt_f16(tOrP_acc, tOrP_cur) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(acc_O, row_scale) - if const_expr(not self.mma_pv_is_rs): - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) - self.warp_scheduler_barrier_sync() - # O += P @ V - mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) - pipeline_v.consumer_release(smem_pipe_read) - smem_pipe_read.advance() - return smem_pipe_read - - @cute.jit - def mma_one_n_block_intrawg_overlap( - self, - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - n_block: Int32, - mma_qk_fn: Callable, - mma_pv_fn: Callable, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - acc_O: cute.Tensor, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - score_mod_fn: Optional[Callable] = None, - mask_fn: Optional[Callable] = None, - check_inf: cutlass.Constexpr = True, - ): - smem_pipe_read_v = smem_pipe_read.clone() - smem_pipe_read.advance() - pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - self.warp_scheduler_barrier_sync() - # S = Q @ K.T - acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) - pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) - # O += P @ V - mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) - self.warp_scheduler_barrier_arrive() - warpgroup.wait_group(1) - pipeline_k.consumer_release(smem_pipe_read) - - # handle score mods and masking - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - if const_expr(mask_fn is not None): - mask_fn(acc_S=acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) - - row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) - warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read_v) - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - # the "to(self.dtype)" conversion fails to vectorize for block sizes other - # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of - # 2 elements. So we just call ptx directly. - utils.cvt_f16(tOrP_acc, tOrP_cur) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(acc_O, row_scale) - if const_expr(not self.mma_pv_is_rs): - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - return smem_pipe_read - - @cute.jit - def mma_init(self): - warp_group_idx = utils.canonical_warp_group_idx(sync=False) - if const_expr(self.use_scheduler_barrier): - if warp_group_idx == 1: - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - @cute.jit - def apply_score_mod( - self, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - acc_S, - n_block, - softmax_scale, - seqlen, - aux_tensors: Optional[list] = None, - fastdiv_mods=None, - ): - # Prepare index tensor - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) - cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) - tScS = thr_mma_qk.partition_C(cS) - - apply_score_mod_inner( - acc_S, - tScS, - self.score_mod, - batch_idx, - head_idx, - softmax_scale, - self.vec_size, - self.qk_acc_dtype, - aux_tensors, - fastdiv_mods, - seqlen_info=seqlen, - constant_q_idx=None, - qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - - def warp_scheduler_barrier_sync(self): - if const_expr(self.use_scheduler_barrier): - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - - 1 - + utils.canonical_warp_group_idx(sync=False), - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - def warp_scheduler_barrier_arrive(self): - if const_expr(self.use_scheduler_barrier): - assert self.num_mma_warp_groups in [2, 3] - cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 - if const_expr(self.num_mma_warp_groups == 2): - next_wg = 1 - cur_wg - else: - t = cur_wg + 1 - next_wg = t % self.num_mma_warp_groups - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, - number_of_threads=2 * self.num_threads_per_warp_group, - ) +# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility +def __getattr__(name): + if name == "FlashAttentionForwardSm90": + from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + return FlashAttentionForwardSm90 + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 4ec277ab842..493620235ec 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -10,7 +10,7 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Boolean, const_expr from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned @@ -24,7 +24,7 @@ def __init__( dtype: Type[cutlass.Numeric], dtype_partial: Type[cutlass.Numeric], head_dim: int, - m_block_size: int = 8, + tile_m: int = 8, k_block_size: int = 64, log_max_splits: int = 4, num_threads: int = 256, @@ -36,7 +36,7 @@ def __init__( :param dtype: output data type :param dtype_partial: partial accumulation data type :param head_dim: head dimension - :param m_block_size: m block size + :param tile_m: m block size :param k_block_size: k block size :param log_max_splits: log2 of maximum splits :param num_threads: number of threads @@ -46,7 +46,7 @@ def __init__( self.dtype = dtype self.dtype_partial = dtype_partial self.head_dim = head_dim - self.m_block_size = m_block_size + self.tile_m = tile_m self.k_block_size = k_block_size self.max_splits = 1 << log_max_splits self.num_threads = num_threads @@ -58,7 +58,7 @@ def can_implement( dtype, dtype_partial, head_dim, - m_block_size, + tile_m, k_block_size, log_max_splits, num_threads, @@ -72,12 +72,12 @@ def can_implement( return False if num_threads % 32 != 0: return False - if m_block_size % 8 != 0: + if tile_m % 8 != 0: return False max_splits = 1 << log_max_splits if max_splits > 256: return False - if (m_block_size * max_splits) % num_threads != 0: + if (tile_m * max_splits) % num_threads != 0: return False return True @@ -124,15 +124,11 @@ def _setup_attributes(self): lse_copy_bits = Float32.width # 1 element per copy, width is in bits m_block_smem = ( 128 - if self.m_block_size % 128 == 0 + if self.tile_m % 128 == 0 else ( 64 - if self.m_block_size % 64 == 0 - else ( - 32 - if self.m_block_size % 32 == 0 - else (16 if self.m_block_size % 16 == 0 else 8) - ) + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) ) ) gmem_threads_per_row_lse = m_block_smem @@ -183,12 +179,12 @@ def _setup_attributes(self): smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) ) self.smem_layout_lse = cute.tile_to_shape( - smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) ) # O partial shared memory layout (simple layout for pipeline stages) self.smem_layout_o = cute.make_ordered_layout( - (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) ) @cute.jit @@ -201,7 +197,9 @@ def __call__( cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, semaphore_to_reset: Optional[cute.Tensor] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # Type checking @@ -269,7 +267,7 @@ class SharedStorage: sLSE: cute.struct.Align[ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 ] - sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] sO: cute.struct.Align[ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 ] @@ -290,7 +288,7 @@ class SharedStorage: head_divmod = FastDivmodDivisor(num_head) grid_dim = ( - cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(seqlen * num_head, self.tile_m), cute.ceil_div(self.head_dim, self.k_block_size), batch_size, ) @@ -303,6 +301,7 @@ class SharedStorage: cu_seqlens, seqused, num_splits_dynamic_ptr, + varlen_batch_idx, semaphore_to_reset, SharedStorage, self.smem_layout_lse, @@ -331,6 +330,7 @@ def kernel( cu_seqlens: Optional[cute.Tensor], seqused: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], semaphore_to_reset: Optional[cute.Tensor], SharedStorage: cutlass.Constexpr, smem_layout_lse: cute.Layout | cute.ComposedLayout, @@ -345,7 +345,14 @@ def kernel( ): # Thread and block indices tidx, _, _ = cute.arch.thread_idx() - m_block, k_block, batch_idx = cute.arch.block_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + # Map virtual batch index to real batch index (for persistent tile schedulers) + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -353,22 +360,23 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) sLSE = storage.sLSE.get_tensor(smem_layout_lse) - sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) sO = storage.sO.get_tensor(smem_layout_o) - # Handle semaphore reset + # Handle semaphore reset — wait for dependent grids first if const_expr(semaphore_to_reset is not None): if ( tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and k_block == cute.arch.grid_dim()[1] - 1 - and batch_idx == cute.arch.grid_dim()[2] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 ): + cute.arch.griddepcontrol_wait() semaphore_to_reset[0] = 0 - # Get number of splits + # Get number of splits (use maybe_virtual_batch for per-batch-slot splits) num_splits = ( - num_splits_dynamic_ptr[batch_idx] + num_splits_dynamic_ptr[maybe_virtual_batch] if const_expr(num_splits_dynamic_ptr is not None) else mLSE_partial.shape[1] ) @@ -378,6 +386,7 @@ def kernel( seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded ) seqlen, offset = seqlen_info.seqlen, seqlen_info.offset @@ -387,29 +396,27 @@ def kernel( # Early exit for single split if dynamic if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( - const_expr(not varlen) or m_block * self.m_block_size < max_idx + const_expr(not varlen) or m_block * self.tile_m < max_idx ): + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + cute.arch.griddepcontrol_wait() + # =============================== # Step 1: Load LSE_partial from gmem to shared memory # =============================== - if const_expr(cu_seqlens is None): - mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] - else: - mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) - gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) - # Create identity tensor for coordinate tracking - cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) # Load LSE partial values for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): mi = tLSEcLSE[0, 0, m][1] # Get m coordinate - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if idx < max_idx: # Calculate actual sequence position and head using FastDivmodDivisor if const_expr(not varlen): @@ -436,22 +443,19 @@ def kernel( # =============================== gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) - cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) tOcO = gmem_thr_copy_O_partial.partition_D(cO) tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) - if const_expr(cu_seqlens is None): - mO_partial_cur = mO_partial[None, None, None, None, batch_idx] - else: - mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) # Precompute these values to avoid recomputing them in the loop num_rows = const_expr(cute.size(tOcO, mode=[1])) - tOmidx = cute.make_fragment(num_rows, cutlass.Int32) - tOhidx = cute.make_fragment(num_rows, cutlass.Int32) - tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) for m in cutlass.range(num_rows, unroll_full=True): mi = tOcO[0, m, 0][0] # m coordinate - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if const_expr(not varlen): tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) else: @@ -463,11 +467,12 @@ def kernel( if idx >= max_idx: tOhidx[m] = -1 - tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + tOpO = None if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) for k in cutlass.range(cute.size(tOpO), unroll_full=True): tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size - # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) load_O_partial = partial( self.load_O_partial, @@ -501,17 +506,17 @@ def kernel( s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) - ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) # =============================== # Step 4: Compute final LSE along split dimension # =============================== - lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) # We compute the max valid split for each row to short-circuit the computation later - max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) assert cute.size(ts2rrLSE, mode=[0]) == 1 # Compute max, scales, and final LSE for each row for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): @@ -561,7 +566,7 @@ def kernel( for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] - if mi < self.m_block_size: + if mi < self.tile_m: sMaxValidSplit[mi] = max_valid_split[m] # =============================== @@ -577,7 +582,7 @@ def kernel( for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if idx < max_idx: if const_expr(not varlen): head_idx, m_idx = divmod(idx, seqlen_divmod) @@ -594,11 +599,11 @@ def kernel( # Get max valid split for this thread thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] - for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) - tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) - tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) tOrO.fill(0.0) stage_load = self.stages - 1 @@ -607,7 +612,7 @@ def kernel( # Main accumulation loop for s in cutlass.range(thr_max_valid_split + 1, unroll=4): # Get scales for this split - scale = cute.make_fragment(num_rows, Float32) + scale = cute.make_rmem_tensor(num_rows, Float32) for m in cutlass.range(num_rows, unroll_full=True): scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem @@ -637,8 +642,9 @@ def kernel( # Step 7: Write final O to gmem # =============================== - rO = cute.make_fragment_like(tOrO, self.dtype) + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) rO.store(tOrO.load().to(self.dtype)) + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) if const_expr(cu_seqlens is None): mO_cur = mO[None, None, None, batch_idx] else: @@ -665,7 +671,7 @@ def load_O_partial( tOrOptr: cute.Tensor, tOsO_partial: cute.Tensor, tOhidx: cute.Tensor, - tOpO: cute.Tensor, + tOpO: Optional[cute.Tensor], tOcO: cute.Tensor, mO_cur_partial_layout: cute.Layout, split: Int32, @@ -684,7 +690,7 @@ def load_O_partial( mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_load - if const_expr(self.is_even_k) or tOpO[k]: + if const_expr(tOpO is None) or tOpO[k]: cute.copy( gmem_tiled_copy_O_partial, mO_partial_cur_copy[None, k_idx, split], diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 0c27439d845..6c9c20d0b76 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -13,25 +13,31 @@ # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py -import enum import math -from typing import Type, Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Int64, Boolean, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass import pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.utils import ClcDynamicPersistentTileScheduler +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL + +from quack import copy_utils, layout_utils from flash_attn.cute.paged_kv import PagedKVManager -import flash_attn.cute.utils as utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -import flash_attn.cute.pipeline as pipeline +from flash_attn.cute import utils +import flash_attn.cute.pipeline as pipeline_custom +import cutlass.pipeline as cutlass_pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -43,31 +49,48 @@ softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor +from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, TileSchedulerArguments, + TileSchedulerProtocol, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, - ParamsBase, ) - - -class NamedBarrierFwd(enum.IntEnum): - Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() -# WarpSchedulerWG1 = enum.auto() -# WarpSchedulerWG2 = enum.auto() -# WarpSchedulerWG3 = enum.auto() -# PFull = enum.auto() -# PEmpty = enum.auto() +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid + +# === TUNING KNOBS (agent-editable) === +# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) +# Values: +# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). +# SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_start_frg: int — fragment index to start emulation from +# num_regs_softmax: int — register count for softmax warps (multiple of 8) +# num_regs_correction: int — register count for correction warps (multiple of 8) +# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, + (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, + (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, + (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, + (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, +} +# === END TUNING KNOBS === class FlashAttentionForwardSm100: - arch = 100 def __init__( self, @@ -89,6 +112,8 @@ def __init__( has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, is_varlen_q: bool = False, + use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -105,14 +130,27 @@ def __init__( self.n_block_size = n_block_size self.q_stage = q_stage assert self.q_stage in [1, 2] - - # 2 Q tile per CTA + self.use_2cta_instrs = use_2cta_instrs + # If split_P_arrive, the softmax warps write some columns of P first, signal to the MMA warp + # to being the P @ V MMA, then write the rest of P and signal again. This allows some overlap + # between compute the last couple columns of P and the P @ V MMA. + self.split_P_arrive = n_block_size // 4 * 3 + self.split_P_arrive = int(self.split_P_arrive / 32) * 32 # multiple of 32 + assert self.split_P_arrive % 32 == 0 + assert self.split_P_arrive < self.n_block_size + self.arch = BaseDSL._get_dsl().get_arch_enum() + assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" + + self.cta_group_size = 2 if self.use_2cta_instrs else 1 + # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) - self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) - self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) + # With 2CTA, the MMA tiler M covers both CTAs, so it's cta_group_size * m_block_size. + # Each CTA owns m_block_size rows; the 2CTA MMA instruction spans both. + self.mma_tiler_qk = (self.cta_group_size * m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_pv = (self.cta_group_size * m_block_size, self.head_dim_v_padded, n_block_size) self.qk_acc_dtype = Float32 self.pv_acc_dtype = Float32 - self.cluster_shape_mn = (1, 1) + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = is_local @@ -122,21 +160,21 @@ def __init__( self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa self.q_subtile_factor = q_subtile_factor - if pack_gqa: - assert m_block_size % self.qhead_per_kvhead == 0, ( - "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" - ) assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( "SplitKV is not supported for hdim >= 192" ) self.score_mod = score_mod self.mask_mod = mask_mod - if cutlass.const_expr(has_aux_tensors): - self.vec_size: cutlass.Constexpr = 1 - else: - self.vec_size: cutlass.Constexpr = 2 + self.vec_size: cutlass.Constexpr = getattr( + score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 + ) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f + self.is_sm103 = is_sm103 + # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic + _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 + self.enable_ex2_emu = _default_enable_ex2_emu self.s0_s1_barrier = False self.overlap_sO_sQ = ( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or @@ -149,6 +187,32 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = ( + use_clc_scheduler + and self.use_tma_KV + and not self.overlap_sO_sQ + ) + self.sched_stages = 1 + if self.use_clc_scheduler: + assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] == self.cta_group_size, ( + f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" + ) + + self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + + if is_varlen_q: + self.TileScheduler = SingleTileVarlenScheduler + elif self.is_causal or self.is_local or self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + elif self.is_persistent: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}") + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -156,8 +220,7 @@ def __init__( self.epilogue_warp_ids = (13,) self.load_warp_ids = (14,) self.empty_warp_ids = (15,) - SM100_TMEM_CAPACITY_COLUMNS = 512 - self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") self.threads_per_cta = cute.arch.WARP_SIZE * len( ( @@ -171,8 +234,10 @@ def __init__( ) ) + self.use_tma_Q = not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + if self.q_stage == 1: - if not self.use_tma_KV: + if not self.use_tma_KV or not self.use_tma_Q: self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids self.load_warp_ids = self.softmax1_warp_ids else: @@ -188,13 +253,15 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) + self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage) ] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded - assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + assert self.tmem_total <= self.tmem_alloc_cols self.tmem_s_to_p_offset = self.n_block_size // 2 self.tmem_p_offset = [ self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) @@ -203,25 +270,26 @@ def __init__( # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset + # Look up tuning config for register counts and ex2_emu params + _tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103) + self._tune = _TUNING_CONFIG.get(_tune_key, {}) + if "ex2_emu_freq" in self._tune: + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 if self.head_dim_padded < 96: self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 self.num_regs_correction = 64 self.num_regs_other = 48 if not paged_kv_non_tma else 80 else: - # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 - self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 - # self.num_regs_softmax = 176 - # self.num_regs_correction = 96 - # self.num_regs_correction = 80 - # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 - self.num_regs_correction = 64 - # self.num_regs_other = 32 - # self.num_regs_other = 64 - # self.num_regs_other = 80 - self.num_regs_other = 48 if not paged_kv_non_tma else 80 - # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 - # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 - self.num_regs_empty = 24 + if not paged_kv_non_tma and "num_regs_softmax" in self._tune: + self.num_regs_softmax = self._tune["num_regs_softmax"] + self.num_regs_correction = self._tune["num_regs_correction"] + elif not paged_kv_non_tma: + self.num_regs_softmax = 192 + self.num_regs_correction = 80 + else: + self.num_regs_softmax = 184 + self.num_regs_correction = 64 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self.buffer_align_bytes = 1024 @@ -235,15 +303,21 @@ def _setup_attributes(self): - Configures pipeline stages for softmax, correction, and epilogue operations """ - self.kv_stage = ( - 4 - if (self.q_dtype.width == 8 or self.q_stage == 1) - and self.head_dim_padded <= 128 - and self.head_dim_v_padded <= 128 - else 3 - ) - self.acc_stage = 1 - # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + smem_size_q = self.q_stage * self.m_block_size * self.head_dim_padded * self.q_dtype.width // 8 + smem_size_o = self.q_stage * self.m_block_size * self.head_dim_v_padded * self.o_dtype.width // 8 + smem_size_q_o = smem_size_q + smem_size_o if not self.overlap_sO_sQ else max(smem_size_q, smem_size_o) + smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8 + smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8 + smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size + kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage + if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2: + # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem + kv_stage = 3 + self.kv_stage = kv_stage + # print("kv_stage", self.kv_stage) + self.s_stage = 2 + assert self.s_stage >= self.q_stage + # For hdim 192,128 1CTA, we don't have enough smem to store all 3 stages of KV: # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be @@ -268,7 +342,6 @@ def __call__( mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -279,6 +352,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -325,40 +400,40 @@ def __call__( V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) - self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() - self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() - self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() - self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) - - if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): - raise RuntimeError("The layout of mQ is not supported") - if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): - raise RuntimeError("The layout of mK is not supported") - if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): - raise RuntimeError("The layout of mV is not supported") - # check type consistency if const_expr(self.q_dtype != self.k_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None - # This can be tuned - self.e2e_freq = 16 - if const_expr( - self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa - ): - self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 - - cta_group = tcgen05.CtaGroup.ONE + self.use_tma_O = ( + self.arch >= Arch.sm_90 + and mCuSeqlensQ is None + and mSeqUsedQ is None + and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + and not (self.pack_gqa and self.is_split_kv) + ) + self.ex2_emu_freq = 0 + self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1) + if const_expr(self.enable_ex2_emu): + self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16) + if const_expr( + self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local + ): + self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10) + + cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + q_major_mode = tcgen05.OperandMajorMode.K + k_major_mode = tcgen05.OperandMajorMode.K + v_major_mode = tcgen05.OperandMajorMode.MN + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) # the intermediate tensor p is from tmem & mK-major p_source = tcgen05.OperandSource.TMEM p_major_mode = tcgen05.OperandMajorMode.K tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( self.q_dtype, - self.q_major_mode, - self.k_major_mode, + q_major_mode, + k_major_mode, self.qk_acc_dtype, cta_group, self.mma_tiler_qk[:2], @@ -366,7 +441,7 @@ def __call__( tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( self.v_dtype, p_major_mode, - self.v_major_mode, + v_major_mode, self.pv_acc_dtype, cta_group, self.mma_tiler_pv[:2], @@ -374,42 +449,27 @@ def __call__( ) self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) - self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout(self.cluster_shape_mnk), - (tiled_mma_qk.thr_id.shape,), + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) ) - self.epi_tile = self.mma_tiler_pv[:2] + # epi_tile is per-CTA (not full 2CTA) since each CTA writes its own O portion + self.epi_tile = (self.m_block_size, self.head_dim_v_padded) sQ_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_qk, - self.mma_tiler_qk, - self.q_dtype, - self.q_stage, + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage ) sK_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_qk, - self.mma_tiler_qk, - self.k_dtype, - self.kv_stage, + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage ) tP_layout = sm100_utils_basic.make_smem_layout_a( - tiled_mma_pv, - self.mma_tiler_pv, - self.q_dtype, - self.acc_stage, + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage ) sV_layout = sm100_utils_basic.make_smem_layout_b( - tiled_mma_pv, - self.mma_tiler_pv, - self.v_dtype, - self.kv_stage, + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage ) sO_layout = sm100_utils_basic.make_smem_layout_epi( - self.o_dtype, - self.o_layout, - self.epi_tile, - self.q_stage, + self.o_dtype, self.o_layout, self.epi_tile, self.q_stage ) if const_expr(not self.same_hdim_kv_padded): # sK and sV are using the same physical smem so we need to adjust the stride so that they line up @@ -440,50 +500,11 @@ def __call__( ) if const_expr(self.pack_gqa): - shape_Q_packed = ( - (self.qhead_per_kvhead, mQ.shape[0]), - mQ.shape[1], - mK.shape[2], - *mQ.shape[3:], - ) - stride_Q_packed = ( - (mQ.stride[2], mQ.stride[0]), - mQ.stride[1], - mQ.stride[2] * self.qhead_per_kvhead, - *mQ.stride[3:], - ) - mQ = cute.make_tensor( - mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) - ) - shape_O_packed = ( - (self.qhead_per_kvhead, mO.shape[0]), - mO.shape[1], - mK.shape[2], - *mO.shape[3:], - ) - stride_O_packed = ( - (mO.stride[2], mO.stride[0]), - mO.stride[1], - mO.stride[2] * self.qhead_per_kvhead, - *mO.stride[3:], - ) - mO = cute.make_tensor( - mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) - ) + nheads_kv = mK.shape[2] + mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) + mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) if const_expr(mLSE is not None): - shape_LSE_packed = ( - (self.qhead_per_kvhead, mLSE.shape[0]), - mK.shape[2], - *mLSE.shape[2:], - ) - stride_LSE_packed = ( - (mLSE.stride[1], mLSE.stride[0]), - mLSE.stride[1] * self.qhead_per_kvhead, - *mLSE.stride[2:], - ) - mLSE = cute.make_tensor( - mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) - ) + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) @@ -493,20 +514,34 @@ def __call__( ("V", mV, sV_layout), ] } + for name in ("Q", "K", "V"): + self.tma_copy_bytes[name] *= self.cta_group_size # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - mQ, - cute.select(sQ_layout, mode=[0, 1, 2]), - self.mma_tiler_qk, - tiled_mma_qk, - self.cluster_layout_vmnk.shape, - ) + if const_expr(self.use_tma_Q): + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + gmem_tiled_copy_Q = None + else: + tma_atom_Q = None + async_copy_elems = 128 // self.q_dtype.width + num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) + threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) + gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True + ) + tma_atom_K = None + tma_atom_V = None if const_expr(self.use_tma_KV): # TMA load for K tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( @@ -515,7 +550,7 @@ def __call__( cute.select(sK_layout, mode=[0, 1, 2]), self.mma_tiler_qk, tiled_mma_qk, - self.cluster_layout_vmnk.shape, + cta_layout_vmnk.shape, ) # TMA load for V tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( @@ -524,21 +559,13 @@ def __call__( cute.select(sV_layout, mode=[0, 1, 2]), self.mma_tiler_pv, tiled_mma_pv, - self.cluster_layout_vmnk.shape, + cta_layout_vmnk.shape, ) - else: - tma_atom_K = None - tma_atom_V = None - - o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) if const_expr(self.use_tma_O): tma_atom_O, mO = cpasync.make_tiled_tma_atom( - tma_store_op, - mO, - cute.select(sO_layout, mode=[0, 1]), - o_cta_v_layout, + tma_store_op, mO, cute.select(sO_layout, mode=[0, 1]), self.epi_tile ) gmem_tiled_copy_O = None else: @@ -560,19 +587,10 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - if const_expr(self.is_causal or self.is_local): - TileScheduler = SingleTileLPTScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_persistent) - else StaticPersistentTileScheduler - ) + TileScheduler = self.TileScheduler + _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) @@ -594,49 +612,55 @@ def __call__( is_persistent=self.is_persistent, lpt=self.is_causal or self.is_local, is_split_kv=self.is_split_kv, + cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, + ) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - self.mbar_load_q_full_offset = 0 - self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage - self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage - self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage - self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage - self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + self.q_stage - self.mbar_O_full_offset = self.mbar_S_full_offset + self.q_stage - self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + self.q_stage - self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + self.q_stage - self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.q_stage - self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.q_stage - self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + self.q_stage - self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 - self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 - self.mbar_total = self.mbar_P_full_2_offset + self.q_stage - sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 sQ_size = ( cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines - mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] + mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_P_full_lastsplit: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_O_full: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 2] + # mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 4 * 2] + mbar_O_epi: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_s0_s1_sequence: cute.struct.MemRange[Int64, 2 * 2] + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: Int64 # Tmem holding buffer tmem_holding_buf: Int32 # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. + # This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + # CLC response storage (16 bytes per stage, stored as 4 Int32s). + clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ - cute.struct.MemRange[self.o_dtype, sO_size], - self.buffer_align_bytes, + cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] sQ: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, sQ_size], - self.buffer_align_bytes, + cute.struct.MemRange[self.q_dtype, sQ_size], self.buffer_align_bytes ] sK: cute.struct.Align[ # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem @@ -646,35 +670,10 @@ class SharedStorage: self.shared_storage = SharedStorage - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = softmax_scale * LOG2_E - softmax_scale = None - else: - # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = LOG2_E - softmax_scale = softmax_scale - - if const_expr(window_size_left is not None): - window_size_left = Int32(window_size_left) - if const_expr(window_size_right is not None): - window_size_right = Int32(window_size_right) - - fastdiv_mods = None - if cutlass.const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = ( - cute.size(mK.shape[0]) - if const_expr(mPageTable is None) - else mK.shape[0] * mPageTable.shape[1] - ) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) + window_size_left = Int32(window_size_left) if window_size_left is not None else None + window_size_right = Int32(window_size_right) if window_size_right is not None else None + fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable) head_divmod = None if cutlass.const_expr(self.pack_gqa): @@ -711,6 +710,7 @@ class SharedStorage: tP_layout, sV_layout, sO_layout, + gmem_tiled_copy_Q, gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, @@ -722,8 +722,7 @@ class SharedStorage: ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, - smem=self.shared_storage.size_in_bytes(), + cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, stream=stream, min_blocks_per_mp=1, ) @@ -742,7 +741,7 @@ def kernel( mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], @@ -757,6 +756,7 @@ def kernel( tP_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, + gmem_tiled_copy_Q: Optional[cute.TiledCopy], gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, @@ -784,86 +784,171 @@ def kernel( # Prefetch tma descriptor if warp_idx == 0: - cpasync.prefetch_descriptor(tma_atom_Q) - if const_expr(tma_atom_K is not None): - cpasync.prefetch_descriptor(tma_atom_K) - if const_expr(tma_atom_V is not None): - cpasync.prefetch_descriptor(tma_atom_V) - if const_expr(tma_atom_O is not None): - cpasync.prefetch_descriptor(tma_atom_O) + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) + ) + # Setup cta/thread coordinates + bidx, _, _ = cute.arch.block_idx() + if const_expr(cute.size(tiled_mma_qk.thr_id.shape) == 1): + mma_tile_coord_v = 0 + else: + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - mbar_ptr = storage.mbar_ptr.data_ptr() - # Use the first N warps to initialize barriers - if warp_idx == 1: - # Init "full" barrier with number of producers, "empty" barrier with number of consumers - for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_load_q_full_offset + i, 1 - ) - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) - ) - if warp_idx == 2: - for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4 - ) - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4 - ) - if warp_idx == 3: - if const_expr(self.s0_s1_barrier): - for i in cutlass.range_constexpr(8): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE - ) - if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4: - for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_corr_epi_full_offset + i, - cute.arch.WARP_SIZE * len(self.correction_warp_ids), - ) - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_corr_epi_empty_offset + i, - cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), - ) - if warp_idx == 5: - for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, - cute.arch.WARP_SIZE - * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), - ) - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) - ) - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) - ) - if warp_idx == 6: - for i in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_P_full_2_offset + i, - cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), - ) - if warp_idx == 7: - cute.arch.mbarrier_init( - mbar_ptr + self.mbar_tmem_dealloc_offset, - cute.arch.WARP_SIZE - * len( - ( - *self.softmax0_warp_ids, - *self.softmax1_warp_ids, - *self.correction_warp_ids, - ) - ), + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), + num_threads=cute.arch.WARP_SIZE * len( + (self.mma_warp_id, + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids) + ), + ) + # Tensor memory dealloc barrier init + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) + mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) + tma_warp = ThreadCooperativeGroup(1) + load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE) + softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) + softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) + # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) + correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * len(self.correction_warp_ids) + ) + # correction_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) + softmax_correction_threads = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) + ) + epilogue_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + # For UMMA-bridging pipelines: the non-MMA side spans both CTAs in the cluster, + # so the thread count must include warps from both CTAs. + softmax_warps_cluster = ThreadCooperativeGroup( + len(self.softmax0_warp_ids) * self.cta_group_size + ) + correction_threads_cluster = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * len(self.correction_warp_ids) * self.cta_group_size + ) + softmax_correction_threads_cluster = ThreadCooperativeGroup( + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size + ) + if const_expr(self.use_tma_Q): + pipeline_q = pipeline_custom.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_warp, + consumer_group=mma_warp, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + else: + pipeline_q = pipeline_custom.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=load_threads, + consumer_group=mma_warp, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + if const_expr(self.use_tma_KV): + pipeline_kv = pipeline_custom.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=tma_warp, + consumer_group=mma_warp, + tx_count=self.tma_copy_bytes["K"], + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + else: + pipeline_kv = pipeline.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_load_KV.data_ptr(), + num_stages=self.kv_stage, + producer_group=load_threads, + consumer_group=mma_warp, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # This pipeline is not the typical producer-consumer pipeline. The "producer" mma warp + # uses it to signal that S is ready, and the softmax threads wait for S to be ready. + # When softmax threads write P to tmem and the correction threads have rescaled O, they + # signal as "consumer". The mma warp then waits for that signal to do the P @ V gemm. + pipeline_s_p_o = pipeline_custom.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), + num_stages=self.q_stage, + producer_group=mma_warp, + consumer_group=softmax_correction_threads_cluster, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_p_lastsplit = pipeline_custom.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), + num_stages=self.q_stage, + producer_group=softmax_warps_cluster, + consumer_group=mma_warp, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + # MMA warp uses this to signal to the correction warps that O is ready. + pipeline_o_acc = pipeline_custom.PipelineUmmaAsync.create( + barrier_storage=storage.mbar_O_full.data_ptr(), + num_stages=self.q_stage, + producer_group=mma_warp, + consumer_group=correction_threads_cluster, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + pipeline_s0_s1_sequence = None + if const_expr(self.s0_s1_barrier and self.q_stage > 1): + # This is not a typical producer-consumer pipeline. We will directly use + # pipeline_s0_s1_sequence.sync_object_full and will not use + # pipeline_s0_s1_sequence.sync_object_empty. + pipeline_s0_s1_sequence = pipeline_custom.PipelineAsync.create( + barrier_storage=storage.mbar_s0_s1_sequence.data_ptr(), + num_stages=2, + producer_group=softmax_threads, + consumer_group=softmax_threads, + defer_sync=True, ) - # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync - pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + pipeline_sm_stats = pipeline_custom.PipelineAsync.create( + barrier_storage=storage.mbar_softmax_stats.data_ptr(), + num_stages=self.q_stage, + producer_group=softmax_threads, + consumer_group=correction_threads, + defer_sync=True, + ) + # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats + sm_stats_barrier = pipeline_custom.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 + ) + pipeline_o_epi = None + if const_expr(not self.use_correction_warps_for_epi): + pipeline_o_epi = pipeline_custom.PipelineAsync.create( + barrier_storage=storage.mbar_O_epi.data_ptr(), + num_stages=self.q_stage, + producer_group=correction_threads, + consumer_group=epilogue_threads, + defer_sync=True, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) @@ -880,39 +965,26 @@ def kernel( sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) - thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM - thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + thr_mma_qk = tiled_mma_qk.get_slice(mma_tile_coord_v) + thr_mma_pv = tiled_mma_pv.get_slice(mma_tile_coord_v) qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) - tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) - # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # This is a fake tensor, by right we need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. - tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) - tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) - + tStS = thr_mma_qk.make_fragment_C(cute.append(qk_acc_shape, self.s_stage)) pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) - tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) - - tStSs = tuple( - cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) - for stage in range(self.q_stage) - ) - tOtOs = tuple( - cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) - for stage in range(self.q_stage) - ) - + tOtO = thr_mma_pv.make_fragment_C(cute.append(pv_acc_shape, self.q_stage)) + tOtO = cute.make_tensor(tOtO.iterator + self.tmem_o_offset[0], tOtO.layout) tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] - - tOrPs = [ - cute.make_tensor( - tOrP.iterator - + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], - tOrP.layout, - ) - for stage in range(self.q_stage) - ] + # Need to multiply by width ratio bc tP is in v_dtype but tmem offsets are in FP32 + tP_width_ratio = Float32.width // self.v_dtype.width + # Need to adjust the stage stride manually since the two stages aren't contiguous in tmem + tP_stage_stride = (self.tmem_p_offset[1] - self.tmem_p_offset[0]) * tP_width_ratio + tOrP = cute.make_tensor( + tOrP.iterator + self.tmem_p_offset[0] * tP_width_ratio, + cute.append(tOrP.layout, cute.make_layout((self.s_stage,), stride=(tP_stage_stride,))) + ) block_info = BlockInfo( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) @@ -944,14 +1016,69 @@ def kernel( window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + # Cluster wait before tensor memory alloc + pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + + block_idx = cute.arch.block_idx() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + block_idx, + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" # /////////////////////////////////////////////////////////////////////////////// - # EMPTY + # EMPTY / CLC SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i]: - cute.arch.setmaxregister_decrease(self.num_regs_empty) + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -971,57 +1098,50 @@ def kernel( tma_atom_Q, tma_atom_K, tma_atom_V, + gmem_tiled_copy_Q, + pipeline_q, pipeline_kv, - mbar_ptr, block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// # MMA # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: cute.arch.setmaxregister_decrease(self.num_regs_other) - # Alloc tmem buffer - tmem_alloc_cols = Int32(self.tmem_alloc_cols) - if warp_idx == self.mma_warp_id: - cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) - cute.arch.sync_warp() - + # Alloc tensor memory buffer + tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) self.mma( tiled_mma_qk, tiled_mma_pv, sQ, sK, sV, - tStSs, - tOtOs, - tOrPs, + tStS, + tOtO, + tOrP, + pipeline_q, pipeline_kv, - mbar_ptr, + pipeline_s_p_o, + pipeline_p_lastsplit, + pipeline_o_acc, + is_leader_cta, block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) - - # if warp_idx == self.mma_warp_id: - # dealloc tmem buffer - cute.arch.relinquish_tmem_alloc_permit() - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) - tmem_alloc_cols = Int32(self.tmem_alloc_cols) - # Retrieving tmem ptr and make acc - tmem_ptr = cute.arch.retrieve_tmem_ptr( - Float32, - alignment=16, - ptr_to_buffer_holding_addr=storage.tmem_holding_buf, - ) - cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + # Dealloc the tensor memory buffer + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr) # /////////////////////////////////////////////////////////////////////////////// # Epilogue @@ -1034,11 +1154,12 @@ def kernel( sO, gmem_tiled_copy_O, tma_atom_O, - mbar_ptr, + pipeline_o_epi, block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, + mma_tile_coord_v, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1050,6 +1171,9 @@ def kernel( ): # increase register after decreasing cute.arch.setmaxregister_increase(self.num_regs_softmax) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, @@ -1057,67 +1181,68 @@ def kernel( thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, + pipeline_s_p_o=pipeline_s_p_o, + pipeline_p_lastsplit=pipeline_p_lastsplit, + pipeline_sm_stats=pipeline_sm_stats, + sm_stats_barrier=sm_stats_barrier, + pipeline_s0_s1_sequence=pipeline_s0_s1_sequence, learnable_sink=learnable_sink, - mbar_ptr=mbar_ptr, block_info=block_info, num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, - TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, + tile_scheduler=tile_scheduler, ) if const_expr(not self.s0_s1_barrier): stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1) - softmax_loop( - stage=stage, - tStSi=cute.make_tensor( - tStS.iterator - + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), - tStS.layout, - ), - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + softmax_loop(stage=stage, tStS=tStS) else: # If there's s0_s1_barrier, it's faster to have 2 WGs having different code if warp_idx < self.softmax1_warp_ids[0]: - tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout) - softmax_loop(stage=0, tStSi=tStSi) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + softmax_loop(stage=0, tStS=tStS) if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: - tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout) - softmax_loop(stage=1, tStSi=tStSi) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + softmax_loop(stage=1, tStS=tStS) + + tmem_alloc_barrier.arrive() # /////////////////////////////////////////////////////////////////////////////// # Correction # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_correction) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) self.correction_loop( thr_mma_qk, thr_mma_pv, tStS, - tOtOs, + tOtO, sScale, mO, mLSE, sO, + pipeline_s_p_o, + pipeline_o_acc, + pipeline_sm_stats, + sm_stats_barrier, + pipeline_o_epi, learnable_sink, gmem_tiled_copy_O, tma_atom_O, - mbar_ptr, softmax_scale_log2, block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + tmem_alloc_barrier.arrive() return @@ -1133,30 +1258,38 @@ def load( sK: cute.Tensor, sV: cute.Tensor, mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], - pipeline_kv: cutlass.pipeline.PipelineAsync, - mbar_ptr: cute.Pointer, + gmem_tiled_copy_Q: Optional[cute.TiledCopy], + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler: TileSchedulerProtocol, ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + issue_kv_for_this_warp = ( + const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) + issue_q_for_this_warp = ( + const_expr(not self.use_tma_Q or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) q_producer_phase = Int32(1) - kv_producer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.kv_stage + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.kv_stage ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) head_idx_kv = ( head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx @@ -1178,12 +1311,32 @@ def load( gV = cute.local_tile( mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) ) - tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) - load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ - ) + if const_expr(self.use_tma_Q): + tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) + gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) + gQ = layout_utils.select( + cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + tSgQ = thr_mma_qk.partition_A(gQ) + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ + ) + load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) + else: + assert gmem_tiled_copy_Q is not None + load_Q = partial( + self.load_Q_non_tma, + mQ_cur, + sQ, + gmem_tiled_copy_Q, + pipeline_q, + tidx, + seqlen.seqlen_q, + m_block, + phase=q_producer_phase, + ) if const_expr(self.use_tma_KV): tKsK, tKgK = cpasync.tma_partition( @@ -1222,15 +1375,6 @@ def load( tKsK, tKgK = None, None tVsV, tVgV = None, None - load_Q = partial( - self.load_Q, - load_Q_fn, - mbar_ptr + self.mbar_load_q_full_offset, - mbar_ptr + self.mbar_load_q_empty_offset, - phase=q_producer_phase, - ) - # We have to use mbarrier directly in the load for KV instead of replying on - # pipeline_kv, because we could have different number of TMA bytes for K and V load_K = partial( self.load_KV, tma_atom_K, @@ -1238,8 +1382,7 @@ def load( tKsK, paged_kv_manager, sK, - mbar_ptr + self.mbar_load_kv_full_offset, - mbar_ptr + self.mbar_load_kv_empty_offset, + pipeline_kv=pipeline_kv, K_or_V="K", ) load_V = partial( @@ -1249,8 +1392,7 @@ def load( tVsV, paged_kv_manager, sV, - mbar_ptr + self.mbar_load_kv_full_offset, - mbar_ptr + self.mbar_load_kv_empty_offset, + pipeline_kv=pipeline_kv, K_or_V="V", ) @@ -1259,8 +1401,6 @@ def load( seqlen, m_block, split_idx, num_splits ) if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( mPageTable[batch_idx, n_block_first] @@ -1269,13 +1409,19 @@ def load( ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block_first) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 - kv_producer_state.advance() - if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): - load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + if issue_kv_for_this_warp: + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0 + if issue_q_for_this_warp: + load_Q(block=0, stage=0) + if issue_kv_for_this_warp: + kv_producer_state.advance() + if const_expr(self.q_stage == 2) and issue_q_for_this_warp: + load_Q(block=1, stage=1) q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( @@ -1286,10 +1432,11 @@ def load( if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki - kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() else: kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( @@ -1309,11 +1456,15 @@ def load( ) - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop + if issue_kv_for_this_warp: + pipeline_kv.producer_tail(kv_producer_state) + # This is equivalent to pipeline_q.producer_tail for the TMA-Q producer warp. + if issue_q_for_this_warp: + pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase) + @cute.jit def mma( self, @@ -1322,16 +1473,20 @@ def mma( sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, - tStSs: Tuple[cute.Tensor, cute.Tensor], - tOtOs: tuple[cute.Tensor], - tOrPs: Tuple[cute.Tensor, cute.Tensor], - pipeline_kv: cutlass.pipeline.PipelineAsync, - mbar_ptr: cute.Pointer, + tStS: cute.Tensor, + tOtO: cute.Tensor, + tOrP: cute.Tensor, + pipeline_q: pipeline.PipelineAsync, + pipeline_kv: pipeline.PipelineAsync, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + is_leader_cta: Boolean, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler=None, ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1342,36 +1497,84 @@ def mma( tSrQs = (tSrQ[None, None, None, 0],) qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op - + qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op) + q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K) + k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K) + v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN) + q_smem_start = [sm100_desc.make_smem_desc_start_addr(sQ[None, None, None, stage].iterator) for stage in range(self.q_stage)] + + sm100_utils.declare_ptx_smem_desc(q_smem_start[self.q_stage - 1], q_smem_base, tSrQ[None, None, None, 0].layout, var_name_prefix="fa_fwd_q_smem_desc") + sm100_utils.declare_ptx_idesc(qk_mma_op, var_name="fa_fwd_qk_mma_idesc") + sm100_utils.declare_ptx_idesc(pv_mma_op, var_name="fa_fwd_pv_mma_idesc") + + sQ_stage_stride = (sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 + if const_expr(self.q_stage == 1): + sQ_stage_stride = 0 gemm_Si = [ partial( - sm100_utils.gemm_ptx_partial, - qk_mma_op, + # sm100_utils.gemm_ptx_precomputed, + # self.tmem_s_offset[stage], + # smem_desc_start_a=q_smem_start[stage], + # idesc=qk_mma_idesc, + # smem_desc_base_a=q_smem_base, + # smem_desc_base_b=k_smem_base, + # tCrA_layout=tSrQ[None, None, None, 0].layout, + sm100_utils.gemm_ptx_precomputed_varname, self.tmem_s_offset[stage], - tSrQs[stage], - sA=sQ[None, None, None, stage], + # idesc=qk_mma_idesc, + smem_desc_base_b=k_smem_base, + tCrB_layout=tSrK[None, None, None, 0].layout, + smem_var_name_prefix=f"fa_fwd_q_smem_desc", + idesc_var_name=f"fa_fwd_qk_mma_idesc", + smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, + cta_group=self.cta_group_size, ) for stage in range(self.q_stage) ] + # gemm_Si = [ + # partial( + # sm100_utils.gemm, + # tiled_mma_qk, + # tStS[None, None, None, stage], + # tCrA=tSrQ[None, None, None, stage], + # zero_init=True, + # ) + # for stage in range(self.q_stage) + # ] gemm_Pi = [ partial( + # sm100_utils.gemm_ptx_precomputed, sm100_utils.gemm_ptx_partial, pv_mma_op, self.tmem_o_offset[stage], - tOrPs[stage], + tOrP[None, None, None, stage], sA=None, + split_arrive=self.split_P_arrive if self.split_P_arrive > 0 else None, + # smem_desc_start_a=tOrP[None, None, None, stage].iterator.toint(), + # smem_desc_start_a=self.tmem_p_offset[stage], + # idesc=pv_mma_idesc, + # smem_desc_base_a=None, + # smem_desc_base_b=v_smem_base, + # tCrA_layout=tOrP[None, None, None, 0].layout, + # tCrB_layout=tOrV[None, None, None, 0].layout + cta_group=self.cta_group_size, ) for stage in range(self.q_stage) ] + # gemm_Pi = [ + # partial( + # sm100_utils.gemm, tOtO[None, None, None, stage], tCrA=tOrP[None, None, None, stage] + # ) + # for stage in range(self.q_stage) + # ] mma_q_consumer_phase = Int32(0) - mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage + mma_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.kv_stage ) P_full_O_rescaled_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1398,32 +1601,32 @@ def mma( else: process_tile = n_block_min < n_block_max - if process_tile: + if process_tile and is_leader_cta: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase - ) + pipeline_q.consumer_wait_w_index_phase(stage, mma_q_consumer_phase) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) - tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tSrKi = tSrK[None, None, None, Ki_index] # We don't need to acquire empty S0 / S1. # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 # are empty. For subsequent iterations, the wait happened at the end # of the while loop. # 3. gemm - # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) - sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, Ki_index] if const_expr(self.uneven_kv_smem): - sK_cur = self.offset_kv_smem( - sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase - ) - gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + # gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + gemm_Si[stage]( + smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator) + ) + # gemm_Si[stage](tCrB=tSrKi) # 4. release S0 / S1 - with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + pipeline_s_p_o.producer_commit_w_index(stage) mma_q_consumer_phase ^= 1 # 5. release K0 pipeline_kv.consumer_release(mma_kv_consumer_state) @@ -1446,11 +1649,8 @@ def mma( # 2. acquire corrected O0/O1_partial and P0 / P1 # For the first iteration in this work tile, waiting for O0/O1_partial # means that the correction warps has finished reading tO during - # the last iteration of the previous work tile has finished. - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, - P_full_O_rescaled_phase, - ) + # the last iteration of the previous work tile. + pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) @@ -1460,18 +1660,17 @@ def mma( gemm_Pi[stage]( tCrB=tOrVi, sB=sV_cur, + # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator), zero_init=not O_should_accumulate, - mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None, mbar_phase=P_full_O_rescaled_phase, ) - # 4. release accumulated O0_partial / O1_partial - # Don't need to signal O_full to the correction warps anymore since the + # Don't need to signal O_full to the correction warps since the # correction warps wait for the softmax warps anyway. By the time the softmax # warps finished, S_i for the next iteration must have been done, so O_i-1 # must have been done as well. - # with cute.arch.elect_one(): - # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) - # 5. release V(i-1) + # pipeline_o_acc.producer_commit_w_index(stage) + # 4. release V(i-1) if const_expr(stage == self.q_stage - 1): pipeline_kv.consumer_release(mma_kv_release_state) mma_kv_release_state.advance() @@ -1487,14 +1686,17 @@ def mma( # Don't need to wait for the softmax warp to have finished reading the previous # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si # has been read and Pi has been written. - # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) + # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrK[None, None, None, Ki_index], zero_init=True) sK_cur = sK[None, None, None, Ki_index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) - gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) - # 3. release S0 - with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + gemm_Si[stage]( + smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator) + ) + # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index]) + # 3. release S0 / S1 + pipeline_s_p_o.producer_commit_w_index(stage) # End of GEMM_QK0i (Q0 * Ki -> S0) # 4. release Ki pipeline_kv.consumer_release(mma_kv_consumer_state) @@ -1504,9 +1706,8 @@ def mma( # End of seqlen_kv loop # release Q0 & Q1 - with cute.arch.elect_one(): - for stage in cutlass.range_constexpr(self.q_stage): - tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + for stage in cutlass.range(self.q_stage): + pipeline_q.consumer_release_w_index(stage) # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 @@ -1515,9 +1716,7 @@ def mma( tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(self.q_stage): # 2. acquire corrected Oi_partial and Pi - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase - ) + pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) @@ -1527,17 +1726,17 @@ def mma( gemm_Pi[stage]( tCrB=tOrVi, sB=sV_cur, + # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator), zero_init=not O_should_accumulate, - mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None, mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp - # has signaled to the correction warps, the softmax warp has just finished compute - # the row sum of the current tile. It does not guarantee that the 1st tile - # of the next work tile has been computed yet. - with cute.arch.elect_one(): - tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # has signaled to the correction warps, the softmax warp has just finished + # computing the row sum of the current tile. It does not guarantee that the 1st + # tile of the next work tile has been computed yet. + pipeline_o_acc.producer_commit_w_index(stage) # End of GEMM_PV00 (P0 * V0 -> O0_partial) P_full_O_rescaled_phase ^= 1 # 5. release Vi_end @@ -1546,10 +1745,13 @@ def mma( # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop + # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end + # pipeline_s_p_o.producer_acquire_w_index_phase(self.q_stage - 1, P_full_O_rescaled_phase) + # We don't need pipeline_o_acc.producer_tail() since we don't call + # pipeline_o_acc.producer_acquire() inside the loop. # for both softmax0 and softmax1 warp group @cute.jit @@ -1559,20 +1761,24 @@ def softmax_loop( softmax_scale_log2: Float32, softmax_scale: Float32, thr_mma_qk: cute.core.ThrMma, - tStSi: cute.Tensor, + tStS: cute.Tensor, # ((TILE_M, TILE_N), 1, 1, q_stage) sScale: cute.Tensor, mLSE: Optional[cute.Tensor], + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sm_stats_barrier: pipeline.NamedBarrier, + pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync], learnable_sink: Optional[cute.Tensor], - mbar_ptr: cute.Pointer, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1590,50 +1796,48 @@ def softmax_loop( # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) * (len(self.softmax0_warp_ids)) ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) + cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1]) + tSAcc = tStS[(None, None), 0, 0, stage] # (128, 128) + tStScale = cute.composition(tSAcc, cute.make_layout((self.m_block_size, 1))) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScS = tScS[(None, None), 0, 0] # (128, 128) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) - tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width tStP_layout = cute.composition( - tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + tSAcc.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) ) - tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) + tStP = cute.make_tensor(tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), - Float32, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype ) - thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) - tStS_t2r = thr_tmem_load.partition_S(tStSi) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tSAcc) # (((32,32),1),1,4) tmem_store_scale_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), - Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32 ) thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( tidx ) - tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), - Float32, + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) - tStP_r2t = thr_tmem_store.partition_D(tStP) + tStP_r2t = thr_tmem_store.partition_D(tStP) # (((16,32),1),1,4) mma_si_consumer_phase = Int32(0) - si_corr_producer_phase = Int32(1) + sm_stats_producer_phase = Int32(1) s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) # self.warp_scheduler_barrier_init() warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1642,7 +1846,7 @@ def softmax_loop( mask = AttentionMaskCls(seqlen) shared_mask_kwargs = dict( - m_block=self.q_stage * m_block + stage, + m_block=(self.q_stage * m_block + stage) * self.cta_group_size, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, @@ -1715,9 +1919,12 @@ def softmax_loop( softmax_step = partial( self.softmax_step, softmax=softmax, - mbar_ptr=mbar_ptr, - mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, thr_mma_qk=thr_mma_qk, + pipeline_s_p_o=pipeline_s_p_o, + pipeline_p_lastsplit=pipeline_p_lastsplit, + pipeline_sm_stats=pipeline_sm_stats, + sm_stats_barrier=sm_stats_barrier, + pipeline_s0_s1_sequence=pipeline_s0_s1_sequence, thr_tmem_load=thr_tmem_load, thr_tmem_store=thr_tmem_store, thr_tmem_store_scale=thr_tmem_store_scale, @@ -1728,32 +1935,30 @@ def softmax_loop( stage=stage, batch_idx=batch_idx, head_idx=head_idx, - m_block=self.q_stage * m_block + stage, + m_block=(self.q_stage * m_block + stage) * self.cta_group_size, seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, ) - if has_work: - # Softmax acts as the producer: wait until correction signals the stage is empty - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase - ) - si_corr_producer_phase ^= 1 + if const_expr(self.use_block_sparsity) or has_work: + # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. + pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) + sm_stats_producer_phase ^= 1 # Block sparse or dense iteration if const_expr(self.use_block_sparsity): # When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid # OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this. if const_expr(aux_tensors is not None): - m_tile_end = (self.q_stage * m_block + stage + 1) * self.m_block_size + m_tile_end = ((self.q_stage * m_block + stage + 1) * self.cta_group_size) * self.m_block_size check_m_boundary = m_tile_end > seqlen.seqlen_q else: check_m_boundary = False ( mma_si_consumer_phase, - si_corr_producer_phase, + sm_stats_producer_phase, s0_s1_sequence_phase, empty_tile, ) = softmax_block_sparse_sm100( @@ -1765,13 +1970,10 @@ def softmax_loop( mask_fn, mask_fn_none, mma_si_consumer_phase, - si_corr_producer_phase, + sm_stats_producer_phase, s0_s1_sequence_phase, - mbar_ptr, - self.mbar_softmax_corr_full_offset, - self.mbar_softmax_corr_empty_offset, - self.mbar_P_full_O_rescaled_offset, - self.mbar_P_full_2_offset, + pipeline_sm_stats, + sm_stats_barrier, self.q_stage, Int32(stage), check_m_boundary, @@ -1786,13 +1988,15 @@ def softmax_loop( ] = softmax.row_max[0] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. + # pipeline_sm_stats.producer_commit_w_index(stage) + sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) else: if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, - si_corr_producer_phase, + sm_stats_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, @@ -1806,10 +2010,10 @@ def softmax_loop( ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = ( softmax_step( mma_si_consumer_phase, - si_corr_producer_phase, + sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), @@ -1823,23 +2027,23 @@ def softmax_loop( for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 if const_expr(self.mask_mod is not None): - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), ) else: - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, ) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = ( softmax_step( mma_si_consumer_phase, - si_corr_producer_phase, + sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), @@ -1853,7 +2057,8 @@ def softmax_loop( sScale[ tidx + stage * self.m_block_size + self.q_stage * self.m_block_size ] = softmax.row_max[0] - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # pipeline_sm_stats.producer_commit_w_index(stage) + sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # # Write LSE to gmem # if const_expr(mLSE is not None): @@ -1875,21 +2080,30 @@ def softmax_loop( # gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop + # This is equivalent to pipeline_sm_stats.producer_tail + pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) + # This is equivalent to pipeline_s0_s1.producer_tail + if const_expr(self.s0_s1_barrier): + if stage == 0: + pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase) + @cute.jit def softmax_step( self, mma_si_consumer_phase: Int32, - si_corr_producer_phase: Int32, + sm_stats_producer_phase: Int32, s0_s1_sequence_phase: Int32, n_block: Int32, softmax: SoftmaxSm100, - mbar_ptr: cute.Pointer, - mbar_s0_s1_sequence_offset: Int32, thr_mma_qk: cute.core.ThrMma, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_p_lastsplit: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sm_stats_barrier: pipeline.NamedBarrier, + pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync], thr_tmem_load: cute.CopyAtom, thr_tmem_store: cute.CopyAtom, thr_tmem_store_scale: cute.CopyAtom, @@ -1923,15 +2137,20 @@ def softmax_step( 5. Computing row sums for normalization 6. Coordinating pipeline synchronization between different processing stages """ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) - tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) - tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32))) + tScS = tScS[(None, None), 0, 0] # (128, 128) + # tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1]) + tScS_shape = cta_qk_tiler # (128, 128) + tScP_shape = (tScS_shape[0], tilePlikeFP32) # (128, 64) # Wait for Si - cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) + pipeline_s_p_o.consumer_wait_w_index_phase(stage, mma_si_consumer_phase) tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype) cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + # tSrS_t2r = copy_utils.load_t2r(thr_tmem_load, tScS_shape, tStS_t2r) if cutlass.const_expr(self.score_mod is not None): self.apply_score_mod( tSrS_t2r, @@ -1961,51 +2180,52 @@ def softmax_step( sScale[thread_idx + stage * self.m_block_size] = acc_scale # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) # Notify correction wg that row_max is ready - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # pipeline_sm_stats.producer_commit_w_index(stage) + sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) - # print(tSrS_t2r) softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait if const_expr(self.s0_s1_barrier): - cute.arch.mbarrier_wait( - mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase - ) - tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) + pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase) + tSrP_r2t_f32 = cute.make_fragment( + thr_tmem_store.partition_S(cute.make_identity_tensor(tScP_shape)).shape, Float32 + ) tSrP_r2t = cute.make_tensor( - cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), - tSrS_t2r.layout, + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) softmax.apply_exp2_convert( tSrS_t2r, tSrP_r2t, - e2e=mask_fn is None and self.head_dim_padded <= 128, - e2e_freq=self.e2e_freq, + ex2_emu_freq=self.ex2_emu_freq if const_expr(mask_fn is None) else 0, + ex2_emu_start_frg=self.ex2_emu_start_frg, ) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): - cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + pipeline_s0_s1_sequence.sync_object_full.arrive(1 - stage, dst=None) # print(tSrP_r2t_f32, tStP_r2t) # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) - for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) - cute.arch.fence_view_async_tmem_store() - # Notify mma warp that P is ready - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - for i in cutlass.range_constexpr( - cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) - ): - cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) - cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + split_P_arrive_idx = cute.size(tStP_r2t.shape[2]) * self.split_P_arrive // self.n_block_size + if const_expr(i + 1 == split_P_arrive_idx): + # Notify mma warp that the 1st half of P is ready + cute.arch.fence_view_async_tmem_store() + pipeline_s_p_o.consumer_release_w_index(stage) # Notify mma warp that the 2nd half of P is ready - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase - ) + cute.arch.fence_view_async_tmem_store() + if const_expr(self.split_P_arrive > 0): + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_p_lastsplit.producer_commit_w_index(stage) + else: + pipeline_s_p_o.consumer_release_w_index(stage) + pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.math.exp2(acc_scale_, fastmath=True) - return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 + return mma_si_consumer_phase ^ 1, sm_stats_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 @cute.jit def correction_loop( @@ -2013,23 +2233,30 @@ def correction_loop( thr_mma_qk: cute.core.ThrMma, thr_mma_pv: cute.core.ThrMma, tStS: cute.Tensor, - tOtOs: tuple[cute.Tensor], + tOtO: cute.Tensor, sScale: cute.Tensor, mO: cute.Tensor, mLSE: cute.Tensor, sO: cute.Tensor, + pipeline_s_p_o: pipeline.PipelineAsync, + pipeline_o_acc: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sm_stats_barrier: pipeline.NamedBarrier, + pipeline_o_epi: pipeline.PipelineAsync, learnable_sink: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, - mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mma_tile_coord_v = thr_mma_qk.thr_idx + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) tStScales = tuple( @@ -2038,8 +2265,7 @@ def correction_loop( ) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tmem_load_v_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), - self.qk_acc_dtype, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype ) thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) @@ -2047,14 +2273,14 @@ def correction_loop( tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape # First iter: no correction is required - for stage in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # Notify mma warp that O has been rescaled + for stage in cutlass.range(self.q_stage): + pipeline_s_p_o.consumer_release_w_index(stage) - softmax_corr_consumer_phase = Int32(0) + sm_stats_consumer_phase = Int32(0) o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2065,7 +2291,14 @@ def correction_loop( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage @@ -2086,24 +2319,20 @@ def correction_loop( if has_work: # Ignore first signal from softmax as no correction is required - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + # pipeline_sm_stats.consumer_wait_w_index_phase(0, sm_stats_consumer_phase) + sm_stats_barrier.arrive_and_wait_w_index(index=0 * 4 + warp_idx) + pipeline_sm_stats.consumer_release_w_index(0) if const_expr(self.q_stage == 2): - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase - ) - softmax_corr_consumer_phase ^= 1 + # pipeline_sm_stats.consumer_wait_w_index_phase(1, sm_stats_consumer_phase) + sm_stats_barrier.arrive_and_wait_w_index(index=1 * 4 + warp_idx) + sm_stats_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) for i in cutlass.range(total_block_count - 1, unroll=1): for stage in cutlass.range_constexpr(self.q_stage): # wait for S0 / S1 - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + stage, - softmax_corr_consumer_phase, - ) + # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) + sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -2113,24 +2342,16 @@ def correction_loop( # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) # Don't need O_full anymore, since by the time softmax has signaled the correction # warps, S_i must have been done, so O_i-1 must have been done as well. - # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + # pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if should_rescale: - self.correction_rescale( - thr_mma_pv, tOtOs[stage], tidx, scale - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) - if const_expr(self.q_stage == 2): - cute.arch.mbarrier_arrive( - mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) - ) - else: - cute.arch.mbarrier_arrive( - mbar_ptr + self.mbar_softmax_corr_empty_offset + stage - ) - softmax_corr_consumer_phase ^= 1 + self.correction_rescale(thr_mma_pv, tOtO[None, None, None, stage], tidx, scale) + # Notify mma warp that O has been rescaled + pipeline_s_p_o.consumer_release_w_index(stage) + pipeline_sm_stats.consumer_release_w_index(self.q_stage - 1 - stage) + sm_stats_consumer_phase ^= 1 # o_corr_consumer_phase ^= 1 if const_expr(self.q_stage == 2): - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + pipeline_sm_stats.consumer_release_w_index(1) # End of seqlen_corr_loop_steps # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without @@ -2144,14 +2365,12 @@ def correction_loop( else: # Each thread might have a different sink value due to different q_head for stage in cutlass.range_constexpr(self.q_stage): q_head_idx = ( - (self.q_stage * m_block + stage) * self.m_block_size + tidx + ((m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v) * self.m_block_size + tidx ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_softmax_corr_full_offset + stage, - softmax_corr_consumer_phase, - ) + # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) + sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] @@ -2160,7 +2379,7 @@ def correction_loop( row_max = sScale[tidx + stage * self.m_block_size + self.q_stage * self.m_block_size] else: row_max = None - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + pipeline_sm_stats.consumer_release_w_index(stage) if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) sink_val = learnable_sink_val[stage] @@ -2176,16 +2395,14 @@ def correction_loop( acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase - ) + # Wait for the last O to be ready from the MMA warp + pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase - ) + pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None self.correction_epilogue( thr_mma_pv, - tOtOs[stage], + tOtO[None, None, None, stage], tidx, stage, m_block, @@ -2193,28 +2410,26 @@ def correction_loop( scale, sO[None, None, stage], mO_cur, - gO, + gO_stage, gmem_tiled_copy_O, ) - if const_expr(not self.use_correction_warps_for_epi): - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so # mma warp can write to them - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + pipeline_s_p_o.consumer_release_w_index(stage) + if const_expr(not self.use_correction_warps_for_epi): + pipeline_o_epi.producer_commit_w_index(stage) # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) o_corr_consumer_phase ^= 1 - softmax_corr_consumer_phase ^= 1 + sm_stats_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 else: - # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781 + gmem_tiled_copy_O_for_empty_tile = None if const_expr(self.use_correction_warps_for_epi): gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O - else: - gmem_tiled_copy_O_for_empty_tile = None if const_expr(self.use_block_sparsity): ( - softmax_corr_consumer_phase, + sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, ) = handle_block_sparse_empty_tile_correction_sm100( @@ -2235,16 +2450,12 @@ def correction_loop( stats, self.correction_epilogue, thr_mma_pv, - tOtOs, + tOtO, sO, - mbar_ptr, - self.mbar_softmax_corr_full_offset, - self.mbar_softmax_corr_empty_offset, - self.mbar_P_full_O_rescaled_offset, - self.mbar_P_full_2_offset, - self.mbar_corr_epi_full_offset, - self.mbar_corr_epi_empty_offset, - softmax_corr_consumer_phase, + pipeline_sm_stats, + sm_stats_barrier, + pipeline_o_epi, + sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, softmax_scale_log2, @@ -2268,9 +2479,7 @@ def correction_loop( else: mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): - gLSE = cute.local_tile( - mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) - ) + m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -2285,15 +2494,30 @@ def correction_loop( if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead ) - if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: - # This actually just works with PackGQA too - gLSE[tidx] = lse + if const_expr(not self.pack_gqa or self.m_block_size % self.qhead_per_kvhead == 0): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) + if tidx < seqlen_q - m_tile_idx * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + else: + idx = m_tile_idx * self.m_block_size + tidx + if idx < seqlen_q: + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + lse_ptr_i64 = utils.elem_pointer(mLSE_cur, ((h_idx, m_idx),)).toint() + lse_gmem_ptr = cute.make_ptr( + mLSE_cur.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop + # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps + if const_expr(not self.use_correction_warps_for_epi): + pipeline_o_epi.producer_acquire_w_index_phase(self.q_stage - 1, corr_epi_producer_phase) + @cute.jit def correction_rescale( self, @@ -2317,8 +2541,7 @@ def correction_rescale( tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) corr_tile_size = 16 # tuneable parameter tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.pv_acc_dtype, + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype ) tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), @@ -2340,8 +2563,7 @@ def correction_rescale( cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), - (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale) ) tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) @@ -2385,8 +2607,9 @@ def correction_epilogue( :type sO: cute.Tensor """ - corr_tile_size = 32 * 8 // self.o_dtype.width - tOsO = thr_mma.partition_C(sO) + corr_tile_size = 8 * 32 // self.o_dtype.width + # Use CTA 0 mapping for smem partitioning since sO is per-CTA sized + tOsO = thr_mma.get_slice(0).partition_C(sO) tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) @@ -2400,11 +2623,9 @@ def correction_epilogue( self.o_dtype, self.pv_acc_dtype, epi_subtile, - use_2cta_instrs=False, - ) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice( - tidx + use_2cta_instrs=self.use_2cta_instrs, ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) thr_tmem_load = tiled_tmem_load.get_slice(tidx) smem_copy_atom = sm100_utils_basic.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load @@ -2412,72 +2633,79 @@ def correction_epilogue( tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) - tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOsO_s2r = copy_utils.partition_D_position_independent(thr_tmem_load, tOsO_i[(None, None), None]) tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) - for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): + for i in cutlass.range(self.head_dim_v_padded // corr_tile_size, unroll_full=True): tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) - for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): + for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( - (tOrO_frg[j], tOrO_frg[j + 1]), - (scale, scale), + (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale) ) - tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) - tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) - cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) - # fence view async shared + copy_utils.cvt_copy(tiled_smem_store, tOrO_frg, tOsO_r2s_i) cute.arch.fence_view_async_shared() if const_expr(self.use_correction_warps_for_epi): assert(not self.use_tma_O) assert(gmem_tiled_copy_O is not None) - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue), number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - tOsO = gmem_thr_copy_O.partition_S(sO) - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tOgO = gmem_thr_copy_O.partition_D(gO) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) - pack_gqa = PackGQA( - self.m_block_size, - self.head_dim_v_padded, - self.check_hdim_v_oob, - self.qhead_per_kvhead, + mma_tile_coord_v = thr_mma.thr_idx + m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v + self._store_O_to_gmem( + sO, gO, mO_cur, gmem_tiled_copy_O, tidx, seqlen_q, m_tile_idx ) - # load acc O from smem to rmem for wider vectorization - tOrO = cute.make_fragment_like(tOsO, self.o_dtype) - cute.autovec_copy(tOsO, tOrO) - # copy acc O from rmem to gmem - if const_expr(not self.pack_gqa): - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if ( - t0OcO[0, rest_m, 0][0] - < seqlen_q - - (self.q_stage * m_block + stage) * self.m_block_size - - tOcO[0][0] - ): - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] - if const_expr(self.check_hdim_v_oob) - else None, - ) - else: - pack_gqa.store_O( - mO_cur, - tOrO, - gmem_tiled_copy_O, - tidx, - self.q_stage * m_block + stage, - seqlen_q, - ) + @cute.jit + def _store_O_to_gmem( + self, + sO_stage: cute.Tensor, + gO: Optional[cute.Tensor], + mO_cur: cute.Tensor, + gmem_tiled_copy_O: cute.TiledCopy, + tidx: Int32, + seqlen_q: Int32, + m_tile_idx: Int32, + ): + """Copy a single stage of O from smem to gmem via registers.""" + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO_stage) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1]) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO, self.o_dtype) + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + assert gO is not None + tOgO = gmem_thr_copy_O.partition_D(gO) + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_tile_idx, seqlen_q + ) @cute.jit def epilogue_s2g( @@ -2486,14 +2714,14 @@ def epilogue_s2g( sO: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], - mbar_ptr: cute.Pointer, + pipeline_o_epi: pipeline.PipelineAsync, block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, + mma_tile_coord_v: Int32 = 0, + tile_scheduler=None, ): epi_consumer_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2505,101 +2733,127 @@ def epilogue_s2g( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] + if const_expr(self.use_tma_O): store_O, _, _ = copy_utils.tma_get_copy_fn( tma_atom_O, 0, cute.make_layout(1), sO, gO ) - for stage in cutlass.range_constexpr(self.q_stage): + for stage in cutlass.range(self.q_stage, unroll_full=True): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase - ) + pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem - store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) + store_O(src_idx=stage, dst_idx=stage) cute.arch.cp_async_bulk_commit_group() for stage in cutlass.range_constexpr(self.q_stage): # Ensure O0 / O1 buffer is ready to be released - if const_expr(self.q_stage == 2): - cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) - else: - cute.arch.cp_async_bulk_wait_group(0, read=True) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + cute.arch.cp_async_bulk_wait_group(self.q_stage - 1 - stage, read=True) + pipeline_o_epi.consumer_release_w_index(stage) else: tidx = cute.arch.thread_idx()[0] % ( cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) ) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - tOsO = gmem_thr_copy_O.partition_S(sO) - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tOgO = gmem_thr_copy_O.partition_D(gO) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - pack_gqa = PackGQA( - self.m_block_size, - self.head_dim_v_padded, - self.check_hdim_v_oob, - self.qhead_per_kvhead, - ) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase - ) + pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem - # load acc O from smem to rmem for wider vectorization - tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) - cute.autovec_copy(tOsO[None, None, None, stage], tOrO) - # copy acc O from rmem to gmem - if const_expr(not self.pack_gqa): - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if ( - t0OcO[0, rest_m, 0][0] - < seqlen.seqlen_q - - (self.q_stage * m_block + stage) * self.m_block_size - - tOcO[0][0] - ): - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None, self.q_stage * m_block + stage], - pred=tOpO[None, rest_m, None] - if const_expr(self.check_hdim_v_oob) - else None, - ) - else: - pack_gqa.store_O( - mO_cur, - tOrO, - gmem_tiled_copy_O, - tidx, - self.q_stage * m_block + stage, - seqlen.seqlen_q, - ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None + self._store_O_to_gmem( + sO[None, None, stage], gO_stage, mO_cur, gmem_tiled_copy_O, + tidx, seqlen.seqlen_q, m_tile_idx, + ) + pipeline_o_epi.consumer_release_w_index(stage) epi_consumer_phase ^= 1 # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() def load_Q( self, load_Q_fn: Callable, - mbar_full_ptr: cute.Pointer, - mbar_empty_ptr: cute.Pointer, + pipeline_q: pipeline.PipelineAsync, + block: Int32, + stage: int, + phase: Int32, + ): + pipeline_q.producer_acquire_w_index_phase(stage, phase) + load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage)) + + def load_Q_non_tma( + self, + mQ: cute.Tensor, + sQ: cute.Tensor, + gmem_tiled_copy_Q: cute.TiledCopy, + pipeline_q: pipeline.PipelineAsync, + tidx: Int32, + seqlen_q: Int32, + m_block: Int32, block: Int32, stage: int, phase: Int32, ): - cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"]) - load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage) + assert self.cta_group_size == 1, "cta_group_size must be 1 for non-tma Q load" + pipeline_q.producer_acquire_w_index_phase(stage, phase) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_padded, + self.check_hdim_oob, + self.qhead_per_kvhead, + ) + sQ_stage = sQ[None, None, None, stage] + sQ_pi = cute.make_tensor( + sQ_stage.iterator, + cute.make_layout( + (sQ_stage.shape[0][0], (sQ_stage.shape[0][1], sQ_stage.shape[2])), + stride=(sQ_stage.stride[0][0], (sQ_stage.stride[0][1], sQ_stage.stride[2])), + ), + ) + pack_gqa.load_Q(mQ, sQ_pi, gmem_tiled_copy_Q, tidx, m_block * self.q_stage + block, seqlen_q) + cute.arch.cp_async_commit_group() + pipeline_q.sync_object_full.arrive_cp_async_mbarrier(stage) @cute.jit def load_KV( @@ -2609,44 +2863,46 @@ def load_KV( tXsX: Optional[cute.Tensor], paged_kv_manager: Optional[PagedKVManager], sX: cute.Tensor, - mbar_full_ptr: cute.Pointer, - mbar_empty_ptr: cute.Pointer, block: Int32, - producer_state: cutlass.pipeline.PipelineState, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, K_or_V: Literal["K", "V"], page_idx: Optional[Int32] = None, + extra_tx_count: Optional[Int32] = None, ): assert K_or_V in ("K", "V") stage, phase = producer_state.index, producer_state.phase - cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + extra_tx_count_kv = self.tma_copy_bytes[K_or_V] - self.tma_copy_bytes["K"] + extra_tx_count = ( + extra_tx_count_kv + (extra_tx_count if extra_tx_count is not None else 0) if const_expr(self.use_tma_KV) + else None + ) + extra_kwargs = {"extra_tx_count": extra_tx_count} if const_expr(self.use_tma_KV) else {} + pipeline_kv.producer_acquire(producer_state, **extra_kwargs) if const_expr(K_or_V == "K" and self.uneven_kv_smem): # Before this round, the smem location was occupied by V, which is smaller than # K. So we need to wait for the stage after that (stage 1) to be empty as well. if stage == 0: - cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) + pipeline_kv.sync_object_empty.wait(1, phase) if const_expr(self.use_tma_KV): - assert ( - tXgX is not None and - tXsX is not None and - tma_atom is not None - ) - with cute.arch.elect_one(): - cute.arch.mbarrier_arrive_and_expect_tx( - mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V], - ) + assert tXgX is not None and tXsX is not None and tma_atom is not None tXsX_cur = tXsX[None, stage] if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] - cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state)) else: assert paged_kv_manager is not None - paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + assert extra_tx_count is None + sX_cur = sX[None, None, None, stage] + if const_expr(self.uneven_kv_smem): + sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1) + paged_kv_manager.load_KV(block, sX_cur, K_or_V) cute.arch.cp_async_commit_group() - cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage) + pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage) @cute.jit def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): @@ -2655,47 +2911,24 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if # phase == 0, or left by offset if phase == 1. offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + # Hint that the offset is 128-bit aligned so that + # ptr + offset preserves the alignment needed by cp.async. + offset = cute.assume(offset, divby=128 // self.k_dtype.width) return cute.make_tensor(sX.iterator + offset, sX.layout) else: return sX - def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): - load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - if self.use_tma_KV: - load_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) - ) - return cutlass.pipeline.PipelineTmaUmma.create( - barrier_storage=load_kv_mbar_ptr, - num_stages=self.kv_stage, - producer_group=load_kv_producer_group, - consumer_group=load_kv_consumer_group, - tx_count=self.tma_copy_bytes["K"], - ) - else: - load_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE - ) - return cutlass.pipeline.PipelineAsyncUmma.create( - num_stages=self.kv_stage, - producer_group=load_kv_producer_group, - consumer_group=load_kv_consumer_group, - barrier_storage=load_kv_mbar_ptr, - ) - # @cute.jit # def warp_scheduler_barrier_init(self): # warp_group_idx = utils.canonical_warp_group_idx(sync=False) # if warp_group_idx == 0: # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128, # ) # def warp_scheduler_barrier_sync(self): # cute.arch.barrier( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), # number_of_threads=2 * 128 # ) @@ -2703,7 +2936,7 @@ def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): # cur_wg = utils.canonical_warp_group_idx(sync=False) # next_wg = 1 - cur_wg # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) @cute.jit @@ -2727,6 +2960,7 @@ def apply_score_mod( cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) tScS = thr_mma_qk.partition_C(cS) + tScS = tScS[(None, None), 0, 0] tScS_t2r = thr_tmem_load.partition_D(tScS) # Shared q_idx for all scores diff --git a/flash_attn/cute/flash_fwd_sm120.py b/flash_attn/cute/flash_fwd_sm120.py new file mode 100644 index 00000000000..08d219acfa8 --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm120.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM120 (Blackwell GeForce / DGX Spark) forward pass. +# +# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has +# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses +# FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly. + +import cutlass +import cutlass.utils as utils_basic + +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 + + +class FlashAttentionForwardSm120(FlashAttentionForwardSm80): + # Keep arch = 80 to use CpAsync code paths (no TMA for output). + # The compilation target is determined by the GPU at compile time, not this field. + arch = 80 + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented on SM120. + + Same logic as SM80 but uses SM120's shared memory capacity (99 KB). + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Shared memory usage: Q tile + (K tile + V tile) + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_K + # SM120 has 99 KB shared memory (vs 163 KB on SM80) + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") + if smem_usage > smem_capacity: + return False + if (tile_m * 2) % num_threads != 0: + return False + return True diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py new file mode 100644 index 00000000000..4108ce451ff --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -0,0 +1,1534 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py. + +from types import SimpleNamespace +from typing import Callable, Literal, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass import pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.base_dsl.arch import Arch + +from quack import copy_utils +from quack import layout_utils +from quack import sm90_utils + +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) +from flash_attn.cute import pipeline as pipeline_custom +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.named_barrier import NamedBarrierFwd +from quack.cute_dsl_utils import ParamsBase +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, +) +from cutlass.cute import FastDivmodDivisor + +from flash_attn.cute.flash_fwd import FlashAttentionForwardBase + + +class FlashAttentionForwardSm90(FlashAttentionForwardBase): + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + paged_kv_non_tma: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs + self.buffer_align_bytes = 1024 + self.use_tma_KV = not paged_kv_non_tma + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + self.cluster_shape_mn = (1, 1) + assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, + ) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv + ), + self.dtype, + ) + sO_layout_atom = sV_layout_atom + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n + ), + self.dtype, + ) + else: + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), + tiler_mn=(64, self.tile_n), + ) + tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes + ] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 + sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + # 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V, + mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + @cute.struct + class SharedStorageQKV: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + sP: sP_struct + + @cute.struct + class SharedStorageSharedQV: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sQ: sQV_struct + sK: sK_struct + sP: sP_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) + ) + + self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + + mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = ( + layout_utils.select(mLSE, LSE_layout_transpose) + if const_expr(mLSE is not None) + else None + ) + + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_qk.size + self.num_threads_per_warp_group = 128 + self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group + assert self.num_wg_mma in [1, 2, 3] + self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1) + self.num_producer_threads = 32 + self.num_Q_load_threads = self.num_threads_per_warp_group # If not TMA_Q + self.num_epilogue_threads = self.num_mma_threads + self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[ + self.num_wg_mma + ] + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.use_scheduler_barrier = ( + (self.num_wg_mma >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_wg_mma == 2) + ) + self.use_tma_Q = self.arch >= Arch.sm_90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = self.use_tma_Q + # Producer needs more registers when doing cp.async Q or KV loads + if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)): + self.num_mma_regs, self.num_producer_regs = 224, 40 + self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap + self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) + ) + + SharedStorage = self._get_shared_storage_cls() + + mQ_og, mO_og = mQ, mO + if const_expr(self.pack_gqa): + nheads_kv = mK.shape[2] + mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) + mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) + if const_expr(mLSE is not None): + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) + + # TMA + gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() + gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast + gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } + make_tiled_tma_atom_fn = ( + partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) + if const_expr(self.pack_gqa) + else cpasync.make_tiled_tma_atom + ) + tma_atom_Q, tma_tensor_Q = None, None + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn( + gmem_tiled_copy_Q, + mQ_og if const_expr(self.pack_gqa) else mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast + ) + tma_atom_K, tma_tensor_K = None, None + tma_atom_V, tma_tensor_V = None, None + if const_expr(self.use_tma_KV): + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + 1, # No mcast for now + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + 1, # No mcast for now + ) + tma_atom_O, tma_tensor_O = None, None + if const_expr(self.use_tma_O): + mO_tma = mO_og if const_expr(self.pack_gqa) else mO + if const_expr(self.varlen_q): + mO_tma = copy_utils.create_ragged_tensor_for_tma( + mO_tma, ragged_dim=0, ptr_shift=True + ) + tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( + gmem_tiled_copy_O, + mO_tma, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast + ) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2( + softmax_scale, self.score_mod + ) + window_size_left = Int32(window_size_left) if window_size_left is not None else None + window_size_right = Int32(window_size_right) if window_size_right is not None else None + fastdiv_mods = utils.compute_fastdiv_mods( + mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable + ) + + self.kernel( + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, + tma_tensor_K if const_expr(self.use_tma_KV) else mK, + tma_tensor_V if const_expr(self.use_tma_KV) else mV, + tma_tensor_O if const_expr(self.use_tma_O) else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + TileScheduler, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors=Optional[list[cute.Tensor]], + fastdiv_mods=None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Prefetch tma descriptor + if warp_idx == 0: + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Mbarrier / pipeline init + mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr() + + ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) + tma_warp = ThreadCooperativeGroup(1) + load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group) + mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE) + if const_expr(self.use_tma_Q): + pipeline_q = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=mbar_ptr_Q, + num_stages=1, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["Q"], + defer_sync=True, + ) + else: + pipeline_q = pipeline_custom.PipelineCpAsync.create( + barrier_storage=mbar_ptr_Q, + num_stages=1, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + + if const_expr(self.use_tma_KV): + pipeline_k = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["K"], + defer_sync=True, + ) + pipeline_v = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["V"], + defer_sync=True, + ) + else: + pipeline_k = pipeline_custom.PipelineCpAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + pipeline_v = pipeline_custom.PipelineCpAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + else: + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = layout_utils.transpose_view(sV) + sP = None + if const_expr(sP_layout is not None): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + # reuse sQ's data iterator + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + # Don't need to pass in tile_mn because we won't access offset_padded + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + # Cluster wait before starting + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + if warp_idx < 4: # Producer + cute.arch.setmaxregister_decrease(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + pipeline_q, + gmem_tiled_copy_Q, + mPageTable, + blocksparse_tensors, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.setmaxregister_increase(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + pipeline_q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + blocksparse_tensors, + aux_tensors, + fastdiv_mods, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + pipeline_q: pipeline.PipelineAsync, + gmem_tiled_copy_Q: cute.TiledCopy, + mPageTable: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + tidx, _, _ = cute.arch.thread_idx() + + # TMA: only warp 0 loads. cp_async: all warps load. + # When not use_tma_Q, all 128 producer threads participate in Q loading. + is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV or not self.use_tma_Q) + # KV loading restricted to warp 0 for TMA, all warps for non-TMA KV + is_kv_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV) + + if is_load_warp: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + load_Q = None + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True + ) + + paged_kv_manager = None + tma_load_K_fn = None + tma_load_V_fn = None + if const_expr(self.use_tma_KV): + # === TMA path (non-paged and paged with page_size == n_block_size) === + if const_expr(mPageTable is not None): + # Paged TMA: keep page dimension indexable + mK_cur = mK[None, None, head_idx_kv, None] + mV_cur = mV[None, None, head_idx_kv, None] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None)) + else: + # Non-paged TMA + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[ + None, None, head_idx_kv + ] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[ + None, None, head_idx_kv + ] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + # TODO: mcast + tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) + tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k) + tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) + tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v) + else: + # === cp_async path (paged KV with page_size != n_block_size) === + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmodDivisor(mK.shape[0]), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.tile_n, + self.tile_hdim, + self.tile_hdimv, + self.num_threads_per_warp_group, + mK.element_type, + arch=self.arch.major * 10 + self.arch.minor, + ) + + load_K = partial( + self.load_KV, + tma_load_K_fn, + paged_kv_manager, + sK, + pipeline_kv=pipeline_k, + K_or_V="K", + ) + load_V = partial( + self.load_KV, + tma_load_V_fn, + paged_kv_manager, + sV, + pipeline_kv=pipeline_v, + K_or_V="V", + ) + + pack_gqa = None + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # Clamp n_block to 0 when n_block_max == 0 (can happen with causal + # + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1 + # gracefully (fills zeros), but cp.async would crash on + # out-of-bounds page table access. + n_block = ( + n_block_max - 1 + if const_expr(self.use_tma_KV) + else cutlass.max(n_block_max - 1, 0) + ) + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + + # First iteration: load K on pipeline_k, Q on pipeline_q + if is_kv_load_warp: + pipeline_k.producer_acquire(kv_producer_state) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) + if const_expr(self.use_tma_Q): + if warp_idx_in_wg == 0: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) + q_producer_phase ^= 1 + else: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + pack_gqa.load_Q( + mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q + ) + cute.arch.cp_async_commit_group() + pipeline_q.producer_commit_w_index(0) + q_producer_phase ^= 1 + + if is_kv_load_warp: + if const_expr(not self.intra_wg_overlap or not self.use_tma_KV): + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + pipeline_k.producer_acquire(kv_producer_state) + load_K( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None) + else None + ) + page_idx_prev = ( + mPageTable[batch_idx, n_block_prev] + if const_expr(mPageTable is not None) + else None + ) + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V( + block=n_block_prev, + producer_state=kv_producer_state_prev, + page_idx=page_idx_prev, + ) + n_block = n_block_min + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None) + else None + ) + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) + kv_producer_state.advance() + else: + # Block sparsity: use TMA closures directly (not paged) + # Load Q on pipeline_q, separate from K/V pipeline + if const_expr(self.use_tma_Q): + if warp_idx_in_wg == 0: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) + q_producer_phase ^= 1 + else: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + pack_gqa.load_Q( + mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q + ) + cute.arch.cp_async_commit_group() + pipeline_q.producer_commit_w_index(0) + q_producer_phase ^= 1 + if is_kv_load_warp: + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + tma_load_K_fn, + tma_load_V_fn, + pipeline_k, + pipeline_v, + self.intra_wg_overlap, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # Producer tail is only useful for cluster to avoid early exit of blocks. + # We only need producer_tail on V since that's the last that's loaded, we don't + # need it for Q (no cluster) and K. + if is_kv_load_warp: + pipeline_v.producer_tail(kv_producer_state) + + @cute.jit + def load_KV( + self, + tma_load_fn: Optional[Callable], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, + block: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + K_or_V: Literal["K", "V"], + page_idx: Optional[Int32] = None, + ): + if const_expr(self.use_tma_KV): + src_idx = block if const_expr(page_idx is None) else page_idx + tma_load_fn(src_idx=src_idx, producer_state=producer_state) + else: + paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V) + cute.arch.cp_async_commit_group() + pipeline_kv.producer_commit(producer_state) + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + pipeline_q: pipeline.PipelineAsync, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + aux_tensors: Optional[list], + fastdiv_mods=None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_wg_mma, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( + wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK + ) + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( + wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom( + self.arch.major * 10 + self.arch.minor, self.dtype + ) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + self.mma_init() + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + + # For RescaleOBeforeGemm: persistent scores_scale across iterations + scores_scale = None + if const_expr(self.rescale_O_before_gemm): + scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32) + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + check_inf=True, + scores_scale=scores_scale, + ) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + scores_scale=scores_scale, + softmax=softmax, + acc_O=acc_O, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + scores_scale=scores_scale, + softmax=softmax, + acc_O=acc_O, + ) + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + + # shape: (atom_v_m * rest_m) + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + # Recompute fastdiv_mods if necessary for varlen with aux_tensors + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + softmax_scale=softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + mma_one_n_block = partial( + mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn + ) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True + O_should_accumulate = False + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + ) + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Release Q pipeline so the producer can load the next tile's Q + pipeline_q.consumer_release_w_index(0) + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + + else: + # ========================================== + # Block sparsity + # ========================================== + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + seqlen, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + fastdiv_mods, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + # Release Q pipeline so the producer can load the next tile's Q + pipeline_q.consumer_release_w_index(0) + + # Handle empty case (when no blocks to process) + if not processed_any: + softmax.reset() + acc_O.fill(0.0) + + q_consumer_phase ^= 1 + + sink_val = None + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.tile_m + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, + acc_O: Optional[cute.Tensor] = None, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + + # For RescaleOBeforeGemm: initialize acc_O + if const_expr(self.rescale_O_before_gemm): + acc_O.fill(0.0) + scores_scale.store(row_scale.load()) + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + scores_scale: Optional[cute.Tensor] = None, + softmax: Optional[Softmax] = None, + acc_O: Optional[cute.Tensor] = None, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + # For RescaleOBeforeGemm: rescale O before the final PV GEMM + if const_expr(self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, scores_scale) + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + return kv_consumer_state + + @cute.jit + def mma_one_n_block( + self, + smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, # not used + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(0) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) + pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read + + @cute.jit + def mma_one_n_block_intrawg_overlap( + self, + smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + check_inf: cutlass.Constexpr = True, + ): + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + # RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM + if const_expr(self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, scores_scale) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(1) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) + warpgroup.wait_group(0) + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + if const_expr(not self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, row_scale) + if const_expr(self.rescale_O_before_gemm): + scores_scale.store(row_scale.load()) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read + + @cute.jit + def mma_init(self): + warp_group_idx = utils.canonical_warp_group_idx(sync=False) + if const_expr(self.use_scheduler_barrier): + if warp_group_idx == 1: + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + def warp_scheduler_barrier_sync(self): + if const_expr(self.use_scheduler_barrier): + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def warp_scheduler_barrier_arrive(self): + if const_expr(self.use_scheduler_barrier): + assert self.num_wg_mma in [2, 3] + cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 + if const_expr(self.num_wg_mma == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_wg_mma + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, + number_of_threads=2 * self.num_threads_per_warp_group, + ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py deleted file mode 100644 index c6a1c301904..00000000000 --- a/flash_attn/cute/hopper_helpers.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2025, Tri Dao. -from typing import Type, Union, Optional -import cutlass -import cutlass.cute as cute -from cutlass import Int32, Float32, Boolean, const_expr -from cutlass.cute.nvgpu import warpgroup -from cutlass.cutlass_dsl import Numeric, dsl_user_op -from cutlass.utils import LayoutEnum -import cutlass.utils.hopper_helpers as sm90_utils_og - - -@cute.jit -def gemm( - tiled_mma: cute.TiledMma, - acc: cute.Tensor, - tCrA: cute.Tensor, - tCrB: cute.Tensor, - zero_init: cutlass.Constexpr[bool] = False, - wg_wait: cutlass.Constexpr[int] = 0, - # A_in_regs: cutlass.Constexpr[bool] = False, - swap_AB: cutlass.Constexpr[bool] = False, -) -> None: - if const_expr(swap_AB): - gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) - else: - warpgroup.fence() - # We make a new mma_atom since we'll be modifying its attribute (accumulate). - # Otherwise the compiler complains "operand #0 does not dominate this use" - mma_atom = cute.make_mma_atom(tiled_mma.op) - mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) - for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): - cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) - mma_atom.set(warpgroup.Field.ACCUMULATE, True) - warpgroup.commit_group() - if const_expr(wg_wait >= 0): - warpgroup.wait_group(wg_wait) - - -def gemm_zero_init( - tiled_mma: cute.TiledMma, - shape: cute.Shape, - tCrA: cute.Tensor, - tCrB: cute.Tensor, - A_idx: Optional[Int32] = None, - B_idx: Optional[Int32] = None, - wg_wait: int = -1, - swap_AB: bool = False, -) -> cute.Tensor: - if const_expr(swap_AB): - return gemm_zero_init( - tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False - ) - else: - acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) - return acc - - -def gemm_w_idx( - tiled_mma: cute.TiledMma, - acc: cute.Tensor, - tCrA: cute.Tensor, - tCrB: cute.Tensor, - zero_init: Boolean, - A_idx: Optional[Int32] = None, - B_idx: Optional[Int32] = None, - wg_wait: int = -1, - swap_AB: bool = False, -) -> None: - if const_expr(swap_AB): - gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) - else: - rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] - rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] - gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) - - -@dsl_user_op -def make_smem_layout( - dtype: Type[Numeric], - layout: LayoutEnum, - shape: cute.Shape, - stage: Optional[int] = None, - *, - loc=None, - ip=None, -) -> Union[cute.Layout, cute.ComposedLayout]: - major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] - smem_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), - dtype, - ) - order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2) - smem_layout_staged = cute.tile_to_shape( - smem_layout_atom, - cute.append(shape, stage) if const_expr(stage is not None) else shape, - order=order if const_expr(stage is not None) else order[:2], - ) - return smem_layout_staged diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index f01a6921ffd..ef624677f01 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -21,6 +21,7 @@ import os import math +from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, Callable @@ -31,6 +32,10 @@ import cutlass import cutlass.cute as cute +from cutlass import Int32, Float32 +from quack.compile_utils import make_fake_tensor as fake_tensor +from flash_attn.cute.cache_utils import get_jit_cache +from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: @@ -41,27 +46,201 @@ from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import to_cute_tensor -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute import fa_logging +from flash_attn.cute.cute_dsl_utils import ( + to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, +) +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, + get_sparse_q_block_size, to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, ) +def _parse_arch_str(arch_str): + """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" + import re + match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str) + if not match: + raise ValueError(f"Invalid arch format: {arch_str}") + major, minor, _ = match.groups() + return int(major) * 10 + int(minor) + + @lru_cache(maxsize=None) -def _get_device_capability(): - """Cached device capability check.""" - return torch.cuda.get_device_capability()[0] +def _get_device_arch(): + """Cached device arch check. + + Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which + kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation + target (CUTE_DSL_ARCH). + + For CPU-only compilation (no GPU), set both: + FLASH_ATTENTION_ARCH=sm_80 (kernel selection) + CUTE_DSL_ARCH=sm_80 (compilation target) + """ + arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None) + if arch_override is not None: + return _parse_arch_str(arch_override) + major, minor = torch.cuda.get_device_capability() + return major * 10 + int(minor) + + +def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: + """Validate head dimension constraints based on compute capability.""" + is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 + + is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 + if compute_capability == 9: + assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. " + f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." + ) + elif compute_capability in [10, 11]: + assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." + ) + + +@dataclass(frozen=True) +class FwdConfig: + m_block_size: int + n_block_size: int + mma_pv_is_rs: bool + intra_wg_overlap: bool + + +def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): + """Return FwdConfig for SM90 forward. + + Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted + for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM). + + When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the + optimal tile_m=192 is used when compatible, otherwise we fall back to 128. + """ + if head_dim <= 64: + # C++: 192×192 non-causal, 192×128 causal/local. + # Python: 192×128 RS+OL is consistently best across seqlens. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, True, True) + return FwdConfig(192, 128, True, True) + elif head_dim <= 96: + # C++: 192×144 noRS+OL for all cases. + # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS). + # noRS+OL is always required. Causal: 192×128 slightly better short seqlen. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, False, True) + if is_causal or is_local: + return FwdConfig(192, 128, False, True) + else: + return FwdConfig(192, 144, False, True) + elif head_dim <= 128: + return FwdConfig(128, 128, True, True) + elif head_dim <= 192: + tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112) + return FwdConfig(128, tile_n, True, True) + else: # hdim 256 + tile_n = 64 if is_local else 80 + return FwdConfig(128, tile_n, True, True) + +@dataclass(frozen=True) +class BwdConfig: + m_block_size: int + n_block_size: int + num_stages_Q: int + num_stages_dO: int + num_stages_PdS: int + SdP_swapAB: bool + dKV_swapAB: bool + dQ_swapAB: bool + AtomLayoutMSdP: int + AtomLayoutNdKV: int + AtomLayoutMdQ: int + num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128) + dQ_single_wg: bool = False + + +def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None): + """Return BwdConfig for SM90. + + Configs based on C++ FA3 hopper/flash_bwd_launch_template.h, + benchmarked on H100 SXM. + """ + if head_dim <= 64: + # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2 + return BwdConfig( + m_block_size=128, n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2, + ) + elif head_dim <= 96: + # C++ FA3: 64, 128, 96, dQ_swapAB=False + return BwdConfig( + m_block_size=64, n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + dQ_single_wg=True, + ) + elif head_dim <= 128: + # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB + is_causal_or_local = causal or local + m_block_size = 64 if is_causal_or_local else 80 + if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0: + m_block_size = 64 + return BwdConfig( + m_block_size=m_block_size, + n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, + dQ_swapAB=m_block_size % 64 != 0, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + ) + elif head_dim <= 192: + hdimv128 = head_dim_v <= 128 + if hdimv128: + return BwdConfig( + m_block_size=64, n_block_size=96, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + num_wg=2, + ) + else: + return BwdConfig( + m_block_size=64, n_block_size=96, + num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + num_wg=2, + ) + else: + # hdim 256 + return BwdConfig( + m_block_size=64, n_block_size=64, + num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, + ) + + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -71,7 +250,8 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" - assert t.is_cuda, f"{name} must be on CUDA" + if not is_fake_mode(): + assert t.is_cuda, f"{name} must be on CUDA" torch2cute_dtype_map = { @@ -91,6 +271,29 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): return min(num_SMs // total_mblocks, max_splits, num_n_blocks) +def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): + """Resolve causal/local/window settings into canonical form. + + Returns (causal, local, window_size_left, window_size_right). + """ + if mask_mod is not None: + return False, False, window_size_left, window_size_right + if causal: + window_size_right = 0 + if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: + window_size_left = None + window_size_right = None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + else: + local = False + return causal, local, window_size_left, window_size_right + + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -108,15 +311,13 @@ def _flash_attn_fwd( window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, learnable_sink: Optional[torch.Tensor] = None, - # m_block_size: int = 128, - # n_block_size: int = 64, - # num_threads: int = 128, - m_block_size: int = 128, - n_block_size: int = 128, + tile_mn: Optional[Tuple[int, int]] = None, + mma_pv_is_rs: Optional[bool] = None, + intra_wg_overlap: Optional[bool] = None, num_threads: int = 384, num_splits: int = 1, pack_gqa: Optional[bool] = None, - _compute_capability: Optional[int] = None, + _arch: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, @@ -133,6 +334,7 @@ def _flash_attn_fwd( mask_mod: A callable that takes token position information and selectively masks block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate + The returned LSE supports taking gradient. out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. @@ -197,25 +399,27 @@ def _flash_attn_fwd( assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - assert all( - t is None or t.is_cuda - for t in ( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - learnable_sink, - ) - ), "inputs must be on CUDA device" + if not is_fake_mode(): + assert all( + t is None or t.is_cuda + for t in ( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + learnable_sink, + ) + ), "inputs must be on CUDA device" + arch = _get_device_arch() if _arch is None else _arch + assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if arch // 10 not in [8, 12]: + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: @@ -247,70 +451,73 @@ def _flash_attn_fwd( _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] - compute_capability = ( - _get_device_capability() - if _compute_capability is None - else _compute_capability + use_block_sparsity = block_sparse_tensors is not None + + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right, mask_mod ) - assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() + requested_disable_2cta = utils._get_disable_2cta_default() - use_block_sparsity = block_sparse_tensors is not None + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) + if arch // 10 in [8, 12]: + num_threads = 128 - if mask_mod is None: - if causal: - window_size_right = 0 - local = window_size_left is not None or window_size_right is not None - if window_size_left is not None or window_size_right is not None: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - window_size_right = None + fwd_cfg = FwdConfig(128, 128, True, True) # default + if tile_mn is None: + if arch // 10 == 12: + # SM120 tile sizes tuned for 99 KB SMEM capacity: + # D<=64: 128x128 → 48 KB (good occupancy) + # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) + if head_dim <= 64: + fwd_cfg = FwdConfig(128, 128, True, True) else: - causal, local = False, True + fwd_cfg = FwdConfig(128, 64, True, True) + elif arch // 10 == 8: + fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune + elif arch // 10 == 9: + sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) + fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) else: - causal, local = False, False + fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + if mma_pv_is_rs is None: + mma_pv_is_rs = fwd_cfg.mma_pv_is_rs + if intra_wg_overlap is None: + intra_wg_overlap = fwd_cfg.intra_wg_overlap - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - if compute_capability == 9: # TODO: tune block size according to hdim. - if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: - n_block_size = 192 - - if compute_capability in [10, 11]: - if ( - pack_gqa - and (128 % qhead_per_kvhead != 0) - ): - pack_gqa = False - # TODO: fix GQA + SplitKV + non-varlen - if pack_gqa and num_splits != 1 and cu_seqlens_q is None: - pack_gqa = False + # TODO: fix GQA + SplitKV + non-varlen + if pack_gqa and num_splits != 1 and cu_seqlens_q is None: + pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if compute_capability == 10: - q_stage = 2 if seqlen_q_packgqa > m_block_size else 1 + if arch // 10 == 10: + q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 - m_block_size_effective = q_stage * m_block_size - seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) + m_block_size_effective = q_stage * tile_m + seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective total_mblocks = batch_size * num_head_kv * num_m_blocks - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size - num_SMs = torch.cuda.get_device_properties(device).multi_processor_count + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count if num_splits < 1: num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) - if compute_capability in [10, 11] and head_dim != head_dim_v and num_splits > 1: + if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: if num_n_blocks >= 64: - n_block_size = 64 - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + tile_n = 64 + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) else: num_splits = 1 @@ -320,6 +527,22 @@ def _flash_attn_fwd( out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) + use_2cta_instrs = ( + arch // 10 in [10, 11] + and not requested_disable_2cta + and not causal + and not local + and not is_split_kv + and cu_seqlens_q is None + and seqused_q is None + and not use_block_sparsity + and page_size in [None, 128] + and int(math.ceil(head_dim / 16) * 16) in [128, 192] + and int(math.ceil(head_dim_v / 16) * 16) == 128 + and seqlen_q_packgqa > 2 * tile_m + and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) + ) + # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False @@ -371,9 +594,13 @@ def _flash_attn_fwd( num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, - block_size=(m_block_size, n_block_size), + block_size=(tile_m, tile_n), q_stage=q_stage, ) + if aux_tensors is not None: + aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) + else: + aux_tensor_metadata = None compile_key = ( dtype, @@ -385,7 +612,7 @@ def _flash_attn_fwd( mask_mod_hash, use_block_sparsity, block_sparse_broadcast_pattern, - len(aux_tensors) if aux_tensors is not None else 0, + aux_tensor_metadata, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, @@ -395,15 +622,20 @@ def _flash_attn_fwd( window_size_left is not None, window_size_right is not None, learnable_sink is not None, - m_block_size, - n_block_size, + tile_m, + tile_n, q_stage, num_threads, is_split_kv, pack_gqa, - compute_capability, - page_size not in [None, 128], # paged KV non-TMA + arch, + page_size not in [None, tile_n], # paged KV non-TMA + use_2cta_instrs, q_subtile_factor, + mma_pv_is_rs, + intra_wg_overlap, + requested_use_clc_scheduler, + fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: ( @@ -438,13 +670,32 @@ def _flash_attn_fwd( sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None + aux_tensor_metadata = None if aux_tensors is not None: - cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] - if compute_capability == 9: - assert page_table is None, "paged KV not supported on SM 9.0" + if arch // 10 == 8: + assert page_table is None, "paged KV not supported on SM 8.0" + assert not is_split_kv, "SplitKV not supported on SM 8.0" + fa_fwd = FlashAttentionForwardSm80( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=tile_m, + tile_n=tile_n, + num_stages=1, + num_threads=num_threads, + Q_in_regs=False, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + ) + elif arch // 10 == 9: assert not is_split_kv, "SplitKV not supported on SM 9.0" - # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, @@ -453,20 +704,21 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - tile_m=m_block_size, - tile_n=n_block_size, + tile_m=tile_m, + tile_n=tile_n, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, - intra_wg_overlap=True, - mma_pv_is_rs=True, + intra_wg_overlap=intra_wg_overlap, + mma_pv_is_rs=mma_pv_is_rs, mask_mod=mask_mod, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, q_subtile_factor=q_subtile_factor, + paged_kv_non_tma=page_size not in [None, tile_n], ) - elif compute_capability in [10, 11]: + elif arch // 10 in [10, 11]: fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -475,8 +727,8 @@ def _flash_attn_fwd( is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, - m_block_size=m_block_size, - n_block_size=n_block_size, + m_block_size=tile_m, + n_block_size=tile_n, q_stage=q_stage, is_persistent=not causal and not local @@ -486,14 +738,37 @@ def _flash_attn_fwd( score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, 128], - is_varlen_q=cu_seqlens_q is not None - or seqused_q is not None, + paged_kv_non_tma=page_size not in [None, tile_n], + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, + use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=requested_use_clc_scheduler, + ) + elif arch // 10 == 12: + # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity + assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" + assert page_table is None, "Paged KV not supported on SM 12.0 in this PR" + assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR" + fa_fwd = FlashAttentionForwardSm120( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=tile_m, + tile_n=tile_n, + num_stages=1, + num_threads=num_threads, + Q_in_regs=False, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, ) else: raise ValueError( - f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x" + f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( @@ -504,7 +779,6 @@ def _flash_attn_fwd( o_tensor, lse_tensor, softmax_scale, - current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, @@ -515,28 +789,33 @@ def _flash_attn_fwd( learnable_sink_tensor, sparse_tensors, cute_aux_tensors, + current_stream, options="--enable-tvm-ffi", ) - _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), - out.detach() if not is_split_kv else out_partial, - lse_partial if is_split_kv else lse, - softmax_scale, - current_stream, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - window_size_left, - window_size_right, - learnable_sink, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, - aux_tensors, - ) + # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: + # - Use those fake metadata to populate compilation cache + # - Return "fake" output tensors, which could be needed in follow-up fake operations + # Thus, we skip the actual kernel invocation here. + if not is_fake_mode(): + _flash_attn_fwd.compile_cache[compile_key]( + q.detach(), + k.detach(), + v.detach(), + out.detach() if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + window_size_left, + window_size_right, + learnable_sink, + normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + aux_tensors, + ) if is_split_kv: _flash_attn_fwd_combine( out_partial, @@ -549,7 +828,141 @@ def _flash_attn_fwd( return out, lse -_flash_attn_fwd.compile_cache = {} +_flash_attn_fwd.compile_cache = get_jit_cache("fwd") + + +def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): + sym = cute.sym_int + # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8 + # For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4 + div = 128 // dtype.width # 8 for fp16/bf16 + # Shared sym_ints for dimensions that must match across tensors + b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym() + h_kv = h_q if not has_gqa else sym() + seqlen_q_rounded, seqlen_k_rounded = sym(), sym() + seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym() + total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym() + total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym() + b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,) + b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,) + mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) + mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) + mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) + mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) + mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) + mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) + mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) + mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) + if not varlen_q: + mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1) + mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) + mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) + dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4) + else: + mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1) + mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) + mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) + dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4) + if not has_gqa: + mdKaccum, mdVaccum = None, None + else: + if not varlen_k: + mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4) + mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4) + else: + mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4) + mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4) + return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum + + +def _compile_bwd_preprocess( + dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, +): + """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" + mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( + dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False + ) + batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() + batchp1 = cute.sym_int() + mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None + mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None + fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) + return cute.compile( + fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def _bwd_preprocess( + out, dout, dpsum, lse, lse_log2, dq_accum, + cu_seqlens_q, seqused_q, dlse, + dtype, head_dim, head_dim_v, m_block_size, +): + """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" + is_varlen = cu_seqlens_q is not None + compile_key = ( + dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, + ) + if compile_key not in _bwd_preprocess.compile_cache: + _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) + if not is_fake_mode(): + _bwd_preprocess.compile_cache[compile_key]( + out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse + ) + + +_bwd_preprocess.compile_cache = get_jit_cache("bwd_pre") + + +def _compile_bwd_postprocess( + dtype, hdim, block_size, num_threads, atom_layout, swap_ab, + has_cuseqlens_q, has_seqused_q, + use_2cta_instrs, cluster_size, arch, +): + """Compile bwd postprocess kernel using cute fake tensors.""" + mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( + dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False + ) + batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() + batchp1 = cute.sym_int() + mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None + mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, + use_2cta_instrs=use_2cta_instrs, + cluster_size=cluster_size, + ) + return cute.compile( + fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def _bwd_postprocess_convert( + accum, output, scale, + cu_seqlens, seqused, + arch, dtype, hdim, block_size, num_threads, + atom_layout, swap_ab, + use_2cta_instrs=False, cluster_size=1, +): + """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" + compile_key = ( + dtype, hdim, block_size, num_threads, atom_layout, swap_ab, + cu_seqlens is not None, seqused is not None, + use_2cta_instrs, cluster_size, arch, + ) + if compile_key not in _bwd_postprocess_convert.compile_cache: + _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) + if not is_fake_mode(): + _bwd_postprocess_convert.compile_cache[compile_key]( + accum, output, scale, cu_seqlens, seqused, + ) + + +_bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post") def _flash_attn_bwd( @@ -592,31 +1005,74 @@ def _flash_attn_bwd( mask_mod: Optional[Callable] = None, aux_tensors: Optional[list[torch.Tensor]] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + dlse: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - compute_capability = _get_device_capability() - assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + arch = _get_device_arch() + assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" + sparse_q = None + if block_sparse_tensors is not None and arch // 10 == 9: + sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128 - if compute_capability == 9: - m_block_size = 80 if not causal else 64 - n_block_size = 128 - num_stages_Q = 2 - num_stages_dO = 2 - num_stages_PdS = 2 - SdP_swapAB = True + num_head, head_dim = q.shape[-2:] + head_dim_v = v.shape[-1] + + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right + ) + + if arch // 10 == 12: + # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps). + m_block_size = 64 + n_block_size = 64 + if head_dim <= 64: + num_stages_Q = 2 + num_stages_dO = 2 + else: + num_stages_Q = 1 + num_stages_dO = 1 + SdP_swapAB = False dKV_swapAB = False - dQ_swapAB = not causal - AtomLayoutMSdP = 1 - AtomLayoutNdKV = 2 - AtomLayoutMdQ = 1 + dQ_swapAB = False + AtomLayoutMSdP = 4 + AtomLayoutNdKV = 4 + AtomLayoutMdQ = 4 + V_in_regs = False + cluster_size = 1 + use_2cta_instrs = False + num_threads = 128 + assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0" + assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0" + assert mask_mod is None, "mask_mod backward not supported on SM 12.0" + assert deterministic is False, "deterministic backward not supported on SM 12.0" + elif arch // 10 == 9: + cfg = _tile_size_bwd_sm90( + head_dim, + head_dim_v, + causal, + local, + sparse_block_size_q=sparse_q, + ) + m_block_size = cfg.m_block_size + n_block_size = cfg.n_block_size + num_stages_Q = cfg.num_stages_Q + num_stages_dO = cfg.num_stages_dO + num_stages_PdS = cfg.num_stages_PdS + SdP_swapAB = cfg.SdP_swapAB + dKV_swapAB = cfg.dKV_swapAB + dQ_swapAB = cfg.dQ_swapAB + AtomLayoutMSdP = cfg.AtomLayoutMSdP + AtomLayoutNdKV = cfg.AtomLayoutNdKV + AtomLayoutMdQ = cfg.AtomLayoutMdQ + num_threads = (cfg.num_wg + 1) * 128 + dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 - assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" + use_2cta_instrs = False is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) - assert not is_varlen, "varlen backward is not yet supported on sm90" else: m_block_size = 128 n_block_size = 128 @@ -624,13 +1080,20 @@ def _flash_attn_bwd( dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 - # TODO: support cluster size 2 - cluster_size = 1 + requested_disable_2cta = utils._get_disable_2cta_default() + disable_2cta = ( + requested_disable_2cta + or score_mod is not None + or score_mod_bwd is not None + or mask_mod is not None + ) + cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 + use_2cta_instrs = cluster_size==2 + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] - num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q @@ -648,31 +1111,14 @@ def _flash_attn_bwd( seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k num_head_kv = k.shape[-2] - head_dim_v = v.shape[-1] - - if causal: - window_size_right = 0 - local = window_size_left is not None or window_size_right is not None - if local: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - window_size_right = None - else: - causal, local = False, True use_block_sparsity = block_sparse_tensors is not None - - # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, - # the base block_m of 128 from forward, and block-sparse size for subtiling. - if compute_capability == 9 and use_block_sparsity: - m_block_size = 64 - # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) - dQ_swapAB = False - - # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 - subtile_factor = 2 + subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + num_n_blocks = seqlen_k_rounded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + seqlen_k_rounded = seqlen_k_rounded + n_block_size if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) @@ -707,14 +1153,16 @@ def _flash_attn_bwd( if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all( - t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) - ), "inputs must be on CUDA device" + if dlse is not None: + dlse = maybe_contiguous(dlse) + if not is_fake_mode(): + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if arch // 10 != 12: + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv @@ -722,9 +1170,6 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if compute_capability not in [10, 11]: - assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now" - if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" @@ -776,13 +1221,13 @@ def _flash_attn_bwd( dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads + # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses + # ragged TMA tensors for direct store, so no longer needs accum+postprocess. dKV_postprocess = qhead_per_kvhead > 1 if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: - num_n_blocks = seqlen_k_rounded // n_block_size - if cluster_size == 2 and num_n_blocks % cluster_size != 0: - seqlen_k_rounded = seqlen_k_rounded + n_block_size dk_accum = torch.zeros( batch_size, num_head_kv, @@ -798,12 +1243,10 @@ def _flash_attn_bwd( device=device, ) else: + cluster_tile_n = cluster_size * n_block_size total_k_rounded_padded = ( - (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n ) - num_n_blocks = total_k_rounded_padded // n_block_size - if cluster_size == 2 and num_n_blocks % cluster_size != 0: - total_k_rounded_padded = total_k_rounded_padded + n_block_size dk_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_rounded, @@ -818,79 +1261,30 @@ def _flash_attn_bwd( ) dtype = torch2cute_dtype_map[q.dtype] - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) if deterministic: - dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device) else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: - dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") - dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) else: dK_semaphore = None dV_semaphore = None - # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. - compile_key_pre = ( - compute_capability, - dtype, - head_dim_v, - m_block_size, - num_threads, - cu_seqlens_q is None, - seqused_q is None, + # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. + _bwd_preprocess( + out, dout, dpsum, lse, lse_log2, dq_accum, + cu_seqlens_q, seqused_q, dlse, + dtype, head_dim, head_dim_v, m_block_size, ) - if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: - o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] - dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) - ] - lse_tensor = to_cute_tensor(lse, assumed_align=4) - cu_seqlens_q_tensor, seqused_q_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, seqused_q) - ] - arch = compute_capability * 10 - fa_bwd_pre = FlashAttentionBackwardPreprocess( - dtype, - head_dim_v, - arch, - m_block_size, - num_threads=num_threads, - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( - fa_bwd_pre, - o_tensor, - do_tensor, - dpsum_tensor, - lse_tensor, - lse_log2_tensor, - dq_accum_tensor, - cu_seqlens_q_tensor, - seqused_q_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - out, - dout, - dpsum, - lse, - lse_log2, - dq_accum, - cu_seqlens_q, - seqused_q, - current_stream, - ) - - # NB num_threads application for 3 kernels - # There are pre, main, post processing kernels, currenlty num_threads is only actually - # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do - # before cache key gen - num_threads = 384 + # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, + # SM100/SM110 uses default from function signature (384). + if arch // 10 not in [9, 12]: + num_threads = 384 # Backward kernel: compute dk, dv, dq_accum. score_mod_hash = utils.hash_callable(score_mod) if score_mod else False @@ -917,14 +1311,16 @@ def _flash_attn_bwd( subtile_factor=subtile_factor, ) - if compute_capability == 9: + if arch // 10 in [8, 9, 12]: compile_key = ( - compute_capability, + arch, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, + window_size_left is not None, + window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, @@ -939,6 +1335,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, + dQ_single_wg, + deterministic, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, @@ -949,10 +1347,14 @@ def _flash_attn_bwd( num_aux_tensors, use_block_sparsity, block_sparse_broadcast_pattern, + get_broadcast_dims(q), + get_broadcast_dims(k), + get_broadcast_dims(v), + get_broadcast_dims(dout), ) else: compile_key = ( - compute_capability, + arch, dtype, head_dim, head_dim_v, @@ -966,6 +1368,7 @@ def _flash_attn_bwd( num_threads, pack_gqa, cluster_size, + use_2cta_instrs, deterministic, score_mod_hash, score_mod_bwd_hash, @@ -977,6 +1380,10 @@ def _flash_attn_bwd( cu_seqlens_k is None, seqused_q is None, seqused_k is None, + get_broadcast_dims(q), + get_broadcast_dims(k), + get_broadcast_dims(v), + get_broadcast_dims(dout), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ @@ -998,51 +1405,56 @@ def _flash_attn_bwd( if t is not None else None for t in (dQ_semaphore, dK_semaphore, dV_semaphore) ] - fa_bwd_sm80 = FlashAttentionBackwardSm80( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - m_block_size, - n_block_size, - num_stages_Q, - num_stages_dO, - num_threads, - pack_gqa, - causal, - SdP_swapAB, - dKV_swapAB, - dQ_swapAB, - AtomLayoutMSdP, - AtomLayoutNdKV, - AtomLayoutMdQ, - V_in_regs=V_in_regs, - ) - if compute_capability == 9: - fa_bwd_obj = FlashAttentionBackwardSm90( + if arch // 10 in [8, 12]: + flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80 + fa_bwd_obj = flash_bwd_obj_cls( dtype, head_dim, head_dim_v, qhead_per_kvhead, - causal, m_block_size, n_block_size, num_stages_Q, num_stages_dO, - num_stages_PdS, + num_threads, + pack_gqa, + causal, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, - num_threads, + V_in_regs=V_in_regs, + ) + elif arch // 10 == 9: + fa_bwd_obj = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + is_local=local, + deterministic=deterministic, + tile_m=m_block_size, + tile_n=n_block_size, + Q_stage=num_stages_Q, + dO_stage=num_stages_dO, + PdS_stage=num_stages_PdS, + SdP_swapAB=SdP_swapAB, + dKV_swapAB=dKV_swapAB, + dQ_swapAB=dQ_swapAB, + AtomLayoutMSdP=AtomLayoutMSdP, + AtomLayoutNdKV=AtomLayoutNdKV, + AtomLayoutMdQ=AtomLayoutMdQ, + num_threads=num_threads, V_in_regs=V_in_regs, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, + dQ_single_wg=dQ_single_wg, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( @@ -1051,10 +1463,10 @@ def _flash_attn_bwd( is_causal=causal, is_local=local, qhead_per_kvhead=qhead_per_kvhead, - # tile_m=m_block_size, - # tile_n=n_block_size, + tile_m=m_block_size, + tile_n=n_block_size, cluster_size=cluster_size, - # cluster_size=1, + use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, score_mod=score_mod, score_mod_bwd=score_mod_bwd, @@ -1081,7 +1493,6 @@ def _flash_attn_bwd( dk_tensor if not dKV_postprocess else dk_accum_tensor, dv_tensor if not dKV_postprocess else dv_accum_tensor, softmax_scale, - current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, @@ -1094,169 +1505,74 @@ def _flash_attn_bwd( dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, - options="--enable-tvm-ffi", - ) - _flash_attn_bwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), - dout, - lse_log2, - dpsum, - dq_accum, - dk if not dKV_postprocess else dk_accum, - dv if not dKV_postprocess else dv_accum, - softmax_scale, - current_stream, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - None, # softcap - not yet supported in backward - window_size_left, - window_size_right, - dQ_semaphore, - dK_semaphore, - dV_semaphore, - aux_tensors, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, - ) - - num_threads = 256 if compute_capability == 9 else 128 - arch = compute_capability * 10 - # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 - compile_key_post = ( - compute_capability, - dtype, - head_dim, - m_block_size, - num_threads, - AtomLayoutMdQ, - dQ_swapAB, - cu_seqlens_q is None, - seqused_q is None, - ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dq_accum_tensor = to_cute_tensor(dq_accum) - dq_tensor = to_cute_tensor(dq) - cu_seqlens_q_tensor, seqused_q_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, seqused_q) - ] - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dq_accum_tensor, - dq_tensor, - softmax_scale, - cu_seqlens_q_tensor, - seqused_q_tensor, current_stream, options="--enable-tvm-ffi", ) - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum, - dq, - softmax_scale, - cu_seqlens_q, - seqused_q, - current_stream, - ) - - if dKV_postprocess: - # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 - compile_key_post = ( - compute_capability, - dtype, - head_dim, - n_block_size, - num_threads, - AtomLayoutNdKV, - dKV_swapAB, - cu_seqlens_k is None, - seqused_k is None, - ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dk_accum_tensor = to_cute_tensor(dk_accum) - dk_tensor = to_cute_tensor(dk) - cu_seqlens_k_tensor, seqused_k_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_k, seqused_k) - ] - arch = compute_capability * 10 - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dk_accum_tensor, - dk_tensor, - softmax_scale, - cu_seqlens_k_tensor, - seqused_k_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum, - dk, + if not is_fake_mode(): + _flash_attn_bwd.compile_cache[compile_key]( + q.detach(), + k.detach(), + v.detach(), + dout, + lse_log2, + dpsum, + dq_accum, + dk if not dKV_postprocess else dk_accum, + dv if not dKV_postprocess else dv_accum, softmax_scale, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, - current_stream, + None, # softcap - not yet supported in backward + window_size_left, + window_size_right, + dQ_semaphore, + dK_semaphore, + dV_semaphore, + aux_tensors, + normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, ) - compile_key_post = ( - compute_capability, - dtype, - head_dim_v, - n_block_size, - num_threads, - AtomLayoutNdKV, - dKV_swapAB, - cu_seqlens_k is None, - seqused_k is None, + + if arch // 10 == 9: + # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg + num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 + num_threads_post_dKV = cfg.num_wg * 128 + else: + num_threads_post_dQ = 128 + num_threads_post_dKV = 128 + + # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 + _bwd_postprocess_convert( + dq_accum, dq, softmax_scale, + cu_seqlens_q, seqused_q, + arch, dtype, head_dim, m_block_size, num_threads_post_dQ, + AtomLayoutMdQ, dQ_swapAB, + use_2cta_instrs=use_2cta_instrs, cluster_size=1, + ) + + if dKV_postprocess: + # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 + _bwd_postprocess_convert( + dk_accum, dk, softmax_scale, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dv_accum_tensor = to_cute_tensor(dv_accum) - dv_tensor = to_cute_tensor(dv) - cu_seqlens_k_tensor, seqused_k_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_k, seqused_k) - ] - arch = compute_capability * 10 - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dv_accum_tensor, - dv_tensor, - cutlass.Float32(1.0), - cu_seqlens_k_tensor, - seqused_k_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum, - dv, - 1.0, - cu_seqlens_k, - seqused_k, - current_stream, + # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 + _bwd_postprocess_convert( + dv_accum, dv, 1.0, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, ) return dq, dk, dv -_flash_attn_bwd.compile_cache_pre = {} -_flash_attn_bwd.compile_cache = {} -_flash_attn_bwd.compile_cache_post = {} +_flash_attn_bwd.compile_cache = get_jit_cache("bwd") class FlashAttnFunc(torch.autograd.Function): @@ -1280,6 +1596,7 @@ def forward( mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, + return_lse: bool = False, ): # Only create block sparse tensors if at least one block sparse parameter is provided block_sparse_tensors = None @@ -1304,7 +1621,8 @@ def forward( num_splits=num_splits, pack_gqa=pack_gqa, mask_mod=mask_mod, - block_sparse_tensors=block_sparse_tensors + block_sparse_tensors=block_sparse_tensors, + return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -1312,11 +1630,17 @@ def forward( ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic + ctx.return_lse = return_lse + ctx.set_materialize_grads(False) return out, lse @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, dlse): q, k, v, out, lse = ctx.saved_tensors + if not ctx.return_lse: + dlse = None + if dout is None: + dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, @@ -1330,6 +1654,7 @@ def backward(ctx, dout, *args): window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, + dlse=dlse, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine @@ -1358,6 +1683,7 @@ def forward( deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, + return_lse: bool = False, ): out, lse = _flash_attn_fwd( q, @@ -1380,6 +1706,7 @@ def forward( pack_gqa=pack_gqa, score_mod=score_mod, aux_tensors=aux_tensors, + return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -1389,12 +1716,18 @@ def forward( ctx.deterministic = deterministic ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k + ctx.return_lse = return_lse + ctx.set_materialize_grads(False) return out, lse @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors assert ctx.softcap == 0.0 + if not ctx.return_lse: + dlse = None + if dout is None: + dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, @@ -1414,6 +1747,7 @@ def backward(ctx, dout, *args): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, + dlse=dlse, ) return dq, dk, dv, *((None,) * 20) @@ -1437,6 +1771,7 @@ def flash_attn_func( mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, + return_lse: bool = False, ): return FlashAttnFunc.apply( q, @@ -1456,6 +1791,7 @@ def flash_attn_func( mask_block_cnt, mask_block_idx, block_size, + return_lse, ) @@ -1480,6 +1816,7 @@ def flash_attn_varlen_func( deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, + return_lse: bool = False, ): return FlashAttnVarlenFunc.apply( q, @@ -1502,6 +1839,64 @@ def flash_attn_varlen_func( deterministic, score_mod, aux_tensors, + return_lse, + ) + + +def _compile_fwd_combine( + dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, + has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, +): + """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" + sym = cute.sym_int + div = 128 // dtype_partial.width # 16-byte alignment in elements + + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + if not fa_combine.can_implement( + dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, + num_threads=256, + ): + raise RuntimeError( + "FlashAttention combine kernel cannot be implemented with given parameters" + ) + + if has_cu_seqlens: + # Varlen: (num_splits, total_q, nheads, headdim) + num_splits, total_q, nheads = sym(), sym(), sym() + mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div) + mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1) + mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div) + mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None + else: + # Batched: (num_splits, batch, seqlen, nheads, headdim) + num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym() + mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div) + mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2) + mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div) + mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None + batch = mO_partial.shape[1] + + batch_for_1d = batch if not has_cu_seqlens else sym() + batchp1 = sym() + mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None + mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None + mNumSplitsDynamic = None # Not parametrized in compile_key + mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None + mSemaphore = None # Not parametrized in compile_key + + return cute.compile( + fa_combine, + mO_partial, mLSE_partial, mO, mLSE, + mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", ) @@ -1513,6 +1908,7 @@ def _flash_attn_fwd_combine( cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + varlen_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. @@ -1536,27 +1932,13 @@ def _flash_attn_fwd_combine( Returns: None """ - # Input validation - assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" - assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( "out_partial must be fp16, bf16, or fp32" ) - assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" - assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" - assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" - assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" - assert lse_partial.shape == out_partial.shape[:-1] - + if not is_fake_mode(): + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 - - # Validate output tensor shapes and types - assert out.shape == out_partial.shape[1:], "out shape mismatch" - if lse is not None: - assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" - assert lse.dtype == torch.float32, "lse must be fp32" - # Validate optional tensors for t, name in [ (cu_seqlens, "cu_seqlens"), @@ -1564,10 +1946,9 @@ def _flash_attn_fwd_combine( (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), ]: if t is not None: - assert t.dtype == torch.int32, f"{name} must be int32" - assert t.is_cuda, f"{name} must be on CUDA device" + if not is_fake_mode(): + assert t.is_cuda, f"{name} must be on CUDA device" assert t.is_contiguous(), f"{name} must be contiguous" - head_dim = out_partial.shape[-1] num_splits = out_partial.shape[0] assert num_splits <= 256 @@ -1576,104 +1957,41 @@ def _flash_attn_fwd_combine( k_block_size = 64 if head_dim <= 64 else 128 # We want kBlockM to be as small as possible to maximize parallelism. # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) log_max_splits = max(math.ceil(math.log2(num_splits)), 4) - if m_block_size == 8: + if tile_m == 8: # If kBlockM == 8 then the minimum number of splits is 32. # TODO: we can deal w this by using 128 threads instead log_max_splits = max(log_max_splits, 5) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Create combine kernel configuration dtype = torch2cute_dtype_map[out.dtype] dtype_partial = torch2cute_dtype_map[out_partial.dtype] - compile_key = ( dtype, dtype_partial, head_dim, - m_block_size, + tile_m, k_block_size, log_max_splits, cu_seqlens is not None, seqused is not None, lse is not None, + varlen_batch_idx is not None, ) - if compile_key not in _flash_attn_fwd_combine.compile_cache: - out_partial_tensor = to_cute_tensor( - out_partial, leading_dim=4 if not is_varlen else 3 + _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( + *compile_key ) - lse_partial_tensor = to_cute_tensor( - lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 - ) - out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) - lse_tensor = ( - to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) - if lse is not None - else None - ) - - optional_tensors = [ - to_cute_tensor(t, assumed_align=4, leading_dim=0) - if t is not None - else None - for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) - ] - cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( - optional_tensors - ) - fa_combine = FlashAttentionForwardCombine( - dtype=dtype, - dtype_partial=dtype_partial, - head_dim=head_dim, - m_block_size=m_block_size, - k_block_size=k_block_size, - log_max_splits=log_max_splits, - ) - - # Check if implementation is supported - if not fa_combine.can_implement( - dtype, - dtype_partial, - head_dim, - m_block_size, - k_block_size, - log_max_splits, - num_threads=256, - ): - raise RuntimeError( - "FlashAttention combine kernel cannot be implemented with given parameters" - ) - - _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( - fa_combine, - out_partial_tensor, - lse_partial_tensor, - out_tensor, - lse_tensor, - cu_seqlens_tensor, - seqused_tensor, - num_splits_dynamic_tensor, - semaphore_tensor, - current_stream, - options="--enable-tvm-ffi", + if not is_fake_mode(): + _flash_attn_fwd_combine.compile_cache[compile_key]( + out_partial, lse_partial, out, lse, + cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, + semaphore_to_reset, ) - _flash_attn_fwd_combine.compile_cache[compile_key]( - out_partial, - lse_partial, - out, - lse, - cu_seqlens, - seqused, - num_splits_dynamic_ptr, - semaphore_to_reset, - current_stream, - ) -_flash_attn_fwd_combine.compile_cache = {} +_flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine") def flash_attn_combine( @@ -1683,6 +2001,7 @@ def flash_attn_combine( out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, + varlen_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. @@ -1702,6 +2021,9 @@ def flash_attn_combine( out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch + varlen_batch_idx: Optional mapping from virtual batch index to real batch index + (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers + that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. Returns: @@ -1718,32 +2040,19 @@ def flash_attn_combine( """ # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" - assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" - assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" - # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 - if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, total_q, num_heads), ( - "lse_partial shape mismatch for varlen" - ) batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( - "lse_partial shape mismatch" - ) - # Determine output dtype if out_dtype is None: out_dtype = out_partial.dtype - # Create output if not provided device = out_partial.device if out is None: @@ -1753,20 +2062,15 @@ def flash_attn_combine( out = torch.empty( batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device ) - # Create lse output only if requested if return_lse: if is_varlen: - lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( - 0, 1 - ) + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device) else: - lse = torch.empty( - batch_size, num_heads, seqlen, dtype=torch.float32, device=device - ).transpose(1, 2) + lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device) + lse = lse.transpose(-1, -2) else: lse = None - _flash_attn_fwd_combine( out_partial, lse_partial, @@ -1774,5 +2078,6 @@ def flash_attn_combine( lse, cu_seqlens, seqused, + varlen_batch_idx=varlen_batch_idx, ) return out, lse diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index f5e3c5f46f3..6b5ca16c6f5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -1,109 +1,102 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Callable +from typing import Optional, Callable, TypeAlias from dataclasses import dataclass import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Uint32, const_expr from quack import layout_utils import flash_attn.cute.utils as utils from flash_attn.cute.seqlen_info import SeqlenInfoQK +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + @cute.jit -def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: - # Bit manipulation, compiles down to the R2P instruction - # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using. - # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., - # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - if const_expr(arch == 90): - col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) - else: - col_limit_transformed = col_limit - ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) - # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - # Don't need to clamp to 32 since the shr.u32 instruction does that already - col_limit_right_s = max(col_limit_transformed - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << col_limit_right_s) - 1 - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - if const_expr(rank1): - X[c] = X[c] if in_bound else -Float32.inf - # This is the equivalent of: - # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf - else: - for r in cutlass.range_constexpr(cute.size(X.shape[0])): - X[r, c] = X[r, c] if in_bound else -Float32.inf +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + """32-bit R2P bitmask keeping positions < limit (exclusive upper bound). + + Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). + Uses inline PTX to avoid shift-by-type-width UB. + """ + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) @cute.jit -def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None: - # Bit manipulation, compiles down to the R2P instruction - # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127 - # or 0, 1, ..., 15, 32, ..., 47, 64, ... - # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - # Here we hardcode for the case of 2 warp groups. - num_wg = 2 - row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min( - row_limit_top % (num_rep * num_wg), num_rep - ) - ncol = cute.size(X.shape) - # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - row_limit_top_s = max(row_limit_top_transformed - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << row_limit_top_s) - 1 - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - out_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - X[c] = -Float32.inf if out_bound else X[c] - # tidx = cute.arch.thread_idx()[0] % 256 - # if tidx == 128: - # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + """32-bit R2P bitmask keeping positions >= limit (inclusive lower bound). + + Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). + Uses inline PTX to avoid shift-by-type-width UB. + """ + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) @cute.jit -def mask_r2p_dual_bound( +def mask_r2p_lambda( X: cute.Tensor, - col_limit_left: Int32, # Inclusive lower bound - col_limit_right: Int32, # Exclusive upper bound + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, ) -> None: - """ - Dual-bound masking using two bitmasks for SM100, following mask_r2p. - Masks elements where: NOT (col_limit_left <= col < col_limit_right) + """Apply R2P masking with a custom bitmask generator. - Uses bit manipulation to create a range mask: - mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1 - mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1 - mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1 + mask_gen_fn(chunk_idx: constexpr int) -> Uint32: + Returns a 32-bit bitmask for the chunk. Bit i set means column + chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf. """ - ncol = const_expr(cute.size(X.shape)) + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + # 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep). + CHUNK_SIZE = MASK_R2P_CHUNK_SIZE + for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)): + mask = mask_gen_fn(s) + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - right_s = max(col_limit_right - s * 24, 0) - left_s = max(col_limit_left - s * 24, 0) - # otherwise cute dsl complains about python int too large to convert into c long - right_s = min(right_s, 24) - left_s = min(left_s, 24) +@cute.jit +def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32: + """Transform SM90 MMA column coordinate to R2P element index. - # bits (right-1)..left are 1 - mask_right = (1 << right_s) - 1 - mask_left = (1 << left_s) - 1 - mask_range = mask_right & ~mask_left + SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ... + Element indices are contiguous: 0, 1, 2, 3, 4, 5, ... + This converts a column-space threshold to element-space for r2p_bitmask_below/above. + """ + return col_limit // 8 * 2 + min(col_limit % 8, 2) - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask_range & (1 << i)) - c = s * 24 + i - X[c] = X[c] if in_bound else -Float32.inf + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + """Convert a row coordinate to an R2P element index in the warp-group interleaved layout. + + In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom + distributes rows in an interleaved pattern: elements 0..num_rep-1 map to + rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to + rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on. + Row-coordinate thresholds (causal limits, window bounds, uih_len) must be + converted to element indices before use with r2p_bitmask_above/below. + + Rows not owned by this thread (in the gap between warp groups) are clamped + to the boundary element index, which is safe because R2P thresholds are + monotonic. + + Example with num_rep=16, num_wg=2: + row 0 -> elem 0, row 15 -> elem 15, + row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped), + row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31. + """ + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) @dataclass(frozen=True) @@ -161,8 +154,7 @@ def apply_mask( seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): - # The compiler now choses not to use R2P - r2p = const_expr(False and not self.swap_AB) + r2p = const_expr(not self.swap_AB) if const_expr(not r2p): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): @@ -170,7 +162,8 @@ def apply_mask( for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: - mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit) + mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s)) elif const_expr( not mask_causal and not mask_local and mask_mod is not None @@ -272,7 +265,12 @@ def apply_mask( else acc_S_mn[r, c] ) else: - mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True) + col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right) + mask_r2p_lambda( + acc_S_mn[r, None], + lambda s: r2p_bitmask_below(col_limit_r2p, s), + rank1=True, + ) else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right @@ -284,6 +282,7 @@ def apply_mask( if const_expr(self.window_size_left is not None) else None ) + r2p_local = const_expr(not self.swap_AB) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m @@ -302,13 +301,22 @@ def apply_mask( if const_expr(self.window_size_left is not None) else 0 ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) - # traverse column index. - for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - col_idx = t0ScS_mn[0, c][1] - # only consider the column index, so the row index sets to 0. - if col_idx >= col_limit_right or col_idx < col_limit_left: - acc_S_mn[r, c] = -Float32.inf + if const_expr(not r2p_local): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col_idx = t0ScS_mn[0, c][1] + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -Float32.inf + else: + col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right) + col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left) + + def mask_gen_fn(s: int) -> Uint32: + return r2p_bitmask_below( + col_limit_right_r2p, s + ) & r2p_bitmask_above(col_limit_left_r2p, s) + + mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True) else: # swap_AB assert self.qhead_per_kvhead_packgqa == 1 thr_row_offset = tScS_mn[0][ROW] @@ -338,11 +346,18 @@ def apply_mask( # column, by setting row limit to be self.tile_m. row_limit_top = ( self.tile_m - if col0 >= seqlenk_col_limit - else col0 - causal_row_offset - self.window_size_right + if col0 >= seqlenk_col_limit and mask_seqlen + else ( + col0 - causal_row_offset - self.window_size_right + if const_expr(self.window_size_right is not None) + else 0 + ) + ) + row_limit_bot = ( + col0 - causal_row_offset + self.window_size_left + if const_expr(self.window_size_left is not None) + else self.tile_m ) - # TODO: do we need col_limit_sink? - row_limit_bot = col0 - causal_row_offset + self.window_size_left for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): row_idx = t0ScS_mn[r, 0][ROW] acc_S_mn[r, c] = ( @@ -374,6 +389,7 @@ def apply_mask_sm100( acc_shape = (self.tile_m, self.tile_n) cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS = thr_mma.partition_C(cS) + tScS = tScS[(None, None), 0, 0] tScS_t2r = thr_tmem_load.partition_D(tScS) # To handle edge cases of completely masked out rows where n_block_max = 0, # we treat negative n_blocks as 0th n_block @@ -391,7 +407,11 @@ def apply_mask_sm100( # For some reason the 2 lines above generate really bad SASS acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: - mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(seqlenk_col_limit, s), + rank1=True, + ) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # Block sparse case w/ mask_mod @@ -444,12 +464,12 @@ def apply_mask_sm100( acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local - causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q + causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa if const_expr(mask_causal): - col_limit_right = row_idx + causal_row_offset + col_limit_right = row_idx + causal_row_offset + 1 if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: @@ -459,15 +479,19 @@ def apply_mask_sm100( for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] else: - mask_r2p(acc_S, col_limit_right, arch=100, rank1=True) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit_right, s), + rank1=True, + ) else: local_row_offset_right = ( - causal_row_offset + self.window_size_right + causal_row_offset + 1 + self.window_size_right if const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( - causal_row_offset - 1 - self.window_size_left + causal_row_offset - self.window_size_left if const_expr(self.window_size_left is not None) else None ) @@ -492,8 +516,15 @@ def apply_mask_sm100( else acc_S[i] ) else: - # XOR-based R2P dual bound masking - mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right) + # Dual-bound R2P masking for SM100. + # Masks elements where: NOT (col_limit_left <= col < col_limit_right) + + def mask_gen_fn(s: int) -> Uint32: + return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above( + col_limit_left, s + ) + + mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True) @cute.jit def apply_mask_sm100_transposed( @@ -528,7 +559,7 @@ def apply_mask_sm100_transposed( assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 - assert t0ScS_t2r[0][COL] == 0, "col0 == 0" + # assert t0ScS_t2r[0][COL] == 0, "col0 == 0" # tmp comment for 2-cta bwd thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset @@ -633,7 +664,13 @@ def apply_mask_sm100_transposed( ) else: num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 - mask_r2p_transposed(acc_S, row_limit_top, num_rep) + num_wg = 2 + row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) else: if const_expr(self.window_size_right is not None): row_limit_top = causal_offset - self.window_size_right @@ -644,9 +681,31 @@ def apply_mask_sm100_transposed( if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: row_limit_top = self.tile_m - for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): - row_idx = t0ScS_t2r[i][ROW] - local_mask = row_idx < row_limit_top - if const_expr(self.window_size_left is not None): - local_mask |= row_idx > row_limit_bot - acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] + r2p = True + if const_expr(not r2p): + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] + else: + + def mask_gen_fn(s: int) -> Uint32: + num_rep = cute.size(tScS_t2r, mode=[0]) + num_wg = 2 + + row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) + mask = r2p_bitmask_above(row_limit, s) + + if const_expr(self.window_size_left is not None): + row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg) + mask = mask & r2p_bitmask_below(row_limit_bottom, s) + + return mask + + mask_r2p_lambda( + acc_S, + mask_gen_fn, + rank1=True, + ) diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py index 16336c34686..ab8dd098b92 100644 --- a/flash_attn/cute/mma_sm100_desc.py +++ b/flash_attn/cute/mma_sm100_desc.py @@ -189,11 +189,7 @@ class LayoutType(IntEnum): # occupies the top-3 bits [61:64) def _layout_type(swizzle: cute.Swizzle) -> LayoutType: - # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ - # Swizzle string has the form "S" - swz_str = str(swizzle) - inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' - B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] + B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift if M == 4: # Swizzle<*,4,3> if S != 3: @@ -289,3 +285,12 @@ def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: # 14 bits, remove 4 LSB (bits 0-13 in desc) return (start_addr.toint() & 0x3FFFF) >> 4 + + +def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: + sA_swizzle = sA.iterator.type.swizzle_type + return make_smem_desc_base( + cute.recast_layout(128, sA.element_type.width, sA.layout[0]), + sA_swizzle, + major, + ) diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py index 777c44079a0..dd0d1988960 100644 --- a/flash_attn/cute/named_barrier.py +++ b/flash_attn/cute/named_barrier.py @@ -12,6 +12,19 @@ class NamedBarrierFwd(enum.IntEnum): PEmpty = enum.auto() +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + + class NamedBarrierBwd(enum.IntEnum): Epilogue = enum.auto() WarpSchedulerWG1 = enum.auto() @@ -20,8 +33,10 @@ class NamedBarrierBwd(enum.IntEnum): PdS = enum.auto() dQFullWG0 = enum.auto() dQFullWG1 = enum.auto() + dQFullWG2 = enum.auto() dQEmptyWG0 = enum.auto() dQEmptyWG1 = enum.auto() + dQEmptyWG2 = enum.auto() class NamedBarrierBwdSm100(enum.IntEnum): @@ -29,3 +44,4 @@ class NamedBarrierBwdSm100(enum.IntEnum): EpilogueWG2 = enum.auto() Compute = enum.auto() dQaccReduce = enum.auto() + TmemPtr = enum.auto() diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index 8bedc37c075..e87df018671 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -1,25 +1,123 @@ # Copyright (c) 2025, Tri Dao. +from dataclasses import dataclass +from typing import Union, Tuple import cutlass import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync + from quack import layout_utils import flash_attn.cute.utils as utils +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def make_packgqa_tiled_tma_atom( + op: cute.atom.CopyOp, + gmem_tensor: cute.Tensor, + smem_layout: Union[cute.Layout, cute.ComposedLayout], + cta_tiler: Tuple[int, int], + qhead_per_kvhead: int, + head_idx: int, +): + # This packing and unpacking of the layout is so that we keep the same TMA dimension as usual. + # e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to + # ((nheads, seqlen), d, b). + # If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA. + # Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b) + gmem_tensor = layout_utils.select( + gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))] + ) + gmem_tensor = cute.group_modes(gmem_tensor, 0, 2) + assert cta_tiler[0] % qhead_per_kvhead == 0, ( + "CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead" + ) + tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( + op, + gmem_tensor, + smem_layout, + ((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]), # No mcast + ) + # Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b) + T = tma_tensor + shape_packed = ( + (qhead_per_kvhead, T.shape[0][1]), + *[T.shape[i] for i in range(1, head_idx)], + T.shape[0][0] // qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx, len(T.shape))], + ) + stride_packed = ( + *[T.stride[i] for i in range(head_idx)], + T.stride[0][0] * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx, len(T.shape))], + ) + tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + return tma_atom, tma_tensor + + +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass class PackGQA: - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - head_dim_padded: cutlass.Constexpr[int], - check_hdim_oob: cutlass.Constexpr[bool], - qhead_per_kvhead: cutlass.Constexpr[bool], - ): - self.m_block_size = m_block_size - self.head_dim_padded = head_dim_padded - self.check_hdim_oob = check_hdim_oob - self.qhead_per_kvhead = qhead_per_kvhead + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] @cute.jit def compute_ptr( diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py index e2d2d84433d..bf11acbc24e 100644 --- a/flash_attn/cute/paged_kv.py +++ b/flash_attn/cute/paged_kv.py @@ -7,7 +7,7 @@ from cutlass import Int32, const_expr from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import ParamsBase +from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor import math @@ -28,6 +28,9 @@ class PagedKVManager(ParamsBase): head_dim_padded: cutlass.Constexpr[Int32] head_dim_v_padded: cutlass.Constexpr[Int32] + arch: cutlass.Constexpr[Int32] + v_gmem_transposed: cutlass.Constexpr[bool] + gmem_threads_per_row: cutlass.Constexpr[Int32] page_entry_per_thread: Int32 async_copy_elems: Int32 @@ -55,7 +58,11 @@ def create( head_dim_v_padded: cutlass.Constexpr[Int32], num_threads: cutlass.Constexpr[Int32], dtype: Type[cutlass.Numeric], + arch: cutlass.Constexpr[int] = 100, ): + # SM100 transposes V in gmem to (dv, page_size, num_pages); + # SM90 keeps V as (page_size, dv, num_pages), same layout as K. + v_gmem_transposed = arch != 90 universal_copy_bits = 128 async_copy_elems = universal_copy_bits // dtype.width dtype_bytes = dtype.width // 8 @@ -97,7 +104,8 @@ def create( else: cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) tVcV = gmem_thr_copy_KV.partition_S(cV) - tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + # When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1]) return PagedKVManager( mPageTable, @@ -111,6 +119,8 @@ def create( num_threads, head_dim_padded, head_dim_v_padded, + arch, + v_gmem_transposed, gmem_threads_per_row, page_entry_per_thread, async_copy_elems, @@ -146,13 +156,17 @@ def load_page_table(self, n_block: Int32): @cute.jit def compute_X_ptr(self, K_or_V: str): tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64) + mX = self.mK_paged if const_expr(K_or_V == "K") else self.mV_paged + # K is always (page_size, d, num_pages). V matches K when not transposed, + # but is (dv, page_size, num_pages) when transposed (SM100). + transposed = const_expr(K_or_V == "V" and self.v_gmem_transposed) for i in cutlass.range(self.page_entry_per_thread, unroll=1): page = self.tPrPage[i] page_offset = self.tPrPageOffset[i] - if const_expr(K_or_V == "K"): - tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint() + if const_expr(transposed): + tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint() else: - tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint() + tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint() return tPrXPtr @cute.jit @@ -161,18 +175,24 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): tPrXPtr = self.compute_X_ptr(K_or_V) - # Finesse sX layout to be (M, N). - sX_pi = cute.make_tensor( - sX.iterator, - cute.make_layout( - (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), - stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), - ), - ) + if const_expr(self.arch == 90): + # SM90: sX is already stage-sliced by caller (sK[None, None, stage]). + # Flatten hierarchical modes to get (n_block_size, head_dim). + sX_pi = cute.group_modes(sX, 0, 1) + # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA) + else: + # SM100: Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) - if const_expr(K_or_V == "V"): - # Need to transpose V - sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + if const_expr(K_or_V == "V"): + # Transpose smem V to match transposed gmem layout + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded cX = cute.make_identity_tensor((self.n_block_size, head_dim)) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py index 32ac02b88b7..f8fdc1e8028 100644 --- a/flash_attn/cute/pipeline.py +++ b/flash_attn/cute/pipeline.py @@ -1,15 +1,38 @@ # Copyright (c) 2025, Tri Dao. -# import math from typing import Optional from dataclasses import dataclass +import cutlass.cute as cute from cutlass import Boolean, Int32, const_expr -from cutlass.cutlass_dsl import if_generate +from cutlass.cutlass_dsl import if_generate, dsl_user_op from cutlass.pipeline import PipelineState from cutlass.pipeline import PipelineUserType +from cutlass.pipeline import NamedBarrier as NamedBarrierOg +from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg +from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg +from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg + + +def _override_create(parent_cls, child_cls): + """Create a static factory that constructs parent_cls then re-classes to child_cls.""" + + @staticmethod + def create(*args, **kwargs): + obj = parent_cls.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", child_cls) + return obj + + return create + + +def _make_state(index: Int32, phase: Int32) -> PipelineState: + """Construct a PipelineState from index and phase (count/stages unused by callers).""" + return PipelineState(stages=0, count=Int32(0), index=index, phase=phase) class PipelineStateSimple: @@ -20,9 +43,6 @@ class PipelineStateSimple: """ def __init__(self, stages: int, phase_index: Int32): - # assert stages < 2**16 - # self._log_stages = int(math.log2(stages)) - # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2." self._stages = stages self._phase_index = phase_index @@ -31,13 +51,10 @@ def clone(self) -> "PipelineStateSimple": @property def stages(self) -> int: - # return 1 << self._log_stages return self._stages @property def index(self) -> Int32: - # return self._phase_index & 0xFFFF - # return self._phase_index & ((1 << self._log_stages) - 1) if const_expr(self._stages == 1): return Int32(0) else: @@ -45,11 +62,8 @@ def index(self) -> Int32: @property def phase(self) -> Int32: - # return self._phase_index >> 16 # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to # take modulo 2. But in practice just passing the phase in without modulo works fine. - # return (self._phase_index >> self._log_stages) % 2 - # return self._phase_index >> self._log_stages if const_expr(self._stages == 1): return self._phase_index else: @@ -61,21 +75,6 @@ def advance(self): else: self._phase_index += 1 - # def then_body(phase_index): - # # XOR the phase bit and set the index to 0 - # return (phase_index & 0xFFFF0000) ^ (1 << 16) - - # def else_body(phase_index): - # return phase_index - - # self._phase_index = if_generate( - # (self._phase_index & 0xFFFF) == self.stages, - # then_body, - # else_body, - # [self._phase_index], - # [Int32], - # ) - def __extract_mlir_values__(self): phase_index = self._phase_index return [phase_index.ir_value()] @@ -89,7 +88,6 @@ def make_pipeline_state(type: PipelineUserType, stages: int): Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. """ if type is PipelineUserType.Producer: - # return PipelineStateSimple(stages, Int32(1 << 16)) return PipelineStateSimple(stages, Int32(stages)) elif type is PipelineUserType.Consumer: return PipelineStateSimple(stages, Int32(0)) @@ -97,20 +95,213 @@ def make_pipeline_state(type: PipelineUserType, stages: int): assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." +# ── Shared helpers ─────────────────────────────────────────────────────────── + + +def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip): + """Optionally wrap a parent pipeline method call in sync_warp + elect_one.""" + if const_expr(elect_one): + if const_expr(syncwarp): + cute.arch.sync_warp() + with cute.arch.elect_one(): + parent_method(self, state, loc=loc, ip=ip) + else: + parent_method(self, state, loc=loc, ip=ip) + + +# ── Mixin: _w_index / _w_index_phase variants that delegate to parent ─────── +# Each parent class has PipelineState-based methods (producer_acquire, producer_commit, +# consumer_wait, consumer_release). The _w_index_phase variants just construct a +# PipelineState from (index, phase) and delegate. + + +class _PipelineIndexPhaseMixin: + """Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents.""" + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + # Call the parent's producer_acquire (which takes PipelineState) + self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.producer_commit(state, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + self.consumer_wait(state, try_wait_token, loc=loc, ip=ip) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.consumer_release(state, loc=loc, ip=ip) + + +# ── NamedBarrier ───────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class NamedBarrier(NamedBarrierOg): + create = _override_create(NamedBarrierOg, None) # patched below + + @dsl_user_op + def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id + index, + number_of_threads=self.num_threads, + loc=loc, + ip=ip, + ) + + +NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier) + + +# ── PipelineAsync ──────────────────────────────────────────────────────────── + + @dataclass(frozen=True) -class PipelineTmaAsync(PipelineTmaAsyncOg): +class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg): """ - Override producer_acquire to take in extra_tx_count parameter. + PipelineAsync with optional elect_one for producer_commit and consumer_release. + + When elect_one_*=True (set at create time), only one elected thread per warp + signals the barrier arrive. This is useful when the mask count is set to 1 per warp. + + Args (to create): + elect_one_commit: If True, only elected thread signals producer_commit. + syncwarp_before_commit: If True (default), issue syncwarp before elect_one. + elect_one_release: If True, only elected thread signals consumer_release. + syncwarp_before_release: If True (default), issue syncwarp before elect_one. + Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group). """ + _elect_one_commit: bool = False + _syncwarp_before_commit: bool = True + _elect_one_release: bool = False + _syncwarp_before_release: bool = True + @staticmethod - def create(*args, **kwargs): - obj = PipelineTmaAsyncOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - # obj.__class__ = PipelineTmaAsync - object.__setattr__(obj, "__class__", PipelineTmaAsync) + def create( + *args, + elect_one_commit: bool = False, + syncwarp_before_commit: bool = True, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, + ): + obj = PipelineAsyncOg.create(*args, **kwargs) + object.__setattr__(obj, "__class__", PipelineAsync) + object.__setattr__(obj, "_elect_one_commit", elect_one_commit) + object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) + return obj + + @dsl_user_op + def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.producer_commit, + self, + state, + self._elect_one_commit, + self._syncwarp_before_commit, + loc, + ip, + ) + + @dsl_user_op + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, + ) + + # _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate + # to producer_commit / consumer_release above. + + +# ── PipelineCpAsync ────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg): + _elect_one_release: bool = False + _syncwarp_before_release: bool = True + + @staticmethod + def create( + *args, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, + ): + obj = PipelineCpAsyncOg.create(*args, **kwargs) + object.__setattr__(obj, "__class__", PipelineCpAsync) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) return obj + @dsl_user_op + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineCpAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, + ) + + # _w_index variants inherited from _PipelineIndexPhaseMixin. + + +# ── PipelineTmaAsync ──────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg): + """Override producer_acquire to take in extra_tx_count parameter.""" + + @dsl_user_op def producer_acquire( self, state: PipelineState, @@ -136,20 +327,17 @@ def producer_acquire( self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) -@dataclass(frozen=True) -class PipelineTmaUmma(PipelineTmaUmmaOg): - """ - Override producer_acquire to take in extra_tx_count parameter. - """ +PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync) - @staticmethod - def create(*args, **kwargs): - obj = PipelineTmaUmmaOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - # obj.__class__ = PipelineTmaUmma - object.__setattr__(obj, "__class__", PipelineTmaUmma) - return obj +# ── PipelineTmaUmma ───────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg): + """Override producer_acquire to take in extra_tx_count parameter.""" + + @dsl_user_op def producer_acquire( self, state: PipelineState, @@ -187,3 +375,28 @@ def producer_acquire( loc=loc, ip=ip, ) + + +PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma) + + +# ── PipelineUmmaAsync ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg): + pass + + +PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync) + + +# ── PipelineAsyncUmma ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg): + pass + + +PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml index 9fc294d8940..2b0b60b42f1 100644 --- a/flash_attn/cute/pyproject.toml +++ b/flash_attn/cute/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["setuptools"] +requires = ["setuptools>=75", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [project] -name = "flash-attn-cute" -version = "0.1.0" +name = "flash-attn-4" +dynamic = ["version"] description = "Flash Attention CUTE (CUDA Template Engine) implementation" readme = "README.md" requires-python = ">=3.10" @@ -22,16 +22,17 @@ classifiers = [ ] dependencies = [ - "nvidia-cutlass-dsl>=4.4.0.dev1", + "nvidia-cutlass-dsl>=4.4.2", "torch", "einops", "typing_extensions", "apache-tvm-ffi>=0.1.5,<0.2", "torch-c-dlpack-ext", - "quack-kernels>=0.2.8", + "quack-kernels>=0.3.3", ] [project.optional-dependencies] +cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"] dev = [ "pytest", "ruff", @@ -45,6 +46,22 @@ Repository = "https://github.com/Dao-AILab/flash-attention" packages = ["flash_attn.cute"] package-dir = {"flash_attn.cute" = "."} +[tool.setuptools_scm] +root = "../.." +tag_regex = "^fa4-v(?P.+)$" +git_describe_command = "git describe --dirty --tags --long --match 'fa4-v*'" +fallback_version = "0.0.0" + +[[tool.uv.index]] +name = "pytorch-cu130" +url = "https://download.pytorch.org/whl/cu130" +explicit = true + +[tool.uv.sources] +torch = [ + { index = "pytorch-cu130", marker = "extra == 'cu13'" }, +] + [tool.ruff] line-length = 100 @@ -53,4 +70,5 @@ ignore = [ "E731", # do not assign a lambda expression, use a def "E741", # Do not use variables named 'I', 'O', or 'l' "F841", # local variable is assigned to but never used + "D102", # Missing docstring in public methods ] diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py index 6d8c6feb279..8e8fdf69ddc 100644 --- a/flash_attn/cute/seqlen_info.py +++ b/flash_attn/cute/seqlen_info.py @@ -5,6 +5,8 @@ import cutlass.cute as cute from cutlass import Int32, const_expr +from quack import copy_utils + """ This consolidates all the info related to sequence length. This is so that we can do all the gmem reads once at the beginning of each tile, rather than having to repeat these reads @@ -14,34 +16,61 @@ @dataclass(frozen=True) class SeqlenInfo: - offset: cutlass.Int32 - seqlen: cutlass.Int32 + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False @staticmethod def create( - batch_idx: cutlass.Int32, - seqlen_static: cutlass.Int32, + batch_idx: Int32, + seqlen_static: Int32, cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, ): offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) if const_expr(seqused is not None): seqlen = seqused[batch_idx] elif const_expr(cu_seqlens is not None): seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: seqlen = seqlen_static - return SeqlenInfo(offset, seqlen) + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) @dataclass(frozen=True) class SeqlenInfoQK: - offset_q: cutlass.Int32 - offset_k: cutlass.Int32 - padded_offset_q: cutlass.Int32 - padded_offset_k: cutlass.Int32 - seqlen_q: cutlass.Int32 - seqlen_k: cutlass.Int32 + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] @@ -49,27 +78,27 @@ class SeqlenInfoQK: @staticmethod def create( - batch_idx: cutlass.Int32, - seqlen_q_static: cutlass.Int32, - seqlen_k_static: cutlass.Int32, + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - tile_m: cutlass.Constexpr[cutlass.Int32] = 128, - tile_n: cutlass.Constexpr[cutlass.Int32] = 128, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, ): offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] padded_offset_q = ( 0 if const_expr(mCuSeqlensQ is None) - else (offset_q + batch_idx * tile_m) // tile_m * tile_m + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) ) padded_offset_k = ( 0 if const_expr(mCuSeqlensK is None) - else (offset_k + batch_idx * tile_n) // tile_n * tile_n + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) ) if const_expr(mSeqUsedQ is not None): seqlen_q = mSeqUsedQ[batch_idx] @@ -87,10 +116,6 @@ def create( if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - offset_k ) - has_cu_seqlens_q: int = mCuSeqlensQ is not None - has_cu_seqlens_k: int = mCuSeqlensK is not None - has_seqused_q: int = mSeqUsedQ is not None - has_seqused_k: int = mSeqUsedK is not None return SeqlenInfoQK( offset_q, offset_k, @@ -98,10 +123,10 @@ def create( padded_offset_k, seqlen_q, seqlen_k, - has_cu_seqlens_q, - has_cu_seqlens_k, - has_seqused_q, - has_seqused_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, ) def offset_batch_Q( @@ -110,16 +135,38 @@ def offset_batch_Q( batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, ) -> cute.Tensor: """Seqlen must be the first dimension of mQ""" - if const_expr(not self.has_cu_seqlens_q): - idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) - return mQ[idx] + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) else: - offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q - offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) - idx = (offset,) + (0,) * (cute.rank(mQ) - 1) - return cute.domain_offset(idx, mQ) + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) def offset_batch_K( self, @@ -127,12 +174,114 @@ def offset_batch_K( batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, ) -> cute.Tensor: """Seqlen must be the first dimension of mK""" - if const_expr(not self.has_cu_seqlens_k): - idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) - return mK[idx] + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) else: - offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k - idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) - return cute.domain_offset(idx, mK) + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) + + +@dataclass(frozen=True) +class SeqlenInfoQKNewK: + """Sequence length info for append-KV with left-padding and new K support. + + Extends SeqlenInfoQK with: + - leftpad_k: left padding for K (tokens to skip at the start of the KV cache) + - offset_k_new: offset into the new K tensor + - seqlen_k_og: original K length (before appending new K), excluding leftpad + - seqlen_k_new: length of new K to append + - seqlen_k: total K length (seqlen_k_og + seqlen_k_new) + - seqlen_rotary: position for rotary embedding computation + """ + + leftpad_k: Int32 + offset_q: Int32 + offset_k: Int32 + offset_k_new: Int32 + seqlen_q: Int32 + seqlen_k_og: Int32 + seqlen_k_new: Int32 + seqlen_k: Int32 + seqlen_rotary: Int32 + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + shape_K_new_0: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mCuSeqlensKNew: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mLeftpadK: Optional[cute.Tensor] = None, + mSeqlensRotary: Optional[cute.Tensor] = None, + ): + leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx] + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + if const_expr(mCuSeqlensK is not None): + offset_k = mCuSeqlensK[batch_idx] + leftpad_k + else: + offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0 + offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx] + # seqlen_q + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + elif const_expr(mCuSeqlensQ is not None): + seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx] + else: + seqlen_q = seqlen_q_static + # seqlen_k_og: original K length (excluding leftpad) + if const_expr(mSeqUsedK is not None): + seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k + elif const_expr(mCuSeqlensK is not None): + seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k + else: + seqlen_k_og = ( + seqlen_k_static - leftpad_k + if const_expr(mCuSeqlensQ is not None) + else seqlen_k_static + ) + # seqlen_k_new + if const_expr(mCuSeqlensKNew is None): + seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0 + else: + seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx] + seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new + + # seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided + if const_expr(mSeqlensRotary is not None): + seqlen_rotary = mSeqlensRotary[batch_idx] + else: + seqlen_rotary = seqlen_k_og + leftpad_k + return SeqlenInfoQKNewK( + leftpad_k, + offset_q, + offset_k, + offset_k_new, + seqlen_q, + seqlen_k_og, + seqlen_k_new, + seqlen_k, + seqlen_rotary, + ) diff --git a/flash_attn/cute/sm90_config_search.py b/flash_attn/cute/sm90_config_search.py new file mode 100644 index 00000000000..6c9584ea364 --- /dev/null +++ b/flash_attn/cute/sm90_config_search.py @@ -0,0 +1,402 @@ +"""Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v). + +Enumerates tile sizes, swap modes, atom layouts, and staging options. +Checks GMMA divisibility, register budget, and shared memory budget. + +Usage: + python flash_attn/cute/sm90_config_search.py --headdim 128 + python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128 + python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96 +""" + +import math + +# H100 hardware limits +SMEM_LIMIT = 224 * 1024 # 228 KB minus ~3 KB for LSE, dPsum, mbarriers +REG_LIMITS = {2: 216, 3: 128} # per-WG budget: 2WG=240-24, 3WG=160-32 +THREADS_PER_WG = 128 + + +def _divisors(n): + return [d for d in range(1, n + 1) if n % d == 0] + + +def _acc_regs(M, N, num_wg): + """Accumulator registers per thread per WG.""" + return M * N // (num_wg * THREADS_PER_WG) + + +def _check_mma(M, N, num_wg, atom_layout_m, swap_AB): + """Check MMA feasibility. Returns regs per WG, or None if infeasible. + + GMMA atom M=64. Swap exchanges (M, N) and atom layout. + Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8). + """ + if swap_AB: + M, N = N, M + atom_layout_m = num_wg // atom_layout_m + atom_layout_n = num_wg // atom_layout_m + if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0: + return None + return _acc_regs(M, N, num_wg) + + +def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False): + """Total SMEM read traffic for one MMA (all WGs combined). + + num_instr = (M_eff / 64) * wg_n instructions total. + Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16). + """ + num_instr = (M_eff // 64) * wg_n + A_per = 64 * K_red * 2 if not is_rs else 0 + B_per = (N_eff // wg_n) * K_red * 2 + return num_instr * (A_per + B_per) + + +# ============================================================================ +# Backward +# ============================================================================ + + +def _check_bwd_config( + hdim, + hdimv, + tile_m, + tile_n, + num_wg, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, +): + reg_limit = REG_LIMITS[num_wg] + + # MMA feasibility + regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB) + regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB) + regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB) + regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB) + if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)): + return None + + # Peak regs: max(S+dP, dQ) + dK + dV + total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV + if total_regs > reg_limit: + return None + + # SMEM + mma_dkv_is_rs = ( + AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB + ) + Q_stage, PdS_stage = 2, 1 + + for dO_stage in (2, 1): + sQ = tile_m * hdim * 2 * Q_stage + sK = tile_n * hdim * 2 + sV = tile_n * hdimv * 2 + sdO = tile_m * hdimv * 2 * dO_stage + sPdS = tile_m * tile_n * 2 * PdS_stage + sP = sPdS if not mma_dkv_is_rs else 0 + sdQaccum = tile_m * hdim * 4 + smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum + if smem <= SMEM_LIMIT: + break + else: + return None + + # SMEM traffic + def _swap(a, b, s): + return (b, a) if s else (a, b) + + def _wg_n(al_m, s): + return al_m if s else num_wg // al_m + + M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB) + wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB) + traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP) + traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP) + + wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB) + M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB) + traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) + M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB) + traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) + + M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB) + wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB) + traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ) + + traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0 + traffic_dS_store = tile_m * tile_n * 2 + traffic_dQ_smem = tile_m * hdim * 4 * 2 # store + TMA load + + smem_traffic = ( + traffic_S + + traffic_dP + + traffic_dV + + traffic_dK + + traffic_dQ + + traffic_P_store + + traffic_dS_store + + traffic_dQ_smem + ) + + return dict( + tile_m=tile_m, + tile_n=tile_n, + num_wg=num_wg, + Q_stage=Q_stage, + dO_stage=dO_stage, + PdS_stage=PdS_stage, + SdP_swapAB=SdP_swapAB, + dKV_swapAB=dKV_swapAB, + dQ_swapAB=dQ_swapAB, + AtomLayoutMSdP=AtomLayoutMSdP, + AtomLayoutNdKV=AtomLayoutNdKV, + AtomLayoutMdQ=AtomLayoutMdQ, + mma_dkv_is_rs=mma_dkv_is_rs, + regs_SdP=regs_SdP, + regs_dK=regs_dK, + regs_dV=regs_dV, + regs_dQ=regs_dQ, + total_regs=total_regs, + reg_limit=reg_limit, + smem_bytes=smem, + smem_kb=smem / 1024, + smem_traffic=smem_traffic, + smem_traffic_kb=smem_traffic / 1024, + smem_traffic_per_block=smem_traffic / (tile_m * tile_n), + ) + + +def find_feasible_bwd_configs( + head_dim, + head_dim_v=None, + tile_m_choices=(64, 80, 96, 112, 128), + tile_n_choices=(64, 80, 96, 112, 128), +): + if head_dim_v is None: + head_dim_v = head_dim + hdim = int(math.ceil(head_dim / 32) * 32) + hdimv = int(math.ceil(head_dim_v / 32) * 32) + + results = [] + for num_wg in (2, 3): + divs = _divisors(num_wg) + for tile_m in tile_m_choices: + for tile_n in tile_n_choices: + for SdP_swap in (False, True): + if (tile_n if SdP_swap else tile_m) % 64 != 0: + continue + for dKV_swap in (False, True): + if not dKV_swap and tile_n % 64 != 0: + continue + if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0): + continue + for dQ_swap in (False, True): + if (hdim if dQ_swap else tile_m) % 64 != 0: + continue + for a1 in divs: + for a2 in divs: + for a3 in divs: + cfg = _check_bwd_config( + hdim, + hdimv, + tile_m, + tile_n, + num_wg, + SdP_swap, + dKV_swap, + dQ_swap, + a1, + a2, + a3, + ) + if cfg is not None: + results.append(cfg) + + results.sort(key=lambda c: (-c["tile_n"], -c["tile_m"], c["smem_traffic_per_block"])) + return results + + +def print_bwd_configs(configs, max_results=20): + if not configs: + print("No feasible configs found!") + return + n = min(len(configs), max_results) + print(f"Found {len(configs)} feasible configs (showing top {n}):\n") + hdr = ( + f"{'wg':>2} {'tm':>3} {'tn':>3} " + f"{'SdP':>3} {'dKV':>3} {'dQ':>3} " + f"{'aSdP':>4} {'adKV':>4} {'adQ':>4} " + f"{'Qs':>2} {'dOs':>3} " + f"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3} " + f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" + ) + print(hdr) + print("-" * len(hdr)) + B = lambda b: "T" if b else "F" + for c in configs[:max_results]: + print( + f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " + f"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3} " + f"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4} " + f"{c['Q_stage']:>2} {c['dO_stage']:>3} " + f"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} " + f"{c['total_regs']:>4}/{c['reg_limit']:<3} " + f"{c['smem_kb']:>4.0f}K " + f"{c['smem_traffic_kb']:>6.0f}K " + f"{c['smem_traffic_per_block']:>6.1f}" + ) + + +# ============================================================================ +# Forward +# ============================================================================ + + +def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg): + reg_limit = REG_LIMITS[num_wg] + tile_m = num_wg * 64 + + if tile_n % 8 != 0: + return None + + regs_S = _acc_regs(tile_m, tile_n, num_wg) + regs_O = _acc_regs(tile_m, hdimv, num_wg) + regs_P = regs_S // 2 # bf16 = half of f32 + + if overlap_wg: + total_regs = regs_S + regs_P + regs_O + else: + total_regs = regs_S + regs_O + + if total_regs > reg_limit: + return None + + # SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS + sQ = tile_m * hdim * 2 + sK = tile_n * hdim * 2 * 2 + sV = tile_n * hdimv * 2 * 2 + sO = tile_m * hdimv * 2 + sP = tile_m * tile_n * 2 if not pv_is_rs else 0 + smem = max(sQ, sO) + sK + sV + sP + if smem > SMEM_LIMIT: + return None + + # SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1) + traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2) + A_pv = 64 * tile_n * 2 if not pv_is_rs else 0 + traffic_O = num_wg * (A_pv + hdimv * tile_n * 2) + traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0 + smem_traffic = traffic_S + traffic_O + traffic_P_store + + return dict( + tile_m=tile_m, + tile_n=tile_n, + num_wg=num_wg, + pv_is_rs=pv_is_rs, + overlap_wg=overlap_wg, + regs_S=regs_S, + regs_O=regs_O, + regs_P=regs_P, + total_regs=total_regs, + reg_limit=reg_limit, + smem_bytes=smem, + smem_kb=smem / 1024, + smem_traffic=smem_traffic, + smem_traffic_kb=smem_traffic / 1024, + smem_traffic_per_block=smem_traffic / (tile_m * tile_n), + ) + + +def find_feasible_fwd_configs( + head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192) +): + if head_dim_v is None: + head_dim_v = head_dim + hdim = int(math.ceil(head_dim / 32) * 32) + hdimv = int(math.ceil(head_dim_v / 32) * 32) + + results = [] + for num_wg in (2, 3): + for tile_n in tile_n_choices: + for pv_is_rs in (True, False): + for overlap_wg in (True, False): + cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg) + if cfg is not None: + results.append(cfg) + + results.sort(key=lambda c: (-c["tile_n"], c["smem_traffic_per_block"])) + return results + + +def print_fwd_configs(configs, max_results=20): + if not configs: + print("No feasible configs found!") + return + n = min(len(configs), max_results) + print(f"Found {len(configs)} feasible configs (showing top {n}):\n") + hdr = ( + f"{'wg':>2} {'tm':>3} {'tn':>3} " + f"{'RS':>2} {'olap':>4} " + f"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3} " + f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" + ) + print(hdr) + print("-" * len(hdr)) + B = lambda b: "T" if b else "F" + for c in configs[:max_results]: + print( + f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " + f"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4} " + f"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} " + f"{c['total_regs']:>4}/{c['reg_limit']:<3} " + f"{c['smem_kb']:>4.0f}K " + f"{c['smem_traffic_kb']:>6.0f}K " + f"{c['smem_traffic_per_block']:>6.1f}" + ) + + +# ============================================================================ +# CLI +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Search feasible SM90 MMA configs") + parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both") + parser.add_argument( + "--headdim", type=str, default="128", help="Head dim, or hdim-hdimv (e.g. 192-128)" + ) + parser.add_argument("--tile-m", type=str, default="64,80,96,112,128", help="Bwd tile_m choices") + parser.add_argument( + "--tile-n", + type=str, + default=None, + help="tile_n choices (default: fwd up to 192, bwd up to 128)", + ) + parser.add_argument("-n", "--num-results", type=int, default=30) + args = parser.parse_args() + + parts = args.headdim.split("-") + hdim = int(parts[0]) + hdimv = int(parts[1]) if len(parts) > 1 else hdim + + TN_FWD = "64,80,96,112,128,144,160,176,192" + TN_BWD = "64,80,96,112,128" + + if args.mode in ("fwd", "both"): + tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(",")) + print(f"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\n") + print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results) + print() + + if args.mode in ("bwd", "both"): + tm = tuple(int(x) for x in args.tile_m.split(",")) + tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(",")) + print(f"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\n") + print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index 354a2097cbe..eed55a0b721 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -11,7 +11,7 @@ from quack import layout_utils import flash_attn.cute.utils as utils -from flash_attn.cute.cute_dsl_utils import ParamsBase +from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -239,10 +239,9 @@ def apply_exp2_convert( self, acc_S_row: cute.Tensor, acc_S_row_converted: cute.Tensor, - e2e: cutlass.Constexpr[bool] = False, - e2e_freq: cutlass.Constexpr[int] = 16, - e2e_res: cutlass.Constexpr[int] = 4, - e2e_frg_limit: cutlass.Constexpr[int] = 1, + ex2_emu_freq: cutlass.Constexpr[int] = 0, + ex2_emu_res: cutlass.Constexpr[int] = 4, + ex2_emu_start_frg: cutlass.Constexpr[int] = 0, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" frg_tile = 32 @@ -257,12 +256,14 @@ def apply_exp2_convert( for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) - if cutlass.const_expr(not e2e): + if cutlass.const_expr(ex2_emu_freq == 0): acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) else: if cutlass.const_expr( - k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit + k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res + or j >= frg_cnt - 1 + or j < ex2_emu_start_frg ): acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) acc_S_row_frg[k + 1, j] = cute.math.exp2( diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py index 2897e64fc3d..6e3c40eb451 100644 --- a/flash_attn/cute/testing.py +++ b/flash_attn/cute/testing.py @@ -1,9 +1,13 @@ import math +from contextlib import nullcontext +from functools import wraps from typing import Optional import torch import torch.nn.functional as F from einops import rearrange, repeat +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode class IndexFirstAxis(torch.autograd.Function): @@ -63,8 +67,15 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None): all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() + in_fake_mode = active_fake_mode() is not None + if not in_fake_mode: + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + else: + # torch.nonzero and .item() are not supported in FakeTensorMode + batch_size, seqlen = attention_mask.shape + indices = torch.arange(batch_size * seqlen, device=hidden_states.device) + max_seqlen_in_batch = seqlen cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), @@ -421,3 +432,25 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def maybe_fake_tensor_mode(fake: bool = True): + """ + One way to populate/pre-compile cache is to use torch fake tensor mode, + which does not allocate actual GPU tensors but retains tensor shape/dtype + metadata for cute.compile. + """ + + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with FakeTensorMode() if fake else nullcontext(): + return fn(*args, **kwargs) + + return wrapper + + return decorator + + +def is_fake_mode() -> bool: + return active_fake_mode() is not None diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 36a5c6b75ec..3ee4bc8bab1 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -1,7 +1,8 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple -from dataclasses import dataclass, fields +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable +from dataclasses import dataclass try: from typing import override @@ -9,13 +10,78 @@ from typing_extensions import override import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState from cutlass._mlir import ir import cutlass.cute as cute from cutlass import Int32, const_expr +from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams + +from quack.cute_dsl_utils import ParamsBase import flash_attn.cute.utils as utils from flash_attn.cute.fast_math import clz -from cutlass.cute import FastDivmodDivisor + + +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `FlashAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) class WorkTileInfo(cutlass.utils.WorkTileInfo): @@ -29,28 +95,45 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": return WorkTileInfo(new_tile_idx, new_is_valid_tile) -@dataclass -class ParamsBase: - def __extract_mlir_values__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] - values, self._values_pos = [], [] - for obj in non_constexpr_fields: - obj_values = cutlass.extract_mlir_values(obj) - values += obj_values - self._values_pos.append(len(obj_values)) - return values +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... @dataclass @@ -73,6 +156,7 @@ class TileSchedulerArguments(ParamsBase): lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -85,6 +169,7 @@ class Params(ParamsBase): num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False @staticmethod def create( @@ -98,6 +183,7 @@ def create( FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, + args.use_cluster_idx, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): @@ -108,12 +194,26 @@ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": - blk_coord = cute.arch.block_idx() + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) # called by host @@ -126,8 +226,13 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) return ( - cute.round_up(params.num_block, params.cluster_shape_mn[0]), + grid_x, params.num_head * params.num_splits, params.num_batch, ) @@ -151,6 +256,10 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -171,17 +280,22 @@ def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler: @dataclass class Params(ParamsBase): - num_block_divmod: FastDivmodDivisor + num_block_cluster_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor - total_blocks: Int32 + total_blocks_cluster: Int32 + cluster_shape_m: cutlass.Constexpr[int] = 1 @staticmethod def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "StaticPersistentTileScheduler.Params": - total_blocks = args.num_block * args.num_head * args.num_batch + num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) + total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch return StaticPersistentTileScheduler.Params( - FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks + FastDivmodDivisor(num_block_cluster), + FastDivmodDivisor(args.num_head), + total_blocks_cluster, + cluster_shape_m=args.cluster_shape_mn[0], ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): @@ -191,15 +305,28 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": - tile_idx = cute.arch.block_idx()[0] + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": + if const_expr(cute.size(params.cluster_shape_m) == 1): + tile_idx = cute.arch.block_idx()[0] + else: + tile_idx = cute.arch.cluster_idx()[0] return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -209,15 +336,14 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) + max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m + grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m) + return (grid_x, Int32(1), Int32(1)) - # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: - hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod) + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) - is_valid = self._tile_idx < self.params.total_blocks - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) + is_valid = self._tile_idx < self.params.total_blocks_cluster return WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) @@ -229,7 +355,14 @@ def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): - self._tile_idx += cute.arch.grid_dim()[0] + if const_expr(self.params.cluster_shape_m == 1): + self._tile_idx += cute.arch.grid_dim()[0] + else: + self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -256,32 +389,41 @@ class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 num_block: Int32 + num_head: Int32 + num_batch: Int32 l2_minor: Int32 - num_block_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileLPTScheduler.Params": - # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit # swizzle is how many heads can fit in L2 - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) - # Seems faster if swizzle if a power of 2 + # Seems faster if swizzle is a power of 2 log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -289,37 +431,84 @@ def create( return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, l2_minor=Int32(swizzle), - num_block_divmod=FastDivmodDivisor(args.num_block), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmodDivisor( - max(num_hb_remainder, 1) - ), # don't divide by 0 + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), num_hb_quotient=Int32(num_hb_quotient), num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) + + @staticmethod + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) @staticmethod @cute.jit - def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -327,10 +516,40 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) return (params.total_blocks, params.num_splits, Int32(1)) + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + block_idx = self.params.num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) @@ -344,25 +563,45 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block - 1 - block + if const_expr(params.lpt): + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -370,10 +609,13 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return self.__class__(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) class SingleTileLPTBwdScheduler: @@ -397,8 +639,8 @@ def create( ) -> "SingleTileLPTBwdScheduler.Params": size_l2 = 50 * 1024 * 1024 size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size - # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 - size_one_dqaccum_head = 0 + size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + # size_one_dqaccum_head = 0 size_one_head = size_one_qdo_head + size_one_dqaccum_head log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) @@ -432,7 +674,16 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @@ -466,11 +717,12 @@ def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) - is_valid = self._tile_idx < params.total_blocks - bidx_in_cluster = cute.arch.block_in_cluster_idx() - block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] if cutlass.const_expr(params.spt): block = params.num_block - 1 - block + if cutlass.const_expr(params.cluster_shape_mn[0] > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + is_valid = self._tile_idx < params.total_blocks return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -482,6 +734,7 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -514,19 +767,39 @@ class Params(ParamsBase): lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ( + # if backward, this is qdo block size + kv_block_size = ( (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] ) + # if backward, add dqaccum block size to calculate swizzle + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) + assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, @@ -540,22 +813,66 @@ def create( lpt=args.lpt, is_split_kv=args.is_split_kv, head_swizzle=args.head_swizzle, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) @@ -568,8 +885,11 @@ def get_grid_shape( ip=None, ) -> Tuple[Int32, Int32, Int32]: total_blocks_max = ( - params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + params.total_q + + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) ) // params.tile_shape_mn[0] + # Round down to nearest multiple of cluster since odd excess is always padding. + total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @cute.jit @@ -590,13 +910,14 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): seqlen *= params.qhead_per_kvhead_packgqa return ( - cute.ceil_div(seqlen, params.tile_shape_mn[0]) + cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 else Int32(0) ) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) @@ -607,7 +928,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: group_end_tile = m_blocks_in_group * params.num_head # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) - next_tile_idx = self._tile_idx + next_tile_idx = self._tile_idx // params.cluster_shape_m while group_end_tile <= next_tile_idx: batch_idx += cute.arch.WARP_SIZE - 1 if batch_idx >= params.num_batch: @@ -649,6 +970,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: num_n_blocks = ( num_m_blocks * params.tile_shape_mn[0] + * params.cluster_shape_m // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] ) @@ -686,23 +1008,69 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: head_idx = mh_block // num_m_blocks block = mh_block - head_idx * num_m_blocks is_valid = self._is_first_block and batch_idx < params.num_batch + if cutlass.const_expr(params.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_m + bidx_in_cluster[0] # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -710,10 +1078,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [self.params, self._tile_idx, self._split_idx], - self._values_pos, - ): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index f2383e89415..31186618569 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -3,13 +3,14 @@ import math import hashlib import inspect -import re +import os from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack @@ -17,26 +18,58 @@ import quack.activation +_MIXER_ATTRS = ("__vec_size__",) -def hash_callable(func: Callable, set_cute_hash=True) -> str: - """Hash a callable based on the source code or bytecode and closure values. +# Obtained from sollya: +# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); +POLY_EX2 = { + 0: (1.0), + 1: ( + 1.0, + 0.922497093677520751953125, + ), + 2: ( + 1.0, + 0.6657850742340087890625, + 0.330107033252716064453125, + ), + 3: ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ), + 4: ( + 1.0, + 0.693042695522308349609375, + 0.2412912547588348388671875, + 5.2225358784198760986328125e-2, + 1.3434938155114650726318359375e-2, + ), + 5: ( + 1.0, + 0.693151414394378662109375, + 0.24016360938549041748046875, + 5.5802188813686370849609375e-2, + 9.01452265679836273193359375e-3, + 1.86810153536498546600341796875e-3, + ), +} - Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` - attribute, that value is returned immediately. Code-generation backends such - as Inductor can set this attribute to avoid expensive runtime hashing. +_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1" +_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" - set_cute_hash: whether or not to set func.__cute_hash__ if not present - """ - if hasattr(func, "__cute_hash__"): - return func.__cute_hash__ - # Unwrap decorated functions (e.g., cute.jit wrappers). - if hasattr(func, "__wrapped__"): - base_func = func.__wrapped__ - if hasattr(base_func, "__cute_hash__"): - return base_func.__cute_hash__ - func = base_func +def _get_use_clc_scheduler_default() -> bool: + return _fa_clc_enabled + + +def _get_disable_2cta_default() -> bool: + return _fa_disable_2cta_enabled + +def _compute_base_hash(func: Callable) -> str: + """Compute hash from source code or bytecode and closure values.""" try: data = inspect.getsource(func).encode() except (OSError, TypeError): @@ -48,16 +81,48 @@ def hash_callable(func: Callable, set_cute_hash=True) -> str: hasher = hashlib.sha256(data) if hasattr(func, "__closure__") and func.__closure__ is not None: - for idx, cell in enumerate(func.__closure__): - cell_value = cell.cell_contents - hasher.update(repr(cell_value).encode()) + for cell in func.__closure__: + hasher.update(repr(cell.cell_contents).encode()) + + return hasher.hexdigest() + + +def hash_callable( + func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True +) -> str: + """Hash a callable based on the source code or bytecode and closure values. + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately as the base hash, then + metadata dunders are mixed in to produce the final dict-key hash. + set_cute_hash: whether or not to set func.__cute_hash__ + """ + # Resolve base hash + if hasattr(func, "__cute_hash__"): + base_hash = func.__cute_hash__ + else: + # Unwrap decorated functions (e.g., cute.jit wrappers). + base_func = getattr(func, "__wrapped__", func) - hash = hasher.hexdigest() + if hasattr(base_func, "__cute_hash__"): + base_hash = base_func.__cute_hash__ + else: + base_hash = _compute_base_hash(base_func) + + if set_cute_hash: + base_func.__cute_hash__ = base_hash + + # Mix in mutable metadata dunders + mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) + + if all(v is None for v in mixer_values): + return base_hash + + hasher = hashlib.sha256(base_hash.encode()) - if set_cute_hash: - func.__cute_hash__ = hash + for attr, val in zip(_MIXER_ATTRS, mixer_values): + hasher.update(f"{attr}={val!r}".encode()) - return hash + return hasher.hexdigest() def create_softcap_scoremod(softcap_val): @@ -71,6 +136,40 @@ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tenso return scoremod_premask_fn +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale, score_mod): + """Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used. + + When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale + to None. When score_mod is present, keep softmax_scale separate so it can be applied before + the score_mod, and set softmax_scale_log2 to just the change-of-base constant. + + Returns (softmax_scale_log2, softmax_scale). + """ + if const_expr(score_mod is None): + return softmax_scale * LOG2_E, None + else: + return LOG2_E, softmax_scale + + +def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None): + """Compute FastDivmodDivisor pairs for aux_tensors index computation. + + Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None. + """ + if const_expr(aux_tensors is None): + return None + seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) + return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k)) + + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -163,44 +262,51 @@ def warp_reduce( return val -def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: - """Extract swizzle parameters from a pointer's swizzle_type. - - The swizzle_type string has the form '!cute.swizzle<"S">' where - b, m, s are the swizzle parameters (bits, base, shift). - - Returns: - A cute.Swizzle object constructed from the extracted parameters - - Raises: - ValueError: If the swizzle_type string cannot be parsed - """ - # Ideally there should be a better API to get swizzle parameters, but we'll just parse - # the string here. - swizzle_str = str(ptr.type.swizzle_type) - # Extract the inner part "S" - match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) - if match: - b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) - return cute.make_swizzle(b, m, s) - else: - raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") +@dsl_user_op +def smid(*, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %smid;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None ) -> Float32: - return Float32( - nvvm.fmax( - T.f32(), - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, - loc=loc, - ip=ip, + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + else: + # New API: infers result type automatically + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) ) - ) @cute.jit @@ -385,8 +491,48 @@ def shuffle_sync( return val[0] +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (§9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely — the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ return cutlass.Uint32( llvm.inline_asm( T.i32(), @@ -394,7 +540,7 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), ], - "shr.s32 $0, $1, $2;", + "shr.u32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False, @@ -542,14 +688,9 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip= @dsl_user_op -def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: +def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: + assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" # We assume x <= 127.0 - poly_ex2_deg3 = ( - 1.0, - 0.695146143436431884765625, - 0.227564394474029541015625, - 0.077119089663028717041015625, - ) fp32_round_int = float(2**23 + 2**22) x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) @@ -558,20 +699,16 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: # We assume the next 2 ops round to nearest even. The rounding mode is important. x_rounded_back = x_rounded - fp32_round_int x_frac = x_clamped - x_rounded_back - x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version @dsl_user_op -def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: +def ex2_emulation_2( + x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None +) -> Tuple[Float32, Float32]: # We assume x <= 127.0 and y <= 127.0 - poly_ex2_deg3 = ( - 1.0, - 0.695146143436431884765625, - 0.227564394474029541015625, - 0.077119089663028717041015625, - ) fp32_round_int = float(2**23 + 2**22) xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) @@ -582,7 +719,7 @@ def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float xy_rounded, (fp32_round_int, fp32_round_int) ) xy_frac = quack.activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) - xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) return x_out, y_out diff --git a/tests/cute/benchmark_mask_mod.py b/tests/cute/benchmark_mask_mod.py index 0da0ddcfbd0..88b967b3645 100644 --- a/tests/cute/benchmark_mask_mod.py +++ b/tests/cute/benchmark_mask_mod.py @@ -14,7 +14,7 @@ import numpy as np import torch -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 from mask_mod_definitions import ( get_mask_pair, random_doc_id_tensor, diff --git a/tests/cute/conftest.py b/tests/cute/conftest.py index 6ee05e9a3a4..d2162255775 100644 --- a/tests/cute/conftest.py +++ b/tests/cute/conftest.py @@ -1,5 +1,11 @@ import os import subprocess +import logging +import tempfile +import json +import time +from pathlib import Path +from getpass import getuser def _get_gpu_ids(): @@ -16,16 +22,50 @@ def _get_gpu_ids(): ) if result.returncode == 0: return result.stdout.strip().splitlines() - except (FileNotFoundError, subprocess.TimeoutExpired): + except (FileNotFoundError,): pass + logging.warning("Failed to get gpu ids, use default '0'") return ["0"] def pytest_configure(config): + tmp = Path(tempfile.gettempdir()) / getuser() / "flash_attention_tests" + tmp.mkdir(parents=True, exist_ok=True) + worker_id = os.environ.get("PYTEST_XDIST_WORKER") + logging.basicConfig( + format=config.getini("log_file_format"), + filename=str(tmp / f"tests_{worker_id}.log"), + level=config.getini("log_file_level"), + ) if not worker_id: return worker_num = int(worker_id.replace("gw", "")) - gpu_ids = _get_gpu_ids() + + # cache gpu_ids, because nvidia-smi is expensive when we launch many workers doing torch initialization + # Always elect worker_0 to get gpu_ids. + cached_gpu_ids = tmp / "gpu_ids.json" + if worker_num == 0: + gpu_ids = _get_gpu_ids() + with cached_gpu_ids.open(mode="w") as f: + json.dump(gpu_ids, f) + else: + while not cached_gpu_ids.exists(): + time.sleep(1) + with cached_gpu_ids.open() as f: + gpu_ids = json.load(f) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)] + +def pytest_collection_finish(session): + if not session.config.option.collectonly: + return + + # file_name -> test_name -> counter + test_counts: dict[str, dict[str, int]] = {} + for item in session.items: + funcname = item.function.__name__ + parent = test_counts.setdefault(item.parent.name, {}) + parent[funcname] = parent.setdefault(funcname, 0) + 1 + print(json.dumps(test_counts, indent=2)) diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py index be6333a6448..aaa3664abf0 100644 --- a/tests/cute/score_mod_definitions.py +++ b/tests/cute/score_mod_definitions.py @@ -15,12 +15,28 @@ def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa +@cute.jit +def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa + + @cute.jit def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask = operator.ge(q_idx, kv_idx) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) +@cute.jit +def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean) + kv_idx0 = kv_idx[0] + q_idx0 = q_idx[0] + for i in cutlass.range_constexpr(cute.size(mask.shape)): + mask[i] = q_idx0 >= kv_idx0 + i + mask_ssa = mask.load() + return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + @cute.jit def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -28,6 +44,18 @@ def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_t return tSrS_ssa + abs_diff.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return tSrS_ssa + abs_diff.load().to(cutlass.Float32) + + @cute.jit def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx @@ -36,10 +64,25 @@ def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + scaled.to(cutlass.Float32) +@cute.jit +def score_mod_rel_bias_x2_vectorized( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + q_idx0 = q_idx[0] + kv_idx0 = kv_idx[0] + diff0 = q_idx0 - kv_idx0 + abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) + for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): + diffi = diff0 - i + abs_diff_x2[i] = mlir_math.absi(diffi) * 2 + return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32) + + @cute.jit def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa * cute.full_like(tSrS_ssa, 2) +score_mod_times_two_vectorized = score_mod_times_two @cute.jit def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -53,6 +96,21 @@ def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tens abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) return score - slope * abs_diff +@cute.jit +def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + score = tSrS_ssa.to(cutlass.Float32) + slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) + slope = cute.math.exp2( + slope_exp.to(cutlass.Float32) + * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) + ) + diff0 = q_idx[0] - kv_idx[0] + abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) + for i in cutlass.range_constexpr(cute.size(abs_diff.shape)): + diffi = diff0 - i + abs_diff[i] = mlir_math.absi(diffi) + return score - slope * abs_diff.load().to(cutlass.Float32) + @cute.jit def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -88,6 +146,16 @@ def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux bias_val = (bias_frag.load()).to(cutlass.Float32) return tSrS_ssa + bias_val +@cute.jit +def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + batch_bias = aux_tensors[0] + dtype = batch_bias.element_type + b_idx0 = b_idx[0] + bias_frag = cute.make_rmem_tensor(1, dtype) + bias_frag[0] = batch_bias[b_idx0] + bias_val = (bias_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias_val + @cute.jit def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): @@ -109,6 +177,22 @@ def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, au return tSrS_ssa + head_val + pos_val +@cute.jit +def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + dtype = head_bias.element_type + + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_idx[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_idx[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + # ============================================================================= # Score_mod functions that use global indices diff --git a/tests/cute/test_clc_fuzz.py b/tests/cute/test_clc_fuzz.py new file mode 100644 index 00000000000..022276d3281 --- /dev/null +++ b/tests/cute/test_clc_fuzz.py @@ -0,0 +1,576 @@ +"""Adversarial regression tests for CLC tile scheduling. + +These cases intentionally target scheduler-sensitive shapes: mismatched +sequence lengths, non-aligned tiles, GQA ratios, minimal problems, and +larger persistent workloads. This is deterministic adversarial coverage, +not randomized fuzzing. +""" + +from contextlib import contextmanager +import os +from unittest import mock + +import pytest +import torch + +from flash_attn.cute import utils as cute_utils +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.testing import attention_ref +from flash_attn.cute.tile_scheduler import SchedulingMode, SingleTileLPTScheduler, SingleTileVarlenScheduler + + +if torch.cuda.is_available(): + COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + SM_COUNT = torch.cuda.get_device_properties("cuda").multi_processor_count +else: + COMPUTE_CAPABILITY = 0 + SM_COUNT = 0 +pytestmark = pytest.mark.skipif( + COMPUTE_CAPABILITY not in (10, 11), + reason="CLC adversarial tests require SM100/SM110 persistent forward", +) + +_captured_schedulers: list[tuple[type, SchedulingMode, bool]] = [] +_orig_init = FlashAttentionForwardSm100.__init__ + + +def _spy_init(self_inner, *a, **kw): + _orig_init(self_inner, *a, **kw) + _captured_schedulers.append(( + self_inner.TileScheduler, + self_inner.scheduling_mode, + self_inner.use_2cta_instrs, + )) + + +@contextmanager +def clc_scheduler_enabled(): + with ( + mock.patch.dict(os.environ, {"FA_CLC": "1"}, clear=False), + mock.patch.object(cute_utils, "_fa_clc_enabled", True), + mock.patch.object(FlashAttentionForwardSm100, "__init__", _spy_init), + ): + yield + + +def check_output(q, k, v, *, causal=False, window_size=(None, None), num_splits=1, assert_clc=True, assert_2cta=False): + _captured_schedulers.clear() + out, _ = flash_attn_func(q, k, v, causal=causal, window_size=window_size, num_splits=num_splits) + torch.cuda.synchronize() + if assert_clc and _captured_schedulers: + sched_cls, sched_mode, use_2cta = _captured_schedulers[-1] + assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + if assert_2cta: + assert use_2cta, "Expected use_2cta_instrs=True but got False" + out_ref, _ = attention_ref(q, k, v, causal=causal, window_size=window_size) + out_pt, _ = attention_ref(q, k, v, causal=causal, window_size=window_size, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out - out_ref).abs().max().item()}, " + f"pt_max_diff={(out_pt - out_ref).abs().max().item()}, " + f"fwd_atol={fwd_atol}, " + f"q={list(q.shape)} k={list(k.shape)} v={list(v.shape)} " + f"causal={causal} window_size={window_size} num_splits={num_splits}" + ) + + +def randn(b, s, h, d): + return torch.randn(b, s, h, d, device="cuda", dtype=torch.bfloat16) + + +def expected_total_tiles_mha(batch, seqlen_q, heads): + q_stage = 2 if COMPUTE_CAPABILITY == 10 and seqlen_q > 128 else 1 + num_block = (seqlen_q + q_stage * 128 - 1) // (q_stage * 128) + return num_block * heads * batch + + +@pytest.fixture(autouse=True) +def seed(): + torch.random.manual_seed(42) + + +@pytest.fixture(autouse=True) +def enable_clc_scheduler(): + with clc_scheduler_enabled(): + yield + + +class TestCLCMismatchedSeqlens: + + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (128, 1024), + (128, 2048), + (256, 64), + (256, 128), + (512, 127), + (512, 129), + (64, 4096), + (1, 128), + (1, 512), + (1, 1024), + ]) + def test_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (128, 513), + (256, 1023), + (64, 257), + (192, 383), + (1, 255), + ]) + def test_qk_mismatch_nonaligned_k(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (1, 128), + (1, 256), + (1, 1024), + (2, 128), + (3, 512), + ]) + def test_tiny_q_long_k(self, sq, sk): + check_output(randn(2, sq, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCNonAlignedShapes: + @pytest.mark.parametrize("sq", [1, 3, 7, 15, 31, 33, 63, 65, 127, 129, 191, 193, 255, 257]) + def test_nonaligned_q(self, sq): + check_output(randn(2, sq, 4, 128), randn(2, 256, 4, 128), randn(2, 256, 4, 128)) + + @pytest.mark.parametrize("sk", [1, 7, 31, 33, 63, 65, 127, 129, 255, 257, 511, 513]) + def test_nonaligned_k(self, sk): + check_output(randn(2, 256, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCPrimes: + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 127, 131), + (3, 5, 131, 127), + (7, 3, 257, 251), + (11, 7, 67, 509), + (13, 1, 191, 193), + (5, 11, 61, 67), + (2, 3, 509, 127), + ]) + def test_all_prime(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLC2CTA: + @pytest.mark.parametrize("sq,sk", [ + (512, 512), + (512, 127), + (512, 129), + (512, 2048), + (1024, 64), + (768, 1024), + (512, 64), + ]) + def test_2cta_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128), assert_2cta=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 512, 128), + (1, 1, 512, 512), + (3, 5, 768, 1024), + (7, 3, 512, 127), + (9, 7, 1024, 257), + (13, 1, 512, 64), + ]) + def test_2cta_adversarial_combos(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + assert_2cta=True, + ) + + +class TestCLCGQA: + @pytest.mark.parametrize("q_heads,kv_heads,sq,sk", [ + (4, 1, 128, 512), + (4, 1, 256, 127), + (8, 1, 64, 1024), + (8, 2, 512, 129), + (8, 4, 1, 256), + (6, 2, 192, 383), + (6, 3, 128, 1), + (12, 4, 257, 511), + ]) + def test_gqa_mismatch(self, q_heads, kv_heads, sq, sk): + check_output( + randn(4, sq, q_heads, 128), + randn(4, sk, kv_heads, 128), + randn(4, sk, kv_heads, 128), + ) + + @pytest.mark.parametrize("q_heads,kv_heads", [ + (4, 1), (4, 2), (8, 1), (8, 2), (8, 4), (6, 2), (6, 3), (12, 4), + ]) + def test_gqa_ratios(self, q_heads, kv_heads): + check_output( + randn(4, 512, q_heads, 128), + randn(4, 512, kv_heads, 128), + randn(4, 512, kv_heads, 128), + ) + + +class TestCLCHeadDim: + @pytest.mark.parametrize("d,dv,sq,sk", [ + (64, 64, 128, 512), + (64, 64, 1, 256), + (96, 96, 255, 127), + (128, 64, 192, 384), + (128, 64, 1, 1024), + ]) + def test_head_dims_adversarial(self, d, dv, sq, sk): + check_output(randn(4, sq, 4, d), randn(4, sk, 4, d), randn(4, sk, 4, dv)) + + def test_overlap_sO_sQ_fallback(self): + from flash_attn.cute.tile_scheduler import SingleTileScheduler + + _captured_schedulers.clear() + check_output(randn(4, 128, 4, 192), randn(4, 257, 4, 192), randn(4, 257, 4, 128), assert_clc=False) + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileScheduler, f"Expected SingleTileScheduler fallback, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC fallback, got {sched_mode!r}" + + +class TestCLCFallback: + + def test_varlen_uses_clc(self): + _captured_schedulers.clear() + batch, seqlen, heads, d = 4, 256, 4, 128 + lens = torch.tensor([64, 128, 32, 32], dtype=torch.int32) + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32), lens.cumsum(0)]).to(device="cuda", dtype=torch.int32) + total = int(cu_seqlens[-1]) + q = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + out, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=int(lens.max()), + max_seqlen_k=int(lens.max()), + ) + torch.cuda.synchronize() + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, ( + f"Expected SingleTileVarlenScheduler for varlen, got {sched_cls.__name__}" + ) + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + @pytest.mark.parametrize("sq,sk,wl,wr", [ + (512, 512, 128, 128), + (256, 1024, 64, 64), + (512, 512, 255, 0), + (128, 2048, 32, 512), + ]) + def test_local_window_with_clc(self, sq, sk, wl, wr): + check_output( + randn(4, sq, 4, 128), + randn(4, sk, 4, 128), + randn(4, sk, 4, 128), + window_size=(wl, wr), + ) + + +def check_varlen_output(seqlens, heads, d, *, causal=False, kv_heads=None, num_splits=1): + kv_heads = kv_heads or heads + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32), torch.tensor(seqlens, dtype=torch.int32).cumsum(0)]).to(device="cuda", dtype=torch.int32) + total = int(cu_seqlens[-1]) + max_s = max(seqlens) + q = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(total, kv_heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(total, kv_heads, d, device="cuda", dtype=torch.bfloat16) + + _captured_schedulers.clear() + out, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_s, + max_seqlen_k=max_s, + causal=causal, + num_splits=num_splits, + ) + torch.cuda.synchronize() + if _captured_schedulers: + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + for i in range(len(seqlens)): + s = slice(cu_seqlens[i], cu_seqlens[i + 1]) + qi, ki, vi, oi = q[s].unsqueeze(0), k[s].unsqueeze(0), v[s].unsqueeze(0), out[s].unsqueeze(0) + out_ref_i, _ = attention_ref(qi, ki, vi, causal=causal) + out_pt_i, _ = attention_ref(qi, ki, vi, causal=causal, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref_i + 0.3 - 0.3 - out_ref_i).abs().max().item() + assert (oi - out_ref_i).abs().max().item() <= 2 * ( + out_pt_i - out_ref_i + ).abs().max().item() + fwd_atol, ( + f"batch={i} max_diff={(oi - out_ref_i).abs().max().item()}, " + f"pt_max_diff={(out_pt_i - out_ref_i).abs().max().item()}, " + f"seqlens={seqlens} heads={heads} d={d} causal={causal} num_splits={num_splits}" + ) + + +def check_varlen_output_seqused(seqlens, heads, d, *, causal=False, kv_heads=None, num_splits=1): + kv_heads = kv_heads or heads + batch = len(seqlens) + max_s = max(seqlens) + seqused = torch.tensor(seqlens, device="cuda", dtype=torch.int32) + q = torch.randn(batch, max_s, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(batch, max_s, kv_heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(batch, max_s, kv_heads, d, device="cuda", dtype=torch.bfloat16) + q_mask = torch.arange(max_s, device="cuda")[None, :] < seqused[:, None] + k_mask = q_mask + + _captured_schedulers.clear() + out, _ = flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q=max_s, + max_seqlen_k=max_s, + seqused_q=seqused, + seqused_k=seqused, + causal=causal, + num_splits=num_splits, + ) + torch.cuda.synchronize() + if _captured_schedulers: + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + out_ref, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal) + out_pt, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal, upcast=False, reorder_ops=True) + q_mask_4d = q_mask.unsqueeze(-1).unsqueeze(-1) + out_masked = out.clone().masked_fill_(~q_mask_4d, 0.0) + out_ref_masked = out_ref.clone().masked_fill_(~q_mask_4d, 0.0) + out_pt_masked = out_pt.clone().masked_fill_(~q_mask_4d, 0.0) + fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item() + assert (out_masked - out_ref_masked).abs().max().item() <= 2 * ( + out_pt_masked - out_ref_masked + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out_masked - out_ref_masked).abs().max().item()}, " + f"pt_max_diff={(out_pt_masked - out_ref_masked).abs().max().item()}, " + f"seqlens={seqlens} heads={heads} d={d} causal={causal} num_splits={num_splits}" + ) + + +class TestCLCVarlen: + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [256, 64, 128, 256], + [1, 512, 1, 1], + [128, 128, 128, 128], + [512, 256, 128, 64], + [1, 1, 1, 1], + [255, 129, 63, 193], + ]) + def test_varlen_basic(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [256, 64, 128, 256], + [512, 256, 128, 64], + [255, 129, 63, 193], + ]) + def test_varlen_causal(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128, causal=True) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [1, 512, 1, 1], + [255, 129, 63, 193], + ]) + def test_varlen_gqa(self, seqlens): + check_varlen_output(seqlens, heads=8, d=128, kv_heads=2) + + @pytest.mark.parametrize("seqlens,heads", [ + pytest.param([512], 4, id="single_batch"), + pytest.param([256, 128], 8, id="two_batch"), + pytest.param([64] * 32, 4, id="many_batches"), + pytest.param([1, 1, 1, 1024, 1, 1, 1, 1], 4, id="imbalanced"), + ]) + def test_varlen_edge_cases(self, seqlens, heads): + check_varlen_output(seqlens, heads=heads, d=128) + + @pytest.mark.parametrize("seqlens", [ + [127, 131, 251, 193], + [1, 3, 7, 13, 31, 61], + [509, 127, 251, 67], + ]) + def test_varlen_primes(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128) + + @pytest.mark.parametrize("d", [64, 96, 128]) + def test_varlen_head_dims(self, d): + check_varlen_output([128, 256, 64, 192], heads=4, d=d) + + @pytest.mark.parametrize("trial", range(3)) + def test_varlen_repeatability(self, trial): + torch.random.manual_seed(trial) + check_varlen_output([64, 128, 32, 256, 1, 512], heads=4, d=128) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_splitkv(self, seqlens, num_splits): + check_varlen_output(seqlens, heads=4, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_seqused_splitkv(self, seqlens, num_splits): + check_varlen_output_seqused(seqlens, heads=4, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_splitkv_gqa(self, seqlens, num_splits): + check_varlen_output(seqlens, heads=8, kv_heads=2, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_seqused_splitkv_gqa(self, seqlens, num_splits): + check_varlen_output_seqused(seqlens, heads=8, kv_heads=2, d=64, num_splits=num_splits) + + +class TestCLCMinimal: + @pytest.mark.parametrize("sq,sk", [(1, 1), (1, 2), (2, 1), (1, 128), (128, 1)]) + def test_minimal(self, sq, sk): + check_output(randn(1, sq, 1, 128), randn(1, sk, 1, 128), randn(1, sk, 1, 128)) + + def test_single_element(self): + check_output(randn(1, 1, 1, 64), randn(1, 1, 1, 64), randn(1, 1, 1, 64)) + + +class TestCLCCausal: + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 5, 259, 259), + (7, 3, 513, 513), + (1, 7, 1023, 1023), + (5, 11, 2049, 2049), + (2, 3, 4097, 4097), + ]) + def test_causal_square(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 127, 513), + (5, 3, 259, 1023), + (7, 5, 63, 2049), + (11, 1, 1, 511), + (2, 9, 1, 1025), + (9, 3, 33, 4097), + ]) + def test_causal_qk_mismatch(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 191, 191), + (7, 5, 193, 193), + (5, 3, 383, 383), + (11, 1, 129, 509), + (2, 13, 1, 131), + (9, 3, 67, 251), + ]) + def test_causal_nonaligned(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,q_heads,kv_heads,sq", [ + (3, 6, 2, 513), + (7, 8, 1, 259), + (5, 12, 4, 1023), + (2, 8, 2, 2049), + (11, 4, 1, 191), + ]) + def test_causal_gqa(self, batch, q_heads, kv_heads, sq): + check_output( + randn(batch, sq, q_heads, 128), + randn(batch, sq, kv_heads, 128), + randn(batch, sq, kv_heads, 128), + causal=True, + ) + + def test_causal_large(self): + check_output(randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), causal=True) + + +class TestCLCLargeScale: + def test_large_batch(self): + check_output(randn(32, 512, 8, 128), randn(32, 512, 8, 128), randn(32, 512, 8, 128)) + + def test_long_seq(self): + check_output(randn(2, 4096, 4, 128), randn(2, 4096, 4, 128), randn(2, 4096, 4, 128)) + + def test_many_heads(self): + check_output(randn(4, 512, 32, 128), randn(4, 512, 32, 128), randn(4, 512, 32, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (24, 8, 768, 2048), + (16, 8, 1536, 4096), + (12, 8, 2305, 4096), + ]) + def test_work_stealing_pressure(self, batch, heads, sq, sk): + total_tiles = expected_total_tiles_mha(batch, sq, heads) + assert total_tiles > SM_COUNT, f"expected total_tiles={total_tiles} > sm_count={SM_COUNT}" + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + def test_long_k_short_q(self): + check_output(randn(8, 64, 8, 128), randn(8, 8192, 8, 128), randn(8, 8192, 8, 128)) + + def test_long_q_short_k(self): + check_output(randn(4, 4096, 4, 128), randn(4, 64, 4, 128), randn(4, 64, 4, 128)) + + +class TestCLCRepeatability: + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(7, 192, 5, 128), randn(7, 513, 5, 128), randn(7, 513, 5, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_2cta(self, trial): + torch.random.manual_seed(trial) + check_output(randn(9, 257, 3, 128), randn(9, 511, 3, 128), randn(9, 511, 3, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_gqa_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(5, 128, 8, 128), randn(5, 1024, 2, 128), randn(5, 1024, 2, 128)) + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index c1f227d7400..69e6308fb60 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -3,6 +3,8 @@ import math import itertools import os +import random +import re import pytest import torch @@ -20,18 +22,21 @@ generate_random_padding_mask, pad_input, unpad_input, + maybe_fake_tensor_mode, + is_fake_mode, ) from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, - flash_attn_combine, - _get_device_capability, ) - +# torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel +# When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). +USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" -# SplitKV and paged KV are not supported on SM90 -IS_SM90 = _get_device_capability() == 9 +# SplitKV is not supported on SM90 +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 +IS_SM100 = torch.cuda.get_device_capability()[0] == 10 TEST_BWD_ONLY = False VERBOSE = True @@ -43,8 +48,8 @@ # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @@ -59,7 +64,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", @@ -91,6 +96,7 @@ ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output( seqlen_q, seqlen_k, @@ -109,7 +115,9 @@ def test_flash_attn_output( pytest.skip() device = "cuda" # set seed - torch.random.manual_seed(0) + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) torch.cuda.empty_cache() torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 @@ -120,7 +128,7 @@ def test_flash_attn_output( dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: + if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] @@ -160,7 +168,7 @@ def test_flash_attn_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( - (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if local_enum == 2: window_size = (None, -window_size[1]) @@ -230,11 +238,12 @@ def test_flash_attn_output( # # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + if not is_fake_mode(): + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # SplitKV is not supported for hdim >= 192 @@ -244,6 +253,8 @@ def test_flash_attn_output( # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue + if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 + continue out, lse = flash_attn_func( q, k, @@ -259,6 +270,10 @@ def test_flash_attn_output( num_splits=num_splits, deterministic=deterministic, ) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: @@ -277,22 +292,22 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and dv == d + and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and learnable_sink is None # and False and not ((causal or local) and seqlen_k < seqlen_q) ): - # TODO: SM90 backward pass has invalid MMA tile config for d=64 + non-causal - # The m_block_size=80 (non-causal) with head_dim=64 creates an invalid tile. - # Fix requires adjusting m_block_size or MMA config in flash_bwd_sm90.py - if IS_SM90 and d == 64 and not causal: - pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") - # TODO: SM90 backward pass does not support local attention yet - if IS_SM90 and local: - pytest.xfail("SM90 backward: local attention not supported yet") + if d > 192 and IS_SM90: + pytest.xfail("hdim > 192 backward: SM90 not supported yet") + if d != dv and mha_type != "mha" and IS_SM90: + pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 @@ -371,8 +386,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @@ -388,7 +403,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -434,6 +449,7 @@ def test_flash_attn_output( (False, True), ], ) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, @@ -462,15 +478,17 @@ def test_flash_attn_varlen_output( seqlen_k = seqlen_q device = "cuda" # set seed - torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) - batch_size = 49 if seqlen_q <= 1024 else 7 + seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local) + random.seed(seed) + torch.random.manual_seed(seed) + batch_size = 49 if seqlen_q <= 512 else 7 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: + if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] @@ -510,7 +528,7 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( - (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if local_enum == 2: window_size = (None, window_size[1]) @@ -610,6 +628,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] + out_ref, attn_ref = attention_ref( q_ref, k_ref, @@ -646,15 +665,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if not is_fake_mode(): + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # pack_gqa_vals = [False] @@ -688,18 +708,32 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): deterministic=deterministic, ) out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * ( - out_pt - out_ref + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp ).abs().max().item() + fwd_atol if ( @@ -707,11 +741,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_qv and not dv > 256 and not attention_chunk != 0 - and dv == d + and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and not has_learnable_sink - and not IS_SM90 # and False ): + if d > 192 and IS_SM90: + pytest.xfail("hdim > 192 backward: SM90 not supported yet") + if d != dv and mha_type != "mha" and IS_SM90: + pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -746,6 +783,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ), g_unpad ) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad @@ -886,6 +927,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -907,8 +949,6 @@ def test_flash_attn_kvcache( ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() - if page_size is not None and IS_SM90: - pytest.xfail("paged KV not supported on SM90") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -917,7 +957,9 @@ def test_flash_attn_kvcache( pytest.skip() device = "cuda" # set seed - torch.random.manual_seed(0) + seed = 0 + random.seed(seed) + torch.random.manual_seed(seed) batch_size = 5 # batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 @@ -972,7 +1014,7 @@ def test_flash_attn_kvcache( cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( - (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) @@ -982,7 +1024,7 @@ def test_flash_attn_kvcache( seqlen_new = ( seqlen_q if seqlen_new_eq_seqlen_q - else torch.randint(1, seqlen_q + 1, (1,)).item() + else random.randrange(1, seqlen_q + 1) ) cu_seqlens_k_new = None key_new_padding_mask = None @@ -1058,43 +1100,58 @@ def test_flash_attn_kvcache( dtype, dtype_ref, ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( + if not is_fake_mode(): + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( - seqlen_k - - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) - + 1 - ) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat( - [ - torch.randint( - 0, - cache_seqlens[i].item(), - (1,), - dtype=torch.int32, - device=device, + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 ) - if cache_seqlens[i].item() > 0 - else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size) - ] + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, ) + else: + cache_seqlens = torch.ones( + batch_size, + dtype=torch.int32, + device=device, + ) + if has_leftpad: + if not is_fake_mode(): + cache_leftpad = torch.cat( + [ + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = torch.zeros(batch_size, dtype=torch.int32, device=device) else: cache_leftpad = None if has_batch_idx: - cache_batch_idx = torch.randperm( - batch_size_cache, dtype=torch.int32, device=device - )[:batch_size] + if not is_fake_mode(): + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = torch.arange( + batch_size, dtype=torch.int32, device=device + ) else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") @@ -1285,6 +1342,10 @@ def test_flash_attn_kvcache( ) if varlen_q: out = output_pad_fn(out) + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) @@ -1375,10 +1436,8 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype): - if IS_SM90 and d == 64 and not causal: - pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") - from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd device = "cuda" @@ -1402,6 +1461,8 @@ def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtyp q, k, v, out, dout, lse, causal=causal, dq=dq, dk=dk, dv=dv ) + if is_fake_mode(): + return assert dq_out is dq assert dk_out is dk assert dv_out is dv @@ -1410,6 +1471,141 @@ def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtyp assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype): + """Test that gradient flows through the returned LSE tensor.""" + device = "cuda" + torch.random.manual_seed(42) + batch_size = 2 + nheads = 4 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + + out, lse = flash_attn_func(q, k, v, causal=causal, return_lse=True) + + if is_fake_mode(): + return + + assert lse is not None + assert lse.requires_grad + + # Compute loss = sum(out * g) + sum(lse * dlse_weight) to test gradient flows through both + g = torch.randn_like(out) + dlse_weight = torch.randn_like(lse) + loss = (out * g).sum() + (lse * dlse_weight).sum() + dq, dk, dv = torch.autograd.grad(loss, (q, k, v)) + + # Compare against reference: manually compute what the gradients should be + # Reference: standard attention in float + q_ref = q.detach().float().requires_grad_() + k_ref = k.detach().float().requires_grad_() + v_ref = v.detach().float().requires_grad_() + # (batch, seqlen_q, nheads, d) -> (batch, nheads, seqlen_q, d) + qk = torch.einsum("bshd,bthd->bhst", q_ref, k_ref) / (d ** 0.5) + if causal: + mask = torch.triu(torch.ones(seqlen_q, seqlen_k, device=device, dtype=torch.bool), diagonal=seqlen_k - seqlen_q + 1) + qk = qk.masked_fill(mask, float("-inf")) + lse_ref = torch.logsumexp(qk, dim=-1) # (batch, nheads, seqlen_q) + p = torch.softmax(qk, dim=-1) + # v_ref: (batch, seqlen_k, nheads, d) + out_ref = torch.einsum("bhst,bthd->bshd", p, v_ref) + loss_ref = (out_ref * g.float()).sum() + (lse_ref * dlse_weight.float()).sum() + dq_ref, dk_ref, dv_ref = torch.autograd.grad(loss_ref, (q_ref, k_ref, v_ref)) + + # Use relaxed tolerances since flash_attn operates in bf16 while reference is float32. + # The reference is also not a perfect bf16 simulation (it doesn't reorder ops), so + # we use a generous tolerance. + print(f"dQ max diff: {(dq.float() - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk.float() - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv.float() - dv_ref).abs().max().item()}") + # Absolute tolerance: bf16 has ~0.004-0.02 error for these sizes + atol = 0.02 + assert (dq.float() - dq_ref).abs().max().item() <= atol, f"dQ error too large" + assert (dk.float() - dk_ref).abs().max().item() <= atol, f"dK error too large" + assert (dv.float() - dv_ref).abs().max().item() <= atol, f"dV error too large" + + # Also test: gradient with only dLSE (no dO) + out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + loss_lse_only = (lse2 * dlse_weight).sum() + dq2, dk2, dv2 = torch.autograd.grad(loss_lse_only, (q, k, v)) + + q_ref2 = q.detach().float().requires_grad_() + k_ref2 = k.detach().float().requires_grad_() + qk2 = torch.einsum("bshd,bthd->bhst", q_ref2, k_ref2) / (d ** 0.5) + if causal: + qk2 = qk2.masked_fill(mask, float("-inf")) + lse_ref2 = torch.logsumexp(qk2, dim=-1) + loss_ref2 = (lse_ref2 * dlse_weight.float()).sum() + dq_ref2, dk_ref2 = torch.autograd.grad(loss_ref2, (q_ref2, k_ref2)) + + print(f"LSE-only dQ max diff: {(dq2.float() - dq_ref2).abs().max().item()}") + print(f"LSE-only dK max diff: {(dk2.float() - dk_ref2).abs().max().item()}") + # dV should be zero when only LSE gradient flows (LSE doesn't depend on V) + print(f"LSE-only dV max: {dv2.abs().max().item()}") + assert dv2.abs().max().item() == 0.0, "dV should be zero when loss depends only on LSE" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype): + """Test return_lse=True when LSE is returned but not used in the loss. + + With set_materialize_grads(False), dlse should be None (not a zero tensor), + so no extra zeroing kernel is launched. Gradients should match the standard + backward (without return_lse). + """ + device = "cuda" + torch.random.manual_seed(42) + batch_size = 2 + nheads = 4 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + g = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + + # Case 1: return_lse=False (standard path, lse marked non-differentiable) + out1, lse1 = flash_attn_func(q, k, v, causal=causal, return_lse=False) + if is_fake_mode(): + return + dq1, dk1, dv1 = torch.autograd.grad(out1, (q, k, v), g) + + # Case 2: return_lse=True but lse NOT used in loss (dlse should be None) + out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + dq2, dk2, dv2 = torch.autograd.grad(out2, (q, k, v), g) + + # Case 3: return_lse=True and lse IS used in loss + out3, lse3 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + dlse_weight = torch.randn_like(lse3) + loss3 = (out3 * g).sum() + (lse3 * dlse_weight).sum() + dq3, dk3, dv3 = torch.autograd.grad(loss3, (q, k, v)) + + # Cases 1 and 2 should produce identical gradients + assert torch.equal(dq1, dq2), "dQ should be identical when LSE is unused" + assert torch.equal(dk1, dk2), "dK should be identical when LSE is unused" + assert torch.equal(dv1, dv2), "dV should be identical when LSE is unused" + + # Case 3 should differ from case 1 (LSE gradient adds extra contribution to dQ, dK) + assert not torch.equal(dq1, dq3), "dQ should differ when LSE gradient is included" + assert not torch.equal(dk1, dk3), "dK should differ when LSE gradient is included" + # dV should be the same since LSE doesn't depend on V + assert torch.equal(dv1, dv3), "dV should be identical since LSE doesn't depend on V" + + print("Case 1 vs 2 (unused LSE): dQ diff =", (dq1 - dq2).abs().max().item()) + print("Case 1 vs 3 (used LSE): dQ diff =", (dq1 - dq3).abs().max().item()) + print("Case 1 vs 3 (used LSE): dK diff =", (dk1 - dk3).abs().max().item()) + print("Case 1 vs 3 (used LSE): dV diff =", (dv1 - dv3).abs().max().item()) + + def _generate_block_kvcache( seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): @@ -1442,82 +1638,67 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks -def attention_combine_ref(out_partial, lse_partial): +@pytest.mark.parametrize("page_size", [16, 64, 256]) +@pytest.mark.parametrize("seqlen_q", [64, 128, 256]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_deepseek(seqlen_q, page_size): + """Regression test: paged non-TMA with DeepSeek MLA shape (d=192, dv=128). + seqlen_q<=128 triggers q_stage=1, seqlen_q>128 triggers q_stage=2. """ - out_partial: (num_splits, batch_size, seqlen, nheads, d) - lse_partial: (num_splits, batch_size, seqlen, nheads) - """ - lse = torch.logsumexp(lse_partial, dim=0) - scale = torch.exp(lse_partial - lse) - scale = torch.where( - torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + if IS_SM90: + pytest.skip("paged KV not supported on SM90") + device = "cuda" + dtype = torch.bfloat16 + d, dv = 192, 128 + nheads = 16 + nheads_kv = 16 + + torch.random.manual_seed(0) + q = torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(seqlen_q, nheads_kv, dv, device=device, dtype=dtype) + cu_seqlens = torch.tensor([0, seqlen_q], dtype=torch.int32, device=device) + + # Non-paged reference + out_ref, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, causal=True, ) - out = (scale.unsqueeze(-1) * out_partial).sum(0) - return out, lse + # Paged + num_pages = (seqlen_q + page_size - 1) // page_size + k_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, dv, device=device, dtype=dtype) + for i in range(seqlen_q): + k_cache_paged[i // page_size, i % page_size] = k[i] + v_cache_paged[i // page_size, i % page_size] = v[i] + page_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0) + cache_seqlens = torch.tensor([seqlen_q], dtype=torch.int32, device=device) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float32]) -# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) -# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) -# @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) -# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) -# @pytest.mark.parametrize("num_splits", [11]) -def test_flash_attn_combine(num_splits, seqlen, d, dtype): - device = "cuda" - # set seed - torch.random.manual_seed(1) - batch_size = 5 - nheads = 16 - # batch_size = 1 - # nheads = 1 - # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) - out_partial = torch.randn( - num_splits * 2, - batch_size, - nheads, - seqlen, - d, - device=device, - dtype=torch.float32, - ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor - lse_partial = torch.randn( - num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 - ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor - # To test short-circuiting based on num_splits - lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") - - # Test with LSE returned (default behavior) - out, lse = flash_attn_combine( - out_partial, lse_partial, out_dtype=dtype, return_lse=True + out, _ = flash_attn_varlen_func( + q, k_cache_paged, v_cache_paged, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=None, + seqused_k=cache_seqlens, page_table=page_table, causal=True, ) - out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) - out_pt = out_ref.to(dtype) - print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + if is_fake_mode(): + return + print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) - multiple = 2 - assert ( - (out - out_ref).abs().max().item() - <= multiple * (out_pt - out_ref).abs().max().item() - ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) - - # Test with LSE not returned - out_no_lse, lse_no_lse = flash_attn_combine( - out_partial, lse_partial, out_dtype=dtype, return_lse=False - ) - assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( - "Output should be the same regardless of return_lse" - ) + assert torch.equal(out, out_ref) + + +@pytest.mark.parametrize("head_dim", [4, 148, 288]) +def test_flash_attn_invalid_head_dim(head_dim): + device = "cuda" + dtype = torch.bfloat16 + batch_size, seqlen, nheads = 1, 64, 4 + + q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + + with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM")): + flash_attn_func(q, k, v) diff --git a/tests/cute/test_flash_attn_combine.py b/tests/cute/test_flash_attn_combine.py new file mode 100644 index 00000000000..6344f96ab4b --- /dev/null +++ b/tests/cute/test_flash_attn_combine.py @@ -0,0 +1,286 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import os + +import pytest +import torch + +from flash_attn.cute.testing import ( + maybe_fake_tensor_mode, + is_fake_mode, +) +from flash_attn.cute.interface import ( + flash_attn_combine, +) + +USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +def check_combine_results(out, lse, out_ref, lse_ref, dtype): + """Check combine kernel output against reference for a single (seqlen, nheads, d) chunk.""" + out_pt = out_ref.to(dtype) + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}, " + f"Output max diff: {(out - out_ref).abs().max().item()}, " + f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + assert ( + (out - out_ref).abs().max().item() + <= 2 * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) + if is_fake_mode(): + return + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + check_combine_results(out, lse, out_ref, lse_ref, dtype) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 96, 128, 256]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 32, 113, 256, 1024]) +# @pytest.mark.parametrize("seqlen", [113]) +@pytest.mark.parametrize("num_splits", [2, 5, 17, 55]) +# @pytest.mark.parametrize("num_splits", [5]) +@pytest.mark.parametrize( + "varlen_mode", + ["cu_seqlens", "seqused", "cu_seqlens_seqused"], +) +# @pytest.mark.parametrize("varlen_mode", ["cu_seqlens"]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype): + device = "cuda" + torch.random.manual_seed(1) + batch_size = 3 + nheads = 8 + use_cu_seqlens = "cu_seqlens" in varlen_mode + use_seqused = "seqused" in varlen_mode + + # Generate variable-length sequences + seqlens = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) + # For cu_seqlens+seqused mode, seqused < seqlen (kernel processes fewer tokens) + seqused_vals = ( + torch.clamp( + seqlens - torch.randint(0, max(1, seqlen // 4), (batch_size,), device=device, dtype=torch.int32), + min=1, + ) + if use_cu_seqlens and use_seqused + else seqlens + ) + + if use_cu_seqlens: + # Packed varlen layout: (num_splits, total_q, nheads, d) + total_q = seqlens.sum().item() + cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seqlens_q[1:] = torch.cumsum(seqlens, dim=0) + + out_partial = torch.randn( + num_splits * 2, total_q, nheads, d, device=device, dtype=torch.float32, + )[:num_splits] # Non-contiguous in splits dim + # lse_partial needs stride(-2)==1 (seqlen dim contiguous) + lse_partial = torch.randn( + num_splits, nheads, total_q, device=device, dtype=torch.float32 + ).transpose(-1, -2) + lse_partial[num_splits // 2:, :total_q // 3] = -float("inf") + + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + cu_seqlens=cu_seqlens_q, + seqused=seqused_vals if use_seqused else None, + return_lse=True, + ) + if is_fake_mode(): + return + + # Reference on full packed tensor + out_ref, lse_ref = attention_combine_ref( + out_partial.unsqueeze(1), lse_partial.unsqueeze(1) + ) + out_ref = out_ref.squeeze(0) + lse_ref = lse_ref.squeeze(0) + + # Validate per-batch (only seqused_vals tokens are guaranteed correct) + for i in range(batch_size): + start = cu_seqlens_q[i].item() + sl = seqused_vals[i].item() + check_combine_results( + out[start:start + sl], lse[start:start + sl], + out_ref[start:start + sl], lse_ref[start:start + sl], dtype, + ) + + # Also test return_lse=False + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + cu_seqlens=cu_seqlens_q, + seqused=seqused_vals if use_seqused else None, + return_lse=False, + ) + assert lse_no_lse is None + # Only compare valid positions (beyond seqused, output is undefined) + for i in range(batch_size): + start = cu_seqlens_q[i].item() + sl = seqused_vals[i].item() + assert torch.allclose(out_no_lse[start:start + sl], out[start:start + sl], atol=1e-5, rtol=1e-5) + + else: + # seqused only — batched layout: (num_splits, batch, max_seqlen, nheads, d) + max_seqlen = seqlens.max().item() + out_partial = torch.randn( + num_splits, batch_size, max_seqlen, nheads, d, device=device, dtype=torch.float32, + ) + # lse_partial needs stride(-2)==1 (seqlen dim contiguous) + lse_partial = torch.randn( + num_splits, batch_size, nheads, max_seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) + lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") + # Zero out / -inf beyond seqused so reference matches kernel + for i in range(batch_size): + out_partial[:, i, seqlens[i]:] = 0 + lse_partial[:, i, seqlens[i]:] = -float("inf") + + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, seqused=seqlens, return_lse=True, + ) + if is_fake_mode(): + return + + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + + # Validate per-batch (only seqused tokens) + for i in range(batch_size): + sl = seqlens[i].item() + check_combine_results( + out[i, :sl], lse[i, :sl], + out_ref[i, :sl], lse_ref[i, :sl], dtype, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [32, 113, 256]) +# @pytest.mark.parametrize("seqlen", [113]) +@pytest.mark.parametrize("num_splits", [2, 5, 17]) +# @pytest.mark.parametrize("num_splits", [5]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): + """Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices. + + varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel + reads AND writes using the remapped batch_idx, so with a permutation the output + should match running without varlen_batch_idx (each real batch is processed once). + + We also test with seqused to verify interaction with variable-length sequences. + """ + device = "cuda" + torch.random.manual_seed(42) + batch_size = 4 + nheads = 8 + + # Create batched input data + out_partial = torch.randn( + num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, + ) + lse_partial = torch.randn( + num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) # stride(-2)==1 + lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") + + # Create a permuted batch index mapping: virtual batch -> real batch + perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) + assert perm.shape[0] == batch_size + + # Also test with seqused to verify interaction with varlen_batch_idx + seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) + # Zero out / -inf beyond seqused so reference matches kernel + for i in range(batch_size): + out_partial[:, i, seqused[i]:] = 0 + lse_partial[:, i, seqused[i]:] = -float("inf") + + # Run with varlen_batch_idx and seqused via public API + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + seqused=seqused, + varlen_batch_idx=perm, + return_lse=True, + ) + if is_fake_mode(): + return + + # Reference: standard combine (no remapping needed since perm is a bijection + # and both reads and writes use the remapped batch_idx) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + + # The kernel reads from input[perm[v]] and writes to output[perm[v]], + # so the net result is output[b] = combine(input[b]) for all b. + for b in range(batch_size): + sl = seqused[b].item() + check_combine_results( + out[b, :sl], lse[b, :sl], + out_ref[b, :sl], lse_ref[b, :sl], dtype, + ) diff --git a/tests/cute/test_flash_attn_fast.py b/tests/cute/test_flash_attn_fast.py new file mode 100644 index 00000000000..433859d94d8 --- /dev/null +++ b/tests/cute/test_flash_attn_fast.py @@ -0,0 +1,331 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Fast subset of test_flash_attn.py for quick iteration. +# Covers: causal/noncausal, varlen/not varlen, MHA/GQA, split/not split, fwd+bwd. + +import os +import random + +import pytest +import torch + +from einops import rearrange + +from flash_attn.cute.testing import ( + attention_ref, + generate_random_padding_mask, + generate_qkv, + maybe_fake_tensor_mode, + is_fake_mode, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, +) + +USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + + +# --------------------------------------------------------------------------- +# Forward + backward (non-varlen) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("num_splits", [1, 3]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (128, 128), + (256, 256), + (113, 203), + (1024, 1024), + ], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, dtype): + if IS_SM90 and num_splits > 1: + pytest.skip("SM90 fwd doens't support num_splits > 1") + device = "cuda" + torch.random.manual_seed(0) + random.seed(0) + torch.cuda.empty_cache() + batch_size = 4 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + + q = q_ref.detach().to(dtype).requires_grad_() + k = k_ref.detach().to(dtype).requires_grad_() + v = v_ref.detach().to(dtype).requires_grad_() + + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, + ) + + out, lse = flash_attn_func(q, k, v, causal=causal, num_splits=num_splits) + + if is_fake_mode(): + return + + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol + + # Backward (only for non-split, matching d) + can_bwd = ( + num_splits == 1 + and d <= 128 + and not (causal and seqlen_k < seqlen_q) + ) + if IS_SM90 and d == 64 and not causal: + can_bwd = False # SM90 d=64 non-causal xfail + if not can_bwd: + return + + g = torch.randn_like(out) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + dq_atol + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + dk_atol + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# --------------------------------------------------------------------------- +# Forward + backward (varlen with cu_seqlens) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [128, 256, 1024]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): + """Varlen test with cu_seqlens (packed): equal seqlens so we can compare with non-varlen ref.""" + device = "cuda" + seed = seqlen + d + int(causal) * 2 + torch.random.manual_seed(seed) + random.seed(seed) + batch_size = 9 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q_ref = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() + k_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + v_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, + ) + + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32) + q_varlen = rearrange(q_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + k_varlen = rearrange(k_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + v_varlen = rearrange(v_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + + out_varlen, lse = flash_attn_varlen_func( + q_varlen, k_varlen, v_varlen, + cu_seqlens, cu_seqlens, + seqlen, seqlen, + causal=causal, + ) + + if is_fake_mode(): + return + + out_reshaped = rearrange(out_varlen, "(b s) h d -> b s h d", b=batch_size) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out_reshaped - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol + + # Backward + can_bwd = d <= 128 + if not can_bwd: + return + + g = torch.randn_like(out_varlen) + dq_varlen, dk_varlen, dv_varlen = torch.autograd.grad(out_varlen, (q_varlen, k_varlen, v_varlen), g) + + assert dq_varlen.isfinite().all(), "dq contains non-finite values" + assert dk_varlen.isfinite().all(), "dk contains non-finite values" + assert dv_varlen.isfinite().all(), "dv contains non-finite values" + assert dq_varlen.abs().max().item() > 0, "dq is all zeros" + assert dk_varlen.abs().max().item() > 0, "dk is all zeros" + assert dv_varlen.abs().max().item() > 0, "dv is all zeros" + + +# --------------------------------------------------------------------------- +# Forward + backward (varlen with padding masks — all unpad combinations) +# Covers 4 compile-key-distinct paths: +# (unpad_q, unpad_kv) = (T,T): cu_seqlens for both Q and K +# (unpad_q, unpad_kv) = (F,F): seqused for both Q and K +# (unpad_q, unpad_kv) = (T,F): cu_seqlens_q + seqused_k +# (unpad_q, unpad_kv) = (F,T): seqused_q + cu_seqlens_k +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [128, 256]) +@pytest.mark.parametrize( + "unpad_q,unpad_kv", + [(True, True), (False, False), (True, False), (False, True)], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, unpad_kv, dtype): + """Varlen test with all 4 (unpad_q, unpad_kv) combos: cu_seqlens vs seqused.""" + device = "cuda" + seed = seqlen + d + int(causal) * 2 + int(unpad_q) * 7 + int(unpad_kv) * 13 + torch.random.manual_seed(seed) + random.seed(seed) + batch_size = 9 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) + q_ref = q.detach().to(dtype).requires_grad_() + k_ref = k.detach().to(dtype).requires_grad_() + v_ref = v.detach().to(dtype).requires_grad_() + + query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") + key_padding_mask = query_padding_mask if causal else generate_random_padding_mask( + seqlen, batch_size, device, mode="random" + ) + + ( + q_unpad_t, k_unpad_t, v_unpad_t, _qv_unpad, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + q_padded, k_padded, v_padded, _qv_padded, + output_pad_fn, dq_pad_fn, dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + + out_ref, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, + ) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, + upcast=False, reorder_ops=True, + ) + + # Select Q input: packed (unpad) or padded (seqused) + if unpad_q: + q_in = q_unpad_t.detach().to(dtype).requires_grad_() + else: + q_in = q.detach().to(dtype).requires_grad_() + # Select KV input: packed (unpad) or padded (seqused) + if unpad_kv: + k_in = k_unpad_t.detach().to(dtype).requires_grad_() + v_in = v_unpad_t.detach().to(dtype).requires_grad_() + else: + k_in = k.detach().to(dtype).requires_grad_() + v_in = v.detach().to(dtype).requires_grad_() + + out_unpad, lse = flash_attn_varlen_func( + q_in, k_in, v_in, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + ) + + if is_fake_mode(): + return + + # Reshape output to (batch, seqlen, nheads, d) for comparison + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + + # Mask out padding positions — kernel output at padding positions is undefined + q_mask = rearrange(query_padding_mask, "b s -> b s 1 1") + out_masked = out.clone().masked_fill_(~q_mask, 0.0) + out_ref_masked = out_ref.clone().masked_fill_(~q_mask, 0.0) + out_pt_masked = out_pt.clone().masked_fill_(~q_mask, 0.0) + + fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item() + assert (out_masked - out_ref_masked).abs().max().item() <= 2 * (out_pt_masked - out_ref_masked).abs().max().item() + fwd_atol + + # Backward (original test skips all SM90 varlen backward) + can_bwd = d <= 128 and not IS_SM90 + if not can_bwd: + return + + g = torch.randn_like(out_unpad) + dq_in, dk_in, dv_in = torch.autograd.grad(out_unpad, (q_in, k_in, v_in), g) + + # Mask out padding positions again + k_mask = rearrange(key_padding_mask, "b s -> b s 1 1") + if not unpad_q: + dq_in = dq_in.clone().masked_fill_(~q_mask, 0.0) + if not unpad_kv: + dk_in = dk_in.clone().masked_fill_(~k_mask, 0.0) + dv_in = dv_in.clone().masked_fill_(~k_mask, 0.0) + + assert dq_in.isfinite().all(), "dq contains non-finite values" + assert dk_in.isfinite().all(), "dk contains non-finite values" + assert dv_in.isfinite().all(), "dv contains non-finite values" + assert dq_in.abs().max().item() > 0, "dq is all zeros" + assert dk_in.abs().max().item() > 0, "dk is all zeros" + assert dv_in.abs().max().item() > 0, "dv is all zeros" + + +# --------------------------------------------------------------------------- +# Combine kernel +# --------------------------------------------------------------------------- + +def attention_combine_ref(out_partial, lse_partial): + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [32, 256]) +@pytest.mark.parametrize("num_splits", [2, 5, 17]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + torch.random.manual_seed(1) + batch_size = 3 + nheads = 8 + + # out_partial: (num_splits, batch, seqlen, nheads, d) with stride(-1)==1 + # lse_partial: (num_splits, batch, seqlen, nheads) with stride(-2)==1 (seqlen contiguous) + out_partial = torch.randn( + num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, + ) + lse_partial = torch.randn( + num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + if is_fake_mode(): + return + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index cadb4a91501..a9b8799f4c1 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -26,7 +26,6 @@ flash_attn_varlen_func, flash_attn_combine, _flash_attn_bwd, - _get_device_capability, ) @@ -36,8 +35,8 @@ # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["gqa"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) # @pytest.mark.parametrize("has_learnable_sink", [False, True]) @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @@ -46,12 +45,13 @@ @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) -# @pytest.mark.parametrize("local_enum", [0]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("d", [64, 128]) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +@pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -76,6 +76,9 @@ def test_flash_attn_output( local = local_enum > 0 if local and causal: pytest.skip() + is_sm90 = torch.cuda.get_device_capability()[0] == 9 + if is_sm90 and d == 192: + pytest.xfail("headdim 192 not supported on sm90") device = "cuda" # set seed torch.random.manual_seed(0) @@ -88,10 +91,9 @@ def test_flash_attn_output( nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + dv_vals = [128] if d == 192 else [d] if dtype == torch.float8_e4m3fn: dv_vals = [d] - dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): @@ -245,7 +247,7 @@ def test_flash_attn_output( and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 - and dv == d + and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and learnable_sink is None # and False ): @@ -356,8 +358,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) -# @pytest.mark.parametrize("local_enum", [0, 1]) +# @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +@pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @@ -368,8 +370,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) -# @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 128, 192]) +# @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -408,11 +410,11 @@ def test_flash_attn_varlen_output( local = local_enum > 0 if local and causal: pytest.skip() - is_sm90 = _get_device_capability() == 9 + is_sm90 = torch.cuda.get_device_capability()[0] == 9 if is_sm90 and local: pytest.xfail("bwd local attention not supported on sm90") - if is_sm90 and deterministic: - pytest.xfail("bwd deterministic not supported on sm90") + if is_sm90 and d == 192: + pytest.xfail("headdim 192 not supported on sm90") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -426,8 +428,8 @@ def test_flash_attn_varlen_output( nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) - dv_vals = [d] # override + # dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + dv_vals = [128] if d == 192 else [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): @@ -649,7 +651,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not has_qv and not dv > 256 and not attention_chunk != 0 - and dv == d + and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and not has_learnable_sink and not is_sm90 # and False @@ -780,4 +782,4 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert torch.equal(dv_unpad, dv_unpad2) if i % 100 == 0: - print(f"✅ Iteration {i} passed!") \ No newline at end of file + print(f"✅ Iteration {i} passed!") diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py index 1666a08fb00..24e55315671 100644 --- a/tests/cute/test_flash_attn_varlen.py +++ b/tests/cute/test_flash_attn_varlen.py @@ -1,15 +1,10 @@ -import itertools from typing import Optional -from einops import rearrange import pytest import torch import torch.nn.functional as F from flash_attn.cute import flash_attn_varlen_func -IS_SM90 = torch.cuda.get_device_capability()[0] == 9 - - @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) @pytest.mark.parametrize("D", [64, 128]) @@ -43,9 +38,6 @@ def test_varlen( dtype=dtype ) - # SM90 backward pass doesn't support varlen yet - skip_backward = IS_SM90 - ok = check_varlen_vs_torch_flash( q, k, v, cu_seqlens_q, cu_seqlens_k, @@ -53,7 +45,6 @@ def test_varlen( softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, - skip_backward=skip_backward, ) assert ok @@ -71,7 +62,6 @@ def check_varlen_vs_torch_flash( softcap=0.0, atol=3e-2, rtol=3e-2, - skip_backward=False, ): assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" @@ -128,10 +118,6 @@ def clone_like(t): if not ok_fwd: return False - # Skip backward if not supported (e.g., SM90 varlen) - if skip_backward: - return True - # Use the same upstream gradient to compare backward paths grad_out = torch.randn_like(out_fa) @@ -312,4 +298,4 @@ def _stats(name, a, b, atol, rtol): mean_abs = diff.abs().mean().item() mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") - return mean_abs < atol and mean_rel < rtol \ No newline at end of file + return mean_abs < atol and mean_rel < rtol diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 438ac8aeecd..26e0a5e1353 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -28,6 +28,7 @@ fast_sampling, normalize_block_sparse_config, ) +from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils from mask_mod_definitions import get_mask_pair, random_doc_id_tensor COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -107,6 +108,76 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, i return out_ref.transpose(1, 2).contiguous() +def assert_fwd_matches_reference(out_cute, out_ref_fp32, out_pt, test_desc: str | None = None): + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + if test_desc is not None: + print(f"\n{test_desc}") + print(" Reference implementation: FlexAttention") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +def assert_bwd_matches_reference( + dq_cute, + dk_cute, + dv_cute, + dq_ref_fp32, + dk_ref_fp32, + dv_ref_fp32, + dq_pt, + dk_pt, + dv_pt, + dtype, + min_seqlen: int, +): + assert not torch.isnan(dq_cute).any(), "dQ contains NaN" + assert not torch.isnan(dk_cute).any(), "dK contains NaN" + assert not torch.isnan(dv_cute).any(), "dV contains NaN" + + bwd_rtol = 2 + bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 + dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(" Backward comparison:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + def get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_block: int): @fast_sampling @cute.jit @@ -348,10 +419,9 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): window_size_left=window_left, window_size_right=window_right, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=pack_gqa, - _compute_capability=None, + _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, @@ -370,18 +440,8 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) out_pt = out_ref.clone() - - # Check for invalid values - assert out_cute.shape == out_ref_fp32.shape == out_ref.shape - assert not torch.isnan(out_cute).any() - assert not torch.isnan(out_ref_fp32).any() - assert torch.isfinite(out_cute).all() - assert torch.isfinite(out_ref_fp32).all() - - # Compute numerical tolerance (matching flash attention tests) fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 - ref_error = (out_ref - out_ref_fp32).abs().max().item() pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() @@ -412,10 +472,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") - # Use the same assertion logic as FlashAttention tests - assert cute_error <= rtol * pt_error + fwd_atol, ( - f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" - ) + assert_fwd_matches_reference(out_cute, out_ref_fp32, out_pt, mask_desc) if needs_backward: q = tensors["q"] @@ -443,38 +500,19 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): q, k, v, flex_block_mask, grad_out ) - # Check for invalid values - assert not torch.isnan(dq_cute).any(), "dQ contains NaN" - assert not torch.isnan(dk_cute).any(), "dK contains NaN" - assert not torch.isnan(dv_cute).any(), "dV contains NaN" - - bwd_rtol = 2 - min_seqlen = min(seqlen_q, seqlen_k) - bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 - dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) - dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) - dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) - - dq_ref = dq_ref_fp32.to(dtype) - dk_ref = dk_ref_fp32.to(dtype) - dv_ref = dv_ref_fp32.to(dtype) - - pt_dq_err = (dq_pt - dq_ref).abs().max().item() - pt_dk_err = (dk_pt - dk_ref).abs().max().item() - pt_dv_err = (dv_pt - dv_ref).abs().max().item() - - cute_dq_err = (dq_cute - dq_ref).abs().max().item() - cute_dk_err = (dk_cute - dk_ref).abs().max().item() - cute_dv_err = (dv_cute - dv_ref).abs().max().item() - - print(" Backward comparison:") - print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") - print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") - print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") - - assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" - assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" - assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + assert_bwd_matches_reference( + dq_cute, + dk_cute, + dv_cute, + dq_ref_fp32, + dk_ref_fp32, + dv_ref_fp32, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) def test_mask_mod_ima_partial_block(): @@ -621,8 +659,8 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): seqused_q=None, seqused_k=None, page_table=None, causal=False, softcap=None, window_size_left=-1, window_size_right=-1, - m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, - _compute_capability=None, score_mod=None, + tile_mn=(tile_m, tile_n), pack_gqa=False, + _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, aux_tensors=aux_tensors_arg, @@ -789,8 +827,7 @@ def test_sm100_block_sparse_sink_all_masked(): window_size_left=None, window_size_right=None, learnable_sink=learnable_sink, - m_block_size=128, - n_block_size=128, + tile_mn=(128, 128), num_threads=384, pack_gqa=False, block_sparse_tensors=sparse, @@ -818,8 +855,8 @@ def wrapped_init(self, *args, **kwargs): "__init__", wrapped_init, ): - compile_cache = dict(_flash_attn_fwd.compile_cache) - _flash_attn_fwd.compile_cache.clear() + compile_cache = _flash_attn_fwd.compile_cache + _flash_attn_fwd.compile_cache = get_jit_cache("test_mask_mod.fwd") try: _run_mask_test( seqlen_q=128, @@ -839,7 +876,7 @@ def wrapped_init(self, *args, **kwargs): ) finally: _flash_attn_fwd.compile_cache.clear() - _flash_attn_fwd.compile_cache.update(compile_cache) + _flash_attn_fwd.compile_cache = compile_cache assert observed.get("q_stage") == 1 @@ -907,10 +944,9 @@ def test_sm100_block_sparse_coarse_blocks(): window_size_left=None, window_size_right=None, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=False, - _compute_capability=None, + _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, @@ -1014,10 +1050,9 @@ def wrapped_normalize(*args, **kwargs): window_size_left=None, window_size_right=None, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=False, - _compute_capability=None, + _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, @@ -1143,6 +1178,9 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + + # Use a block_size whose Q dimension doesn't divide m_block_size (100 % 80 != 0) + bad_block_size_q = 100 bm = create_block_mask( mask_mod_flex, batch_size, @@ -1150,7 +1188,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): seqlen_q, seqlen_k, device="cuda", - BLOCK_SIZE=(tile_m, tile_n), + BLOCK_SIZE=(bad_block_size_q, tile_n), ) ( _seq_q, @@ -1171,7 +1209,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, - block_size=(tile_m, tile_n), + block_size=(bad_block_size_q, tile_n), ) softmax_scale = 1.0 / math.sqrt(headdim) @@ -1181,7 +1219,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): with pytest.raises( ValueError, - match=r"Block sparsity expects sparse_block_size_q=128 for subtile_factor=2\.", + match=r"Block sparsity expects sparse_block_size_q=", ): _flash_attn_bwd( q=tensors["q"], @@ -1199,6 +1237,209 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): ) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_infers_block_size(): + torch.manual_seed(0) + + batch_size = 1 + nheads = 4 + seqlen_q = 128 + seqlen_k = 128 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + softmax_scale = 1.0 / math.sqrt(headdim) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + + def block_causal(batch, head, q_idx, kv_idx): + return kv_idx // tile_n <= q_idx // tile_m + + bm = create_block_mask( + block_causal, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=None, + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=None, + ) + + out, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = run_cute_mask_bwd( + q, + k, + v, + out, + lse, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + ) + + out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( + q, k, v, bm, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, bm, grad_out) + assert_fwd_matches_reference(out, out_ref, out_pt) + assert_bwd_matches_reference( + dq, + dk, + dv, + dq_ref, + dk_ref, + dv_ref, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_explicit_192_block_size(): + torch.manual_seed(0) + + batch_size = 1 + nheads = 4 + seqlen_q = 384 + seqlen_k = 384 + headdim = 96 + block_size_q = 192 + block_size_kv = 128 + dtype = torch.bfloat16 + softmax_scale = 1.0 / math.sqrt(headdim) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + + def block_causal(batch, head, q_idx, kv_idx): + return (q_idx >= block_size_q) & (kv_idx < block_size_kv) + + bm = create_block_mask( + block_causal, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(block_size_q, block_size_kv), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(block_size_q, block_size_kv), + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(block_size_q, block_size_kv), + ) + + out, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=True, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = _flash_attn_bwd( + q=q, + k=k, + v=v, + out=out, + dout=grad_out, + lse=lse, + softmax_scale=softmax_scale, + causal=True, + block_sparse_tensors=block_sparse_mask_bwd, + ) + + out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( + q, k, v, bm, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, bm, grad_out) + assert_fwd_matches_reference(out, out_ref, out_pt) + assert_bwd_matches_reference( + dq, + dk, + dv, + dq_ref, + dk_ref, + dv_ref, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) + + def test_gqa_block_sparse_broadcast_pattern_recompilation(): """Test that different block sparse broadcast patterns trigger recompilation. @@ -1268,7 +1509,7 @@ def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, to q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=False, window_size_left=-1, window_size_right=-1, - m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + tile_mn=(tile_m, tile_n), pack_gqa=False, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd, return_lse=True, ) @@ -1343,7 +1584,7 @@ def test_gqa_expand_stride_zero_bug(): q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=True, - m_block_size=128, n_block_size=128, + tile_mn=(128, 128), return_lse=True, ) out_fwd, lse_fwd = out_tuple[0], out_tuple[1] @@ -1415,5 +1656,147 @@ def causal_mask(b, h, q_idx, kv_idx): assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 persistent forward only") +def test_persistent_blocksparse_empty_tiles(): + """Regression test for persistent forward deadlock with highly-sparse block masks. + + When most Q-tiles are empty (no active KV blocks), the persistent kernel + deadlocked due to barrier phase desync in the empty-tile paths of both the + softmax and correction warp groups. + """ + torch.manual_seed(5) + batch_size, nheads_q, nheads_kv = 2, 16, 1 + seqlen_q, seqlen_k, headdim = 8192, 128, 128 + tile_m, tile_n = 128, 128 + dtype = torch.bfloat16 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + window_size = 64 + mask_mod_cute, mask_mod_flex = get_mask_pair( + "sliding_window", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size, + ) + + bm = create_block_mask( + mask_mod_flex, batch_size, nheads_q, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype) + + out, lse = _flash_attn_fwd( + q=q, k=k, v=v, + out=torch.empty(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads_q, seqlen_q, device="cuda", dtype=torch.float32), + cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, + page_table=None, softmax_scale=1.0 / math.sqrt(headdim), + causal=False, softcap=None, + window_size_left=None, window_size_right=None, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=False, _arch=None, + score_mod=None, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, aux_tensors=None, + ) + torch.cuda.synchronize() + assert out.shape == (batch_size, seqlen_q, nheads_q, headdim) + assert not out.isnan().any() + + + +def test_compact_block_sparse_indices(): + """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. + + FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last + dimension does not need to be as large as ceil(seqlen_k / block_size_n). This + test verifies that truncated (compact) index tensors produce identical output + to full-sized ones. + """ + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 1024 + seqlen_k = 1024 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None + ) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype + ) + + bm = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + + # Determine the max count across all query tiles — this is the compact last dim + max_mask_k = kv_mask_cnt.max().item() if kv_mask_cnt is not None else 0 + max_full_k = full_kv_cnt.max().item() if full_kv_cnt is not None else 0 + max_k = max(max_mask_k, max_full_k, 1) + + # Truncate index tensors to compact size + kv_mask_idx_compact = kv_mask_idx[:, :, :, :max_k].contiguous() + full_kv_idx_compact = full_kv_idx[:, :, :, :max_k].contiguous() if full_kv_idx is not None else None + + block_sparse_compact = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx_compact, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx_compact, + block_size=(sparse_tile_m, tile_n), + ) + + out_compact, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_compact, + return_lse=True, + ) + + # Reference: use full-sized index tensors + block_sparse_full = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_full, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_full, + return_lse=True, + ) + + assert not torch.isnan(out_compact).any(), "Compact output has NaN" + assert torch.isfinite(out_compact).all(), "Compact output has Inf" + # Compact and full should produce bit-identical results + assert torch.equal(out_compact, out_full), ( + f"Compact and full block sparse outputs differ: " + f"max diff = {(out_compact - out_full).abs().max().item():.2e}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index 11efcc8cdbc..43bf62e7d54 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -4,8 +4,9 @@ import cutlass.cute as cute from cutlass._mlir.dialects import math as mlir_math import operator -from torch.nn.attention.flex_attention import flex_attention -from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd, _tile_size_bwd_sm90 +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -23,6 +24,16 @@ score_mod_batch_bias as score_mod_10, score_mod_dual_buffer as score_mod_11, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized as score_mod_1_vectorized, + score_mod_causal_vectorized as score_mod_2_vectorized, + score_mod_rel_bias as score_mod_3_vectorized, + score_mod_rel_bias_x2_vectorized as score_mod_4_vectorized, + score_mod_times_two_vectorized as score_mod_5_vectorized, + score_mod_alibi_vectorized as score_mod_6_vectorized, + score_mod_batch_bias_vectorized as score_mod_10_vectorized, + score_mod_dual_buffer_vectorized as score_mod_11_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -59,6 +70,21 @@ (score_mod_11, dual_buffer_bias), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED = [ + (score_mod_1, score_mod_1_vectorized), + (score_mod_2, score_mod_2_vectorized), + (score_mod_3, score_mod_3_vectorized), + (score_mod_4, score_mod_4_vectorized), + (score_mod_5, score_mod_5_vectorized), + (score_mod_6, score_mod_6_vectorized), +] + +TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED = [ + (score_mod_10, score_mod_10_vectorized), + (score_mod_11, score_mod_11_vectorized), +] + SEQLEN_CONFIGS = [ (1, 1), (64, 128), @@ -82,6 +108,8 @@ (4224, 4224), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if COMPUTE_CAPABILITY == 10 else [1, 2] + def create_tensors( batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 @@ -92,12 +120,8 @@ def create_tensors( return q, k, v -def run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False -) -> torch.Tensor: - q_transposed, k_transposed, v_transposed = map( - lambda x: x.transpose(1, 2), (q, k, v) - ) +def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map(lambda x: x.transpose(1, 2), (q, k, v)) out = torch.empty_like(q_transposed) _flash_attn_fwd( q_transposed, @@ -116,9 +140,7 @@ def run_cute_flash( def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) - return flex_attention( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @@ -174,6 +196,39 @@ def test_cute_vs_flex_attention( ) +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_VECTORIZED) +def test_cute_score_mod_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, +): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) + assert torch.equal(out, out_ref) + + @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -214,9 +269,7 @@ def test_cute_vs_flex_attention_with_aux_tensors( out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) - out_cute = run_cute_flash( - q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa - ) + out_cute = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape @@ -247,19 +300,65 @@ def test_cute_vs_flex_attention_with_aux_tensors( ) -def _generate_block_kvcache( - seqlen_k, page_size, batch_size, nheads_k, d, device, dtype +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED) +def test_cute_score_mod_with_aux_tensors_vectorized( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_pair, ): + """Tests equality between original and vectorized versions of score mods""" + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair + batch_size = 2 + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [buffer] + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + assert head_bias.shape == (num_q_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, + k, + v, + cute_vectorized_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + assert torch.equal(out, out_ref) + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): import math from einops import rearrange num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) - v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + k_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) + v_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -321,12 +420,8 @@ def test_score_mod_with_paged_kvcache( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -342,9 +437,7 @@ def test_score_mod_with_paged_kvcache( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) from einops import rearrange @@ -426,9 +519,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -478,12 +569,8 @@ def test_score_mod_with_paged_kvcache_aux_tensors( q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: - k_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) - v_cache = torch.randn( - batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype - ) + k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) + v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None @@ -499,9 +586,7 @@ def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) - cache_seqlens = torch.randint( - 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device - ) + cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 @@ -595,9 +680,7 @@ def masked_score_mod(score, b, h, q_idx, kv_idx): pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() - print( - f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" - ) + print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") @@ -628,7 +711,7 @@ def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info @cute.jit def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). - + At unmasked positions (q_idx >= kv_idx), grad passes through. At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. """ @@ -678,7 +761,9 @@ def run_cute_flash_bwd( v_t = v.transpose(1, 2) out, lse = _flash_attn_fwd( - q_t, k_t, v_t, + q_t, + k_t, + v_t, return_lse=True, score_mod=cute_score_mod, aux_tensors=aux_tensors, @@ -688,8 +773,12 @@ def run_cute_flash_bwd( grad_out = torch.randn_like(out) dq, dk, dv = _flash_attn_bwd( - q_t, k_t, v_t, - out, grad_out, lse, + q_t, + k_t, + v_t, + out, + grad_out, + lse, score_mod=cute_score_mod, score_mod_bwd=cute_score_mod_bwd, aux_tensors=aux_tensors, @@ -718,14 +807,164 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): v = v.requires_grad_(True) compiled_flex = torch.compile(flex_attention) - out = compiled_flex( - q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] - ) + out = compiled_flex(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out) return out, dq, dk, dv +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_score_mod_backward_with_dq_swapab(): + torch.random.manual_seed(42) + + batch_size = 1 + num_heads = 4 + seqlen_q = 640 + seqlen_kv = 640 + dim = 128 + block_size_q = 640 + block_size_kv = 128 + dtype = torch.bfloat16 + + cfg = _tile_size_bwd_sm90( + dim, + dim, + causal=False, + local=False, + sparse_block_size_q=block_size_q, + ) + assert cfg.m_block_size == 80 + assert cfg.dQ_swapAB + + q, k, v = create_tensors( + batch_size=batch_size, + num_heads=num_heads, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, + dim=dim, + dtype=dtype, + ) + + def prefix_visible(batch, head, q_idx, kv_idx): + return kv_idx < 3 * block_size_kv + + block_mask = create_block_mask( + prefix_visible, + B=batch_size, + H=num_heads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_kv, + device=q.device, + BLOCK_SIZE=(block_size_q, block_size_kv), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = block_mask.as_tuple() + + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(block_size_q, block_size_kv), + ) + block_sparse_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(block_size_q, block_size_kv), + ) + + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + out, lse = _flash_attn_fwd( + q_t, + k_t, + v_t, + return_lse=True, + score_mod=score_mod_squared, + block_sparse_tensors=block_sparse_fwd, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = _flash_attn_bwd( + q_t, + k_t, + v_t, + out, + grad_out, + lse, + score_mod=score_mod_squared, + score_mod_bwd=score_mod_bwd_squared, + block_sparse_tensors=block_sparse_bwd, + ) + + def run_flex_block_sparse_score_mod_ref(q_ref, k_ref, v_ref, grad_out_ref, ref_dtype=None): + if ref_dtype is not None: + q_ref = q_ref.to(ref_dtype).requires_grad_(True) + k_ref = k_ref.to(ref_dtype).requires_grad_(True) + v_ref = v_ref.to(ref_dtype).requires_grad_(True) + grad_out_ref = grad_out_ref.to(ref_dtype) + else: + q_ref = q_ref.requires_grad_(True) + k_ref = k_ref.requires_grad_(True) + v_ref = v_ref.requires_grad_(True) + + compiled_flex = torch.compile(flex_attention) + out_ref = compiled_flex( + q_ref, + k_ref, + v_ref, + block_mask=block_mask, + score_mod=score_squared_eager, + enable_gqa=False, + ) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) + return out_ref, dq_ref, dk_ref, dv_ref + + out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_block_sparse_score_mod_ref( + q, k, v, grad_out.transpose(1, 2), ref_dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_block_sparse_score_mod_ref( + q, k, v, grad_out.transpose(1, 2) + ) + + rtol = 2 + out_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() + dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() + dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() + + out_ref = out_ref_fp32.to(dtype) + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + assert (out.transpose(1, 2) - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + out_atol + assert (dq.transpose(1, 2) - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + assert (dk.transpose(1, 2) - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + assert (dv.transpose(1, 2) - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ @@ -755,15 +994,11 @@ def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_ seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) - out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( - q, k, v, cute_fwd, cute_bwd - ) + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any(), "dQ contains NaN" assert not torch.isnan(dk_cute).any(), "dK contains NaN" @@ -839,9 +1074,7 @@ def test_cute_vs_flex_attention_backward_with_aux( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() @@ -901,9 +1134,7 @@ def test_cute_vs_flex_attention_backward_pack_gqa( out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) - out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( - q, k, v, eager_ref, grad_out - ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py index 7cca7f2aa0a..c8092228a51 100644 --- a/tests/cute/test_score_mod_varlen.py +++ b/tests/cute/test_score_mod_varlen.py @@ -28,6 +28,16 @@ score_mod_stress_xor_pattern, score_mod_times_two, ) # isort: split +from score_mod_definitions import ( + score_mod_identity_vectorized, + score_mod_causal_vectorized, + score_mod_rel_bias as score_mod_rel_bias_vectorized, + score_mod_rel_bias_x2_vectorized, + score_mod_times_two_vectorized, + score_mod_alibi_vectorized, + score_mod_batch_bias_vectorized, + score_mod_dual_buffer_vectorized, +) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, @@ -55,6 +65,7 @@ ) IS_SM90 = torch.cuda.get_device_capability()[0] == 9 +IS_SM100 = torch.cuda.get_device_capability()[0] == 10 # ============================================================================= # Test pairs @@ -77,6 +88,17 @@ (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), ] +# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) +TEST_PAIRS_VECTORIZED_NO_GLOBAL = [ + (score_mod_identity, score_mod_identity_vectorized, None), + (score_mod_causal, score_mod_causal_vectorized, None), + (score_mod_rel_bias, score_mod_rel_bias_vectorized, None), + (score_mod_rel_bias_x2, score_mod_rel_bias_x2_vectorized, None), + (score_mod_times_two, score_mod_times_two_vectorized, None), + (score_mod_alibi, score_mod_alibi_vectorized, None), + (score_mod_batch_bias, score_mod_batch_bias_vectorized, "batch"), + (score_mod_dual_buffer, score_mod_dual_buffer_vectorized, "dual_buffer"), +] # (cute_score_mod, eager_factory, aux_type, requires_global) # aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" # requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) @@ -151,6 +173,8 @@ ([1, 1, 1], [256 * 1024] * 3), ] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if IS_SM100 else [1, 2] + # ============================================================================= # Helper functions # ============================================================================= @@ -488,6 +512,86 @@ def test_varlen_with_score_mod( cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_vec_tuple", TEST_PAIRS_VECTORIZED_NO_GLOBAL) +def test_varlen_with_score_mod_vectorized( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_vec_tuple, +): + """Tests equality between original and vectorized versions of score mods""" + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + torch.random.manual_seed(42) + cute_score_mod, cute_vectorized_score_mod, aux_type = score_mod_vec_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + out_ref = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: + cute_vectorized_score_mod.__vec_size__ = vec_size + out = run_cute_flash( + q, + k, + v, + cute_vectorized_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + assert torch.equal(out, out_ref) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False])