Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,10 @@ repos:
files: ^flash_attn/cute/.*\.py$
exclude: &cute_exclude |
(?x)^flash_attn/cute/(
__init__|
copy_utils|
cute_dsl_utils|
fast_math|
flash_bwd|
flash_fwd|
flash_fwd_combine|
flash_fwd_sm100|
hopper_helpers|
interface|
pack_gqa|
testing|
utils
)\.py$
- id: ruff-format
files: ^flash_attn/cute/.*\.py$
Expand Down
6 changes: 3 additions & 3 deletions flash_attn/cute/copy_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.

import math
from typing import Optional, Type, Tuple, Callable
from typing import Optional, Type, Callable

import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, Boolean, const_expr
from cutlass import Float32, Int32, const_expr
from cutlass.cute.nvgpu import cpasync
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cutlass_dsl import T, dsl_user_op
Expand Down Expand Up @@ -279,7 +279,7 @@ def copy_bulk(src_idx, dst_idx, **new_kwargs):
dst[None, dst_idx].iterator,
size=size,
**new_kwargs,
**kwargs
**kwargs,
)

def copy_bulk_single_stage(**new_kwargs):
Expand Down
154 changes: 107 additions & 47 deletions flash_attn/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def __init__(

@staticmethod
def can_implement(
dtype, dtype_partial, head_dim, m_block_size, k_block_size,
log_max_splits, num_threads,
dtype,
dtype_partial,
head_dim,
m_block_size,
k_block_size,
log_max_splits,
num_threads,
) -> bool:
"""Check if the kernel can be implemented with the given parameters."""
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
Expand All @@ -83,8 +88,7 @@ def _setup_attributes(self):
assert self.k_block_size % async_copy_elems == 0

k_block_gmem = (
128 if self.k_block_size % 128 == 0 else
(64 if self.k_block_size % 64 == 0 else 32)
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
)
gmem_threads_per_row = k_block_gmem // async_copy_elems
assert self.num_threads % gmem_threads_per_row == 0
Expand All @@ -111,16 +115,25 @@ def _setup_attributes(self):
num_bits_per_copy=async_copy_elems * self.dtype.width,
)
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store
atom_universal_copy,
tOpartial_layout,
vOpartial_layout, # 4 vals per store
)

# LSE copy setup with async copy (alignment = 1)
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
m_block_smem = (
128 if self.m_block_size % 128 == 0 else
(64 if self.m_block_size % 64 == 0 else
(32 if self.m_block_size % 32 == 0 else
(16 if self.m_block_size % 16 == 0 else 8)))
128
if self.m_block_size % 128 == 0
else (
64
if self.m_block_size % 64 == 0
else (
32
if self.m_block_size % 32 == 0
else (16 if self.m_block_size % 16 == 0 else 8)
)
)
)
gmem_threads_per_row_lse = m_block_smem
assert self.num_threads % gmem_threads_per_row_lse == 0
Expand Down Expand Up @@ -167,21 +180,17 @@ def _setup_attributes(self):
else:
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
smem_layout_atom_lse = cute.make_composed_layout(
smem_lse_swizzle,
0,
cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
)
self.smem_layout_lse = cute.tile_to_shape(
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
)

# O partial shared memory layout (simple layout for pipeline stages)
self.smem_layout_o = cute.make_ordered_layout(
(self.m_block_size, self.k_block_size, self.stages),
order=(1, 0, 2)
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
)


@cute.jit
def __call__(
self,
Expand All @@ -200,38 +209,63 @@ def __call__(
raise TypeError("O partial tensor must match dtype_partial")
if const_expr(not (mO.element_type == self.dtype)):
raise TypeError("O tensor must match dtype")
if const_expr(not mLSE_partial.element_type in [Float32]):
if const_expr(mLSE_partial.element_type not in [Float32]):
raise TypeError("LSE partial tensor must be Float32")
if const_expr(mLSE is not None and not mLSE.element_type in [Float32]):
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
raise TypeError("LSE tensor must be Float32")

# Shape validation - input tensors are in user format, need to be converted to kernel format
if const_expr(len(mO_partial.shape) not in [4, 5]):
raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)")
raise ValueError(
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
)
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)")
raise ValueError(
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
)
if const_expr(len(mO.shape) not in [3, 4]):
raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)")
raise ValueError(
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
)
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)")
raise ValueError(
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
)

# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)]
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mO_partial, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mO_partial, mO)
]
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
O_partial_layout_transpose = (
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
)
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose))
mO_partial = cute.make_tensor(
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
)
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose))
mLSE_partial = cute.make_tensor(
mLSE_partial.iterator,
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
)
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None
mLSE = (
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
if mLSE is not None
else None
)

# Determine if we have variable length sequences
varlen = const_expr(cu_seqlens is not None or seqused is not None)
Expand All @@ -243,9 +277,7 @@ class SharedStorage:
sLSE: cute.struct.Align[
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
]
sMaxValidSplit: cute.struct.Align[
cute.struct.MemRange[Int32, self.m_block_size], 128
]
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
sO: cute.struct.Align[
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
]
Expand All @@ -255,7 +287,11 @@ class SharedStorage:
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
seqlen = mO_partial.shape[0]
num_head = mO_partial.shape[3]
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)
batch_size = (
mO_partial.shape[4]
if const_expr(cu_seqlens is None)
else Int32(cu_seqlens.shape[0] - 1)
)

