Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_attn/cute/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.num_threads = num_threads
self.AtomLayoutMdQ = AtomLayoutMdQ
self.dQ_swapAB = dQ_swapAB
self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64
self.use_2cta_instrs = use_2cta_instrs and arch // 10 in [10, 11] and head_dim != 64
self.cluster_size = cluster_size

@staticmethod
Expand Down Expand Up @@ -373,7 +373,7 @@ def kernel(
seqlen_q = seqlen.seqlen_q
seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)

if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs):
if const_expr(self.arch // 10 in [10, 11] and self.use_2cta_instrs):
# 2-CTA: remap dQaccum layout into TMEM view before writing sdQ
num_reduce_threads = self.num_threads
thr_mma_dsk = tiled_mma.get_slice(tidx)
Expand Down
4 changes: 1 addition & 3 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,9 +1423,7 @@ def kernel(
)
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)

AttentionMaskCls = self._generate_attention_mask_cls(
window_size_left, window_size_right
)
AttentionMaskCls = self._generate_attention_mask_cls(window_size_left, window_size_right)
# EMPTY
# (15)
if warp_idx == self.empty_warp_id:
Expand Down
8 changes: 4 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def _flash_attn_fwd(
if cu_seqlens_k is None and seqused_k is None:
min_seqlen_k = seqlen_k
seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead
if arch // 10 == 10:
if arch // 10 in [10, 11]:
q_stage = 2 if seqlen_q_packgqa > tile_m else 1
else:
q_stage = 1
Expand Down Expand Up @@ -575,8 +575,8 @@ def _flash_attn_fwd(
and (tile_m % qhead_per_kvhead == 0 or not pack_gqa)
)

# hd=256 2CTA forward uses dedicated kernel (SM100 only)
use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256
# hd=256 2CTA forward uses dedicated kernel (Blackwell family)
use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256
use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel

if softcap is not None:
Expand Down Expand Up @@ -1332,7 +1332,7 @@ def _flash_attn_bwd(
cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1
use_2cta_instrs = cluster_size==2

use_dedicated_hd256_kernel = arch // 10 == 10 and head_dim == 256 and head_dim_v == 256
use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256
use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel

q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [
Expand Down
Loading