diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index f18d40a71a0..239dff46664 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -152,9 +152,9 @@ def setup_fa4(ctx): bwd_fn = None if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn: if ctx["varlen"]: - qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"] + qu, ku, vu, gu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"], ctx["g_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] - bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) + bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), gu, [qu, ku, vu]) else: bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v]) return fwd_fn, bwd_fn @@ -352,9 +352,10 @@ def main(): g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) # Varlen tensors - q_unpad = k_unpad = v_unpad = cu_seqlens_q = cu_seqlens_k = None + q_unpad = k_unpad = v_unpad = g_unpad = cu_seqlens_q = cu_seqlens_k = None if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + g_unpad = rearrange(g.detach(), "b s h d -> (b s) h d") cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None @@ -374,7 +375,7 @@ def main(): q=q, k=k, v=v, g=g, causal=causal, headdim=headdim, headdim_v=headdim_v, dtype=dtype, has_backward=has_backward, - varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, + varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, g_unpad=g_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqlen_q=seqlen_q, seqlen=seqlen, page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table, @@ -387,12 +388,12 @@ def main(): fwd_fn, bwd_fn = setup_fn(ctx) if fwd_fn is not None and has_forward: time.sleep(1.0) - print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}") + print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}") ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3 time_f[cfg, display_name] = ms if bwd_fn is not None and has_backward: time.sleep(1.0) - print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}") + print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}, {nheads=}, {nheads_kv=}, {deterministic=}") ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3 time_b[cfg, display_name] = ms diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index 95481099b21..196c12a32ee 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -395,8 +395,8 @@ def create( ) -> "SingleTileLPTBwdScheduler.Params": size_l2 = 50 * 1024 * 1024 size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size - # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 - size_one_dqaccum_head = 0 + size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + # size_one_dqaccum_head = 0 size_one_head = size_one_qdo_head + size_one_dqaccum_head log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) @@ -521,9 +521,12 @@ def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ( - (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] - ) + # if backward, this is qdo block size + kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + # if backward, add dqaccum block size to calculate swizzle + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) @@ -654,6 +657,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: num_n_blocks = ( num_m_blocks * params.tile_shape_mn[0] + * params.cluster_shape_m // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] )