# Create FastDivmodDivisor objects for efficient division
seqlen_divmod = FastDivmodDivisor(seqlen)
Expand Down Expand Up @@ -330,22 +366,26 @@ def kernel(

# Handle semaphore reset
if const_expr(semaphore_to_reset is not None):
if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and
k_block == cute.arch.grid_dim()[1] - 1 and
batch_idx == cute.arch.grid_dim()[2] - 1):
if (
tidx == 0
and m_block == cute.arch.grid_dim()[0] - 1
and k_block == cute.arch.grid_dim()[1] - 1
and batch_idx == cute.arch.grid_dim()[2] - 1
):
semaphore_to_reset[0] = 0

# Get number of splits
num_splits = (
num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None)
num_splits_dynamic_ptr[batch_idx]
if const_expr(num_splits_dynamic_ptr is not None)
else mLSE_partial.shape[1]
)
# Handle variable length sequences using SeqlenInfo
seqlen_info = SeqlenInfo.create(
batch_idx=batch_idx,
seqlen_static=mO_partial.shape[0],
cu_seqlens=cu_seqlens,
seqused=seqused
seqused=seqused,
)
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset

Expand All @@ -354,8 +394,9 @@ def kernel(
max_idx = seqlen * num_head

# Early exit for single split if dynamic
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx):

if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
const_expr(not varlen) or m_block * self.m_block_size < max_idx
):
# ===============================
# Step 1: Load LSE_partial from gmem to shared memory
# ===============================
Expand Down Expand Up @@ -390,7 +431,11 @@ def kernel(
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
if si < num_splits:
cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m])
cute.copy(
gmem_thr_copy_LSE,
mLSE_partial_cur_copy[None, si],
tLSEsLSE[None, s, m],
)
else:
tLSEsLSE[None, s, m].fill(-Float32.inf)
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
Expand Down Expand Up @@ -424,7 +469,9 @@ def kernel(
else:
tOhidx[m] = idx // seqlen
tOmidx[m] = idx - tOhidx[m] * seqlen
tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint()
tOrOptr[m] = utils.elem_pointer_i64(
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
).toint()
if idx >= max_idx:
tOhidx[m] = -1

Expand Down Expand Up @@ -483,7 +530,9 @@ def kernel(
# Find max LSE value across splits
threads_per_col = const_expr(self.smem_threads_per_col_lse)
lse_max = utils.warp_reduce(
ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
ts2rrLSE[None, None, m]
.load()
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
op=cute.arch.fmax,
width=threads_per_col,
)
Expand All @@ -496,7 +545,9 @@ def kernel(
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
# Compute exp scales and sum
lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf
lse_max_cur = (
0.0 if lse_max == -Float32.inf else lse_max
) # In case all local LSEs are -inf
LOG2_E = math.log2(math.e)
lse_sum_cur = 0.0
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
Expand All @@ -506,7 +557,9 @@ def kernel(
lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
# Normalize scales
inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
inv_sum = (
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
)
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
# Store the scales exp(lse - lse_logsum) back to smem
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
Expand Down Expand Up @@ -584,7 +637,10 @@ def kernel(
# Accumulate scaled partial results
for m in cutlass.range(num_rows, unroll_full=True):
if tOhidx[m] >= 0 and scale[m] > 0.0:
tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32))
tOrO[None, m, None].store(
tOrO[None, m, None].load()
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
)

# ===============================
# Step 7: Write final O to gmem
Expand All @@ -605,7 +661,9 @@ def kernel(
# Write final results
for m in cutlass.range(num_rows, unroll_full=True):
if tOhidx[m] >= 0:
mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,))
mO_cur_copy = cute.tiled_divide(
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
)
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
k_idx = tOcO[0, 0, k][1] // elems_per_store
if const_expr(self.is_even_k) or tOpO[k]:
Expand All @@ -631,7 +689,9 @@ def load_O_partial(
o_gmem_ptr = cute.make_ptr(
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
)
mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)))
mO_partial_cur = cute.make_tensor(
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
)
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
k_idx = tOcO[0, 0, k][1] // elems_per_load
Expand All @@ -640,5 +700,5 @@ def load_O_partial(
gmem_tiled_copy_O_partial,
# mO_partial_cur_copy[None, k_idx, split],
utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
tOsO_partial_cur[None, m, k]
tOsO_partial_cur[None, m, k],
)
1 change: 0 additions & 1 deletion flash_attn/cute/hopper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import cutlass.cute as cute
from cutlass import Int32, Float32, Boolean, const_expr
from cutlass.cute.nvgpu import warpgroup
from cutlass._mlir.dialects import llvm
from cutlass.cutlass_dsl import Numeric, dsl_user_op
from cutlass.utils import LayoutEnum
import cutlass.utils.hopper_helpers as sm90_utils_og
Expand Down
2 changes: 0 additions & 2 deletions flash_attn/cute/pack_gqa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Copyright (c) 2025, Tri Dao.

import math
import operator

import cutlass
import cutlass.cute as cute
Expand Down
Loading