Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 126 additions & 94 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
deterministic: bool = False,
cluster_size: int = 1,
):
assert qhead_per_kvhead == 1, "GQA is not supported yet in FlashAttentionBackwardSm100"
# padding head_dim to a multiple of 16 as k_block_size
hdim_multiple_of = 16
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
Expand Down Expand Up @@ -163,13 +162,15 @@ def _setup_attributes(self):
self.Q_stage = 2
self.dO_stage = 1
# LSE_stage = Q_stage and dPsum_stage = dO_stage
self.sdKVaccum_stage = 2
# self.sdKVaccum_stage = 2
# number of tma reduce adds per dQacc mma
self.dQ_reduce_ncol = 32
self.sdQaccum_stage = 64 // self.dQ_reduce_ncol
assert self.tile_hdim % self.dQ_reduce_ncol == 0
self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1
# number of tma reduce adds for dKacc and dVacc epilogue
self.dK_reduce_ncol = 32

def _get_tiled_mma(self):
cta_group = tcgen05.CtaGroup.ONE
Expand Down Expand Up @@ -314,15 +315,23 @@ def _setup_smem_layout(self):
)
self.sdKV_epi_tile = (
self.tile_n,
128 // (self.dk_dtype.width // 8),
128 // (self.dk_dtype.width // 8), # 64 or 32
) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2]
self.num_epi_stages = (self.tile_hdim // 2) // self.sdKV_epi_tile[1]
self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages

# TODO: dK and dV could have different shapes
self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
self.dk_dtype,
LayoutEnum.ROW_MAJOR,
self.sdKV_epi_tile,
self.sdKVaccum_stage,
)
if const_expr(self.qhead_per_kvhead == 1):
self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
self.dk_dtype,
LayoutEnum.ROW_MAJOR,
self.sdKV_epi_tile,
2, # num compute wgs
)
else:
self.sdKV_layout = cute.make_layout(
(self.tile_n * self.dK_reduce_ncol, 2)
)

@cute.jit
def __call__(
Expand Down Expand Up @@ -380,14 +389,21 @@ def __call__(
]

layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
mQ, mK, mV, mdO, mdK, mdV = [
utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO, mdK, mdV)
mQ, mK, mV, mdO = [
utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)
]
LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b)
mLSE, mdPsum, mdQaccum = [
utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
]
dO_transpose = [1, 0, 2, 3]
if const_expr(self.qhead_per_kvhead == 1):
layout_dKV_transpose = layout_transpose
else:
layout_dKV_transpose = LSE_dPsum_dQaccum_transpose
mdK, mdV = [
utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)
]
dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h)
mdO = utils.select(mdO, mode=dO_transpose)

semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b)
Expand Down Expand Up @@ -426,21 +442,18 @@ def __call__(
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_q_do_mcast = self.num_mcast_ctas_b > 1

self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
dK_major_mode = self.mdK_layout_enum.mma_major_mode()
dV_major_mode = self.mdV_layout_enum.mma_major_mode()
if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K):
raise RuntimeError("The layout of mdK is wrong")
if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
raise RuntimeError("The layout of mdV is wrong")

if const_expr(self.use_tma_store):
if const_expr(self.dk_dtype.width == 32):
tma_copy_op_dKV = cpasync.CopyReduceBulkTensorTileS2GOp()
else:
tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()

if const_expr(self.qhead_per_kvhead == 1):
self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
dK_major_mode = self.mdK_layout_enum.mma_major_mode()
dV_major_mode = self.mdV_layout_enum.mma_major_mode()
if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K):
raise RuntimeError("The layout of mdK is wrong")
if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
raise RuntimeError("The layout of mdV is wrong")

