Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 94 additions & 54 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
# Speed optimizations, does not affect correctness
self.shuffle_LSE = False
self.shuffle_dPsum = False
self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal

self.reduce_warp_ids = (0, 1, 2, 3)
self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11)
Expand Down Expand Up @@ -146,7 +147,7 @@ def __init__(
self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m
self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP

if not is_causal and not is_local:
if (not is_causal and not is_local) or deterministic:
self.num_regs_reduce = 152
self.num_regs_compute = 136
else:
Expand Down Expand Up @@ -203,14 +204,18 @@ def _get_tiled_mma(self):
a_source=tcgen05.OperandSource.TMEM,
)
# dK += dS.T @ Q
if const_expr(self.use_smem_dS_for_mma_dK):
mma_dK_a_src = tcgen05.OperandSource.SMEM
else:
mma_dK_a_src = tcgen05.OperandSource.TMEM
tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma(
self.do_dtype,
tcgen05.OperandMajorMode.K, # dS_major_mode
tcgen05.OperandMajorMode.MN, # Q_major_mode
self.acc_dtype,
cta_group,
self.mma_tiler_dsq[:2],
a_source=tcgen05.OperandSource.TMEM,
a_source=mma_dK_a_src,
)
# dQ = dS @ K
tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma(
Expand Down Expand Up @@ -403,13 +408,13 @@ def __call__(
semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b)
if const_expr(self.deterministic):
assert mdQ_semaphore is not None
mdQ_semaphore = utils.select(mdQ_semaphore.layout, mode=semaphore_transpose)
mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose)

if const_expr(self.deterministic and self.qhead_per_kvhead > 1):
assert mdK_semaphore is not None
assert mdV_semaphore is not None
mdK_semaphore, mdV_semaphore = [
utils.select(t.layout, mode=semaphore_transpose)
utils.select(t, mode=semaphore_transpose)
for t in (mdK_semaphore, mdV_semaphore)
]
else:
Expand Down Expand Up @@ -546,15 +551,18 @@ def __call__(
self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8
self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8

# TileScheduler = SingleTileScheduler if not self.is_causal else SingleTileLPTBwdScheduler
TileScheduler = SingleTileScheduler
# TODO -- optimizer scheduler for causal
# TileScheduler = SingleTileScheduler
if const_expr(self.deterministic):
TileScheduler = SingleTileLPTBwdScheduler
else:
TileScheduler = SingleTileScheduler
self.spt = self.is_causal and self.deterministic
tile_sched_args = TileSchedulerArguments(
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]),
cute.size(mQ.shape[2]), # num_heads = num_query_heads
cute.size(mK.shape[3]),
1, # num_splits
cute.size(mK.shape[0]),
cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k
mQ.shape[1],
mV.shape[1],
total_q=cute.size(mQ.shape[0]),
Expand All @@ -565,7 +573,7 @@ def __call__(
qhead_per_kvhead_packgqa=1,
element_size=self.k_dtype.width // 8,
is_persistent=self.is_persistent,
lpt=False,
lpt=self.spt,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
Expand Down Expand Up @@ -1364,8 +1372,10 @@ def mma(
tdPrV = tiled_mma_dP.make_fragment_A(sV)
tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt)
# dK = dS.T @ Q
# tdKrdS = tiled_mma_dK.make_fragment_A(sdSt)
tdKrdS = tiled_mma_dK.make_fragment_A(tdS)
if const_expr(self.use_smem_dS_for_mma_dK):
tdKrdS = tiled_mma_dK.make_fragment_A(sdSt)
else:
tdKrdS = tiled_mma_dK.make_fragment_A(tdS)
tdKrQ = tiled_mma_dK.make_fragment_B(sQt)
# dQ = dS @ K
tdQrdS = tiled_mma_dQ.make_fragment_A(sdS)
Expand Down Expand Up @@ -1404,18 +1414,20 @@ def mma(
# mma_dsk_fn = partial(
# gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True
# )
# mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ)
# Need to explicitly pass in tA_addr for correctness
mma_dsq_fn = partial(
gemm_ptx_w_idx,
tiled_mma_dK,
tdKtdK,
tdKrdS,
tdKrQ,
sA=None,
sB=sQt,
tA_addr=self.tmem_dS_offset,
)
if const_expr(self.use_smem_dS_for_mma_dK):
mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ)
else:
# Need to explicitly pass in tA_addr for correctness
mma_dsq_fn = partial(
gemm_ptx_w_idx,
tiled_mma_dK,
tdKtdK,
tdKrdS,
tdKrQ,
sA=None,
sB=sQt,
tA_addr=self.tmem_dS_offset,
)

consumer_state_dO = cutlass.pipeline.make_pipeline_state(
cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
Expand Down Expand Up @@ -1486,18 +1498,29 @@ def mma(
mma_qk_fn(B_idx=handle_Q_next.index)
pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)

# 2) dK = dS.T @ Q
# 2-3)
# Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma
# Otherwise, reverse order
pipeline_dS.consumer_wait(consumer_state_dS)
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
accumulate_dK = True
handle_Q.release()

# 3) dQ = dS @ K
if const_expr(self.use_smem_dS_for_mma_dK):
mma_dsk_fn()
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
accumulate_dK = True
handle_Q.release()
else:
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
accumulate_dK = True
handle_Q.release()
mma_dsk_fn()
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)

# dP uses the same tmem as dQ
# However, if dS is ready, then dP must have been ready, so we don't need to wait
# However, if dS is ready, then dP must have been ready,
# so we don't need this wait before mma_dsk_fn()
# pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)
mma_dsk_fn()
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)

