Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions benchmarks/benchmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def setup_fa4(ctx):
bwd_fn = None
if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn:
if ctx["varlen"]:
qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"]
qu, ku, vu, gu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"], ctx["g_unpad"]
csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"]
bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu])
bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), gu, [qu, ku, vu])
else:
bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v])
return fwd_fn, bwd_fn
Expand Down Expand Up @@ -352,9 +352,10 @@ def main():
g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen)

# Varlen tensors
q_unpad = k_unpad = v_unpad = cu_seqlens_q = cu_seqlens_k = None
q_unpad = k_unpad = v_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None
if varlen:
q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]]
g_unpad = rearrange(g.detach(), "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None

Expand All @@ -374,7 +375,7 @@ def main():
q=q, k=k, v=v, g=g, causal=causal,
headdim=headdim, headdim_v=headdim_v, dtype=dtype,
has_backward=has_backward,
varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad,
varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, g_unpad=g_unpad,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
seqlen_q=seqlen_q, seqlen=seqlen,
page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table,
Expand All @@ -387,12 +388,12 @@ def main():
fwd_fn, bwd_fn = setup_fn(ctx)
if fwd_fn is not None and has_forward:
time.sleep(1.0)
print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}")
print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}")
ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3
time_f[cfg, display_name] = ms
if bwd_fn is not None and has_backward:
time.sleep(1.0)
print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}")
print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}")
ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3
time_b[cfg, display_name] = ms

Expand Down
14 changes: 9 additions & 5 deletions flash_attn/cute/tile_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def create(
) -> "SingleTileLPTBwdScheduler.Params":
size_l2 = 50 * 1024 * 1024
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
# size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
size_one_dqaccum_head = 0
size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
# size_one_dqaccum_head = 0
size_one_head = size_one_qdo_head + size_one_dqaccum_head
log2_floor = lambda n: 31 - clz(n)
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
Expand Down Expand Up @@ -521,9 +521,12 @@ def create(
args: TileSchedulerArguments, *, loc=None, ip=None
) -> "SingleTileVarlenScheduler.Params":
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
max_kvblock_in_l2 = size_l2 // (
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
)
# if backward, this is qdo block size
kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
# if backward, add dqaccum block size to calculate swizzle
if args.head_swizzle:
kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
max_kvblock_in_l2 = size_l2 // kv_block_size
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
)
Expand Down Expand Up @@ -654,6 +657,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
num_n_blocks = (
num_m_blocks
* params.tile_shape_mn[0]
* params.cluster_shape_m
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this affect any of the 2cta bwd code?

Copy link
Copy Markdown
Collaborator Author

@jayhshah jayhshah Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is meant to get the right head swizzle heuristic for 2cta bwd, by accounting for the num_m_blocks being defined with respect to tiler divided by the cluster shape.

num_n_blocks here is only used to derive nheads_in_l2, so it doesn't affect correctness.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the tile_shape_mn[0] do we pass the CTA's tile shape or the cluster tile shape? I think we had this discussion and realized we have not been consistent.
In any case if it doesn't affect correctness it's fine w me.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently for the tile scheduler args we pass cta_tiler[:2] as tile_shape_min and cluster_shape_mn as a separate parameter; this makes the most sense to me since cluster shape is in principle separate from use of 2cta mma.

// params.qhead_per_kvhead_packgqa
// params.tile_shape_mn[1]
)
Expand Down
Loading