Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
"""

from typing import Callable
from typing import Callable, Optional
from functools import partial
import math
import cutlass
Expand Down Expand Up @@ -606,6 +606,9 @@ def handle_block_sparse_empty_tile_correction_sm100(
o_corr_consumer_phase: Int32,
corr_epi_producer_phase: Int32,
softmax_scale_log2: Float32,
mO_cur: Optional[cute.Tensor] = None,
gO: Optional[cute.Tensor] = None,
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
):
"""Handle the block-sparse case where a tile is fully masked:
* zero staged results
Expand Down Expand Up @@ -650,18 +653,26 @@ def handle_block_sparse_empty_tile_correction_sm100(
)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage)

cute.arch.mbarrier_wait(
mbar_ptr + mbar_corr_epi_empty_offset + stage,
corr_epi_producer_phase,
)
if const_expr(gmem_tiled_copy_O is None):
cute.arch.mbarrier_wait(
mbar_ptr + mbar_corr_epi_empty_offset + stage,
corr_epi_producer_phase,
)
correction_epilogue(
thr_mma_pv,
tOtOs[stage],
tidx,
stage,
m_block,
seqlen.seqlen_q,
Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
sO[None, None, stage],
mO_cur,
gO,
gmem_tiled_copy_O,
)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
if const_expr(gmem_tiled_copy_O is None):
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)

Expand Down
155 changes: 123 additions & 32 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
)


# class NamedBarrierFwd(enum.IntEnum):
# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
class NamedBarrierFwd(enum.IntEnum):
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
# WarpSchedulerWG1 = enum.auto()
# WarpSchedulerWG2 = enum.auto()
# WarpSchedulerWG3 = enum.auto()
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
mask_mod: cutlass.Constexpr | None = None,
has_aux_tensors: cutlass.Constexpr = False,
paged_kv_non_tma: bool = False,
is_varlen_q: bool = False,
):
self.use_tma_KV = not paged_kv_non_tma
# self.dtype = dtype
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
self.is_persistent = is_persistent
self.is_causal = is_causal
self.is_local = is_local
self.is_varlen_q = is_varlen_q
self.use_correction_warps_for_epi = is_varlen_q
self.qhead_per_kvhead = qhead_per_kvhead
self.is_split_kv = is_split_kv
self.pack_gqa = pack_gqa
Expand Down Expand Up @@ -146,8 +149,8 @@ def __init__(
self.softmax1_warp_ids = (4, 5, 6, 7)
self.correction_warp_ids = (8, 9, 10, 11)
self.mma_warp_id = 12
self.load_warp_ids = (13,)
self.epilogue_warp_ids = (14,)
self.epilogue_warp_ids = (13,)
self.load_warp_ids = (14,)
self.empty_warp_ids = (15,)
SM100_TMEM_CAPACITY_COLUMNS = 512
self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
Expand All @@ -164,6 +167,15 @@ def __init__(
)
)

if not self.use_tma_KV:
self.load_warp_ids = (14, 15)
self.empty_warp_ids = ()
if self.use_correction_warps_for_epi:
self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids
self.epilogue_warp_ids = self.correction_warp_ids
elif self.is_varlen_q: # fallback
self.epilogue_warp_ids = (13, 14)

self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
self.tmem_o_offset = [
self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
Expand Down Expand Up @@ -506,19 +518,11 @@ def __call__(
self.cluster_layout_vmnk.shape,
)
else:
assert self.use_tma_O, "Loading O and K/V will contend for the empty warp."
self.epilogue_warp_ids = (13,)
self.load_warp_ids = (14, 15)
self.empty_warp_ids = ()
tma_atom_K = None
tma_atom_V = None

o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile)