if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1):
tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()
tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(
tma_copy_op_dKV,
mdK,
Expand All @@ -456,24 +469,28 @@ def __call__(
1, # no mcast
)
else:
assert self.qhead_per_kvhead == 1, "Must use TMA reduce add for GQA"
mdV_tma_tensor = mdV
mdK_tma_tensor = mdK
tma_atom_dV = None
tma_atom_dK = None

thr_layout_r2s_dKV = cute.make_ordered_layout((self.tile_n, 1), order=(1, 0)) # 128 threads
val_layout_r2s_dKV = cute.make_ordered_layout(
(1, 128 // self.dk_dtype.width), order=(1, 0)
) # 4 or 8 vals for 16 byte store
copy_atom_r2s_dKV = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.dk_dtype,
num_bits_per_copy=128,
)
tiled_copy_r2s_dKV = cute.make_tiled_copy_tv(
copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV
)
if const_expr(self.qhead_per_kvhead == 1):
thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads
val_layout_r2s_dKV = cute.make_ordered_layout(
(1, 128 // self.dk_dtype.width), order=(1, 0)
) # 4 or 8 vals for 16 byte store
copy_atom_r2s_dKV = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.dk_dtype,
num_bits_per_copy=128,
)
tiled_copy_r2s_dKV = cute.make_tiled_copy_tv(
copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV
)
else:
tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d(
Float32, 128, num_copy_elems=128 // Float32.width
)

tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group)
Expand Down Expand Up @@ -533,6 +550,7 @@ def __call__(
self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8
self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8

# TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler
TileScheduler = SingleTileScheduler
Expand Down Expand Up @@ -708,7 +726,7 @@ def kernel(
sdS_layout: cute.ComposedLayout,
sKt_layout: cute.ComposedLayout,
sdQaccum_layout: cute.Layout,
sdKV_layout: cute.ComposedLayout,
sdKV_layout: cute.ComposedLayout | cute.Layout,
tP_layout: cute.ComposedLayout,
tdS_layout: cute.ComposedLayout,
tiled_mma_S: cute.TiledMma,
Expand Down Expand Up @@ -871,12 +889,16 @@ def kernel(
sdOt = cute.make_tensor(cute.recast_ptr(sdO.iterator, sdOt_layout.inner), sdOt_layout.outer)
sLSE = storage.sLSE.get_tensor(sLSE_layout)
sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)
sdV = storage.sdO.get_tensor(
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
)
sdK = storage.sQ.get_tensor(
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype
)
if const_expr(self.qhead_per_kvhead == 1):
sdV = storage.sdO.get_tensor(
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
)
sdK = storage.sQ.get_tensor(
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype
)
else:
sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype)
sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype)
assert cute.size_in_bytes(self.do_dtype, sdO_layout) >= cute.size_in_bytes(
self.dv_dtype, sdKV_layout
), "Not enough space for sdV"
Expand Down Expand Up @@ -1930,7 +1952,7 @@ def compute_loop(
thr_copy_r2s_dKV,
pipeline_dKV,
consumer_state_dKV,
softmax_scale,
softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None,
int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
mdK_semaphore,
)
Expand Down Expand Up @@ -2228,32 +2250,53 @@ def epilogue_dK_or_dV_tma(
num_wg = num_compute_threads // 128
leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0

sdKV = sdKV[None, None, wg_idx]
if const_expr(self.qhead_per_kvhead == 1):
sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16
else:
sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32

# (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8)
tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)

head_idx_kv = head_idx // self.qhead_per_kvhead
mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx]

gdKV_p = cute.local_tile(mdKV_cur, (self.tile_m, self.tile_hdimv), (n_block, 0))
gdKV = self.split_wg(gdKV_p, wg_idx, num_wg)
gdKV_epi = cute.local_tile(gdKV, self.sdKV_epi_tile, (0, None))
if const_expr(self.qhead_per_kvhead == 1):
mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim)
gdKV_p = cute.local_tile(
mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)
) # (tile_n, hdim)
gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2)
gdKV_epi = cute.local_tile(
gdKV, self.sdKV_epi_tile, (0, None)
) # (tile_n, 64, epi_stage = (hdim / 2) / 64)
else:
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
gdKV_p = cute.local_tile(
mdKV_cur, (self.tile_n * self.tile_hdim, ), (n_block, )
) # (tile_n * hdim)
gdKV = cute.logical_divide(
gdKV_p, (self.tile_n * self.tile_hdim // num_wg, )
)[((None, wg_idx), )] # (tile_n * hdim / 2)
gdKV_epi = cute.flat_divide(
gdKV, (self.sdKV_flat_epi_tile, )
) # (tile_n * hdim / 2 / epi_stage, epi_stage)

if const_expr(self.deterministic and self.qhead_per_kvhead > 1):
mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]

