Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 104 additions & 70 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,16 @@ def _check_type(
mCuSeqlensK_type: Type[cutlass.Numeric] | None,
mSeqUsedQ_type: Type[cutlass.Numeric] | None,
mSeqUsedK_type: Type[cutlass.Numeric] | None,
is_split_kv: bool = False,
):
# Get the data type and check if it is fp16 or bf16
if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
raise TypeError("All tensors must have the same data type")
if is_split_kv:
if const_expr(not (mQ_type == mK_type == mV_type)):
raise TypeError("Q, K, V tensors must have the same data type")
if const_expr(mO_type != Float32):
raise TypeError("O tensor must be Float32 for split_kv")
else:
if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):
raise TypeError("All tensors must have the same data type")
if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):
raise TypeError("Only Float16 or BFloat16 is supported")
if const_expr(mLSE_type not in [None, Float32]):
Expand Down Expand Up @@ -336,30 +342,33 @@ def epilogue(
m_block: Int32,
head_idx: Int32,
batch_idx: Int32,
split_idx: Int32 = Int32(0),
):
# store acc_O
rO = cute.make_fragment_like(acc_O, self.dtype)
rO.store(acc_O.load().to(self.dtype))
# Make sure all threads have finished reading V
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
)
smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
# taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
# copy acc O from rmem to smem with the smem copy atom
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)

cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv))
pack_gqa = PackGQA(
self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead
)

if const_expr(not self.is_split_kv):
rO = cute.make_fragment_like(acc_O, self.dtype)
rO.store(acc_O.load().to(self.dtype))
# Make sure all threads have finished reading V
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
)
smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
# copy acc O from rmem to smem with the smem copy atom
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)

# Write LSE from rmem -> gmem
if const_expr(mLSE is not None):
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
if const_expr(self.is_split_kv):
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx, split_idx]
else:
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx]
if const_expr(not self.pack_gqa):
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))
gLSE_expanded_layout = cute.append(
Expand All @@ -383,63 +392,88 @@ def epilogue(
pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)

ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]
# thr_mma = tiled_mma.get_slice(tidx)
# taccOgO = thr_mma.partition_C(gO)
# cute.autovec_copy(rO, taccOgO)
# sync to make sure all smem stores are done
if const_expr(self.use_tma_O):
# ensure smem writes are visible to TMA
cute.arch.fence_view_async_shared()
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
)
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
store_O, _, _ = copy_utils.tma_get_copy_fn(
tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
)
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == 4:
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
)
store_O()
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
if const_expr(self.is_split_kv):
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
else:
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]

if const_expr(self.is_split_kv):
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads,
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads
)
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
tOsO = gmem_thr_copy_O.partition_S(sO)
tOrO = cute.make_fragment_like(tOsO, self.dtype)
# load acc O from smem to rmem for wider vectorization
cute.autovec_copy(tOsO, tOrO)
if const_expr(not self.pack_gqa):
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
tOgO = gmem_thr_copy_O.partition_D(gO)
tOcO = gmem_thr_copy_O.partition_S(cO)
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
# copy acc O from rmem to gmem
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
if (
t0OcO[0, rest_m, 0][0]
< seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
):
cute.copy(
gmem_tiled_copy_O,
tOrO[None, rest_m, None],
tOgO[None, rest_m, None],
pred=tOpO[None, rest_m, None]
if const_expr(self.check_hdim_v_oob)
else None,
)
thr_mma = tiled_mma.get_slice(tidx)
taccOgO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gO))
taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO))
taccOrO = layout_utils.reshape_acc_to_mn(acc_O)
seqlen_q_limit = seqlen.seqlen_q - m_block * self.tile_m
for k in cutlass.range_constexpr(cute.size(taccOrO.shape[0])):
if taccOcO[k, 0][0] < seqlen_q_limit:
for m in cutlass.range_constexpr(cute.size(taccOrO.shape[1])):
if const_expr(not self.check_hdim_v_oob) or taccOcO[k, m][1] < mO.shape[1]:
taccOgO[k, m] = taccOrO[k, m]
else:
pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)
# mO_gqa is ((qheads_per_kvhead, seqlen_q), d, h_kv)
if const_expr(not seqlen.has_cu_seqlens_q):
mO_gqa = mO[None, None, None, batch_idx, split_idx]
else:
offset = (0, seqlen.offset_q)
mO_gqa = cute.domain_offset((offset, 0, 0), mO[None, None, None, split_idx])
pack_gqa.store_O_splitkv(mO_gqa, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx)
else:
if const_expr(self.use_tma_O):
# ensure smem writes are visible to TMA
cute.arch.fence_view_async_shared()
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
)
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
store_O, _, _ = copy_utils.tma_get_copy_fn(
tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True
)
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == 4:
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
)
store_O()
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
else:
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads,
)
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
tOsO = gmem_thr_copy_O.partition_S(sO)
tOrO = cute.make_fragment_like(tOsO, self.dtype)
# load acc O from smem to rmem for wider vectorization
cute.autovec_copy(tOsO, tOrO)
if const_expr(not self.pack_gqa):
gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))
tOgO = gmem_thr_copy_O.partition_D(gO)
tOcO = gmem_thr_copy_O.partition_S(cO)
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
# copy acc O from rmem to gmem
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
if (
t0OcO[0, rest_m, 0][0]
< seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]
):
cute.copy(
gmem_tiled_copy_O,
tOrO[None, rest_m, None],
tOgO[None, rest_m, None],
pred=tOpO[None, rest_m, None]
if const_expr(self.check_hdim_v_oob)
else None,
)
else:
pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)

@cute.jit
def advance_pipeline(self, pipeline_index):
Expand Down
Loading
Loading