pipeline_dS.consumer_release(consumer_state_dS)
consumer_state_dS.advance()

Expand Down Expand Up @@ -1823,8 +1846,8 @@ def compute_loop(
)

cute.arch.fence_view_async_tmem_store()
self.compute_sync_barrier.arrive_and_wait()

cute.arch.sync_warp()
with cute.arch.elect_one():
pipeline_S_P.consumer_release(consumer_state_S_P_dP)
# pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
Expand All @@ -1847,6 +1870,7 @@ def compute_loop(
tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32)
cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r)
cute.arch.fence_view_async_tmem_load()
self.compute_sync_barrier.arrive_and_wait()
tdPrdP_cur = tdPrdP_t2r[None, 0, 0]
tSrS_cur = tSrS_t2r[None, stage, 0, 0]
tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index]
Expand Down Expand Up @@ -1875,22 +1899,20 @@ def compute_loop(
if const_expr(stage == 0):
pipeline_dS.producer_acquire(producer_state_dS)
cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage])
tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32)
cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0])
if const_expr(not self.use_smem_dS_for_mma_dK):
tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32)
cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0])

cute.arch.fence_view_async_tmem_store()
if const_expr(not self.use_smem_dS_for_mma_dK):
cute.arch.fence_view_async_tmem_store()
cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta)
self.compute_sync_barrier.arrive_and_wait()

cute.arch.sync_warp()
# with cute.arch.elect_one():
# The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive
# pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask)
pipeline_dPsum.consumer_release(consumer_state_dPsum)
consumer_state_dPsum.advance()

cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.sync_warp()
with cute.arch.elect_one():
pipeline_dS.producer_commit(producer_state_dS)
producer_state_dS.advance()
Expand Down Expand Up @@ -2010,10 +2032,13 @@ def dQacc_reduce(
gdQaccum = cute.flat_divide(
gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,)
)
mdQ_semaphore_cur = None

if const_expr(self.deterministic):
mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]

delay_semaphore_release = self.is_causal
n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n)

for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):
pipeline_dQ.consumer_wait(dQ_consumer_state)
# TMEM -> RMEM
Expand All @@ -2025,11 +2050,6 @@ def dQacc_reduce(
pipeline_dQ.consumer_release(dQ_consumer_state)
dQ_consumer_state.advance()

# semaphore acquire
if const_expr(self.deterministic):
barrier.wait_eq(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, n_block)
self.reduce_sync_barrier.arrive_and_wait()

gdQaccum_cur = gdQaccum[None, None, m_block]

for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4
Expand All @@ -2043,6 +2063,17 @@ def dQacc_reduce(
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
# semaphore acquire
if const_expr(self.deterministic and stage == 0):
if const_expr(self.spt):
n_block_max_for_m_block = min(
n_block_global_max,
cute.ceil_div((m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q, self.tile_n)
)
lock_value = n_block_max_for_m_block - 1 - n_block
else:
lock_value = n_block
barrier.wait_eq(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value)
self.reduce_sync_barrier.arrive_and_wait()
# Copy from shared memory to global memory
if is_tma_warp:
Expand All @@ -2067,17 +2098,25 @@ def dQacc_reduce(
# tdQrdQ_r2s[4 * i + 3],
# utils.elem_pointer(tdQgdQ, 4 * i),
# )
# semaphore release for prior m_block
if const_expr(self.deterministic and stage == 0 and delay_semaphore_release):
if m_block > m_block_min:
barrier.arrive_inc(mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1)

# semaphore release
# NOTE: arrive_inc calls red_release which issues membar
if const_expr(self.deterministic):
if tidx == 0:
if const_expr(self.deterministic and not delay_semaphore_release):
if is_tma_warp:
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
self.reduce_sync_barrier.arrive_and_wait()
barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1)

if warp_idx == 0:
if is_tma_warp:
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
self.reduce_sync_barrier.arrive_and_wait()
# final semaphore release
if const_expr(self.deterministic and delay_semaphore_release):
barrier.arrive_inc(mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1)

tile_scheduler.advance_to_next_work()
work_tile = tile_scheduler.get_current_work()
Expand Down Expand Up @@ -2274,7 +2313,8 @@ def epilogue_dK_or_dV_tma(
gdKV, (self.sdKV_flat_epi_tile,)
) # (tile_n * hdim / 2 / epi_stage, epi_stage)

if const_expr(self.deterministic and self.qhead_per_kvhead > 1):
deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1
if const_expr(deterministic_KV):
mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]

if const_expr(self.qhead_per_kvhead == 1):
Expand All @@ -2296,12 +2336,12 @@ def epilogue_dK_or_dV_tma(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)

read_flag = const_expr(not self.deterministic)
read_flag = const_expr(not deterministic_KV)

pipeline_dKV.consumer_wait(consumer_state_dKV)

# semaphore acquire
if const_expr(self.deterministic):
if const_expr(deterministic_KV):
barrier.wait_eq(
mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead
)
Expand Down Expand Up @@ -2377,7 +2417,7 @@ def epilogue_dK_or_dV_tma(

# semaphore release
# NOTE: arrive_inc calls red_release which issues membar
if const_expr(self.deterministic):
if const_expr(deterministic_KV):
if leader_warp:
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
Expand Down
Loading