# (TMA) and (TMA, EPI_STAGE)
tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
tma_atom_dKV,
0, # no multicast
cute.make_layout(1),
cute.group_modes(sdKV, 0, 2),
cute.group_modes(gdKV_epi, 0, 2),
)

assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV"
assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV"

num_epi_stages = cute.size(tdKVgdKV.shape[1])
assert num_epi_stages == 1 or num_epi_stages == 2, "Wrong number of epi stages"
if const_expr(self.qhead_per_kvhead == 1):
tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
tma_atom_dKV,
0, # no multicast
cute.make_layout(1),
cute.group_modes(sdKV, 0, 2),
cute.group_modes(gdKV_epi, 0, 2),
) # (TMA) and (TMA, EPI_STAGE)
assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV"
assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV"
num_epi_stages = cute.size(tdKVgdKV.shape[1])
assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong"
else:
num_epi_stages = self.num_epi_stages

tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
Expand All @@ -2270,20 +2313,20 @@ def epilogue_dK_or_dV_tma(
)
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)

for s in cutlass.range_constexpr(num_epi_stages):
for epi_stage in cutlass.range_constexpr(num_epi_stages):
# TMEM -> RMEM -- setup
thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx)
tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV)
tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
if const_expr(num_epi_stages > 1):
tdKVtdKV_t2r = tdKVtdKV_t2r[None, s]
tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage]

cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
tdKVcdKV = thr_mma.partition_C(cdKV)
tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV)
tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
if const_expr(num_epi_stages > 1):
tdKVcdKV_t2r = tdKVcdKV_t2r[None, s]
tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage]

tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32)

Expand All @@ -2301,30 +2344,11 @@ def epilogue_dK_or_dV_tma(
tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2(
(tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale)
)
tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype)
tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns)
tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype))

# RMEM -> SMEM -- setup
tdKVcdKV_r2s_p = thr_copy_r2s_dKV.partition_S(cdKV)
tdKVcdKV_r2s = self.split_wg(tdKVcdKV_r2s_p, wg_idx, num_wg)
tdKVcdKV_r2s = cute.logical_divide(
tdKVcdKV_r2s,
(
tdKVcdKV_r2s.shape[0],
tdKVcdKV_r2s.shape[1],
tdKVcdKV_r2s.shape[2] // num_epi_stages,
),
)[((None, 0), (None, 0), (None, s))]

tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVcdKV_r2s.shape)

tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)

assert cute.size(tdKVrdKV_r2s) == cute.size(tdKVsdKV_r2s), (
"RMEM<->SMEM fragment size mismatch"
)

# RMEM -> SMEM -- copy, fence and barrier
tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)
cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
Expand All @@ -2333,8 +2357,16 @@ def epilogue_dK_or_dV_tma(

# SMEM -> GMEM
if leader_warp:
cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, s])
if s < num_epi_stages - 1:
if const_expr(self.qhead_per_kvhead == 1):
cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])
else:
with cute.arch.elect_one():
copy_utils.cpasync_reduce_bulk_add_f32(
sdKV.iterator,
gdKV_epi[None, epi_stage].iterator,
self.tma_copy_bytes["dKacc"],
)
if const_expr(epi_stage < num_epi_stages - 1):
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
cute.arch.barrier_arrive(
Expand Down
Loading