diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index de5fea43b99..783e76866c5 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -434,7 +434,7 @@ def load_K( else: seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) seqlen_limit -= tKcK[0][0] - for n in cutlass.range_constepxr(cute.size(tKsK.shape[1])): + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: cute.copy( gmem_tiled_copy, @@ -468,7 +468,7 @@ def load_V( # Do we need to check if we overshoot kBlockN when we load V? is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_v): - for n in cutlass.range_constepxr(cute.size(tVsV.shape[1])): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None @@ -476,8 +476,8 @@ def load_V( seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) - for k in cutlass.range_constepxr(cute.size(predicate.shape[1])): - for i in cutlass.range_constepxr(cute.size(predicate.shape[0])): + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n cute.copy( gmem_tiled_copy, @@ -586,12 +586,13 @@ def __call__( # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) # (assigning it to softmax_scale_log2). LOG2_E = math.log2(math.e) - if const_expr(softcap is not None): + if const_expr(softcap is None): softmax_scale_log2 = softmax_scale * LOG2_E softcap_val = None else: softmax_scale_log2 = softcap * LOG2_E softcap_val = Float32(softmax_scale / softcap) + self.kernel( mQ, mK, @@ -631,8 +632,8 @@ def kernel( mLSE: Optional[cute.Tensor], softmax_scale_log2: Float32, softcap_val: Optional[Float32], - window_size_left: Int32, - window_size_right: Int32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -655,7 +656,7 @@ def kernel( window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoQK(seqlen_q=mQ.shape[0], seqlen_k=mK.shape[0]) + seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # TODO: return early if n_block_max == 0 # if self.is_causal: @@ -802,7 +803,7 @@ def preprocess_Q(): preprocess_Q() cute.arch.barrier() # Make sure all threads have read smem_q before loading V - for stage in cutlass.range_constepxr(self.num_stages): + for stage in cutlass.range_constexpr(self.num_stages): if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) @@ -867,7 +868,7 @@ def preprocess_Q(): # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( - acc_O, softmax.row_sum, mO, mLSE, sO, + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size )