Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flash_attn/cute/flash_fwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ def _get_smem_layout_atom(self):
return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom

def _get_tiled_mma(self):
atom_layout_n = 2 if self.tile_hdim > 256 or self.tile_hdimv > 256 else 1
tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(
self.dtype,
self.dtype,
warpgroup.OperandMajorMode.K,
warpgroup.OperandMajorMode.K,
Float32,
atom_layout_mnk=(self.tile_m // 64, 1, 1),
atom_layout_mnk=(self.tile_m // 64, atom_layout_n, 1),
tiler_mn=(64, self.tile_n),
)
tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(
Expand All @@ -108,8 +109,8 @@ def _get_tiled_mma(self):
warpgroup.OperandMajorMode.K,
warpgroup.OperandMajorMode.MN,
Float32,
atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512
tiler_mn=(64, self.tile_hdimv),
atom_layout_mnk=(self.tile_m // 64, atom_layout_n, 1), # Might need (1, 2, 1) for hdim 512
tiler_mn=(64, min(256, self.tile_hdimv)),
a_source=warpgroup.OperandSource.RMEM
if self.mma_pv_is_rs
else warpgroup.OperandSource.SMEM,
Expand Down
10 changes: 6 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int,
is_deepseek_shape = head_dim == 192 and head_dim_v == 128
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128

is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256
is_sm90_range = 8 <= head_dim <= 512 and 8 <= head_dim_v <= 512
if compute_capability == 9:
assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, (
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
Expand Down Expand Up @@ -155,9 +155,12 @@ def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_
elif head_dim <= 192:
tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112)
return FwdConfig(128, tile_n, True, True)
else: # hdim 256
elif head_dim <= 256:
tile_n = 64 if is_local else 80
return FwdConfig(128, tile_n, True, True)
else: # 512
tile_n = 64
return FwdConfig(64, tile_n, False, True)

@dataclass(frozen=True)
class BwdConfig:
Expand Down Expand Up @@ -706,8 +709,7 @@ def _flash_attn_fwd(
pack_gqa=pack_gqa,
tile_m=tile_m,
tile_n=tile_n,
# num_stages=1,
num_stages=2,
num_stages=1 if any(d > 256 for d in [head_dim, head_dim_v]) else 2,
num_threads=num_threads,
Q_in_regs=False,
intra_wg_overlap=intra_wg_overlap,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/cute/paged_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def create(
val_layout = cute.make_layout((1, async_copy_elems))
gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
page_entry_per_thread = n_block_size // num_threads
page_entry_per_thread = max(1, n_block_size // num_threads)

tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
Expand Down