diff --git a/benchmarks/benchmark_varlen_sched.py b/benchmarks/benchmark_varlen_sched.py new file mode 100644 index 00000000000..a055fcec58a --- /dev/null +++ b/benchmarks/benchmark_varlen_sched.py @@ -0,0 +1,517 @@ +"""Benchmark the dynamic-persistent varlen scheduler against the prior default +(`SingleTileVarlenScheduler`), CLC (if available), and — on constant-seqlen workloads — the +non-varlen `flash_attn_func` baseline. + +Examples: + python benchmarks/benchmark_varlen_sched.py --total-tokens 32k --patterns longtail + python benchmarks/benchmark_varlen_sched.py --total-tokens 32k,64k --shapes 32x1k,16x2k \\ + --patterns constant longtail --csv > out.csv +""" + +import argparse +import time +from itertools import product + +import torch +from triton.testing import do_bench + +from flash_attn.cute import utils as fa_utils +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + get_scheduler_metadata, +) + + +_CLC_MODES = {"clc", "clc-prep"} + + +def _supports_clc(device): + return torch.cuda.get_device_capability(device)[0] == 10 + + +def parse_int_k(s): + """Parse an integer with optional k/K/m/M suffix, e.g. '8k' -> 8192, '1m' -> 1048576.""" + s = str(s).strip().lower() + if s.endswith("m"): + return int(s[:-1]) * 1024 * 1024 + if s.endswith("k"): + return int(s[:-1]) * 1024 + return int(s) + + +def csv_ints(s): + return [parse_int_k(x) for x in s.split(",")] + + +def parse_shape(s): + """Parse 'x' (seqlen accepts k suffix). Returns (batch, seqlen).""" + b, sl = s.lower().split("x") + return int(b), parse_int_k(sl) + + +def parse_shapes(s): + return [parse_shape(x) for x in s.split(",")] + + +def _make_seqlens(batch, seqlen, pattern, seed): + g = torch.Generator(device="cpu").manual_seed(seed) + if pattern == "constant": + return [seqlen] * batch + if pattern == "uniform": + lo = max(1, seqlen // 2) + return torch.randint(lo, seqlen + 1, (batch,), generator=g).tolist() + if pattern == "wide": + return torch.randint(1, seqlen + 1, (batch,), generator=g).tolist() + if pattern == "longtail": + n_long = max(1, batch // 8) + out = torch.randint( + max(1, seqlen // 16), max(2, seqlen // 8), (batch,), generator=g + ).tolist() + for i in torch.randperm(batch, generator=g)[:n_long].tolist(): + out[i] = seqlen + return out + if pattern == "bimodal": + return [seqlen if i % 2 == 0 else max(1, seqlen // 8) for i in range(batch)] + if pattern == "skew": + return [max(1, int(seqlen * i / max(1, batch - 1))) for i in range(batch)] + if pattern == "skew_shuffled": + out = [max(1, int(seqlen * i / max(1, batch - 1))) for i in range(batch)] + return [out[i] for i in torch.randperm(batch, generator=g).tolist()] + raise ValueError(f"unknown pattern {pattern!r}") + + +def _causal_tiles(sq, sk, tile_m=128, tile_n=128): + if sq <= 0 or sk <= 0: + return 0 + nq = (sq + tile_m - 1) // tile_m + nk = (sk + tile_n - 1) // tile_n + if nq <= 1: + return nk + return nq * nk - (nq * (nq - 1)) // 2 + + +def _apply_sort(seqlens_q, seqlens_k, sort): + if sort == "none": + return seqlens_q, seqlens_k + pairs = list(zip(seqlens_q, seqlens_k)) + keyfn = { + "asc": lambda p: _causal_tiles(*p), + "desc": lambda p: -_causal_tiles(*p), + }.get(sort) + if keyfn is None: + raise ValueError(f"unknown sort {sort!r}") + pairs.sort(key=keyfn) + return [p[0] for p in pairs], [p[1] for p in pairs] + + +def _override_random_subset( + seqlens_q, seqlens_k, frac, seed_salt, sq_value, sk_value, seed +): + """Pick `frac` of batches at random and overwrite their seqlens to the given values. + `sk_value=None` leaves seqlens_k untouched (used for decode-mix).""" + if frac <= 0: + return seqlens_q, seqlens_k + g = torch.Generator(device="cpu").manual_seed(seed + seed_salt) + B = len(seqlens_q) + n = int(round(frac * B)) + if n <= 0: + return seqlens_q, seqlens_k + idx = torch.randperm(B, generator=g)[:n].tolist() + sq, sk = list(seqlens_q), list(seqlens_k) + for i in idx: + sq[i] = sq_value + if sk_value is not None: + sk[i] = sk_value + return sq, sk + + +def build_ctx( + args, batch, seqlen, pattern, sort, decode_frac, zero_frac, num_splits, seed +): + seqlens_k = _make_seqlens(batch, seqlen, pattern, seed) + seqlens_q = list(seqlens_k) + seqlens_q, seqlens_k = _override_random_subset( + seqlens_q, seqlens_k, decode_frac, 7919, sq_value=1, sk_value=None, seed=seed + ) + seqlens_q, seqlens_k = _override_random_subset( + seqlens_q, seqlens_k, zero_frac, 31337, sq_value=0, sk_value=0, seed=seed + ) + seqlens_q, seqlens_k = _apply_sort(seqlens_q, seqlens_k, sort) + + dtype, device = torch.bfloat16, "cuda" + nheads, nheads_kv, headdim = args.nheads, args.nheads_kv, args.headdim + + cu_q = torch.zeros(batch + 1, dtype=torch.int32, device=device) + cu_q[1:] = torch.tensor(seqlens_q, dtype=torch.int32, device=device).cumsum(0) + cu_k = torch.zeros(batch + 1, dtype=torch.int32, device=device) + cu_k[1:] = torch.tensor(seqlens_k, dtype=torch.int32, device=device).cumsum(0) + q_unpad = torch.randn( + max(sum(seqlens_q), 1), nheads, headdim, device=device, dtype=dtype + ) + k_unpad = torch.randn( + max(sum(seqlens_k), 1), nheads_kv, headdim, device=device, dtype=dtype + ) + v_unpad = torch.randn( + max(sum(seqlens_k), 1), nheads_kv, headdim, device=device, dtype=dtype + ) + + return dict( + batch=batch, + seqlen=seqlen, + pattern=pattern, + decode_frac=decode_frac, + zero_frac=zero_frac, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=headdim, + seqlens_q=seqlens_q, + seqlens_k=seqlens_k, + q_unpad=q_unpad, + k_unpad=k_unpad, + v_unpad=v_unpad, + cu_q=cu_q, + cu_k=cu_k, + max_seqlen_q=max(seqlens_q) if seqlens_q else 0, + max_seqlen_k=max(seqlens_k) if seqlens_k else 0, + causal=True, + num_splits=num_splits, + pack_gqa=args.pack_gqa, + ) + + +def _make_meta(ctx): + tile_m = 128 + qhead_per_kvhead = ctx["nheads"] // ctx["nheads_kv"] + arch = torch.cuda.get_device_capability()[0] + if arch == 10 and ctx["max_seqlen_q"] * qhead_per_kvhead > tile_m: + q_stage = 2 + else: + q_stage = 1 + return get_scheduler_metadata( + num_batch=ctx["batch"], + max_seqlen_q=ctx["max_seqlen_q"], + max_seqlen_k=ctx["max_seqlen_k"], + nheads=ctx["nheads"], + nheads_kv=ctx["nheads_kv"], + headdim=ctx["headdim"], + num_splits=ctx["num_splits"], + tile_m=tile_m, + tile_n=128, + causal=ctx["causal"], + pack_gqa=ctx["pack_gqa"], + cu_seqlens_q=ctx["cu_q"], + cu_seqlens_k=ctx["cu_k"], + q_stage=q_stage, + ) + + +def _make_meta_no_semaphore(ctx): + """Like _make_meta but with tile_count_semaphore nulled out, so the FA kernel + selects SingleTileVarlenScheduler (STATIC) instead of DynamicPersistentVarlen. + Exercises the binary-search hint path on the scheduler that lacks resumption.""" + m = _make_meta(ctx) + return m._replace(tile_count_semaphore=None) + + +def setup_dense(ctx): + """Non-varlen baseline; only meaningful when every batch has the same seqlen.""" + if ctx["pattern"] != "constant" or ctx["decode_frac"] != 0 or ctx["zero_frac"] != 0: + return None + batch, seqlen = ctx["batch"], ctx["seqlen"] + nheads, nheads_kv, headdim = ctx["nheads"], ctx["nheads_kv"], ctx["headdim"] + dtype, device = torch.bfloat16, "cuda" + q = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) + k = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=dtype) + v = torch.randn(batch, seqlen, nheads_kv, headdim, device=device, dtype=dtype) + return lambda: flash_attn_func( + q, k, v, causal=ctx["causal"], num_splits=ctx["num_splits"] + ) + + +def make_varlen_setup(*, clc: bool, prep: str, no_semaphore: bool = False): + """`prep` is one of 'none', 'precompute', 'recompute'. + + `no_semaphore=True` nulls out `tile_count_semaphore` in the metadata so the + FA kernel picks SingleTileVarlenScheduler (STATIC) instead of the auto- + selected DynamicPersistentVarlenScheduler. Use this to exercise the binary- + search hint path on the no-resumption scheduler that PR #2520 targets.""" + assert prep in ("none", "precompute", "recompute") + meta_fn = _make_meta_no_semaphore if no_semaphore else _make_meta + + def setup(ctx): + meta_precomputed = meta_fn(ctx) if prep == "precompute" else None + + def fn(): + fa_utils._fa_clc_enabled = clc + meta = meta_fn(ctx) if prep == "recompute" else meta_precomputed + return flash_attn_varlen_func( + ctx["q_unpad"], + ctx["k_unpad"], + ctx["v_unpad"], + cu_seqlens_q=ctx["cu_q"], + cu_seqlens_k=ctx["cu_k"], + max_seqlen_q=ctx["max_seqlen_q"], + max_seqlen_k=ctx["max_seqlen_k"], + causal=ctx["causal"], + num_splits=ctx["num_splits"], + scheduler_metadata=meta, + disable_scheduler_metadata=(prep == "none"), + pack_gqa=ctx["pack_gqa"], + ) + + return fn + + return setup + + +# fmt: off +MODES = [ + ("dense", setup_dense), + ("single-tile", make_varlen_setup(clc=False, prep="none")), + ("st-prep", make_varlen_setup(clc=False, prep="precompute", no_semaphore=True)), + ("clc", make_varlen_setup(clc=True, prep="none")), + ("clc-prep", make_varlen_setup(clc=True, prep="precompute")), + ("dynamic-prep", make_varlen_setup(clc=False, prep="precompute")), + ("dynamic+prep", make_varlen_setup(clc=False, prep="recompute")), +] +# fmt: on + + +def parse_args(): + p = argparse.ArgumentParser(description="Benchmark FA4 varlen scheduler modes") + p.add_argument( + "--total-tokens", + type=csv_ints, + default=[32 * 1024], + help="Total tokens (batch*seqlen) per workload, comma-separated. e.g. 32k,64k", + ) + p.add_argument( + "--shapes", + type=parse_shapes, + default=None, + help="Explicit (batch x seqlen) pairs, comma-separated, e.g. 32x1k,16x2k. " + "If unset, derived from --total-tokens by sweeping a default isoline.", + ) + p.add_argument( + "--patterns", + nargs="+", + default=["constant", "longtail", "bimodal", "uniform"], + help="Length distributions: constant, uniform, wide, longtail, bimodal, skew, skew_shuffled", + ) + p.add_argument( + "--sorts", + nargs="+", + default=["none"], + help="Batch ordering by tile count: none, asc, desc", + ) + p.add_argument( + "--decode-fracs", + nargs="+", + type=float, + default=[0.0], + help="Fraction(s) of batches to force seqlen_q=1 (mixed prefill/decode)", + ) + p.add_argument( + "--zero-fracs", + nargs="+", + type=float, + default=[0.0], + help="Fraction(s) of batches to force seqlen=0", + ) + p.add_argument( + "--num-splits", + nargs="+", + type=int, + default=[1], + help="num_splits values; >1 enables SplitKV", + ) + p.add_argument("--modes", nargs="+", default=[cli for cli, _ in MODES]) + p.add_argument("--headdim", type=int, default=128) + p.add_argument("--nheads", type=int, default=16) + p.add_argument("--nheads-kv", type=int, default=2) + p.add_argument( + "--pack-gqa", + action="store_true", + default=True, + help="Force pack_gqa=True (default). --no-pack-gqa to disable.", + ) + p.add_argument("--no-pack-gqa", dest="pack_gqa", action="store_false") + p.add_argument("--seeds", type=int, default=3) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--rep", type=int, default=20) + p.add_argument( + "--sleep", + type=float, + default=0.5, + help="Sleep between modes to dodge clock throttling (seconds)", + ) + p.add_argument("--device", type=int, default=0) + p.add_argument( + "--csv", action="store_true", help="Emit CSV rows instead of the pretty table" + ) + return p.parse_args() + + +def _default_isoline(total_tokens): + """(batch, seqlen) pairs where batch * seqlen == total_tokens, doubling seqlen from 256.""" + return [ + (total_tokens // s, s) + for s in (1 << b for b in range(8, total_tokens.bit_length())) + if total_tokens // s >= 1 + ] + + +def _format_row(cells, csv, widths): + if csv: + return ",".join(str(c) for c in cells) + return " ".join(f"{str(c):<{w}}" for c, w in zip(cells, widths)) + + +def main(): + args = parse_args() + torch.cuda.set_device(args.device) + torch.manual_seed(0) + + if args.shapes is not None: + shapes = args.shapes + else: + shapes = [s for t in args.total_tokens for s in _default_isoline(t)] + + selected_modes = [(cli, fn) for cli, fn in MODES if cli in args.modes] + if not _supports_clc(args.device): + dropped = [cli for cli, _ in selected_modes if cli in _CLC_MODES] + if dropped: + print(f"# skipping CLC modes: {', '.join(dropped)}") + selected_modes = [ + (cli, fn) for cli, fn in selected_modes if cli not in _CLC_MODES + ] + + print(f"# device {args.device}: {torch.cuda.get_device_name(args.device)}") + print( + f"# headdim={args.headdim} nheads={args.nheads} nheads_kv={args.nheads_kv} " + f"(qhead_per_kvhead={args.nheads // args.nheads_kv})" + ) + cols = [ + ("pattern", 14), + ("decode", 8), + ("zero", 6), + ("shape", 10), + ("splits", 8), + ("mode", 14), + ("mean_us", 10), + ("tok/us", 9), + ("tflops", 8), + ("rel_st", 7), + ("rel_clc", 9), + ] + widths = [w for _, w in cols] + print(_format_row([h for h, _ in cols], args.csv, widths)) + + for shape, pattern, sort, decode_frac, zero_frac, num_splits in product( + shapes, + args.patterns, + args.sorts, + args.decode_fracs, + args.zero_fracs, + args.num_splits, + ): + batch, seqlen = shape + results = {} + # Workload is identical across modes; build once to get total_q for the report. + ref_ctx = build_ctx( + args, + batch, + seqlen, + pattern, + sort, + decode_frac, + zero_frac, + num_splits, + seed=0, + ) + total_q = sum(ref_ctx["seqlens_q"]) + # Causal varlen attention FLOPs per batch: + # per (head, query q in [0, sq)): 4 * d * effective_k where + # effective_k = max(0, sk - sq + q + 1). + # sum_q effective_k = sq*sk - sq*(sq-1)/2 (for sk >= sq; otherwise clamped). + total_flops = 0 + for sq, sk in zip(ref_ctx["seqlens_q"], ref_ctx["seqlens_k"]): + if sq == 0 or sk == 0: + continue + if ref_ctx["causal"]: + # sum_{q=0}^{sq-1} max(0, sk - sq + q + 1) + shift = sk - sq + if shift >= 0: + eff = sq * sk - sq * (sq - 1) // 2 + else: + # clamp to non-negative for queries near 0 + first_visible_q = ( + -shift + ) # smallest q with sk - sq + q + 1 > 0 is q = sq - sk + visible = sq - first_visible_q + eff = visible * sk - visible * (visible - 1) // 2 + eff = max(0, eff) + else: + eff = sq * sk + total_flops += 4 * args.headdim * args.nheads * eff + + for cli, setup in selected_modes: + samples = [] + for s in range(args.seeds): + ctx = build_ctx( + args, + batch, + seqlen, + pattern, + sort, + decode_frac, + zero_frac, + num_splits, + seed=s, + ) + fn = setup(ctx) + if fn is None: + samples = None + break + fn() + torch.cuda.synchronize() + time.sleep(args.sleep) + samples.append(do_bench(fn, warmup=args.warmup, rep=args.rep)) + results[cli] = ( + None if samples is None else sum(samples) / len(samples) * 1e3 + ) + + single_tile_us = results.get("single-tile") + clc_us = results.get("clc") + for cli, _ in selected_modes: + us = results.get(cli) + if us is None: + continue + tok_per_us = (total_q / us) if us > 0 else 0.0 + tflops = (total_flops / (us * 1e6)) if us > 0 else 0.0 + rel_st = f"{single_tile_us / us:.3f}" if single_tile_us else "-" + rel_cl = f"{clc_us / us:.3f}" if clc_us else "-" + print( + _format_row( + [ + pattern, + f"{decode_frac:.2f}", + f"{zero_frac:.2f}", + f"{batch}x{seqlen}", + num_splits, + cli, + f"{us:.2f}", + f"{tok_per_us:.2f}", + f"{tflops:.2f}", + rel_st, + rel_cl, + ], + args.csv, + widths, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index 35bb4365ff6..9cadfd38651 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -19,6 +19,9 @@ class BlockInfo: window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + num_splits: Int32 = 1 + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_n_blocks_per_split: Optional[cutlass.Constexpr[Int32]] = None @cute.jit def get_n_block_min_max( @@ -26,6 +29,7 @@ def get_n_block_min_max( seqlen_info: SeqlenInfoQK, m_block: Int32, split_idx: Int32 = 0, + batch_idx: Int32 = 0, num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) @@ -45,11 +49,20 @@ def get_n_block_min_max( n_idx_left = n_idx - self.window_size_left n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) if cutlass.const_expr(self.is_split_kv): - num_n_blocks_per_split = ( - Int32(0) - if n_block_max <= n_block_min - else (n_block_max - n_block_min + num_splits - 1) // num_splits - ) + if const_expr(self.num_splits_dynamic_ptr is not None): + # Unpack num_splits from top 16 bits of split_idx (packed by scheduler) + num_splits = split_idx >> 16 + split_idx = split_idx & 0xFFFF + else: + num_splits = self.num_splits + if const_expr(self.num_n_blocks_per_split is not None): + num_n_blocks_per_split = self.num_n_blocks_per_split + else: + num_n_blocks_per_split = ( + Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) n_block_min = n_block_min + split_idx * num_n_blocks_per_split n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) return n_block_min, n_block_max diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py index 0eb0eddf976..5bd6d514022 100644 --- a/flash_attn/cute/flash_bwd.py +++ b/flash_attn/cute/flash_bwd.py @@ -389,6 +389,7 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -429,6 +430,7 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py index 913b43d377b..ccd4c143969 100644 --- a/flash_attn/cute/flash_bwd_postprocess.py +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -215,6 +215,7 @@ def __call__( scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -258,6 +259,7 @@ def __call__( tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py index 8142def5ebb..d3b9b5d2974 100644 --- a/flash_attn/cute/flash_bwd_preprocess.py +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -136,6 +136,7 @@ def __call__( mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,) mSeqUsedQ: Optional[cute.Tensor], # (batch,) mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -193,6 +194,7 @@ def __call__( tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 174ac0ed9eb..38d2e112e86 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -466,6 +466,7 @@ def __call__( aux_tensors: Optional[list] = None, # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -732,6 +733,7 @@ def __call__( qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, # persistent mode not tested + cu_total_m_blocks_ptr=mCuTotalMBlocks, lpt=self.spt, head_swizzle=self.deterministic, ) diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py index bb5798df2cc..5d6be85d0a7 100644 --- a/flash_attn/cute/flash_bwd_sm90.py +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -357,6 +357,7 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -536,6 +537,7 @@ def _qkv_transpose(t): is_persistent=False, lpt=self.spt, head_swizzle=self.deterministic, + cu_total_m_blocks_ptr=mCuTotalMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 143128b3afe..2d2aa087c58 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -637,6 +637,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors=None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -698,6 +700,8 @@ def __call__( qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py index 493620235ec..97164b8f533 100644 --- a/flash_attn/cute/flash_fwd_combine.py +++ b/flash_attn/cute/flash_fwd_combine.py @@ -197,7 +197,7 @@ def __call__( cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, - varlen_batch_idx: Optional[cute.Tensor] = None, + virtual_batch_idx: Optional[cute.Tensor] = None, semaphore_to_reset: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, @@ -301,7 +301,7 @@ class SharedStorage: cu_seqlens, seqused, num_splits_dynamic_ptr, - varlen_batch_idx, + virtual_batch_idx, semaphore_to_reset, SharedStorage, self.smem_layout_lse, @@ -330,7 +330,7 @@ def kernel( cu_seqlens: Optional[cute.Tensor], seqused: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], - varlen_batch_idx: Optional[cute.Tensor], + virtual_batch_idx: Optional[cute.Tensor], semaphore_to_reset: Optional[cute.Tensor], SharedStorage: cutlass.Constexpr, smem_layout_lse: cute.Layout | cute.ComposedLayout, @@ -349,8 +349,8 @@ def kernel( # Map virtual batch index to real batch index (for persistent tile schedulers) batch_idx = ( - varlen_batch_idx[maybe_virtual_batch] - if const_expr(varlen_batch_idx is not None) + virtual_batch_idx[maybe_virtual_batch] + if const_expr(virtual_batch_idx is not None) else maybe_virtual_batch ) @@ -394,8 +394,9 @@ def kernel( num_head = mO_partial.shape[3] max_idx = seqlen * num_head - # Early exit for single split if dynamic - if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + # TODO: early exit for single split if dynamic — for now always merge so the + # num_splits_dynamic == 1 case still writes mO from mO_partial[0]. + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 0) and ( const_expr(not varlen) or m_block * self.tile_m < max_idx ): # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) diff --git a/flash_attn/cute/flash_fwd_mla_sm100.py b/flash_attn/cute/flash_fwd_mla_sm100.py index 84c349c5e3a..f52688cf817 100644 --- a/flash_attn/cute/flash_fwd_mla_sm100.py +++ b/flash_attn/cute/flash_fwd_mla_sm100.py @@ -29,7 +29,7 @@ import flash_attn.cute.blackwell_helpers as fa_sm100_utils from flash_attn.cute.softmax import SoftmaxSm100 from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SchedulingMode, TileSchedulerArguments, TileSchedulerProtocol, @@ -993,7 +993,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): clc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps ) - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), cute.arch.block_idx(), diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index abddd200751..f74a823f8af 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -56,7 +56,7 @@ from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SchedulingMode, TileSchedulerArguments, TileSchedulerProtocol, @@ -64,6 +64,7 @@ StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, + DynamicPersistentVarlenScheduler, ) from flash_attn.cute.fa_logging import fa_log, fa_printf from flash_attn.cute.utils import smid @@ -126,7 +127,7 @@ def __init__( m_block_size: int = 128, n_block_size: int = 128, q_stage: cutlass.Constexpr[int] = 2, - is_persistent: bool = True, + is_static_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, @@ -134,6 +135,8 @@ def __init__( is_varlen_q: bool = False, use_2cta_instrs: bool = False, use_clc_scheduler: bool = False, + has_tile_count_semaphore: bool = False, + seqlen_k_per_split: Optional[int] = None, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -158,6 +161,10 @@ def __init__( self.split_P_arrive = int(self.split_P_arrive / 32) * 32 # multiple of 32 assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size + assert seqlen_k_per_split is None or seqlen_k_per_split % n_block_size == 0 + self.num_n_blocks_per_split = ( + seqlen_k_per_split // n_block_size if seqlen_k_per_split is not None else None + ) self.arch = BaseDSL._get_dsl().get_arch_enum() assert self.arch.is_family_of(Arch.sm_100f) or self.arch.is_family_of(Arch.sm_110f), \ "Only SM 10.x and 11.x are supported" @@ -172,7 +179,7 @@ def __init__( self.qk_acc_dtype = Float32 self.pv_acc_dtype = Float32 self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) - self.is_persistent = is_persistent + self.is_static_persistent = is_static_persistent self.is_causal = is_causal self.is_local = is_local self.is_varlen_q = is_varlen_q @@ -206,19 +213,19 @@ def __init__( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or (self.head_dim_v_padded >= 128 and self.is_split_kv) ) - if self.overlap_sO_sQ: - self.is_persistent = False assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = use_clc_scheduler # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV - and not self.overlap_sO_sQ ) + self.dynamic_persistent = (has_tile_count_semaphore and is_varlen_q) or use_clc_scheduler + self.is_persistent = self.dynamic_persistent or self.is_static_persistent self.sched_stages = 1 if self.use_clc_scheduler: assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" @@ -227,13 +234,25 @@ def __init__( f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" ) - self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + self.scheduling_mode = ( + SchedulingMode.CLC if self.use_clc_scheduler + else SchedulingMode.DYNAMIC if self.dynamic_persistent + else SchedulingMode.STATIC + ) + self.use_varlen_scheduler = False if is_varlen_q: - self.TileScheduler = SingleTileVarlenScheduler + if self.dynamic_persistent and not self.use_clc_scheduler: + self.use_varlen_scheduler = True + self.TileScheduler = DynamicPersistentVarlenScheduler + elif self.is_static_persistent and not self.use_clc_scheduler: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.use_varlen_scheduler = True + self.TileScheduler = SingleTileVarlenScheduler elif self.is_causal or self.is_local or self.use_clc_scheduler: self.TileScheduler = SingleTileLPTScheduler - elif self.is_persistent: + elif self.is_static_persistent: self.TileScheduler = StaticPersistentTileScheduler else: self.TileScheduler = SingleTileScheduler @@ -280,7 +299,7 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) - self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.scheduler_warp_id = self.empty_warp_ids[0] if self.dynamic_persistent else None self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ @@ -383,6 +402,13 @@ def __call__( descale_tensors: Optional[DescaleTensors] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + tile_count_semaphore: Optional[cute.Tensor] = None, + virtual_batch_idx_ptr: Optional[cute.Tensor] = None, + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, + max_seqlen_q: Int32 | int | None = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -636,10 +662,14 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + if const_expr(max_seqlen_q is None): + eff_seqlen_q = cute.size(mQ.shape[0]) + else: + eff_seqlen_q = max_seqlen_q if const_expr(not self.pack_gqa) else max_seqlen_q * self.qhead_per_kvhead TileScheduler = self.TileScheduler _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), + cute.ceil_div(eff_seqlen_q, _num_block_divisor), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) @@ -663,6 +693,12 @@ def __call__( is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, + num_splits_dynamic_ptr=num_splits_dynamic_ptr, + virtual_batch_idx_ptr=virtual_batch_idx_ptr, + num_nheads_in_l2_ptr=num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, + tile_count_semaphore=tile_count_semaphore.iterator if tile_count_semaphore is not None else None, ) tile_sched_params = TileScheduler.to_underlying_arguments( tile_sched_args, scheduling_mode=self.scheduling_mode @@ -676,8 +712,9 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) - clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 - clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + sched_response_size = self.sched_stages * 4 if self.dynamic_persistent else 0 + sched_mbar_size = self.sched_stages * 2 if self.dynamic_persistent else 0 + load_epi_mbar_size = 2 if const_expr(self.overlap_sO_sQ) else 0 @cute.struct class SharedStorage: @@ -690,6 +727,7 @@ class SharedStorage: mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 2] # mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 4 * 2] mbar_O_epi: cute.struct.MemRange[Int64, self.q_stage * 2] + mbar_load_epi: cute.struct.MemRange[Int64, load_epi_mbar_size] mbar_s0_s1_sequence: cute.struct.MemRange[Int64, 2 * 2] # Tmem dealloc cluster barrier tmem_dealloc_mbar_ptr: Int64 @@ -698,12 +736,13 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] - # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. - # This avoids adding bytes at the end when we're at the smem limit. - # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). - clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] - # CLC response storage (16 bytes per stage, stored as 4 Int32s). - clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Scheduler buffers placed here to utilize padding before sO's 1024-byte + # alignment. This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync / PipelineAsync both expect + # 2 * sched_stages mbarriers (full + empty). Response is 4 Int32 per stage + # (CLC HW response, or work_info written by dynamic persistent producer). + sched_mbar_ptr: cute.struct.MemRange[Int64, sched_mbar_size] + sched_response: cute.struct.MemRange[Int32, sched_response_size] # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes @@ -770,6 +809,10 @@ class SharedStorage: tiled_mma_pv, tile_sched_params, num_splits, + num_splits_dynamic_ptr, + tile_count_semaphore, + virtual_batch_idx_ptr, + num_nheads_in_l2_ptr, aux_tensors, fastdiv_mods, head_divmod, @@ -829,6 +872,10 @@ def kernel( tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, num_splits: Int32, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + tile_count_semaphore: Optional[cute.Tensor] = None, + virtual_batch_idx_ptr: Optional[cute.Tensor] = None, + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, @@ -891,6 +938,7 @@ def kernel( ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) tma_warp = ThreadCooperativeGroup(1) + load_warps = ThreadCooperativeGroup(len(self.load_warp_ids)) load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE) softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) @@ -902,6 +950,7 @@ def kernel( softmax_correction_threads = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) ) + epilogue_warps = ThreadCooperativeGroup(len(self.epilogue_warp_ids)) epilogue_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) # For UMMA-bridging pipelines: the non-MMA side spans both CTAs in the cluster, # so the thread count must include warps from both CTAs. @@ -1014,6 +1063,25 @@ def kernel( defer_sync=True, ) + pipeline_load_epi = None + if const_expr(self.overlap_sO_sQ and self.is_persistent): + # when overlapping sO and sQ with a persistent kernel, we need this + # additional pipeline to ensure sO from the previous work tile is + # free for use by sQ in the current one. + epi_warps_for_release = ( + ThreadCooperativeGroup(len(self.correction_warp_ids)) + if self.use_correction_warps_for_epi + else epilogue_warps + ) + pipeline_load_epi = pipeline_custom.PipelineAsync.create( + barrier_storage=storage.mbar_load_epi.data_ptr(), + num_stages=1, + producer_group=epi_warps_for_release, + consumer_group=load_warps, + defer_sync=True, + ) + + # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) @@ -1063,6 +1131,9 @@ def kernel( window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + num_splits=num_splits, + num_splits_dynamic_ptr=num_splits_dynamic_ptr, + num_n_blocks_per_split=self.num_n_blocks_per_split, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, @@ -1087,60 +1158,80 @@ def kernel( # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) - if const_expr(self.use_clc_scheduler): - clc_response_ptr = storage.clc_response.data_ptr() - clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() - - clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + sched_ctx = None + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + sched_response_ptr = storage.sched_response.data_ptr() + sched_mbar_ptr = storage.sched_mbar_ptr.data_ptr() + sched_producer_group = cutlass_pipeline.CooperativeGroup( cutlass_pipeline.Agent.Thread ) - num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + num_sched_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume - num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size - clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( - cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps - ) - - block_idx = cute.arch.block_idx() - clc = ClcState.create( - hw_scheduler=ClcDynamicPersistentTileScheduler.create( - self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), - block_idx, - cute.arch.grid_dim(), - clc_response_ptr, - ), - pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( - barrier_storage=clc_mbar_ptr, - num_stages=self.sched_stages, - producer_group=clc_pipeline_producer_group, - consumer_group=clc_pipeline_consumer_group, - tx_count=16, - cta_layout_vmnk=cta_layout_vmnk, - ), - consumer_state=cutlass_pipeline.make_pipeline_state( - cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages - ), - producer_state=cutlass_pipeline.make_pipeline_state( - cutlass_pipeline.PipelineUserType.Producer, self.sched_stages - ), + num_sched_consumer_warps = num_sched_consumer_warps_per_cta * self.cta_group_size + sched_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, + cute.arch.WARP_SIZE * num_sched_consumer_warps, ) - tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + if const_expr(self.use_clc_scheduler): + _block_idx = cute.arch.block_idx() + sched_ctx = SchedulerState.create_clc( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + _block_idx, + cute.arch.grid_dim(), + sched_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=sched_mbar_ptr, + num_stages=self.sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + else: + assert tile_count_semaphore is not None + sched_ctx = SchedulerState.create_dynamic_persistent( + work_info=storage.sched_response.get_tensor((4,)), + pipeline=cutlass_pipeline.PipelineAsync.create( + barrier_storage=sched_mbar_ptr, + num_stages=self.sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, ctx=sched_ctx) else: tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" # /////////////////////////////////////////////////////////////////////////////// - # EMPTY / CLC SCHEDULER WARP + # EMPTY / SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - if const_expr(self.use_clc_scheduler): - if warp_idx == self.clc_scheduler_warp_id: + if const_expr(self.use_clc_scheduler or self.dynamic_persistent): + if warp_idx == self.scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) - if is_leader_cta: - self.clc_scheduler_warp(tile_scheduler) + # CLC: only leader CTA produces. + if const_expr(self.dynamic_persistent) or is_leader_cta: + self.scheduler_warp(tile_scheduler) else: self.empty_warp(tile_scheduler) for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) self.empty_warp(tile_scheduler) else: @@ -1169,6 +1260,7 @@ def kernel( gmem_tiled_copy_Q, pipeline_q, pipeline_kv, + pipeline_load_epi, block_info, num_splits, SeqlenInfoCls, @@ -1223,6 +1315,7 @@ def kernel( gmem_tiled_copy_O, tma_atom_O, pipeline_o_epi, + pipeline_load_epi, block_info, num_splits, SeqlenInfoCls, @@ -1302,6 +1395,7 @@ def kernel( pipeline_sm_stats, sm_stats_barrier, pipeline_o_epi, + pipeline_load_epi, learnable_sink, descale_tensors, gmem_tiled_copy_O, @@ -1335,6 +1429,7 @@ def load( gmem_tiled_copy_Q: Optional[cute.TiledCopy], pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, @@ -1344,6 +1439,9 @@ def load( num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + load_epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) issue_kv_for_this_warp = ( const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0] @@ -1469,9 +1567,15 @@ def load( if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max( - seqlen, m_block, split_idx, num_splits + seqlen, + m_block, + split_idx=split_idx, + batch_idx=batch_idx, + num_splits=num_splits, ) - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF + if self.process_work_tile(seqlen, n_block_min, n_block_max): n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( mPageTable[batch_idx, n_block_first] @@ -1531,7 +1635,11 @@ def load( work_tile = tile_scheduler.advance_to_next_work() - # End of persistent scheduler loop + if const_expr(pipeline_load_epi is not None): + pipeline_load_epi.consumer_wait(load_epi_consumer_state) + with cute.arch.elect_one(): + pipeline_load_epi.consumer_release(load_epi_consumer_state) + load_epi_consumer_state.advance() if issue_kv_for_this_warp: pipeline_kv.producer_tail(kv_producer_state) @@ -1655,7 +1763,6 @@ def mma( while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - block_iter_count = Int32(0) process_tile = False @@ -1673,12 +1780,15 @@ def mma( ) process_tile = block_iter_count > Int32(0) else: - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + m_block, + split_idx=split_idx, + batch_idx=batch_idx, + num_splits=num_splits, + ) block_iter_count = n_block_max - n_block_min - if const_expr(not self.is_split_kv): - process_tile = True - else: - process_tile = n_block_min < n_block_max + process_tile = self.process_work_tile(seqlen, n_block_min, n_block_max) if process_tile and is_leader_cta: for stage in cutlass.range_constexpr(self.q_stage): @@ -1952,7 +2062,11 @@ def softmax_loop( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx kv_head_idx = self._kv_head_idx(head_idx) seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF mask = AttentionMaskCls(seqlen) shared_mask_kwargs = dict( @@ -2044,7 +2158,7 @@ def softmax_loop( has_work = tile_block_count > Int32(0) else: tile_block_count = n_block_max - n_block_min - has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + has_work = self.process_work_tile(seqlen, n_block_min, n_block_max) softmax_step = partial( self.softmax_step, @@ -2125,7 +2239,7 @@ def softmax_loop( sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) else: - if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + if has_work: mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, @@ -2375,6 +2489,7 @@ def correction_loop( pipeline_sm_stats: pipeline.PipelineAsync, sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], learnable_sink: Optional[cute.Tensor], descale_tensors: Optional[DescaleTensors], gmem_tiled_copy_O: cute.TiledCopy, @@ -2413,6 +2528,9 @@ def correction_loop( sm_stats_consumer_phase = Int32(0) o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) + load_epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: @@ -2429,7 +2547,11 @@ def correction_loop( Float32(256.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(1.0) ) seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] @@ -2462,7 +2584,7 @@ def correction_loop( has_work = total_block_count > Int32(0) else: total_block_count = n_block_max - n_block_min - has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + has_work = self.process_work_tile(seqlen, n_block_min, n_block_max) if has_work: # Ignore first signal from softmax as no correction is required @@ -2660,6 +2782,12 @@ def correction_loop( ) cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse + if const_expr(pipeline_load_epi is not None and self.use_correction_warps_for_epi): + pipeline_load_epi.producer_acquire(load_epi_producer_state) + with cute.arch.elect_one(): + pipeline_load_epi.producer_commit(load_epi_producer_state) + load_epi_producer_state.advance() + # Advance to next tile work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop @@ -2865,6 +2993,7 @@ def epilogue_s2g( gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], pipeline_o_epi: pipeline.PipelineAsync, + pipeline_load_epi: Optional[pipeline.PipelineAsync], block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, @@ -2873,14 +3002,20 @@ def epilogue_s2g( tile_scheduler=None, ): epi_consumer_phase = Int32(0) + load_epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - has_work = const_expr(self.use_block_sparsity or not self.is_split_kv) or n_block_min < n_block_max + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx=split_idx, batch_idx=batch_idx, num_splits=num_splits, + ) + if const_expr(self.is_split_kv and block_info.num_splits_dynamic_ptr is not None): + split_idx = split_idx & 0xFFFF - if has_work: + if self.process_work_tile(seqlen, n_block_min, n_block_max): if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: @@ -2928,11 +3063,17 @@ def epilogue_s2g( epi_consumer_phase ^= 1 + if const_expr(pipeline_load_epi is not None): + pipeline_load_epi.producer_acquire(load_epi_producer_state) + with cute.arch.elect_one(): + pipeline_load_epi.producer_commit(load_epi_producer_state) + load_epi_producer_state.advance() + # Advance to next tile work_tile = tile_scheduler.advance_to_next_work() @cute.jit - def clc_scheduler_warp( + def scheduler_warp( self, tile_scheduler: TileSchedulerProtocol, ): @@ -2940,18 +3081,20 @@ def clc_scheduler_warp( while work_tile.is_valid_tile: tile_scheduler.prefetch_next_work() work_tile = tile_scheduler.advance_to_next_work() - if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: - fa_printf( - 3, - "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", - smid(), - cute.arch.block_idx()[0], - work_tile.tile_idx[0], - work_tile.tile_idx[1], - work_tile.tile_idx[2], - work_tile.tile_idx[3], - work_tile.is_valid_tile, - ) + if const_expr(self.dynamic_persistent): + if cute.arch.thread_idx()[0] == self.scheduler_warp_id * cute.arch.WARP_SIZE: + prefix_str = "[CLC] query " if const_expr(self.use_clc_scheduler) else "[DYNAMIC] info " + fa_printf( + 3, + prefix_str + "sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) tile_scheduler.producer_tail() @cute.jit @@ -3145,3 +3288,18 @@ def apply_score_mod( constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, ) + + @cute.jit + def process_work_tile( + self, + seqlen_info: SeqlenInfoQK, + n_block_min: Int32, + n_block_max: Int32, + ): + is_varlen_q = seqlen_info.has_cu_seqlens_q or seqlen_info.has_seqused_q + process_work_tile_k = const_expr(not self.is_split_kv) or n_block_min < n_block_max + if const_expr(is_varlen_q and not self.use_varlen_scheduler): + process_work_tile_q = seqlen_info.seqlen_q > 0 + else: + process_work_tile_q = True + return process_work_tile_k and process_work_tile_q diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 93bccfa715b..171e76ca2ca 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -172,6 +172,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + mCuTotalMBlocks: Optional[cute.Tensor] = None, + mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): @@ -312,6 +314,7 @@ def __call__( (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + # TODO: dispatch to DynamicPersistentVarlenScheduler when appropriate TileScheduler = SingleTileVarlenScheduler else: TileScheduler = ( @@ -341,6 +344,8 @@ def __call__( element_size=self.dtype.width // 8, is_persistent=False, lpt=self.is_causal or self.is_local, + cu_total_m_blocks_ptr=mCuTotalMBlocks, + cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b88bc50543c..ec976b12223 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -44,6 +44,7 @@ from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 +from flash_attn.cute.prepare_scheduler import FlashPrepareScheduler, SchedulerMetadataTensorsTorch # SM100 head_dim=256 2CTA kernel imports from flash_attn.cute.sm100_hd256_2cta_fmha_forward import BlackwellFusedMultiHeadAttentionForward @@ -58,6 +59,8 @@ get_block_sparse_broadcast_pattern, ) +BIN_BATCH_SEARCH_THRESH = 512 # SingleTileVarlenScheduler uses binary search to find batch above this + def _parse_arch_str(arch_str): """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" import re @@ -113,6 +116,8 @@ class FwdConfig: n_block_size: int mma_pv_is_rs: bool intra_wg_overlap: bool + q_stage: int = 1 + num_splits: int = 1 def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): @@ -267,6 +272,100 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): return min(num_SMs // total_mblocks, max_splits, num_n_blocks) +def _get_fwd_config( + *, + arch: int, + head_dim: int, + head_dim_v: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_head_kv: int, + qhead_per_kvhead: int, + pack_gqa: bool, + batch_size: int, + causal: bool, + local: bool, + window_size_left: Optional[int], + window_size_right: Optional[int], + num_splits: int, + device, + seqlen_q: Optional[int] = None, + tile_mn: Optional[Tuple[int, int]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + mma_pv_is_rs: Optional[bool] = None, + intra_wg_overlap: Optional[bool] = None, +) -> FwdConfig: + if seqlen_q is None: + seqlen_q = max_seqlen_q + + # Base tile sizes and flags: explicit override, else per-arch heuristic. + cfg = FwdConfig(128, 128, True, True) + if tile_mn is None: + if arch // 10 == 12: + # SM120 tile sizes tuned for 99 KB SMEM capacity: + # D<=64: 128x128 → 48 KB (good occupancy) + # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) + if head_dim > 64: + cfg = FwdConfig(128, 64, True, True) + elif arch // 10 == 8: + cfg = FwdConfig(128, 64, True, True) # SM80, should tune + elif arch // 10 == 9: + sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) + cfg = _tile_size_fwd_sm90( + head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q + ) + else: + cfg = FwdConfig(tile_mn[0], tile_mn[1], cfg.mma_pv_is_rs, cfg.intra_wg_overlap) + + tile_m, tile_n = cfg.m_block_size, cfg.n_block_size + if mma_pv_is_rs is None: + mma_pv_is_rs = cfg.mma_pv_is_rs + if intra_wg_overlap is None: + intra_wg_overlap = cfg.intra_wg_overlap + + seqlen_q_packgqa = max_seqlen_q * (qhead_per_kvhead if pack_gqa else 1) + if arch // 10 in [10, 11]: + q_stage = 2 if seqlen_q_packgqa > tile_m else 1 + else: + q_stage = 1 + + m_block_size_effective = q_stage * tile_m + seqlen_k_loaded = ( + max_seqlen_k + if not local + else max( + 0, + min( + max_seqlen_k, + (window_size_right or max_seqlen_k) + + (window_size_left or max_seqlen_k) + + 1 + + tile_m, + ), + ) + ) + num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_SMs = ( + 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count + ) + if num_splits < 1: + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + + # SplitKV uses float32 partial output, which doubles the O buffer size + # in shared memory, causing OOM for diff-headdim (192, 128) + if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: + if num_n_blocks >= 64 and head_dim_v != 512: + tile_n = 64 + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + else: + num_splits = 1 + + return FwdConfig(tile_m, tile_n, mma_pv_is_rs, intra_wg_overlap, q_stage, num_splits) + + def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): """Resolve causal/local/window settings into canonical form. @@ -289,6 +388,44 @@ def _resolve_causal_local_window(causal, window_size_left, window_size_right, ma local = False return causal, local, window_size_left, window_size_right + +def _compute_tile_cumsum( + *, + num_m_blocks: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, + tile_size: int = 1, + q_stage: int = 1, + cluster_shape_m: int = 1, + qhead_per_kvhead: int = 1, + pack_gqa: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """(cu_total_m_blocks, cu_total_splits_m_blocks), int32, (num_batch + 1,). + + cu_total_splits_m_blocks is None when num_splits_dynamic is None. + """ + if num_m_blocks is None: + seqlens = seqused if seqused is not None else (cu_seqlens[1:] - cu_seqlens[:-1]) + if pack_gqa and qhead_per_kvhead > 1: + seqlens = seqlens * qhead_per_kvhead + num_m_blocks = (seqlens + tile_size - 1) // tile_size + num_m_blocks_eff = (num_m_blocks + q_stage - 1) // q_stage + num_m_blocks_eff = (num_m_blocks_eff + cluster_shape_m - 1) // cluster_shape_m + order = virtual_batch_idx.long() if virtual_batch_idx is not None else None + if order is not None: + num_m_blocks_eff = num_m_blocks_eff[order] + if num_splits_dynamic is None: + cum = torch.cumsum(num_m_blocks_eff, dim=0, dtype=torch.int32) + return torch.nn.functional.pad(cum, (1, 0)), None + splits = num_splits_dynamic[order] if order is not None else num_splits_dynamic + stacked = torch.stack([num_m_blocks_eff, num_m_blocks_eff * splits], dim=0) + cum = torch.cumsum(stacked, dim=1, dtype=torch.int32) + padded = torch.nn.functional.pad(cum, (1, 0)) + return padded[0], padded[1] + + def _flash_attn_fwd( q: Optional[torch.Tensor], k: Optional[torch.Tensor], @@ -326,6 +463,9 @@ def _flash_attn_fwd( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, gather_kv_indices: Optional[torch.Tensor] = None, + scheduler_metadata: Optional[SchedulerMetadataTensorsTorch] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -515,33 +655,10 @@ def _flash_attn_fwd( if arch // 10 in [8, 12]: num_threads = 128 - fwd_cfg = FwdConfig(128, 128, True, True) # default - if tile_mn is None: - if arch // 10 == 12: - # SM120 tile sizes tuned for 99 KB SMEM capacity: - # D<=64: 128x128 → 48 KB (good occupancy) - # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) - if head_dim <= 64: - fwd_cfg = FwdConfig(128, 128, True, True) - else: - fwd_cfg = FwdConfig(128, 64, True, True) - elif arch // 10 == 8: - fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune - elif arch // 10 == 9: - sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) - fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) - else: - fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) - tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size - if mma_pv_is_rs is None: - mma_pv_is_rs = fwd_cfg.mma_pv_is_rs - if intra_wg_overlap is None: - intra_wg_overlap = fwd_cfg.intra_wg_overlap - # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False - + if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: pack_gqa = False @@ -550,31 +667,38 @@ def _flash_attn_fwd( if max_seqlen_k is None: max_seqlen_k = seqlen_k if cu_seqlens_k is None and seqused_k is None: - min_seqlen_k = seqlen_k - seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead - if arch // 10 in [10, 11]: - q_stage = 2 if seqlen_q_packgqa > tile_m else 1 - else: - q_stage = 1 + min_seqlen_k = seqlen_k - m_block_size_effective = q_stage * tile_m - seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) - num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective - total_mblocks = batch_size * num_head_kv * num_m_blocks - num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n - num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count - if num_splits < 1: - num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + fwd_cfg = _get_fwd_config( + arch=arch, + head_dim=head_dim, + head_dim_v=head_dim_v, + causal=causal, + local=local, + window_size_left=window_size_left, + window_size_right=window_size_right, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + batch_size=batch_size, + num_head_kv=num_head_kv, + num_splits=num_splits, + device=device, + seqlen_q=seqlen_q, + tile_mn=tile_mn, + block_sparse_tensors=block_sparse_tensors, + mma_pv_is_rs=mma_pv_is_rs, + intra_wg_overlap=intra_wg_overlap, + ) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + q_stage = fwd_cfg.q_stage + num_splits = fwd_cfg.num_splits + mma_pv_is_rs = fwd_cfg.mma_pv_is_rs + intra_wg_overlap = fwd_cfg.intra_wg_overlap - # SplitKV uses float32 partial output, which doubles the O buffer size - # in shared memory, causing OOM for diff-headdim (192, 128) - if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: - if num_n_blocks >= 64 and head_dim_v != 512: - tile_n = 64 - num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n - num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) - else: - num_splits = 1 + seqlen_q_packgqa = max_seqlen_q * (qhead_per_kvhead if pack_gqa else 1) + max_m_blocks_leq_one = seqlen_q_packgqa <= q_stage * tile_m is_split_kv = num_splits > 1 if is_split_kv: @@ -704,6 +828,110 @@ def _flash_attn_fwd( disable_sparse_kv_bitmask = None p = row_max = None + + reuse_scheduler_metadata = scheduler_metadata is not None + is_varlen_q = cu_seqlens_q is not None or seqused_q is not None + if use_dedicated_hd256_kernel: + # The hd=256 2CTA fwd kernel does not support the dynamic-persistent scheduler. + scheduler_metadata = None + reuse_scheduler_metadata = False + if ( + is_split_kv + and is_varlen_q + and scheduler_metadata is None + and not disable_scheduler_metadata + and not use_dedicated_hd256_kernel + ): + scheduler_metadata = _get_scheduler_metadata( + num_batch=batch_size, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + nheads=num_head, + nheads_kv=num_head_kv, + headdim=head_dim, + num_splits=num_splits, + tile_m=tile_m, + tile_n=tile_n, + headdim_v=head_dim_v, + pack_gqa=pack_gqa, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + seqlen_k_per_split=seqlen_k_per_split, + q_stage=q_stage, + cluster_shape_m=2 if use_2cta_instrs else 1, + ) + + has_scheduler_metadata = scheduler_metadata is not None and not disable_scheduler_metadata + if has_scheduler_metadata: + num_m_blocks = scheduler_metadata.num_m_blocks_ptr + num_splits_dynamic = scheduler_metadata.num_splits_dynamic_ptr + virtual_batch_idx = scheduler_metadata.virtual_batch_idx_ptr + num_nheads_in_l2 = scheduler_metadata.num_nheads_in_l2_ptr + tile_count_semaphore = scheduler_metadata.tile_count_semaphore + assert all( + t is None or t.is_cuda + for t in scheduler_metadata + ), "scheduler metadata must be on CUDA device" + assert all( + t is None or t.shape == (batch_size,) + for t in ( + num_m_blocks, + num_splits_dynamic, + virtual_batch_idx, + num_nheads_in_l2, + ) + ), "these scheduler metadata tensors must have shape (batch_size,)" + if tile_count_semaphore is not None: + assert tile_count_semaphore.shape == (1,), "semaphore must have size 1" + else: + num_m_blocks = None + num_splits_dynamic = None + virtual_batch_idx = None + num_nheads_in_l2 = None + tile_count_semaphore = None + + # use binary batch search in SingleTileVarlenScheduler to avoid + # O(N^2) lookup; observed to be faster only for batch_size > BIN_BATCH_SEARCH_THRESH; this is tunable + cu_total_m_blocks = None + cu_total_splits_m_blocks = None + use_single_tile_varlen_scheduler = use_clc_scheduler or tile_count_semaphore is None + use_cu_hint = ( + is_varlen + and use_single_tile_varlen_scheduler + and batch_size > BIN_BATCH_SEARCH_THRESH + and not use_dedicated_hd256_kernel + ) + if ( + use_cu_hint + and has_scheduler_metadata + and scheduler_metadata.cu_total_m_blocks is not None + ): + cu_total_m_blocks = scheduler_metadata.cu_total_m_blocks + cu_total_splits_m_blocks = scheduler_metadata.cu_total_splits_m_blocks + elif use_cu_hint: + cu_total_m_blocks, cu_total_splits_m_blocks = _compute_tile_cumsum( + num_m_blocks=num_m_blocks, + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + num_splits_dynamic=num_splits_dynamic, + virtual_batch_idx=virtual_batch_idx, + tile_size=tile_m, + q_stage=q_stage, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + ) + + is_static_persistent = ( + not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv + ) or (max_m_blocks_leq_one and not is_split_kv) + compile_key = ( dtype, head_dim, @@ -742,6 +970,13 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, use_clc_scheduler, + num_splits_dynamic is not None, + virtual_batch_idx is not None, + num_nheads_in_l2 is not None, + tile_count_semaphore is not None, + cu_total_m_blocks is not None, + cu_total_splits_m_blocks is not None, + is_static_persistent, q is not None, qv is not None, p is not None, @@ -816,6 +1051,25 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + ( + num_splits_dynamic_tensor, + tile_count_semaphore_tensor, + virtual_batch_idx_tensor, + num_nheads_in_l2_tensor, + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + for t in ( + num_splits_dynamic, + tile_count_semaphore, + virtual_batch_idx, + num_nheads_in_l2, + cu_total_m_blocks, + cu_total_splits_m_blocks, + ) + ] + qv_tensor = to_cute_tensor(qv) if qv is not None else None gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None p_tensor = to_cute_tensor(p) if p is not None else None @@ -917,9 +1171,7 @@ def _flash_attn_fwd( else FlashAttentionForwardSm100 ) - fa_fwd = flash_fwd_obj_cls( - head_dim, - head_dim_v, + fa_fwd_kwargs = dict( qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, @@ -928,11 +1180,7 @@ def _flash_attn_fwd( m_block_size=tile_m, n_block_size=tile_n, q_stage=q_stage, - is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, + is_static_persistent=is_static_persistent, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, @@ -942,6 +1190,9 @@ def _flash_attn_fwd( use_2cta_instrs=use_2cta_instrs, use_clc_scheduler=use_clc_scheduler, ) + if not use_dedicated_hd256_kernel: + fa_fwd_kwargs["has_tile_count_semaphore"] = tile_count_semaphore is not None + fa_fwd = flash_fwd_obj_cls(head_dim, head_dim_v, **fa_fwd_kwargs) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" @@ -1016,10 +1267,23 @@ def _flash_attn_fwd( sparse_tensors, cute_aux_tensors, ]) + if arch // 10 in [10, 11] and not use_dedicated_hd256_kernel: + compile_args.extend([ + num_splits_dynamic_tensor, + tile_count_semaphore_tensor, + virtual_batch_idx_tensor, + num_nheads_in_l2_tensor, + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, + max_seqlen_q, + ]) + elif arch // 10 in [8, 9, 12]: + compile_args.extend([ + cu_total_m_blocks_tensor, + cu_total_splits_m_blocks_tensor, + ]) compile_args.append(current_stream) - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - *compile_args, options="--enable-tvm-ffi" - ) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") if not is_fake_mode(): q_call, k_call, v_call, qv_call = [ @@ -1091,6 +1355,21 @@ def _flash_attn_fwd( else None, aux_tensors, ]) + if arch // 10 in [10, 11] and not use_dedicated_hd256_kernel: + call_args.extend([ + num_splits_dynamic, + tile_count_semaphore, + virtual_batch_idx, + num_nheads_in_l2, + cu_total_m_blocks, + cu_total_splits_m_blocks, + max_seqlen_q, + ]) + elif arch // 10 in [8, 9, 12]: + call_args.extend([ + cu_total_m_blocks, + cu_total_splits_m_blocks, + ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( @@ -1100,7 +1379,12 @@ def _flash_attn_fwd( lse.transpose(-1, -2) if lse is not None else None, cu_seqlens_q, seqused_q, + num_splits_dynamic_ptr=num_splits_dynamic if has_scheduler_metadata else None, + virtual_batch_idx=virtual_batch_idx if has_scheduler_metadata else None, ) + if reuse_scheduler_metadata and tile_count_semaphore is not None: + # combine kernel does this for us + tile_count_semaphore.zero_() return out, lse @@ -1153,7 +1437,7 @@ def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): def _compile_bwd_preprocess( dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, has_dq_accum, - use_padded_offsets, + use_padded_offsets, has_cu_total_m_blocks, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -1165,11 +1449,13 @@ def _compile_bwd_preprocess( mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None mdQaccum = mdQaccum if has_dq_accum else None + mCuTotalMBlocks = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_total_m_blocks else None fa_bwd_pre = FlashAttentionBackwardPreprocess( dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_padded_offsets ) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, + mCuTotalMBlocks, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -1180,18 +1466,27 @@ def _bwd_preprocess( cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=True, + cu_total_m_blocks=None, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None + batch_size = (cu_seqlens_q.shape[0] - 1) if is_varlen else 0 + if cu_total_m_blocks is None and batch_size > BIN_BATCH_SEARCH_THRESH: + cu_total_m_blocks, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + tile_size=m_block_size, + ) compile_key = ( dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, dq_accum is not None, - use_padded_offsets, + use_padded_offsets, cu_total_m_blocks is not None, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) if not is_fake_mode(): _bwd_preprocess.compile_cache[compile_key]( - out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse + out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, + cu_total_m_blocks, ) @@ -1201,7 +1496,7 @@ def _bwd_preprocess( def _compile_bwd_postprocess( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, has_cuseqlens_q, has_seqused_q, - use_2cta_instrs, cluster_size, arch, + use_2cta_instrs, cluster_size, arch, has_cu_total_m_blocks, ): """Compile bwd postprocess kernel using cute fake tensors.""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( @@ -1211,6 +1506,7 @@ def _compile_bwd_postprocess( batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + mCuTotalMBlocks = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_total_m_blocks else None fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=use_2cta_instrs, @@ -1218,6 +1514,7 @@ def _compile_bwd_postprocess( ) return cute.compile( fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, + mCuTotalMBlocks, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -1229,18 +1526,28 @@ def _bwd_postprocess_convert( arch, dtype, hdim, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=False, cluster_size=1, + cu_total_m_blocks=None, ): """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" + is_varlen = cu_seqlens is not None + batch_size = (cu_seqlens.shape[0] - 1) if is_varlen else 0 + if cu_total_m_blocks is None and is_varlen and batch_size > BIN_BATCH_SEARCH_THRESH: + cu_total_m_blocks, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens, + seqused=seqused, + tile_size=block_size, + ) compile_key = ( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, cu_seqlens is not None, seqused is not None, - use_2cta_instrs, cluster_size, arch, + use_2cta_instrs, cluster_size, arch, cu_total_m_blocks is not None, ) if compile_key not in _bwd_postprocess_convert.compile_cache: _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) if not is_fake_mode(): _bwd_postprocess_convert.compile_cache[compile_key]( accum, output, scale, cu_seqlens, seqused, + cu_total_m_blocks, ) @@ -1350,12 +1657,6 @@ def _flash_attn_bwd( dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 use_2cta_instrs = False - is_varlen = ( - cu_seqlens_q is not None - or cu_seqlens_k is not None - or seqused_q is not None - or seqused_k is not None - ) else: m_block_size = 128 n_block_size = 128 @@ -1376,6 +1677,12 @@ def _flash_attn_bwd( use_dedicated_hd256_kernel = arch // 10 in [10, 11] and head_dim == 256 and head_dim_v == 256 use_2cta_instrs = use_2cta_instrs or use_dedicated_hd256_kernel + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) @@ -1579,6 +1886,22 @@ def _flash_attn_bwd( dK_semaphore = None dV_semaphore = None + # SingleTileVarlenScheduler uses binary search to find batch idx with > 512 batch size + # shared across preprocess, main bwd, and the three postprocess calls. + cu_total_m_blocks_q = None + cu_total_m_blocks_k = None + if is_varlen and batch_size > BIN_BATCH_SEARCH_THRESH and not use_dedicated_hd256_kernel: + cu_total_m_blocks_q, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + tile_size=m_block_size, + ) + cu_total_m_blocks_k, _ = _compute_tile_cumsum( + cu_seqlens=cu_seqlens_k, + seqused=seqused_k, + tile_size=n_block_size, + ) + # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. # For hd=256 dedicated path, dq_accum is None so preprocess only fills dpsum/lse_log2. _bwd_preprocess( @@ -1586,6 +1909,7 @@ def _flash_attn_bwd( cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, use_padded_offsets=use_dedicated_hd256_kernel, + cu_total_m_blocks=cu_total_m_blocks_q, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). @@ -1683,6 +2007,7 @@ def _flash_attn_bwd( # Prevent TVM stride poisoning when only one block is present. (seqlen_q_rounded // m_block_size == 1), (seqlen_k_rounded // n_block_size == 1), + cu_total_m_blocks_k is not None, ) else: compile_key = ( @@ -1719,6 +2044,7 @@ def _flash_attn_bwd( # Prevent TVM stride poisoning when only one block is present. (seqlen_q_rounded // m_block_size == 1), (seqlen_k_rounded // n_block_size == 1), + cu_total_m_blocks_k is not None, ) if compile_key not in _flash_attn_bwd.compile_cache: @@ -1731,9 +2057,9 @@ def _flash_attn_bwd( dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) ] - cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, cu_total_m_blocks_k_tensor = [ to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, cu_total_m_blocks_k) ] dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) @@ -1875,6 +2201,7 @@ def _flash_attn_bwd( dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, + cu_total_m_blocks_k_tensor, current_stream, options="--enable-tvm-ffi", ) @@ -1913,6 +2240,7 @@ def _flash_attn_bwd( ) if normalized_block_sparse_tensors is not None else None, + cu_total_m_blocks_k, ) # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 # hd=256 2CTA backward has its own internal postprocess, skip here. @@ -1931,6 +2259,7 @@ def _flash_attn_bwd( arch, dtype, head_dim, m_block_size, num_threads_post_dQ, AtomLayoutMdQ, dQ_swapAB, use_2cta_instrs=use_2cta_instrs, cluster_size=1, + cu_total_m_blocks=cu_total_m_blocks_q, ) if dKV_postprocess: @@ -1941,6 +2270,7 @@ def _flash_attn_bwd( arch, dtype, head_dim, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, + cu_total_m_blocks=cu_total_m_blocks_k, ) # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 _bwd_postprocess_convert( @@ -2085,6 +2415,9 @@ def forward( block_sparse_tensors: Optional[list] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, + scheduler_metadata: Optional["SchedulerMetadataTensorsTorch"] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, ): shared_kv = k is v if shared_kv and v.shape[-1] == 512: @@ -2120,6 +2453,9 @@ def forward( aux_tensors=aux_tensors, return_lse=return_lse, gather_kv_indices=gather_kv_indices, + scheduler_metadata=scheduler_metadata, + seqlen_k_per_split=seqlen_k_per_split, + disable_scheduler_metadata=disable_scheduler_metadata, ) ctx.save_for_backward( q, @@ -2256,6 +2592,9 @@ def flash_attn_varlen_func( block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, + scheduler_metadata: Optional[SchedulerMetadataTensorsTorch] = None, + seqlen_k_per_split: Optional[int] = None, + disable_scheduler_metadata: bool = False, ): """ Tensor arguments: @@ -2287,6 +2626,15 @@ def flash_attn_varlen_func( min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. Used with gather_kv_indices to determine if we need oob masking. + + scheduler_metadata: optional tensors used by certain tile schedulers, for optimization + and functionality. computed in get_scheduler_metadata. + + seqlen_k_per_split: when using dynamic (per-batch) num_splits, can set a fixed seqlen_k to be + covered per split for bitwise reproducibility. + + disable_scheduler_metadata: if True, ignores scheduler_metadata if it is passed and skips + computing metadata fresh. """ return FlashAttnVarlenFunc.apply( q, @@ -2316,12 +2664,16 @@ def flash_attn_varlen_func( block_sparse_tensors, aux_tensors, return_lse, + scheduler_metadata, + seqlen_k_per_split, + disable_scheduler_metadata, ) def _compile_fwd_combine( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, - has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, + has_cu_seqlens, has_seqused, has_lse, has_virtual_batch_idx, + has_num_splits_dynamic, has_semaphore_to_reset, ): """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" sym = cute.sym_int @@ -2363,14 +2715,14 @@ def _compile_fwd_combine( batchp1 = sym() mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None - mNumSplitsDynamic = None # Not parametrized in compile_key - mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None - mSemaphore = None # Not parametrized in compile_key + mNumSplitsDynamic = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_num_splits_dynamic else None + mVirtualBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_virtual_batch_idx else None + mSemaphore = fake_tensor(Int32, (1,), divisibility=1) if has_semaphore_to_reset else None return cute.compile( fa_combine, mO_partial, mLSE_partial, mO, mLSE, - mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, + mCuSeqlens, mSeqused, mNumSplitsDynamic, mVirtualBatchIdx, mSemaphore, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) @@ -2384,7 +2736,7 @@ def _flash_attn_fwd_combine( cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, - varlen_batch_idx: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. @@ -2453,7 +2805,9 @@ def _flash_attn_fwd_combine( cu_seqlens is not None, seqused is not None, lse is not None, - varlen_batch_idx is not None, + virtual_batch_idx is not None, + num_splits_dynamic_ptr is not None, + semaphore_to_reset is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( @@ -2462,7 +2816,7 @@ def _flash_attn_fwd_combine( if not is_fake_mode(): _flash_attn_fwd_combine.compile_cache[compile_key]( out_partial, lse_partial, out, lse, - cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, + cu_seqlens, seqused, num_splits_dynamic_ptr, virtual_batch_idx, semaphore_to_reset, ) @@ -2477,7 +2831,7 @@ def flash_attn_combine( out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, - varlen_batch_idx: Optional[torch.Tensor] = None, + virtual_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. @@ -2497,7 +2851,7 @@ def flash_attn_combine( out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch - varlen_batch_idx: Optional mapping from virtual batch index to real batch index + virtual_batch_idx: Optional mapping from virtual batch index to real batch index (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. @@ -2554,6 +2908,345 @@ def flash_attn_combine( lse, cu_seqlens, seqused, - varlen_batch_idx=varlen_batch_idx, + virtual_batch_idx=virtual_batch_idx, ) return out, lse + + +def _get_scheduler_metadata( + num_batch: int, + max_seqlen_q: int, + max_seqlen_k: int, + nheads: int, + nheads_kv: int, + headdim: int, + num_splits: int, + tile_m: int, + tile_n: int, + headdim_v: Optional[int] = None, + pack_gqa: Optional[bool] = False, + q_stage: int = 1, + cluster_shape_m: int = 1, + causal: bool = False, + enable_pdl: bool = False, + sort: bool = False, + seqlen_k_new: int = 0, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqlen_k_per_split: Optional[int] = None, + zfill_padded_output: bool = True, +) -> SchedulerMetadataTensorsTorch: + device = None + for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: + if t is not None: + device = t.device + break + if device is None: + raise ValueError( + "At least one of cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be provided on device" + ) + if headdim_v is None: + headdim_v = headdim + + # Override enable_pdl (not supported yet) + enable_pdl = False + + # Override sort (not supported yet) + sort = False + + if seqlen_k_per_split is not None: + assert seqlen_k_per_split % tile_n == 0, "seqlen per split must be divisible by tile_n" + n_blocks_per_split = seqlen_k_per_split // tile_n + else: + n_blocks_per_split = None + + is_split_kv = num_splits > 1 + needs_prepare_kernel = is_split_kv or causal or sort + + if needs_prepare_kernel: + num_m_blocks = torch.empty(num_batch, dtype=torch.int32, device=device) + num_splits_dynamic = torch.empty(num_batch, dtype=torch.int32, device=device) + virtual_batch_idx = ( + torch.empty(num_batch, dtype=torch.int32, device=device) if sort else None + ) + num_nheads_in_l2 = ( + torch.empty(num_batch, dtype=torch.int32, device=device) if causal else None + ) + tile_count_semaphore = torch.empty(1, dtype=torch.int32, device=device) + + num_warps = min((num_batch + 30) // 31, 32) + num_warps = 1 << (num_warps - 1).bit_length() + + cache_key = ( + num_warps, + tile_m, + tile_n, + nheads, + nheads_kv, + headdim, + headdim_v, + causal, + pack_gqa, + enable_pdl, + sort, + cu_seqlens_q is not None, + cu_seqlens_k is not None, + cu_seqlens_k_new is not None, + seqused_q is not None, + seqused_k is not None, + leftpad_k is not None, + num_m_blocks is not None, + num_splits_dynamic is not None, + virtual_batch_idx is not None, + num_nheads_in_l2 is not None, + tile_count_semaphore is not None, + n_blocks_per_split is not None, + zfill_padded_output, + ) + + if cache_key not in _get_scheduler_metadata.compile_cache: + ( + num_m_blocks_cute, + num_splits_dynamic_cute, + virtual_batch_idx_cute, + num_nheads_in_l2_cute, + tile_count_semaphore_cute, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + ) = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in ( + num_m_blocks, + num_splits_dynamic, + virtual_batch_idx, + num_nheads_in_l2, + tile_count_semaphore, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + ) + ] + scheduler = FlashPrepareScheduler( + num_warps, + tile_m, + tile_n, + nheads, + nheads_kv, + headdim, + headdim_v, + causal, + packgqa=pack_gqa, + sort=sort, + zfill_padded_output=zfill_padded_output, + ) + _get_scheduler_metadata.compile_cache[cache_key] = cute.compile( + scheduler, + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + cu_seqlens_k_new_cute, + seqused_q_cute, + seqused_k_cute, + leftpad_k_cute, + num_batch, + num_splits, + tile_count_semaphore_cute, + num_m_blocks_cute, + num_splits_dynamic_cute, + virtual_batch_idx_cute, + num_nheads_in_l2_cute, + n_blocks_per_split, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + if not is_fake_mode(): + _get_scheduler_metadata.compile_cache[cache_key]( + max_seqlen_q, + max_seqlen_k, + seqlen_k_new, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_k_new, + seqused_q, + seqused_k, + leftpad_k, + num_batch, + num_splits, + tile_count_semaphore, + num_m_blocks, + num_splits_dynamic, + virtual_batch_idx, + num_nheads_in_l2, + n_blocks_per_split, + ) + else: + num_m_blocks = None + num_splits_dynamic = None + virtual_batch_idx = None + num_nheads_in_l2 = None + tile_count_semaphore = None + + if is_fake_mode(): + return + + qhead_per_kvhead = nheads // nheads_kv + cu_total_m_blocks, cu_total_splits_m_blocks = _compute_tile_cumsum( + num_m_blocks=num_m_blocks, + cu_seqlens=cu_seqlens_q, + seqused=seqused_q, + num_splits_dynamic=num_splits_dynamic, + virtual_batch_idx=virtual_batch_idx, + tile_size=tile_m, + q_stage=q_stage, + cluster_shape_m=cluster_shape_m, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=bool(pack_gqa), + ) + + return SchedulerMetadataTensorsTorch( + num_m_blocks_ptr=num_m_blocks, + num_splits_dynamic_ptr=num_splits_dynamic, + virtual_batch_idx_ptr=virtual_batch_idx, + num_nheads_in_l2_ptr=num_nheads_in_l2, + tile_count_semaphore=tile_count_semaphore, + cu_total_m_blocks=cu_total_m_blocks, + cu_total_splits_m_blocks=cu_total_splits_m_blocks, + ) + + +_get_scheduler_metadata.compile_cache = get_jit_cache("scheduler_metadata") + + +def get_scheduler_metadata( + max_seqlen_q: int, + max_seqlen_k: int, + nheads: int, + nheads_kv: int, + headdim: int, + num_splits: int, + headdim_v: Optional[int] = None, + pack_gqa: Optional[int] = None, + causal: bool = False, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + seqlen_k_new: int = 0, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqlen_k_per_split: Optional[int] = None, + _arch: Optional[int] = None, +) -> SchedulerMetadataTensorsTorch: + """Prepares metadata tensors used by varlen tile schedulers (SingleTileVarlenScheduler + and DynamicPersistentVarlenScheduler) + + Explanation of selected args: + num_splits: maximum number of splits per batch entry that the prepare kernel can emit + seqlen_k_per_split: for bitwise reproducibility between forward and backward, can fix + an exact seqlen_k per split; num_splits is calculated accordingly. + + Returns + SchedulerMetadataTensorsTorch, a named tuple including: + - num_splits_dynamic_ptr: per-batch num_splits + - num_nheads_in_l2_ptr: used for head swizzle to avoid l2 cache thrashing + - tile_count_semaphore: the global semaphore used by DynamicPersistentVarlenScheduler atomic incrementation + - cu_total_m_blocks: cumsum tensor counting total m_blocks, used for binary batch search with large batch_size + - cu_total_splits_m_blocks: complementary cumsum tensor used for binary batch search and to + extract dynamic num splits in the absense of num_splits_dynamic_ptr + """ + arch = _get_device_arch() if _arch is None else _arch + if headdim_v is None: + headdim_v = headdim + + batch_sizes = {} + if cu_seqlens_q is not None: + batch_sizes["cu_seqlens_q"] = cu_seqlens_q.shape[0] - 1 + if cu_seqlens_k is not None: + batch_sizes["cu_seqlens_k"] = cu_seqlens_k.shape[0] - 1 + if seqused_q is not None: + batch_sizes["seqused_q"] = seqused_q.shape[0] + if seqused_k is not None: + batch_sizes["seqused_k"] = seqused_k.shape[0] + assert batch_sizes, ( + "get_scheduler_metadata requires at least one of " + "cu_seqlens_q/cu_seqlens_k/seqused_q/seqused_k" + ) + num_batch = next(iter(batch_sizes.values())) + assert all(b == num_batch for b in batch_sizes.values()), ( + f"inconsistent batch size across inputs: {batch_sizes}" + ) + device = next( + t.device for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) if t is not None + ) + + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right + ) + + qhead_per_kvhead = nheads // nheads_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + + fwd_cfg = _get_fwd_config( + arch=arch, + head_dim=headdim, + head_dim_v=headdim_v, + causal=causal, + local=local, + window_size_left=window_size_left, + window_size_right=window_size_right, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + qhead_per_kvhead=qhead_per_kvhead, + pack_gqa=pack_gqa, + batch_size=num_batch, + num_head_kv=nheads_kv, + num_splits=num_splits, + device=device, + ) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + q_stage = fwd_cfg.q_stage + num_splits = fwd_cfg.num_splits + + return _get_scheduler_metadata( + num_batch, + max_seqlen_q, + max_seqlen_k, + nheads, + nheads_kv, + headdim, + num_splits, + tile_m, + tile_n, + headdim_v=headdim_v, + pack_gqa=pack_gqa, + q_stage=q_stage, + causal=causal, + enable_pdl=False, # pdl not yet enabled + sort=False, # LPT batch sort not yet enabled + seqlen_k_new=seqlen_k_new, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + cu_seqlens_k_new=cu_seqlens_k_new, + seqused_q=seqused_q, + seqused_k=seqused_k, + leftpad_k=leftpad_k, + seqlen_k_per_split=seqlen_k_per_split, + zfill_padded_output=True, + ) diff --git a/flash_attn/cute/prepare_scheduler.py b/flash_attn/cute/prepare_scheduler.py new file mode 100644 index 00000000000..4532d5e3d70 --- /dev/null +++ b/flash_attn/cute/prepare_scheduler.py @@ -0,0 +1,397 @@ +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_prepare_scheduler.cu +# from CUTLASS C++ to Cute-DSL. + +from typing import Tuple, Optional, NamedTuple +import operator +import torch +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr, Float32 +from cutlass.cute import FastDivmodDivisor +import flash_attn.cute.utils as utils + + +class SchedulerMetadataTensorsTorch(NamedTuple): + """Class to store scheduler metadata for varlen""" + + # tensors of shape (batch) + num_m_blocks_ptr: Optional[torch.Tensor] + num_splits_dynamic_ptr: Optional[torch.Tensor] + virtual_batch_idx_ptr: Optional[torch.Tensor] + num_nheads_in_l2_ptr: Optional[torch.Tensor] + # tensor of shape (1) + tile_count_semaphore: Optional[torch.Tensor] + # tensors of shape (batch + 1) + # cu_total_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] + # cu_total_splits_m_blocks[b+1] = sum_{i<=b} num_m_blocks[i] * num_splits_dynamic[i] + cu_total_m_blocks: Optional[torch.Tensor] = None + cu_total_splits_m_blocks: Optional[torch.Tensor] = None + + +class FlashPrepareScheduler: + def __init__( + self, + num_warps: int, + tile_m: int, + tile_n: int, + nheads: int, + nheads_kv: int, + headdim: int, + headdim_v: Optional[int] = None, + is_causal: bool = False, + packgqa: bool = False, + sort: bool = False, + zfill_padded_output: bool = False, + ): + self.num_warps = num_warps + self.is_causal = is_causal + self.packgqa = packgqa + # TODO: Implement batch sort for LPT. + self.sort = False + self.num_threads_per_warp = 32 + self.tile_m = tile_m + self.tile_n = tile_n + self.d = headdim + self.dv = headdim_v if headdim_v is not None else headdim + self.k_num_batch_per_warp = 31 + self.k_smem_size = 1 + self.zfill_padded_output = zfill_padded_output + + # for pack gqa, query heads per kv head is combined with seqlen_q + self.nheads_computed = nheads if not self.packgqa else nheads_kv + + # L2 cache calculations + self.qhead_per_khead = nheads // nheads_kv + self.size_l2_divisor = ( + 1 + if self.qhead_per_khead == 1 + else ( + 2 + if self.qhead_per_khead <= 2 + else (4 if self.qhead_per_khead <= 4 else (8 if self.qhead_per_khead <= 8 else 16)) + ) + ) + self.size_l2 = (32 * 1024 * 1024) // self.size_l2_divisor + element_size = 2 + self.size_one_kvblock = self.tile_n * (self.d + self.dv) * element_size + self.max_kvblocks_in_l2 = ( + self.size_l2 + self.size_one_kvblock - 1 + ) // self.size_one_kvblock + + @staticmethod + def get_grid_shape(num_batch: int) -> Tuple[int, int, int]: + num_ctas = (num_batch + (31 * 32 - 1)) // (31 * 32) + return (num_ctas, 1, 1) + + @cute.jit + def __call__( + self, + seqlen_q_static: int, + seqlen_k_static: int, + seqlen_k_new_static: int, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mLeftPadK: Optional[cute.Tensor], + num_batch: int, + num_splits_static: int, + tile_count_semaphore: Optional[cute.Tensor], + num_m_blocks_ptr: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + virtual_batch_idx_ptr: Optional[cute.Tensor], + num_nheads_in_l2_ptr: Optional[cute.Tensor], + n_blocks_per_split: Optional[int], # overrides heuristic + stream: cuda.CUstream, + ): + tile_m_divmod = FastDivmodDivisor(self.tile_m) + tile_n_divmod = FastDivmodDivisor(self.tile_n) + + @cute.struct + class SharedStorage: + total_blocks_smem: cute.struct.MemRange[Int32, self.k_smem_size] + + self.shared_storage = SharedStorage + + block = (32 * self.num_warps, 1, 1) + grid = self.get_grid_shape(num_batch) + + hardware_info = cutlass.utils.HardwareInfo() + num_sm = hardware_info.get_device_multiprocessor_count() + + self.kernel( + seqlen_q_static, + seqlen_k_static, + seqlen_k_new_static, + mCuSeqlensQ, + mCuSeqlensK, + mCuSeqlensKNew, + mSeqUsedQ, + mSeqUsedK, + mLeftPadK, + num_batch, + num_sm, + num_splits_static, + tile_m_divmod, + tile_n_divmod, + tile_count_semaphore, + num_m_blocks_ptr, + num_splits_dynamic_ptr, + virtual_batch_idx_ptr, + num_nheads_in_l2_ptr, + n_blocks_per_split, + ).launch( + grid=grid, + block=block, + stream=stream, + smem=self.shared_storage.size_in_bytes(), + ) + + @cute.kernel + def kernel( + self, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + seqlen_k_new_static: Int32, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mLeftPadK: Optional[cute.Tensor], + num_batch: Int32, + num_sm: Int32, + num_splits_static: Int32, + tile_m_divmod: FastDivmodDivisor, + tile_n_divmod: FastDivmodDivisor, + tile_count_semaphore: Optional[cute.Tensor], + num_m_blocks_ptr: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + virtual_batch_idx_ptr: Optional[cute.Tensor], + num_nheads_in_l2_ptr: Optional[cute.Tensor], + n_blocks_per_split: Optional[Int32], + ): + bidx, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + grid_dimx, _, _ = cute.arch.grid_dim() + warp_idx = cute.arch.warp_idx() + lane_idx = cute.arch.lane_idx() + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + total_blocks_smem = storage.total_blocks_smem.get_tensor((1,)) + + if tidx == 0: + total_blocks_smem[0] = Int32(0) + cute.arch.sync_threads() + + if const_expr(tile_count_semaphore is not None): + if tidx == 0: + tile_count_semaphore[0] = Int32(0) + + batch_cta_idx_offset = bidx * 992 + bidb_start = batch_cta_idx_offset + self.k_num_batch_per_warp * warp_idx + batch_idx = lane_idx + bidb_start + + num_m_blocks, seqlen_q = self.get_num_m_blocks_and_seqlen( + lane_idx, + batch_idx, + mSeqUsedQ, + mCuSeqlensQ, + seqlen_q_static, + tile_m_divmod, + num_batch, + ) + + num_n_blocks = self.get_num_n_blocks( + lane_idx, + batch_idx, + mSeqUsedK, + mCuSeqlensK, + mCuSeqlensKNew, + seqlen_k_static, + seqlen_k_new_static, + mLeftPadK, + tile_n_divmod, + num_batch, + ) + + num_splits_dynamic = Int32(1) + if const_expr(n_blocks_per_split is not None): + # print("n_blocks_per_splits = ", n_blocks_per_split) + num_splits_dynamic = cutlass.min( + cute.ceil_div(num_n_blocks, n_blocks_per_split), num_splits_static + ) + if const_expr(self.zfill_padded_output): + num_splits_dynamic = cutlass.max(num_splits_dynamic, Int32(1)) + if num_splits_dynamic > 0: + num_n_blocks = cute.ceil_div(num_n_blocks, num_splits_dynamic) + else: + if grid_dimx > 1 or num_splits_static == 1: + num_splits_dynamic = Int32(1) + else: + total_blocks = num_m_blocks * num_n_blocks + total_blocks = utils.warp_reduce(total_blocks, operator.add) + if lane_idx == 0: + utils.atomic_add_i32(total_blocks, total_blocks_smem.iterator) + + cute.arch.sync_threads() + + total_blocks = total_blocks_smem[0] + + sm_margin = max(Float32(num_sm) / 128 + 0.001, 1.1) # e.g. 148/128 = 1.15625 + blocks_per_sm = cutlass.max( + Int32( + ( + Float32(total_blocks) + * sm_margin + * Float32(self.nheads_computed) + / Float32(num_sm) + ) + ), + Int32(1), + ) + # blocks_per_sm = cute.ceil_div(total_blocks * self.nheads_computed, num_sm) + num_splits_dynamic = cutlass.min( + cute.ceil_div(num_n_blocks, blocks_per_sm), num_splits_static + ) + if const_expr(self.zfill_padded_output): + num_splits_dynamic = cutlass.max(num_splits_dynamic, Int32(1)) + if num_splits_dynamic > 0: + num_n_blocks = cute.ceil_div(num_n_blocks, num_splits_dynamic) + # if tidx == 0: + # cute.printf("num_batch = {}", num_batch) + # cute.printf("num_m_blocks = {}", num_m_blocks) + # cute.printf("num_n_blocks = {}", num_n_blocks) + # cute.printf("total_blocks = {}", total_blocks) + # cute.printf("numerator = {}", total_blocks * self.nheads_computed) + # cute.printf("denominator num_sm = {}", num_sm) + # cute.printf("blocks_per_sm = {}", blocks_per_sm) + # cute.printf("sm margin = {}", sm_margin) + # cute.printf("num_splits_dynamic = {}", num_splits_dynamic) + + if const_expr(self.sort): + # TODO: Implement sort logic + pass + + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp: + if const_expr(num_m_blocks_ptr is not None): + num_m_blocks_ptr[batch_idx] = num_m_blocks + if const_expr(num_splits_dynamic_ptr is not None): + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic + if const_expr(num_nheads_in_l2_ptr is not None): + nheads_in_l2 = self.get_num_nheads_in_l2(num_n_blocks) + num_nheads_in_l2_ptr[batch_idx] = nheads_in_l2 + + @cute.jit + def get_num_m_blocks_and_seqlen( + self, + lane_idx: Int32, + batch_idx: Int32, + mSeqUsedQ: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + seqlen_q_static: Int32, + tile_m_divmod: FastDivmodDivisor, + num_batch: Int32, + ): + seqlen = Int32(0) + if const_expr(mSeqUsedQ is not None): + seqlen = mSeqUsedQ[batch_idx] if batch_idx < num_batch else Int32(0) + elif const_expr(mCuSeqlensQ is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + else: + seqlen = seqlen_q_static + + seqlen_for_blocks = seqlen + if const_expr(self.packgqa): + seqlen_for_blocks = seqlen * self.qhead_per_khead + num_m_blocks = ( + (seqlen_for_blocks + self.tile_m - 1) // tile_m_divmod + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp + else Int32(0) + ) + return (num_m_blocks, seqlen) + + @cute.jit + def get_num_n_blocks( + self, + lane_idx: Int32, + batch_idx: Int32, + mSeqUsedK: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mCuSeqlensKNew: Optional[cute.Tensor], + seqlen_k_static: Int32, + seqlen_k_new_static: Int32, + mLeftPadK: Optional[cute.Tensor], + tile_n_divmod: FastDivmodDivisor, + num_batch: Int32, + ): + leftpad_k = ( + mLeftPadK[batch_idx] + if const_expr(mLeftPadK is not None) and batch_idx < num_batch + else Int32(0) + ) + seqlen = Int32(0) + if const_expr(mSeqUsedK is not None): + seqlen = mSeqUsedK[batch_idx] if batch_idx < num_batch else Int32(0) + elif const_expr(mCuSeqlensK is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = mCuSeqlensK[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + else: + seqlen = seqlen_k_static + + seqlen_new = Int32(0) + if const_expr(mCuSeqlensKNew is not None): + # Since k_num_batch_per_warp = 31, lane 31 never processes batches + # So shuffle_down is safe: lane 30 gets lane 31's value (which is 0) + # Only access cu_seqlens if batch_idx is valid (0 to num_batch inclusive) + cur_cu_seqlen_new = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen_new = mCuSeqlensKNew[batch_idx] + next_cu_seqlen_new = cute.arch.shuffle_sync_down(cur_cu_seqlen_new, offset=1) + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new + else: + seqlen_new = seqlen_k_new_static + seqlen = seqlen - leftpad_k + seqlen_new + return ( + (seqlen + self.tile_n - 1) // tile_n_divmod + if batch_idx < num_batch and lane_idx < self.k_num_batch_per_warp + else Int32(0) + ) + + @cute.jit + def get_num_nheads_in_l2( + self, + num_n_blocks: Int32, + ): + max_kvblocks_in_l2 = self.max_kvblocks_in_l2 + qhead_per_khead = self.qhead_per_khead + nheads_in_l2 = Int32(16) + if num_n_blocks * Int32(16) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(16) + elif num_n_blocks * Int32(8) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(8) + elif num_n_blocks * Int32(4) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(4) + elif num_n_blocks * Int32(2) <= max_kvblocks_in_l2: + nheads_in_l2 = Int32(2) + else: + nheads_in_l2 = Int32(1) + if const_expr(not self.packgqa): + nheads_in_l2 *= qhead_per_khead + return cutlass.min(nheads_in_l2, self.nheads_computed) diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 885ae336f5f..629e95c42bc 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -24,7 +24,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, SM100_TMEM_CAPACITY_COLUMNS, make_sm100_thread_cooperative_group as make_thread_cooperative_group, Sm100FmhaClcDynamicTileSchedulerParams as FmhaClcDynamicTileSchedulerParams, @@ -1062,7 +1062,7 @@ def dkdv_bwd( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py index 25d6a91de70..5b666083132 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -16,7 +16,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, compute_sm100_fmha_grid as compute_grid, compute_sm100_fmha_grid_clc as compute_grid_clc, make_sm100_thread_cooperative_group as make_thread_cooperative_group, @@ -779,7 +779,7 @@ def kernel( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py index 379cebc1905..6cafc6da30e 100644 --- a/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py +++ b/flash_attn/cute/sm100_hd256_2cta_fmha_forward.py @@ -15,7 +15,7 @@ from cutlass.utils import ClcDynamicPersistentTileScheduler from flash_attn.cute.tile_scheduler import ( - ClcState, + SchedulerState, compute_sm100_fmha_grid as compute_grid, compute_sm100_fmha_grid_clc as compute_grid_clc, make_sm100_thread_cooperative_group as make_thread_cooperative_group, @@ -46,7 +46,7 @@ def __init__( m_block_size: int = 128, n_block_size: int = 128, q_stage: int = 2, - is_persistent: bool = True, + is_static_persistent: bool = True, score_mod=None, mask_mod=None, has_aux_tensors: bool = False, @@ -54,6 +54,8 @@ def __init__( is_varlen_q: bool = False, use_2cta_instrs: bool = False, use_clc_scheduler: bool = False, + has_tile_count_semaphore: bool = False, + seqlen_k_per_split: Optional[int] = None, ): head_dim_v = head_dim if head_dim_v is None else head_dim_v assert head_dim == 256 and head_dim_v == 256, ( @@ -697,7 +699,7 @@ def kernel( pipeline.Agent.Thread, num_clc_consumer_threads ) clc_response_ptr = storage.clc_response.data_ptr() - clc = ClcState.create( + clc = SchedulerState.create_clc( hw_scheduler=ClcDynamicPersistentTileScheduler.create( self.tile_sched_params.clc_hw_params(), cute.arch.block_idx(), diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py index ff820e59626..4af09907709 100644 --- a/flash_attn/cute/tile_scheduler.py +++ b/flash_attn/cute/tile_scheduler.py @@ -38,7 +38,56 @@ class SchedulingMode(IntEnum): @dataclass -class ClcState(ParamsBase): +class SchedulerState(ParamsBase): + """Runtime state shared by CLC and dynamic persistent tile schedulers: + the async pipeline and its producer/consumer states. + + Main kernels construct this via `create_clc` / `create_dynamic_persistent`, + which return the appropriate concrete state (`ClcSchedulerState` or + `DynamicPersistentSchedulerState`). Schedulers consume it through the + `ctx: SchedulerState | None` parameter on their `__init__(...)`. + """ + + _pipeline: cutlass.pipeline.PipelineAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create_clc( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcSchedulerState": + return ClcSchedulerState(pipeline, consumer_state, producer_state, hw_scheduler) + + @staticmethod + def create_dynamic_persistent( + *, + work_info: cute.Tensor, + pipeline: cutlass.pipeline.PipelineAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "DynamicPersistentSchedulerState": + return DynamicPersistentSchedulerState(pipeline, consumer_state, producer_state, work_info) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def advance_consumer_state(self, *, loc=None, ip=None): + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + +@dataclass +class ClcSchedulerState(SchedulerState): """Owns the runtime state shared by CLC-capable tile schedulers. `FlashAttentionForwardSm100` constructs this state because it owns the CLC @@ -49,24 +98,10 @@ class ClcState(ParamsBase): To add CLC support to a scheduler: - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler - - accept `clc: ClcState | None` in `create(...)` / `__init__` - - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + - map `ctx.initial_work_tile_info()` and `ctx.get_current_work()` into scheduler coordinates """ _hw_scheduler: ClcDynamicPersistentTileScheduler - _pipeline: PipelineClcFetchAsync - _consumer_state: PipelineState - _producer_state: PipelineState - - @staticmethod - def create( - *, - hw_scheduler: ClcDynamicPersistentTileScheduler, - pipeline: PipelineClcFetchAsync, - consumer_state: PipelineState, - producer_state: PipelineState, - ) -> "ClcState": - return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) def initial_work_tile_info(self): return self._hw_scheduler.initial_work_tile_info() @@ -80,15 +115,31 @@ def prefetch_next_work(self, *, loc=None, ip=None): self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) self._producer_state.advance(loc=loc, ip=ip) - def consumer_wait(self, *, loc=None, ip=None): - self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) - def consumer_release(self, *, loc=None, ip=None): - self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) - self._consumer_state.advance(loc=loc, ip=ip) +@dataclass +class DynamicPersistentSchedulerState(SchedulerState): + """Semaphore-backed: the scheduler class drives atomicAdd + warp-prefix-sum + and writes the resolved work tile via `write_work_info`.""" + + _work_info: cute.Tensor + + def producer_acquire(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + + def producer_commit(self, *, loc=None, ip=None): + self._pipeline.producer_commit(self._producer_state, loc=loc, ip=ip) + + def advance_producer_state(self, *, loc=None, ip=None): + self._producer_state.advance(loc=loc, ip=ip) + + def write_work_info(self, block: Int32, head: Int32, batch: Int32, split: Int32): + self._work_info[0] = block + self._work_info[1] = head + self._work_info[2] = batch + self._work_info[3] = split - def producer_tail(self, *, loc=None, ip=None): - self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + +ClcState = SchedulerState class WorkTileInfo(cutlass.utils.WorkTileInfo): @@ -108,13 +159,9 @@ class TileSchedulerProtocol(Protocol): Schedulers are responsible for: 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) - 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + 2. Work distribution: how to get the next tile (static grid-stride vs dynamic) """ - def get_current_work(self) -> WorkTileInfo: - """Get the current work tile coordinates.""" - ... - def initial_work_tile_info(self) -> WorkTileInfo: """Get the initial work tile for this CTA.""" ... @@ -123,14 +170,14 @@ def advance_to_next_work(self, *, loc=None, ip=None): """Consumer-side advance: move to next tile and return it. For static schedulers: grid-stride increment + get_current_work. - For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + For dynamic schedulers: consumer wait + get_current_work + consumer release + state advance. """ ... def prefetch_next_work(self, *, loc=None, ip=None) -> None: """Producer-side prefetch of next work tile (no-op for static schedulers). - For CLC schedulers: producer acquire + issue CLC query + producer state advance. + For dynamic schedulers: producer acquire (+ issue CLC query) + producer state advance. Only called by the scheduler warp. """ ... @@ -138,7 +185,7 @@ def prefetch_next_work(self, *, loc=None, ip=None) -> None: def producer_tail(self, *, loc=None, ip=None) -> None: """Producer-side cleanup after the last tile. - No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + No-op for static schedulers. For dynamic schedulers: pipeline producer_tail. """ ... @@ -164,6 +211,14 @@ class TileSchedulerArguments(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False use_cluster_idx: cutlass.Constexpr[bool] = False + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None + tile_count_semaphore: Optional[cute.Pointer] = None + persistent_cta_multiplier: cutlass.Constexpr[int] = 1 class SingleTileScheduler: @@ -177,6 +232,7 @@ class Params(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) use_cluster_idx: cutlass.Constexpr[bool] = False + num_splits_dynamic_ptr: Optional[cute.Tensor] = None @staticmethod def create( @@ -191,6 +247,7 @@ def create( args.is_split_kv, args.cluster_shape_mn, args.use_cluster_idx, + args.num_splits_dynamic_ptr, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): @@ -215,7 +272,7 @@ def to_underlying_arguments( @staticmethod def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileScheduler": if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): blk_coord = cute.arch.block_idx() @@ -246,13 +303,19 @@ def get_grid_shape( def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord + is_valid = self._is_first_block if const_expr(self.params.is_split_kv): head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) else: split_idx = Int32(0) + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None): + if is_valid: + num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( (block_idx, head_idx, batch_idx, split_idx), - self._is_first_block, + is_valid, ) def initial_work_tile_info(self, *, loc=None, ip=None): @@ -326,7 +389,7 @@ def to_underlying_arguments( @staticmethod def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] @@ -410,6 +473,7 @@ class Params(ParamsBase): scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True use_cluster_idx: cutlass.Constexpr[bool] = True + num_splits_dynamic_ptr: Optional[cute.Tensor] = None @staticmethod @cute.jit @@ -454,6 +518,7 @@ def create( scheduling_mode=scheduling_mode, lpt=args.lpt, use_cluster_idx=args.use_cluster_idx, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, ) def __init__( @@ -461,7 +526,7 @@ def __init__( params: Params, tile_idx: Int32, split_idx: Int32, - clc: ClcState | None = None, + ctx: SchedulerState | None = None, *, loc=None, ip=None, @@ -469,7 +534,7 @@ def __init__( self.params = params self._tile_idx = tile_idx self._split_idx = split_idx - self.clc = clc + self._ctx = ctx self._loc = loc self._ip = ip @@ -509,11 +574,11 @@ def clc_problem_shape(params: Params): @staticmethod @cute.jit def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileLPTScheduler": if const_expr(params.scheduling_mode == SchedulingMode.CLC): return SingleTileLPTScheduler( - params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + params, cute.arch.block_idx()[0], Int32(0), ctx, loc=loc, ip=ip ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) @@ -554,6 +619,11 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): bidx_in_cluster = cute.arch.block_in_cluster_idx() block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(self.params.is_split_kv and self.params.num_splits_dynamic_ptr is not None): + if work.is_valid_tile: + num_splits = Int32(self.params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), work.is_valid_tile, @@ -562,7 +632,7 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.clc.get_current_work() + work = self._ctx.get_current_work() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) # Static path: L2-swizzled coordinate mapping @@ -582,27 +652,33 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(params.lpt): block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks + split_idx = self._split_idx + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(params.is_split_kv and params.num_splits_dynamic_ptr is not None): + if is_valid: + num_splits = Int32(params.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) return WorkTileInfo( - (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), is_valid ) @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - work = self.clc.initial_work_tile_info() + work = self._ctx.initial_work_tile_info() self._tile_idx = work.tile_idx[0] return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.prefetch_next_work(loc=loc, ip=ip) + self._ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.consumer_wait(loc=loc, ip=ip) + self._ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.clc.consumer_release(loc=loc, ip=ip) + self._ctx.consumer_release(loc=loc, ip=ip) return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks @@ -610,13 +686,13 @@ def advance_to_next_work(self, *, loc=None, ip=None): def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.producer_tail(loc=loc, ip=ip) + self._ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self._ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -627,7 +703,7 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self._ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] @@ -768,23 +844,305 @@ def __new_from_mlir_values__(self, values): return self.__class__(*(tuple(obj_list)), loc=self._loc) +@dataclass +class VarlenDecoder(ParamsBase): + """Per-batch m-block lookup + warp-prefix-sum search-and-decode of the + varlen work tile. Composed into both `SingleTileVarlenScheduler.Params` + and `DynamicPersistentVarlenScheduler.Params`. + + `fold_splits_into_scan` controls whether the prefix-sum scan folds per-batch + `num_splits` into the per-batch tile count (DynamicPersistent) or always + counts only m_blocks (SingleTileVarlen, where splits are dispatched at the + grid level and resolved post-scan). + """ + + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + is_split_kv: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + fold_splits_into_scan: cutlass.Constexpr[bool] = False + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + num_m_blocks_ptr: Optional[cute.Tensor] = None + num_splits_dynamic_ptr: Optional[cute.Tensor] = None + virtual_batch_idx_ptr: Optional[cute.Tensor] = None + num_nheads_in_l2_ptr: Optional[cute.Tensor] = None + cu_total_m_blocks_ptr: Optional[cute.Tensor] = None + cu_total_splits_m_blocks_ptr: Optional[cute.Tensor] = None + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, + *, + fold_splits_into_scan: bool, + head_swizzle: bool = False, + cluster_shape_m: int = 1, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> "VarlenDecoder": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + # if backward, this is qdo block size + kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + # if backward, add dqaccum block size to calculate swizzle + if head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size + return VarlenDecoder( + num_head=args.num_head, + num_batch=args.num_batch, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + is_split_kv=args.is_split_kv, + lpt=args.lpt, + head_swizzle=head_swizzle, + cluster_shape_m=cluster_shape_m, + fold_splits_into_scan=fold_splits_into_scan, + scheduling_mode=scheduling_mode, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + num_m_blocks_ptr=args.num_m_blocks_ptr, + num_splits_dynamic_ptr=args.num_splits_dynamic_ptr, + virtual_batch_idx_ptr=args.virtual_batch_idx_ptr, + num_nheads_in_l2_ptr=args.num_nheads_in_l2_ptr, + cu_total_m_blocks_ptr=args.cu_total_m_blocks_ptr, + cu_total_splits_m_blocks_ptr=args.cu_total_splits_m_blocks_ptr, + ) + + @cute.jit + def _num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + """Per-batch m-block count""" + batch_idx = lane + bidb_start + is_valid = batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(self.num_m_blocks_ptr is not None): + num_m_blocks_raw = Int32(0) + if is_valid: + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): + real_batch_idx = self.virtual_batch_idx_ptr[batch_idx] + else: + real_batch_idx = batch_idx + num_m_blocks_raw = Int32(self.num_m_blocks_ptr[real_batch_idx]) + return cute.ceil_div(num_m_blocks_raw, self.cluster_shape_m) if is_valid else Int32(0) + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): + seqlen = Int32(0) + if is_valid: + real_batch_idx = self.virtual_batch_idx_ptr[batch_idx] + seqlen = self.mCuSeqlensQ[real_batch_idx + 1] - self.mCuSeqlensQ[real_batch_idx] + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, self.tile_shape_mn[0]), self.cluster_shape_m) + if is_valid + else Int32(0) + ) + if cutlass.const_expr(self.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < self.num_batch: + seqlen = self.mSeqUsedQ[batch_idx] + else: + assert self.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= self.num_batch: + cur_cu_seqlen = self.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + seqlen *= self.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(cute.ceil_div(seqlen, self.tile_shape_mn[0]), self.cluster_shape_m) + if is_valid + else Int32(0) + ) + + @cute.jit + def _num_splits(self, lane: Int32, bidb_start: Int32) -> Int32: + if cutlass.const_expr(not self.fold_splits_into_scan): + return Int32(1) + batch_idx = lane + bidb_start + is_valid = batch_idx < self.num_batch and lane < cute.arch.WARP_SIZE - 1 + if cutlass.const_expr(not self.is_split_kv): + return Int32(1) + elif cutlass.const_expr(self.num_splits_dynamic_ptr is not None): + num_splits = Int32(0) + if is_valid: + if cutlass.const_expr(self.virtual_batch_idx_ptr is not None): + batch_idx = self.virtual_batch_idx_ptr[batch_idx] + num_splits = self.num_splits_dynamic_ptr[batch_idx] + return num_splits + else: + return Int32(0) if not is_valid else self.num_splits + + @cute.jit + def decode( + self, + next_tile_idx: Int32, + bidb_start: Int32, + group_start_tile: Int32, + ) -> Tuple[Int32, Int32, Int32, Int32, Int32, Int32, Boolean]: + """Search varlen batches via warp-level prefix sums and decode the work tile. + + Returns + - block + - head_idx + - batch_idx + - split_idx + - num_splits + - group_start_tile + - is_valid + """ + if const_expr(self.is_split_kv): + cu_hint_ptr = self.cu_total_splits_m_blocks_ptr + else: + cu_hint_ptr = self.cu_total_m_blocks_ptr + # Both SingleTileVarlen STATIC and CLC; not DynamicPersistent (where + # warp-scan's _bidb_start resumption already amortizes per-call cost). + use_cumsum_hint = const_expr( + cu_hint_ptr is not None + and ( + self.scheduling_mode == SchedulingMode.STATIC + or self.scheduling_mode == SchedulingMode.CLC + ) + ) + if const_expr(use_cumsum_hint): + target = next_tile_idx // self.num_head + lo = utils.get_batch_from_cu_tensor(target, cu_hint_ptr) + group_size = Int32(cute.arch.WARP_SIZE - 1) + bidb_start = (lo // group_size) * group_size + group_start_tile = cu_hint_ptr[bidb_start] * self.num_head + + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._num_m_blocks(lane_idx, bidb_start=bidb_start) + num_splits = self._num_splits(lane_idx, bidb_start=bidb_start) + per_batch = num_m_blocks * num_splits if const_expr(self.is_split_kv) else num_m_blocks + cumulative = utils.warp_prefix_sum(per_batch, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) + group_end_tile = m_blocks_in_group * self.num_head + group_start_tile + + batch_idx = bidb_start + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= self.num_batch: + batch_idx = Int32(self.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._num_m_blocks(lane_idx, bidb_start=batch_idx) + num_splits = self._num_splits(lane_idx, bidb_start=batch_idx) + per_batch = ( + num_m_blocks * num_splits if const_expr(self.is_split_kv) else num_m_blocks + ) + cumulative = utils.warp_prefix_sum(per_batch, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync(cumulative, cute.arch.WARP_SIZE - 1) + group_end_tile += m_blocks_in_group * self.num_head + + is_valid = batch_idx < self.num_batch + if is_valid: + group_start_tile = group_end_tile - m_blocks_in_group * self.num_head + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + cumulative * self.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + Int32(0) + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(cumulative, batch_idx_in_group - 1) + ) + group_start_tile += num_m_blocks_prev_lane * self.num_head + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + if const_expr(self.is_split_kv): + num_splits = cute.arch.shuffle_sync(num_splits, batch_idx_in_group) + + block, head_idx, split_idx = Int32(0), Int32(0), Int32(0) + if is_valid: + mh_block = next_tile_idx - group_start_tile + + if const_expr(self.lpt or self.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + if const_expr(not self.is_split_kv) or num_splits == 1: + if const_expr(self.num_nheads_in_l2_ptr is not None): + if const_expr(self.virtual_batch_idx_ptr is not None): + nheads_in_l2 = Int32( + self.num_nheads_in_l2_ptr[self.virtual_batch_idx_ptr[batch_idx]] + ) + else: + nheads_in_l2 = Int32(self.num_nheads_in_l2_ptr[batch_idx]) + else: + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * self.tile_shape_mn[0] + * self.cluster_shape_m + // self.qhead_per_kvhead_packgqa + // self.tile_shape_mn[1] + ) + # Seems faster to have nheads_in_l2 be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= self.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= self.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= self.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= self.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, self.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= self.num_head + else self.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + if const_expr(self.lpt): + block = num_m_blocks - 1 - block + else: + head_split_idx = mh_block // num_m_blocks + block = mh_block - head_split_idx * num_m_blocks + if const_expr(self.is_split_kv): + head_idx = head_split_idx // num_splits + split_idx = head_split_idx - head_idx * num_splits + else: + head_idx = head_split_idx + + if const_expr(self.cluster_shape_m > 1): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * self.cluster_shape_m + bidx_in_cluster[0] + + return block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid + + class SingleTileVarlenScheduler: @dataclass class Params(ParamsBase): - num_head: Int32 - num_batch: Int32 total_q: Int32 - num_splits: Int32 - max_kvblock_in_l2: Int32 - tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] - mCuSeqlensQ: Optional[cute.Tensor] = None - mSeqUsedQ: Optional[cute.Tensor] = None - qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 - lpt: cutlass.Constexpr[bool] = False - is_split_kv: cutlass.Constexpr[bool] = False - head_swizzle: cutlass.Constexpr[bool] = False - cluster_shape_m: cutlass.Constexpr[int] = 1 - scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + scheduling_mode: cutlass.Constexpr[SchedulingMode] + decoder: VarlenDecoder @staticmethod @cute.jit @@ -798,15 +1156,6 @@ def create( assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( f"Only STATIC and CLC are supported, got {scheduling_mode!r}" ) - size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - # if backward, this is qdo block size - kv_block_size = ( - (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] - ) - # if backward, add dqaccum block size to calculate swizzle - if args.head_swizzle: - kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] - max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) @@ -816,21 +1165,19 @@ def create( assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( "Varlen CLC currently requires cluster_shape_mn[0] == 1" ) - return SingleTileVarlenScheduler.Params( - num_head=args.num_head, - num_batch=args.num_batch, - total_q=args.total_q, - num_splits=args.num_splits, - max_kvblock_in_l2=max_kvblock_in_l2, - tile_shape_mn=args.tile_shape_mn, - mCuSeqlensQ=args.mCuSeqlensQ, - mSeqUsedQ=args.mSeqUsedQ, - qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, - lpt=args.lpt, - is_split_kv=args.is_split_kv, + decoder = VarlenDecoder.create( + args, + fold_splits_into_scan=False, head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, + loc=loc, + ip=ip, + ) + return SingleTileVarlenScheduler.Params( + total_q=args.total_q, + scheduling_mode=scheduling_mode, + decoder=decoder, ) def __init__( @@ -838,7 +1185,7 @@ def __init__( params: Params, tile_idx: Int32, split_idx: Int32, - clc: ClcState | None = None, + ctx: SchedulerState | None = None, *, loc=None, ip=None, @@ -847,7 +1194,7 @@ def __init__( self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True - self.clc = clc + self._ctx = ctx self._loc = loc self._ip = ip @@ -874,18 +1221,18 @@ def clc_problem_shape(params: Params): @staticmethod @cute.jit def create( - params: Params, clc: ClcState | None = None, *, loc=None, ip=None + params: Params, ctx: SchedulerState | None = None, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler": if const_expr(params.scheduling_mode == SchedulingMode.CLC): block_idx = cute.arch.block_idx() split_idx = Int32(0) - if const_expr(params.is_split_kv): + if const_expr(params.decoder.is_split_kv): split_idx = block_idx[1] return SingleTileVarlenScheduler( params, block_idx[0], split_idx, - clc, + ctx, loc=loc, ip=ip, ) @@ -900,142 +1247,40 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + d = params.decoder total_blocks_max = ( - params.total_q - + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) - ) // params.tile_shape_mn[0] + params.total_q + d.num_batch * (d.cluster_shape_m * d.tile_shape_mn[0] - 1) + ) // d.tile_shape_mn[0] # Round down to nearest multiple of cluster since odd excess is always padding. - total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m - return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + total_blocks_max = total_blocks_max // d.cluster_shape_m * d.cluster_shape_m + return (total_blocks_max * d.num_head, d.num_splits, Int32(1)) @cute.jit - def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: - params = self.params - batch_idx = lane + bidb_start - if cutlass.const_expr(params.mSeqUsedQ is not None): - seqlen = Int32(0) - if batch_idx < params.num_batch: - seqlen = params.mSeqUsedQ[batch_idx] - else: - assert params.mCuSeqlensQ is not None - cur_cu_seqlen = Int32(0) - if batch_idx <= params.num_batch: - cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] - next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) - seqlen = next_cu_seqlen - cur_cu_seqlen - if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): - seqlen *= params.qhead_per_kvhead_packgqa - return ( - cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) - if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 - else Int32(0) + def _decode_work_tile(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch, split) via warp-level prefix sums.""" + d = self.params.decoder + next_tile_idx = self._tile_idx // d.cluster_shape_m + block, head_idx, batch_idx, _, _, _, is_valid = d.decode(next_tile_idx, Int32(0), Int32(0)) + is_valid = is_valid and self._is_first_block + split_idx = self._split_idx if const_expr(d.is_split_kv) else Int32(0) + if const_expr(d.virtual_batch_idx_ptr is not None): + if is_valid: + batch_idx = d.virtual_batch_idx_ptr[batch_idx] + # Pack dynamic per-batch num_splits into high 16 bits of split_idx + if const_expr(d.is_split_kv and d.num_splits_dynamic_ptr is not None): + if is_valid: + num_splits = Int32(d.num_splits_dynamic_ptr[batch_idx]) + split_idx = split_idx | (num_splits << 16) + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), + is_valid, ) - @cute.jit - def _varlen_coord_map(self) -> WorkTileInfo: - """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" - params = self.params - lane_idx = cute.arch.lane_idx() - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) - # Total number of blocks for the next 31 batches - m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) - # Same for all lanes - group_end_tile = m_blocks_in_group * params.num_head - # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) - block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) - next_tile_idx = self._tile_idx // params.cluster_shape_m - while group_end_tile <= next_tile_idx: - batch_idx += cute.arch.WARP_SIZE - 1 - if batch_idx >= params.num_batch: - batch_idx = Int32(params.num_batch) - group_end_tile = next_tile_idx + 1 - else: - num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) - num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) - m_blocks_in_group = cute.arch.shuffle_sync( - num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 - ) - group_end_tile += m_blocks_in_group * params.num_head - is_valid = False - if batch_idx >= params.num_batch: - block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) - else: - group_start_tile = group_end_tile - m_blocks_in_group * params.num_head - # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) - # The next problem to process is the first one that does not have ending tile position - # that is greater than or equal to tile index. - batch_idx_in_group = cute.arch.popc( - cute.arch.vote_ballot_sync( - group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx - ) - ) - batch_idx += batch_idx_in_group - num_m_blocks_prev_lane = ( - 0 - if batch_idx_in_group == 0 - else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) - ) - num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) - mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head - if cutlass.const_expr(params.lpt or params.head_swizzle): - # This is a version of the SingleTileLPTScheduler, complicated by the fact that - # the seqlen can vary per batch. - # TODO: is there any case where num_m_blocks is 0? - # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here - num_n_blocks = ( - num_m_blocks - * params.tile_shape_mn[0] - * params.cluster_shape_m - // params.qhead_per_kvhead_packgqa - // params.tile_shape_mn[1] - ) - # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) - # Seems faster to have this be a power of 2 - nheads_in_l2 = ( - 16 - if num_n_blocks * 16 <= params.max_kvblock_in_l2 - else ( - 8 - if num_n_blocks * 8 <= params.max_kvblock_in_l2 - else ( - 4 - if num_n_blocks * 4 <= params.max_kvblock_in_l2 - else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) - ) - ) - ) - nheads_in_l2 = min(nheads_in_l2, params.num_head) - mh_in_l2 = nheads_in_l2 * num_m_blocks - section_idx = mh_block // mh_in_l2 - l2_mod = mh_block - section_idx * mh_in_l2 - # Deal with tail section - nheads_in_this_section = ( - nheads_in_l2 - if nheads_in_l2 * (section_idx + 1) <= params.num_head - else params.num_head - section_idx * nheads_in_l2 - ) - block = l2_mod // nheads_in_this_section - head_idx_residual = l2_mod - block * nheads_in_this_section - head_idx = section_idx * nheads_in_l2 + head_idx_residual - if cutlass.const_expr(params.lpt): - block = num_m_blocks - 1 - block - else: - head_idx = mh_block // num_m_blocks - block = mh_block - head_idx * num_m_blocks - is_valid = self._is_first_block and batch_idx < params.num_batch - if cutlass.const_expr(params.cluster_shape_m > 1): - bidx_in_cluster = cute.arch.block_in_cluster_idx() - block = block * params.cluster_shape_m + bidx_in_cluster[0] - # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) - split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) - return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) - @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.clc.get_current_work() - # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + clc_work = self._ctx.get_current_work() + # Default to grid_dim (one past last valid flat index) so _decode_work_tile # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural # mismatch on self inside the runtime if. @@ -1043,49 +1288,49 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: new_split_idx = Int32(0) if clc_work.is_valid_tile: new_tile_idx = clc_work.tile_idx[0] - if const_expr(self.params.is_split_kv): + if const_expr(self.params.decoder.is_split_kv): new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx - return self._varlen_coord_map() + return self._decode_work_tile() @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - clc_work = self.clc.initial_work_tile_info() + clc_work = self._ctx.initial_work_tile_info() # See get_current_work for why grid_dim and local-then-assign. new_tile_idx = cute.arch.grid_dim()[0] new_split_idx = Int32(0) if clc_work.is_valid_tile: new_tile_idx = clc_work.tile_idx[0] - if const_expr(self.params.is_split_kv): + if const_expr(self.params.decoder.is_split_kv): new_split_idx = clc_work.tile_idx[1] self._tile_idx = new_tile_idx self._split_idx = new_split_idx - return self._varlen_coord_map() + return self._decode_work_tile() def prefetch_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.prefetch_next_work(loc=loc, ip=ip) + self._ctx.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.consumer_wait(loc=loc, ip=ip) + self._ctx.consumer_wait(loc=loc, ip=ip) work = self.get_current_work() - self.clc.consumer_release(loc=loc, ip=ip) + self._ctx.consumer_release(loc=loc, ip=ip) return work self._is_first_block = False return self.get_current_work() def producer_tail(self, *, loc=None, ip=None): if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - self.clc.producer_tail(loc=loc, ip=ip) + self._ctx.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self._ctx] for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values @@ -1096,13 +1341,198 @@ def __new_from_mlir_values__(self, values): obj_list = [] objs = [self.params, self._tile_idx, self._split_idx] if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): - objs += [self.clc] + objs += [self._ctx] for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*obj_list, loc=self._loc) +class DynamicPersistentVarlenScheduler: + @dataclass + class Params(ParamsBase): + total_q: Int32 + decoder: VarlenDecoder + tile_count_semaphore: Optional[cute.Pointer] = None + persistent_cta_multiplier: cutlass.Constexpr[int] = 1 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "DynamicPersistentVarlenScheduler.Params": + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + decoder = VarlenDecoder.create( + args, + fold_splits_into_scan=True, + scheduling_mode=SchedulingMode.DYNAMIC, + loc=loc, + ip=ip, + ) + return DynamicPersistentVarlenScheduler.Params( + total_q=args.total_q, + decoder=decoder, + tile_count_semaphore=args.tile_count_semaphore, + persistent_cta_multiplier=args.persistent_cta_multiplier, + ) + + def __init__( + self, + params: Params, + ctx: SchedulerState, + bidb_start: Int32, + group_start_tile: Int32, + *, + loc=None, + ip=None, + ): + self.params = params + self._ctx = ctx + self._bidb_start = bidb_start + self._group_start_tile = group_start_tile + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.DYNAMIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.DYNAMIC, ( + f"DynamicPersistentVarlenScheduler only supports DYNAMIC, got {scheduling_mode!r}" + ) + return DynamicPersistentVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + ctx: SchedulerState, + *, + loc=None, + ip=None, + ) -> "DynamicPersistentVarlenScheduler": + return DynamicPersistentVarlenScheduler(params, ctx, Int32(0), Int32(0), loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + d = params.decoder + total_blocks_max = ( + params.total_q + d.num_batch * (d.tile_shape_mn[0] - 1) + ) // d.tile_shape_mn[0] + total_blocks = total_blocks_max * d.num_head * d.num_splits + hardware_info = HardwareInfo() + sm_count = ( + hardware_info.get_device_multiprocessor_count() * params.persistent_cta_multiplier + ) + return (cutlass.min(sm_count, total_blocks), Int32(1), Int32(1)) + + @cute.jit + def get_current_work( + self, + next_tile_idx: Int32, + bidb_start: Int32, + group_start_tile: Int32, + *, + loc=None, + ip=None, + ) -> WorkTileInfo: + d = self.params.decoder + block, head_idx, batch_idx, split_idx, num_splits, group_start_tile, is_valid = d.decode( + next_tile_idx, bidb_start, group_start_tile + ) + if const_expr(d.is_split_kv and d.num_splits_dynamic_ptr is not None): + if is_valid: + split_idx = split_idx | (num_splits << 16) + if const_expr(d.virtual_batch_idx_ptr is not None): + if is_valid: + batch_idx = d.virtual_batch_idx_ptr[batch_idx] + return ( + WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(split_idx)), + is_valid, + ), + group_start_tile, + ) + + @cute.jit + def prefetch_next_work(self, *, loc=None, ip=None): + ctx = self._ctx + next_tile_idx = Int32(0) + if cute.arch.lane_idx() == 0: + next_tile_idx = cute.arch.grid_dim()[0] + utils.atomic_add_i32( + 1, + self.params.tile_count_semaphore, + ) + next_tile_idx = cute.arch.shuffle_sync(next_tile_idx, 0) + work_info, new_group_start_tile = self.get_current_work( + next_tile_idx, self._bidb_start, self._group_start_tile + ) + # Advance scan state so the next prefetch resumes from this tile's batch + # group instead of restarting at batch 0. + self._bidb_start = Int32(work_info.tile_idx[2]) + self._group_start_tile = new_group_start_tile + ctx.producer_acquire() + with cute.arch.elect_one(): + block, head_idx, batch_idx, split_idx = work_info.tile_idx + ctx.write_work_info(block, head_idx, batch_idx, split_idx) + ctx.producer_commit() + ctx.advance_producer_state() + + @cute.jit + def advance_to_next_work(self, *, loc=None, ip=None) -> WorkTileInfo: + ctx = self._ctx + ctx.consumer_wait() + block = ctx._work_info[0] + head_idx = ctx._work_info[1] + batch_idx = ctx._work_info[2] + split_idx = ctx._work_info[3] + is_valid = batch_idx < self.params.decoder.num_batch + work_info = WorkTileInfo((block, head_idx, batch_idx, split_idx), is_valid) + ctx.consumer_release() + return work_info + + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + cta_tile_idx, _, _ = cute.arch.block_idx() + work_info, new_group_start_tile = self.get_current_work(cta_tile_idx, Int32(0), Int32(0)) + self._bidb_start = Int32(work_info.tile_idx[2]) + self._group_start_tile = new_group_start_tile + return work_info + + def producer_tail(self, *, loc=None, ip=None): + self._ctx.producer_tail(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._ctx, self._bidb_start, self._group_start_tile]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._ctx, self._bidb_start, self._group_start_tile], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*obj_list, loc=self._loc) + + # ----------------------------------------------------------------------------- # SM100 FMHA-specific schedulers (kept separate from generic schedulers). # ----------------------------------------------------------------------------- diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index 8778065966d..0c0e8900607 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -466,6 +466,13 @@ def fadd_reduce( return local_sum[0][0] + local_sum[0][1] +@dsl_user_op +def atomic_add_i32(a: int | Int32, ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + @dsl_user_op def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 5baeaff31c4..4d1779fe7f4 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -30,6 +30,7 @@ from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, + get_scheduler_metadata, _flash_attn_fwd, _flash_attn_bwd, ) @@ -806,10 +807,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + precompute_metadata_vals = [False, True] + for pack_gqa, num_splits, precompute_metadata in itertools.product( + pack_gqa_vals, num_splits_vals, precompute_metadata_vals + ): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue + if precompute_metadata and is_fake_mode(): + continue # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) # Remove this when support is added. @@ -818,56 +824,76 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): continue if pack_gqa is None and mha_type != "mha": continue - out_unpad, lse = flash_attn_varlen_func( - q_unpad if unpad_q else q, - k_unpad if unpad_kv else k, - v_unpad if unpad_kv else v, - cu_seqlens_q=cu_seqlens_q if unpad_q else None, - cu_seqlens_k=cu_seqlens_k if unpad_kv else None, - max_seqlen_q=seqlen_q, - max_seqlen_k=seqlen_k, - seqused_q=seqused_q if not unpad_q else None, - seqused_k=seqused_k if not unpad_kv else None, - causal=causal, - # qv=qv_unpad, - # q_descale=q_descale, - # k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - # attention_chunk=attention_chunk, - learnable_sink=learnable_sink, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - ) - out = output_pad_fn(out_unpad) if unpad_q else out_unpad - if is_fake_mode(): - # no more flash_attn cutedsl calls for the rest of the loop - # skip data-dependent postprocessing - continue - if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - # When unpad_q=False with seqused_q, the kernel doesn't write positions - # beyond seqused_q, so those contain uninitialized values. Mask them out - # before comparing. - out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt - if not unpad_q and seqused_q is not None: - seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] - seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") - out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) - out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) - out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) - print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") - print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + causal=causal, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + ) + else: + scheduler_metadata = None + # Repeat to exercise metadata reuse across calls. + for _ in range(1 if not precompute_metadata else 2): + out_unpad, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + ) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if is_fake_mode(): + # no more flash_attn cutedsl calls for the rest of the loop + # skip data-dependent postprocessing + continue + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( - out_pt_cmp - out_ref_cmp - ).abs().max().item() + fwd_atol + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp + ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn @@ -1009,6 +1035,171 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): ).abs().max().item() + dv_atol +@pytest.mark.parametrize( + "cumsum_mode", ["jit_cumsum", "metadata_cumsum_only", "metadata_full"] +) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("qhead_per_kvhead", [1, 4]) +@retry_on_oom +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_cumsum_metadata_paths(causal, cumsum_mode, qhead_per_kvhead): + """Exercise the cu_total_m_blocks fast paths end-to-end. + + - "jit_cumsum": batch_size > 512 varlen, no scheduler_metadata. Triggers + the just-in-time host cumsum in _flash_attn_fwd and the hoisted Q/K + cumsum in _flash_attn_bwd. + - "metadata_cumsum_only": scheduler_metadata from get_scheduler_metadata + with num_splits=1 — skips the FlashPrepareScheduler kernel and returns + only cu_total_m_blocks. Fwd reads it from scheduler_metadata. + - "metadata_full": scheduler_metadata with num_splits>1 (SM100 only). + Runs the full prepare kernel and populates both cu_total tensors. + """ + if cumsum_mode == "metadata_full" and (IS_SM90 or DISABLE_SPLIT): + pytest.skip("split-kv not yet implemented on SM90") + device = "cuda" + torch.manual_seed(0) + random.seed(0) + + if cumsum_mode == "jit_cumsum": + batch_size = 600 + else: + batch_size = 64 + seqlen_q = seqlen_k = 64 + nheads_kv = 4 + nheads = nheads_kv * qhead_per_kvhead + d = dv = 128 + dtype = torch.bfloat16 + num_splits = 4 if cumsum_mode == "metadata_full" else 1 + + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype + ).requires_grad_() + k_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype + ).requires_grad_() + v_ref = torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype + ).requires_grad_() + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="third" + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="third" + ) + ( + q_unpad, + k_unpad, + v_unpad, + _qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + _seqused_q, + _seqused_k, + max_seqlen_q, + max_seqlen_k, + _q, + _k, + _v, + _qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + q_unpad = q_unpad.detach().requires_grad_() + k_unpad = k_unpad.detach().requires_grad_() + v_unpad = v_unpad.detach().requires_grad_() + + scheduler_metadata = None + if cumsum_mode != "jit_cumsum": + scheduler_metadata = get_scheduler_metadata( + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + nheads=nheads, + nheads_kv=nheads_kv, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if is_fake_mode(): + return + assert scheduler_metadata.cu_total_m_blocks is not None + if cumsum_mode == "metadata_cumsum_only" and not causal: + # FlashPrepareScheduler is skipped only when num_splits == 1 and not causal and not sort. + assert scheduler_metadata.num_m_blocks_ptr is None + assert scheduler_metadata.tile_count_semaphore is None + if cumsum_mode == "metadata_full": + assert scheduler_metadata.num_m_blocks_ptr is not None + assert scheduler_metadata.cu_total_splits_m_blocks is not None + + out_ref, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal + ) + out_pt, _ = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + out_unpad, _ = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=causal, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + ) + if is_fake_mode(): + return + out = output_pad_fn(out_unpad) + + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if cumsum_mode == "metadata_full": + return # split-kv bwd not supported + + g_unpad = torch.randn_like(out_unpad) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + + g = output_pad_fn(g_unpad) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + + for name, x, x_ref, x_pt in [ + ("dq", dq, dq_ref, dq_pt), + ("dk", dk, dk_ref, dk_pt), + ("dv", dv, dv_ref, dv_pt), + ]: + atol = 2 * (x_ref + 0.3 - 0.3 - x_ref).abs().max().item() + assert (x - x_ref).abs().max().item() <= 2 * ( + x_pt - x_ref + ).abs().max().item() + atol, name + + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @@ -1425,26 +1616,32 @@ def test_flash_attn_kvcache( # num_splits_vals = [1, 0] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] - # precompute_metadata_vals = [False, True] - precompute_metadata_vals = [False] + precompute_metadata_vals = [False, True] + # precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue - # if precompute_metadata: - # scheduler_metadata = get_scheduler_metadata( - # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, - # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, - # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, - # max_seqlen_k_new=seqlen_new, page_size=page_size, - # causal=causal, window_size=window_size, attention_chunk=attention_chunk, - # num_splits=num_splits - # ) - # else: - # scheduler_metadata = None - scheduler_metadata = None + if precompute_metadata and is_fake_mode(): + continue + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + max_seqlen_q=max_seqlen_q if varlen_q else seqlen_q, + max_seqlen_k=seqlen_k, + nheads=nheads, + nheads_kv=nheads_k, + headdim=d, + headdim_v=dv, + num_splits=num_splits, + causal=causal, + sort=True, + cu_seqlens_q=cu_seqlens_q, + seqused_k=cache_seqlens, + ) + else: + scheduler_metadata = None # Repeat to test metadata reuse for _ in range(1 if not precompute_metadata else 2): if page_size is None: @@ -1475,7 +1672,7 @@ def test_flash_attn_kvcache( learnable_sink=learnable_sink, # attention_chunk=attention_chunk, # rotary_interleaved=rotary_interleaved, - # scheduler_metadata=scheduler_metadata, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, # return_softmax_lse=True ) diff --git a/tests/cute/test_flash_attn_combine.py b/tests/cute/test_flash_attn_combine.py index 6344f96ab4b..ea77ded67ed 100644 --- a/tests/cute/test_flash_attn_combine.py +++ b/tests/cute/test_flash_attn_combine.py @@ -228,12 +228,12 @@ def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype): @pytest.mark.parametrize("num_splits", [2, 5, 17]) # @pytest.mark.parametrize("num_splits", [5]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) -def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): - """Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices. +def test_flash_attn_combine_virtual_batch_idx(num_splits, seqlen, d, dtype): + """Test that virtual_batch_idx correctly remaps virtual batch indices to real batch indices. - varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel + virtual_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel reads AND writes using the remapped batch_idx, so with a permutation the output - should match running without varlen_batch_idx (each real batch is processed once). + should match running without virtual_batch_idx (each real batch is processed once). We also test with seqused to verify interaction with variable-length sequences. """ @@ -255,18 +255,18 @@ def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) assert perm.shape[0] == batch_size - # Also test with seqused to verify interaction with varlen_batch_idx + # Also test with seqused to verify interaction with virtual_batch_idx seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) # Zero out / -inf beyond seqused so reference matches kernel for i in range(batch_size): out_partial[:, i, seqused[i]:] = 0 lse_partial[:, i, seqused[i]:] = -float("inf") - # Run with varlen_batch_idx and seqused via public API + # Run with virtual_batch_idx and seqused via public API out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, seqused=seqused, - varlen_batch_idx=perm, + virtual_batch_idx=perm, return_lse=True, ) if is_fake_mode():