# print(sO_layout.outer)
if const_expr(not self.use_tma_O):
self.epilogue_warp_ids = (14, 15)
self.empty_warp_ids = ()
self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
if const_expr(self.use_tma_O):
tma_atom_O, mO = cpasync.make_tiled_tma_atom(
Expand Down Expand Up @@ -546,7 +550,6 @@ def __call__(
assert self.m_block_size % tO_layout.shape[0] == 0
vO_layout = cute.make_layout((1, async_copy_elems))
gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
print("gmem_tiled_copy_O: ", gmem_tiled_copy_O)

if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
TileScheduler = SingleTileVarlenScheduler
Expand Down Expand Up @@ -799,7 +802,7 @@ def kernel(
cute.arch.mbarrier_init(
mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE
)
if warp_idx == 4:
if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4:
for i in cutlass.range_constexpr(self.q_stage):
cute.arch.mbarrier_init(
mbar_ptr + self.mbar_corr_epi_full_offset + i,
Expand Down Expand Up @@ -931,6 +934,12 @@ def kernel(
if warp_idx == self.empty_warp_ids[0]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)

if const_expr(len(self.empty_warp_ids) > 1):
if warp_idx == self.empty_warp_ids[1]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)

assert len(self.empty_warp_ids) <= 2

# ///////////////////////////////////////////////////////////////////////////////
# LOAD
# ///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1004,19 +1013,20 @@ def kernel(
# ///////////////////////////////////////////////////////////////////////////////
# Epilogue
# ///////////////////////////////////////////////////////////////////////////////
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
self.epilogue_s2g(
mO,
sO,
gmem_tiled_copy_O,
tma_atom_O,
mbar_ptr,
block_info,
num_splits,
SeqlenInfoCls,
TileSchedulerCls,
)
if const_expr(not self.use_correction_warps_for_epi):
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
self.epilogue_s2g(
mO,
sO,
gmem_tiled_copy_O,
tma_atom_O,
mbar_ptr,
block_info,
num_splits,
SeqlenInfoCls,
TileSchedulerCls,
)

# ///////////////////////////////////////////////////////////////////////////////
# Softmax
Expand Down Expand Up @@ -1080,6 +1090,7 @@ def kernel(
mLSE,
sO,
learnable_sink,
gmem_tiled_copy_O,
tma_atom_O,
mbar_ptr,
softmax_scale_log2,
Expand Down Expand Up @@ -1931,6 +1942,7 @@ def correction_loop(
mLSE: cute.Tensor,
sO: cute.Tensor,
learnable_sink: Optional[cute.Tensor],
gmem_tiled_copy_O: cute.TiledCopy,
tma_atom_O: cute.CopyAtom,
mbar_ptr: cute.Pointer,
softmax_scale_log2: Float32,
Expand Down Expand Up @@ -1972,6 +1984,12 @@ def correction_loop(
seqlen = SeqlenInfoCls(batch_idx)
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)

if const_expr(self.is_split_kv):
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
else:
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))

# Default LSE to -inf for invalid split_idx tiles
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage

Expand Down Expand Up @@ -2070,17 +2088,25 @@ def correction_loop(
cute.arch.mbarrier_wait(
mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase
)
cute.arch.mbarrier_wait(
mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
)
if const_expr(not self.use_correction_warps_for_epi):
cute.arch.mbarrier_wait(
mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
)
self.correction_epilogue(
thr_mma_pv,
tOtOs[stage],
tidx,
stage,
m_block,
seqlen.seqlen_q,
scale,
sO[None, None, stage],
mO_cur,
gO,
gmem_tiled_copy_O,
)
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
if const_expr(not self.use_correction_warps_for_epi):
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
# Signal for the next work tile that O buffers in tmem are already read, so
# mma warp can write to them
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
Expand All @@ -2090,6 +2116,11 @@ def correction_loop(
softmax_corr_consumer_phase ^= 1
corr_epi_producer_phase ^= 1
else:
# WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781
if const_expr(self.use_correction_warps_for_epi):
gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O
else:
gmem_tiled_copy_O_for_empty_tile = None
if const_expr(self.use_block_sparsity):
(
softmax_corr_consumer_phase,
Expand Down Expand Up @@ -2126,6 +2157,9 @@ def correction_loop(
o_corr_consumer_phase,
corr_epi_producer_phase,
softmax_scale_log2,
mO_cur,
gO,
gmem_tiled_copy_O_for_empty_tile,
)

if const_expr(mLSE is not None):
Expand Down Expand Up @@ -2228,8 +2262,14 @@ def correction_epilogue(
thr_mma: cute.core.ThrMma,
tOtO: cute.Tensor,
tidx: Int32,
stage: Int32,
m_block: Int32,
seqlen_q: Int32,
scale: Float32,
sO: cute.Tensor,
mO_cur: Optional[cute.Tensor] = None,
gO: Optional[cute.Tensor] = None,
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
):
"""Apply final scaling and transformation to attention output before writing to global memory.

Expand Down Expand Up @@ -2302,6 +2342,57 @@ def correction_epilogue(
space=cute.arch.SharedSpace.shared_cta,
)

if const_expr(self.use_correction_warps_for_epi):
assert(not self.use_tma_O)
assert(gmem_tiled_copy_O is not None)
cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
tOsO = gmem_thr_copy_O.partition_S(sO)
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
tOgO = gmem_thr_copy_O.partition_D(gO)
tOcO = gmem_thr_copy_O.partition_S(cO)
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1])
# TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it
assert not self.pack_gqa
pack_gqa = PackGQA(
self.m_block_size,
self.head_dim_v_padded,
self.check_hdim_v_oob,
self.qhead_per_kvhead,
)

# load acc O from smem to rmem for wider vectorization
tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
cute.autovec_copy(tOsO, tOrO)
# copy acc O from rmem to gmem
if const_expr(not self.pack_gqa):
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
if (
t0OcO[0, rest_m, 0][0]
< seqlen_q
- (self.q_stage * m_block + stage) * self.m_block_size
- tOcO[0][0]
):
cute.copy(
gmem_tiled_copy_O,
tOrO[None, rest_m, None],
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
pred=tOpO[None, rest_m, None]
if const_expr(self.check_hdim_v_oob)
else None,
)
else:
pack_gqa.store_O(
mO_cur,
tOrO,
gmem_tiled_copy_O,
tidx,
self.q_stage * m_block + stage,
seqlen_q,
)

@cute.jit
def epilogue_s2g(
self,
Expand Down Expand Up @@ -2389,7 +2480,7 @@ def epilogue_s2g(
tOrO[None, rest_m, None],
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
pred=tOpO[None, rest_m, None]
if self.check_hdim_v_oob
if const_expr(self.check_hdim_v_oob)
else None,
)
else:
Expand Down
10 changes: 6 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,16 @@ def _flash_attn_fwd(
m_block_size=m_block_size,
n_block_size=n_block_size,
is_persistent=not causal
and not local
and cu_seqlens_q is None
and seqused_q is None
and not is_split_kv,
and not local
and cu_seqlens_q is None
and seqused_q is None
and not is_split_kv,
score_mod=score_mod,
mask_mod=mask_mod,
has_aux_tensors=aux_tensors is not None,
paged_kv_non_tma=page_size not in [None, 128],
is_varlen_q=cu_seqlens_q is not None
or seqused_q is not None,
)
else:
raise ValueError(
Expand Down
Loading