Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def load_K(
else:
seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size)
seqlen_limit -= tKcK[0][0]
for n in cutlass.range_constepxr(cute.size(tKsK.shape[1])):
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
if t0KcK[0, n, 0][0] < seqlen_limit:
cute.copy(
gmem_tiled_copy,
Expand Down Expand Up @@ -468,16 +468,16 @@ def load_V(
# Do we need to check if we overshoot kBlockN when we load V?
is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0
if const_expr(need_predicates or not is_even_n_smem_v):
for n in cutlass.range_constepxr(cute.size(tVsV.shape[1])):
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
# If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked
if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:
predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None
if const_expr(need_predicates):
seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0]
predicate_n = t0VcV[0, n, 0][0] < seqlen_limit
predicate = cute.make_fragment_like(tVpV[None, 0, None])
for k in cutlass.range_constepxr(cute.size(predicate.shape[1])):
for i in cutlass.range_constepxr(cute.size(predicate.shape[0])):
for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):
for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):
predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n
cute.copy(
gmem_tiled_copy,
Expand Down Expand Up @@ -586,12 +586,13 @@ def __call__(
# (assigning it to softcap_val) and pre-multiply softcap_val * log2(e)
# (assigning it to softmax_scale_log2).
LOG2_E = math.log2(math.e)
if const_expr(softcap is not None):
if const_expr(softcap is None):
softmax_scale_log2 = softmax_scale * LOG2_E
softcap_val = None
else:
softmax_scale_log2 = softcap * LOG2_E
softcap_val = Float32(softmax_scale / softcap)

self.kernel(
mQ,
mK,
Expand Down Expand Up @@ -631,8 +632,8 @@ def kernel(
mLSE: Optional[cute.Tensor],
softmax_scale_log2: Float32,
softcap_val: Optional[Float32],
window_size_left: Int32,
window_size_right: Int32,
window_size_left: Optional[Int32],
window_size_right: Optional[Int32],
sQ_layout: cute.ComposedLayout,
sK_layout: cute.ComposedLayout,
sV_layout: cute.ComposedLayout,
Expand All @@ -655,7 +656,7 @@ def kernel(
window_size_left, window_size_right,
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)
seqlen = SeqlenInfoQK(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0])
seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0])
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)
# TODO: return early if n_block_max == 0
# if self.is_causal:
Expand Down Expand Up @@ -802,7 +803,7 @@ def preprocess_Q():
preprocess_Q()
cute.arch.barrier() # Make sure all threads have read smem_q before loading V

for stage in cutlass.range_constepxr(self.num_stages):
for stage in cutlass.range_constexpr(self.num_stages):
if const_expr(not self.Q_in_regs or stage > 0):
if stage == 0 or n_block - stage >= 0:
load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0)
Expand Down Expand Up @@ -867,7 +868,7 @@ def preprocess_Q():
# reuse sQ's data iterator
sO = cute.make_tensor(sQ.iterator, sO_layout)
self.epilogue(
acc_O, softmax.row_sum, mO, mLSE, sO,
acc_O, softmax.row_sum, mO, mLSE, sO, seqlen,
gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size
)

Expand Down