diff --git a/AI/CLC_TRACE_DEBUG.md b/AI/CLC_TRACE_DEBUG.md new file mode 100644 index 00000000000..9f1502aa57a --- /dev/null +++ b/AI/CLC_TRACE_DEBUG.md @@ -0,0 +1,82 @@ +# CLC Trace Debugging + +Use this when you suspect the CLC work scheduler is making surprising tile assignment decisions and you want a raw scheduler trace from the current kernel. + +## Current trace format + +SM100 forward kernels emit one trace line per scheduler-warp query at `FA_LOG_LEVEL=3`: + +```text +[CLC] query sm= cta= (m_blk=,h=,b=,s=) valid=<0|1> +``` + +Current emit sites: +- `flash_attn/cute/flash_fwd_sm100.py` +- `flash_attn/cute/flash_fwd_mla_sm100.py` + +## How to capture a trace + +Important: +- `FA_LOG_LEVEL=3` is needed for the `[CLC] query ...` device-side prints. +- `FA_CLC=1` only requests CLC; the kernel may still fall back if the shape/features disable it. + +Minimal repro pattern: + +```bash +FA_LOG_LEVEL=3 FA_CLC=1 CUDA_VISIBLE_DEVICES=0 python - <<'PY' \ + > agent_space/clc_trace.log 2>&1 +import torch +from flash_attn.cute.interface import flash_attn_func + +torch.manual_seed(0) +q = torch.randn(1, 512, 16, 128, device='cuda', dtype=torch.bfloat16) +k = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +v = torch.randn(1, 512, 1, 128, device='cuda', dtype=torch.bfloat16) +flash_attn_func(q, k, v, causal=True) +torch.cuda.synchronize() +PY +``` + +If you want the run to say explicitly whether CLC was selected, keep the host log prefix too: + +```text +[FA] TileScheduler=SingleTileLPTScheduler, scheduling_mode=CLC, USE_2CTA=False +``` + +## What to look for + +- `scheduling_mode=CLC` in host logs confirms the shape actually used the CLC path. +- `valid=1` means the returned work tile is valid. +- `valid=0` means the scheduler is exhausted for that CTA/scheduler warp query. +- `m_blk`, `h`, `b`, `s` are the logical work coordinates after the scheduler mapping. +- `cta` is the physical `blockIdx.x`; for clustered launches multiple CTAs may participate in the same logical tile. + +## Parse the trace + +A lightweight parser lives in `AI/parse_clc_log.py`. + +Text summary: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log +``` + +HTML view: + +```bash +python AI/parse_clc_log.py agent_space/clc_trace.log --html -o agent_space/clc_trace.html +``` + +## Suggested workflow + +1. Reproduce the surprising case with `FA_LOG_LEVEL=3 FA_CLC=1`. +2. Save stdout/stderr to `agent_space/clc_trace.log`. +3. Run `AI/parse_clc_log.py` on that log to get a compact per-SM / per-CTA summary. +4. If the trace still looks suspicious, attach or paste that log in the investigation thread / agent notes. +5. Compare against the relevant mapping logic in `flash_attn/cute/tile_scheduler.py`. + +## Caveats + +- The trace is noisy and expensive; use a single small shape first. +- Because the print happens on scheduler queries, many lines may be terminal `valid=0` queries after work is exhausted. +- Dense noncausal and varlen MHA may intentionally fall back away from CLC depending on the current heuristic in `flash_attn/cute/interface.py`. diff --git a/AI/parse_clc_log.py b/AI/parse_clc_log.py new file mode 100644 index 00000000000..c1b94543bf4 --- /dev/null +++ b/AI/parse_clc_log.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import html +import json +import re +import sys +from collections import Counter, defaultdict +from dataclasses import asdict, dataclass +from pathlib import Path + +TRACE_RE = re.compile( + r"\[CLC\]\s+query\s+sm=(?P\d+)\s+cta=(?P\d+)\s+" + r"\(m_blk=(?P-?\d+),h=(?P-?\d+),b=(?P-?\d+),s=(?P-?\d+)\)\s+" + r"valid=(?P[01])" +) + + +@dataclass(frozen=True) +class TraceRow: + sm: int + cta: int + m_blk: int + h: int + b: int + s: int + valid: int + + +def parse_rows(text: str) -> list[TraceRow]: + rows: list[TraceRow] = [] + for line in text.splitlines(): + match = TRACE_RE.search(line) + if match is None: + continue + rows.append(TraceRow(**{key: int(value) for key, value in match.groupdict().items()})) + return rows + + +def summarize(rows: list[TraceRow]) -> dict: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + by_cta: dict[int, list[TraceRow]] = defaultdict(list) + tile_counter: Counter[tuple[int, int, int, int, int]] = Counter() + for row in rows: + by_sm[row.sm].append(row) + by_cta[row.cta].append(row) + tile_counter[(row.m_blk, row.h, row.b, row.s, row.valid)] += 1 + + def encode_group(grouped: dict[int, list[TraceRow]]) -> dict[str, dict]: + out: dict[str, dict] = {} + for key, group_rows in sorted(grouped.items()): + out[str(key)] = { + "count": len(group_rows), + "valid_count": sum(row.valid for row in group_rows), + "invalid_count": sum(1 - row.valid for row in group_rows), + "first": asdict(group_rows[0]), + "last": asdict(group_rows[-1]), + "unique_tiles": len({(r.m_blk, r.h, r.b, r.s, r.valid) for r in group_rows}), + } + return out + + top_tiles = [ + { + "tile": { + "m_blk": tile[0], + "h": tile[1], + "b": tile[2], + "s": tile[3], + "valid": tile[4], + }, + "count": count, + } + for tile, count in tile_counter.most_common(20) + ] + + return { + "rows": len(rows), + "valid_rows": sum(row.valid for row in rows), + "invalid_rows": sum(1 - row.valid for row in rows), + "unique_sms": len(by_sm), + "unique_ctas": len(by_cta), + "by_sm": encode_group(by_sm), + "by_cta": encode_group(by_cta), + "top_tiles": top_tiles, + } + + +def format_summary(summary: dict) -> str: + lines = [ + f"rows={summary['rows']} valid={summary['valid_rows']} invalid={summary['invalid_rows']}", + f"unique_sms={summary['unique_sms']} unique_ctas={summary['unique_ctas']}", + "top_tiles:", + ] + for entry in summary["top_tiles"][:10]: + tile = entry["tile"] + lines.append( + f" count={entry['count']:>4} tile=(m_blk={tile['m_blk']}, h={tile['h']}, b={tile['b']}, s={tile['s']}, valid={tile['valid']})" + ) + lines.append("by_sm:") + for sm, sm_summary in summary["by_sm"].items(): + first = sm_summary["first"] + last = sm_summary["last"] + lines.append( + f" sm={sm:>3} count={sm_summary['count']:>4} valid={sm_summary['valid_count']:>4} invalid={sm_summary['invalid_count']:>4} " + f"first=(cta={first['cta']},m_blk={first['m_blk']},h={first['h']},b={first['b']},s={first['s']},v={first['valid']}) " + f"last=(cta={last['cta']},m_blk={last['m_blk']},h={last['h']},b={last['b']},s={last['s']},v={last['valid']})" + ) + return "\n".join(lines) + + +def visualize_html(rows: list[TraceRow], summary: dict) -> str: + by_sm: dict[int, list[TraceRow]] = defaultdict(list) + for row in rows: + by_sm[row.sm].append(row) + + data = [ + { + "sm": sm, + "tiles": [ + { + "id": r.m_blk, + "type": "INIT" if idx == 0 else "PULL", + "valid": bool(r.valid), + "m": r.m_blk, + "h": r.h, + "b": r.b, + "s": r.s, + "cta": r.cta, + } + for idx, r in enumerate(chain) + ], + } + for sm, chain in sorted(by_sm.items()) + ] + + total_tiles = sum(len(d["tiles"]) for d in data) + valid_pulls = sum(1 for d in data for t in d["tiles"] if t["type"] == "PULL" and t["valid"]) + work_per_sm = [sum(1 for t in d["tiles"] if t["valid"]) for d in data] + histogram = defaultdict(int) + for work in work_per_sm: + histogram[work] += 1 + histogram_data = [{"work": k, "count": v} for k, v in sorted(histogram.items())] + work_stats = { + "min": min(work_per_sm) if work_per_sm else 0, + "max": max(work_per_sm) if work_per_sm else 0, + "mean": (sum(work_per_sm) / len(work_per_sm)) if work_per_sm else 0.0, + "std": ( + sum((w - sum(work_per_sm) / len(work_per_sm)) ** 2 for w in work_per_sm) / len(work_per_sm) + ) ** 0.5 if work_per_sm else 0.0, + } + + return f""" + + + + +CLC Work Distribution Viewer + + + +

CLC Work Distribution Viewer

+
+ query-trace mode + SMs: {len(data)} + Total queries: {total_tiles} + Valid pulls: {valid_pulls} + Invalid queries: {summary['invalid_rows']} +
+
+ + + Press Esc to clear +
+
+
First query on SM
+
Later query / pull
+
Invalid / exhausted
+
+
+

Work Distribution Histogram min={work_stats['min']}, max={work_stats['max']}, mean={work_stats['mean']:.1f}, std={work_stats['std']:.2f}

+
+
+
+
+
+

SM

+
+
+
+ + + +""" + + +def read_text(path: str | None) -> str: + if path is None or path == "-": + return sys.stdin.read() + return Path(path).read_text() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Parse FlashAttention CLC trace lines.") + parser.add_argument("logfile", nargs="?", default="-", help="Trace log path or - for stdin") + parser.add_argument("--json", action="store_true", help="Emit JSON summary") + parser.add_argument("--rows", action="store_true", help="Emit parsed rows as JSON") + parser.add_argument("--html", action="store_true", help="Emit HTML view") + parser.add_argument("-o", "--output", help="Output path for --html") + args = parser.parse_args() + + rows = parse_rows(read_text(args.logfile)) + if args.rows: + print(json.dumps([asdict(row) for row in rows], indent=2)) + return + + summary = summarize(rows) + if args.html: + html_text = visualize_html(rows, summary) + if args.output is not None: + Path(args.output).write_text(html_text) + else: + print(html_text) + return + if args.json: + print(json.dumps(summary, indent=2)) + else: + print(format_summary(summary)) + + +if __name__ == "__main__": + main() diff --git a/CLAUDE.md b/CLAUDE.md index f170541d482..35d96195fd8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -128,7 +128,7 @@ Env vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PT ## Debugging GPU Kernels -See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. +See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. See `AI/CLC_TRACE_DEBUG.md` for visualization of CLC scheduling. Key tools: - `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output diff --git a/benchmarks/clc_bench.py b/benchmarks/clc_bench.py index 18e7358a6d7..46ee55980eb 100644 --- a/benchmarks/clc_bench.py +++ b/benchmarks/clc_bench.py @@ -50,9 +50,9 @@ class DenseSweep: enabled: bool = True batches: list[int] = field(default_factory=lambda: [1, 4, 8, 16, 32]) seqlen_pairs: list[list[int]] = field( - default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + default_factory=lambda: [[32, 8192], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] ) - head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]]) causal: bool | list[bool] = True @@ -64,7 +64,7 @@ class VarlenSweep: max_kv_tokens: list[int] = field(default_factory=lambda: [2048, 4096, 8192, 16384, 32768]) batches: list[int] = field(default_factory=lambda: [4, 8, 16, 32]) patterns: list[str] = field(default_factory=lambda: ["uniform", "longtail"]) - head_dims: list[int] = field(default_factory=lambda: [64, 96, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 96, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 8], [16, 4], [16, 2], [16, 1]]) causal: bool | list[bool] = False @@ -76,7 +76,7 @@ class BlockSparseSweep: seqlen_pairs: list[list[int]] = field( default_factory=lambda: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] ) - head_dims: list[int] = field(default_factory=lambda: [64, 128]) + head_dims: list[int | list[int]] = field(default_factory=lambda: [64, 128, [192, 128]]) head_pairs: list[list[int]] = field(default_factory=lambda: [[16, 16], [16, 4], [16, 1]]) mask_names: list[str] = field(default_factory=lambda: ["block_diagonal"]) sliding_window_sizes: list[int] = field(default_factory=lambda: [2048]) @@ -89,6 +89,7 @@ class Case: q_heads: int kv_heads: int d: int + dv: int causal: bool batch: int | None = None seqlen_q: int | None = None @@ -116,12 +117,36 @@ def token_label(value: int) -> str: return f"{value // 1024}k" if value >= 1024 and value % 1024 == 0 else str(value) -def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: +def head_dim_label(d: int, dv: int) -> str: + return f"h{d}" if d == dv else f"h{d}_dv{dv}" + + +def head_dim_pairs(head_dims: list[int | list[int]]) -> list[tuple[int, int]]: + pairs: list[tuple[int, int]] = [] + invalid_pairs: list[int | list[int]] = [] + for dims in head_dims: + if isinstance(dims, int): + pairs.append((dims, dims)) + continue + if len(dims) == 1: + pairs.append((dims[0], dims[0])) + continue + if len(dims) == 2: + pairs.append((dims[0], dims[1])) + continue + invalid_pairs.append(dims) + if invalid_pairs: + raise ValueError(f"Expected d or [d] or [d, dv], got {invalid_pairs}") + return pairs + + +def dense_case_name(q_heads: int, kv_heads: int, causal: bool, d: int, dv: int, batch: int, seqlen_q: int, seqlen_k: int) -> str: causal_name = "causal" if causal else "noncausal" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) if seqlen_q == seqlen_k: - return f"{pair}_{causal_name}_h{d}_{token_label(seqlen_q)}_b{batch}" - return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_h{d}_b{batch}" + return f"{pair}_{causal_name}_{dims}_{token_label(seqlen_q)}_b{batch}" + return f"{pair}_{causal_name}_q{seqlen_q}_k{seqlen_k}_{dims}_b{batch}" def varlen_case_name( @@ -130,14 +155,16 @@ def varlen_case_name( kv_heads: int, causal: bool, d: int, + dv: int, batch: int, max_q_tokens: int, max_kv_tokens: int, ) -> str: causal_name = "causal" if causal else "noncausal" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) return ( - f"varlen_{pattern}_{pair}_{causal_name}_h{d}_" + f"varlen_{pattern}_{pair}_{causal_name}_{dims}_" f"b{batch}_q{token_label(max_q_tokens)}_kv{token_label(max_kv_tokens)}" ) @@ -200,21 +227,22 @@ def generate_cases( ) -> list[Case]: cases: list[Case] = [] if dense.enabled: - for batch, seqlen_pair, d, (q_heads, kv_heads), causal in product( + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), causal in product( dense.batches, dense.seqlen_pairs, - dense.head_dims, + head_dim_pairs(dense.head_dims), dense.head_pairs, bool_values(dense.causal), ): seqlen_q, seqlen_k = seqlen_pair cases.append( Case( - name=dense_case_name(q_heads, kv_heads, causal, d, batch, seqlen_q, seqlen_k), + name=dense_case_name(q_heads, kv_heads, causal, d, dv, batch, seqlen_q, seqlen_k), mode="dense", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=causal, batch=batch, seqlen_q=seqlen_q, @@ -222,12 +250,12 @@ def generate_cases( ) ) if varlen.enabled: - for max_q_tokens, max_kv_tokens, batch, pattern, d, (q_heads, kv_heads), causal in product( + for max_q_tokens, max_kv_tokens, batch, pattern, (d, dv), (q_heads, kv_heads), causal in product( varlen.max_q_tokens, varlen.max_kv_tokens, varlen.batches, varlen.patterns, - varlen.head_dims, + head_dim_pairs(varlen.head_dims), varlen.head_pairs, bool_values(varlen.causal), ): @@ -236,11 +264,12 @@ def generate_cases( lengths_k = normalize_lengths(weights, max(batch, max_kv_tokens)) cases.append( Case( - name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, batch, max_q_tokens, max_kv_tokens), + name=varlen_case_name(pattern, q_heads, kv_heads, causal, d, dv, batch, max_q_tokens, max_kv_tokens), mode="varlen", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=causal, batch=batch, seqlens_q=lengths_q, @@ -249,10 +278,10 @@ def generate_cases( ) ) if block_sparse.enabled: - for batch, seqlen_pair, d, (q_heads, kv_heads), mask_name in product( + for batch, seqlen_pair, (d, dv), (q_heads, kv_heads), mask_name in product( block_sparse.batches, block_sparse.seqlen_pairs, - block_sparse.head_dims, + head_dim_pairs(block_sparse.head_dims), block_sparse.head_pairs, block_sparse.mask_names, ): @@ -263,16 +292,18 @@ def generate_cases( for window_size in window_sizes: window_label = f"_w{window_size}" if window_size is not None else "" pair = head_pair_label(q_heads, kv_heads) + dims = head_dim_label(d, dv) cases.append( Case( name=( f"block_sparse_{mask_name}{window_label}_{pair}_" - f"h{d}_q{seqlen_q}_k{seqlen_k}_b{batch}" + f"{dims}_q{seqlen_q}_k{seqlen_k}_b{batch}" ), mode="block_sparse", q_heads=q_heads, kv_heads=kv_heads, d=d, + dv=dv, causal=False, batch=batch, seqlen_q=seqlen_q, @@ -302,11 +333,12 @@ def compile_signature(case: Case) -> tuple: case.q_heads, case.kv_heads, case.d, + case.dv, case.mask_name, case.window_size, q_stage, ) - return case.mode, case.q_heads, case.kv_heads, case.d, case.causal, q_stage + return case.mode, case.q_heads, case.kv_heads, case.d, case.dv, case.causal, q_stage def select_compile_cases(cases: list[Case]) -> list[Case]: @@ -369,7 +401,7 @@ def build_cu_seqlens(torch_mod, lengths: list[int]) -> torch_mod.Tensor: def build_dense_inputs(torch_mod, flash_attn_func, case: Case, dtype, factory): q = factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) return flash_attn_func, dict(q=q, k=k, v=v, causal=case.causal) @@ -380,7 +412,7 @@ def build_varlen_inputs(torch_mod, flash_attn_varlen_func, case: Case, dtype, fa total_k = sum(lengths_k) q = factory(total_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = factory(total_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = factory(total_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) return flash_attn_varlen_func, dict( q=q, k=k, @@ -412,7 +444,7 @@ def build_block_sparse_inputs(torch_mod, flash_attn_func, case: Case, dtype, ten raise ValueError(f"Aux-backed block-sparse masks are not supported by clc_bench.py: {case.mask_name}") q = tensor_factory(case.batch, case.seqlen_q, case.q_heads, case.d, device="cuda", dtype=dtype) k = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) - v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.d, device="cuda", dtype=dtype) + v = tensor_factory(case.batch, case.seqlen_k, case.kv_heads, case.dv, device="cuda", dtype=dtype) cute_mask, _ = get_mask_pair( case.mask_name, seqlen_q=case.seqlen_q, @@ -468,11 +500,40 @@ def build_inputs(case: Case, dtype_name: DTypeName, fake_tensor: bool): def attended_pairs(seqlen_q: int, seqlen_k: int, causal: bool) -> float: + """Lower-right aligned causal: last query aligns with last key. + When M > N, only the bottom N query rows attend (triangle of size N), + so valid pairs = N*(N+1)/2, not the upper-left formula M*N - N*(N-1)/2. + """ if not causal: return float(seqlen_q * seqlen_k) if seqlen_q <= seqlen_k: return float(seqlen_q * (2 * seqlen_k - seqlen_q + 1) / 2) - return float(seqlen_q * seqlen_k - seqlen_k * (seqlen_k - 1) / 2) + return float(seqlen_k * (seqlen_k + 1) / 2) + + +def block_sparse_pairs(case: Case) -> float: + seqlen_q = case.seqlen_q or 0 + seqlen_k = case.seqlen_k or 0 + match case.mask_name: + case "block_diagonal": + total = 0 + for q_idx in range(seqlen_q): + block_start = (q_idx // BLOCK_SIZE_K) * BLOCK_SIZE_K + block_end = min(block_start + BLOCK_SIZE_K, seqlen_k) + total += max(0, block_end - block_start) + return float(total) + case "sliding_window": + window = case.window_size or 0 + offset = seqlen_k - seqlen_q + total = 0 + for q_idx in range(seqlen_q): + center = q_idx + offset + lower = max(0, center - window) + upper = min(seqlen_k - 1, center + window) + total += max(0, upper - lower + 1) + return float(total) + case _: + raise ValueError(f"Unsupported block-sparse FLOP mask: {case.mask_name}") def fwd_flops(case: Case, kwargs: dict | None = None) -> float: @@ -481,19 +542,15 @@ def fwd_flops(case: Case, kwargs: dict | None = None) -> float: case.seqlen_q or 0, case.seqlen_k or 0, case.causal, - ) * (case.d + case.d) + ) * (case.d + case.dv) if case.mode == "block_sparse": - if kwargs is None: - return 0.0 - total_blocks = kwargs["mask_block_cnt"].sum().item() - if kwargs["full_block_cnt"] is not None: - total_blocks += kwargs["full_block_cnt"].sum().item() - return float(total_blocks * BLOCK_SIZE_Q * BLOCK_SIZE_K * case.q_heads * 2 * (case.d + case.d)) + num_pairs = (case.batch or 0) * block_sparse_pairs(case) + return case.q_heads * 2 * num_pairs * (case.d + case.dv) lengths_q = case.seqlens_q or [] lengths_k = case.seqlens_k or lengths_q total = 0.0 for seqlen_q, seqlen_k in zip(lengths_q, lengths_k): - total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.d) + total += case.q_heads * 2 * attended_pairs(seqlen_q, seqlen_k, case.causal) * (case.d + case.dv) return total @@ -531,6 +588,7 @@ def case_metadata(case: Case) -> dict: "q_heads": case.q_heads, "kv_heads": case.kv_heads, "d": case.d, + "dv": case.dv, "causal": case.causal, "pattern": case.pattern, "mask_name": case.mask_name, diff --git a/benchmarks/configs/clc.yaml b/benchmarks/configs/clc.yaml index b7bc5d4a949..94daf11770d 100644 --- a/benchmarks/configs/clc.yaml +++ b/benchmarks/configs/clc.yaml @@ -7,9 +7,9 @@ bench_iters: 256 dense: enabled: true batches: [1, 4, 8, 16, 32] - seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] - head_dims: [64, 96, 128] - head_pairs: [[16, 16], [16, 8], [16, 4], [16, 2], [16, 1]] + seqlen_pairs: [[32, 8192], [2048, 2048], [4096, 4096], [8192, 8192], [16384, 16384]] + head_dims: [64, 96, 128, [192, 128]] + head_pairs: [[16, 16], [16, 8], [16, 4], [16, 1]] causal: [true] varlen: @@ -20,15 +20,15 @@ varlen: # uniform: all sequences in the batch are similar length # longtail: a few long sequences plus many shorter ones patterns: [uniform, longtail] - head_dims: [64, 96, 128] - head_pairs: [[16, 8], [16, 4], [16, 2], [16, 1]] + head_dims: [64, 128, [192, 128]] + head_pairs: [[16, 8], [16, 4], [16, 1]] causal: [false] block_sparse: enabled: false batches: [1, 4, 8, 16, 32] seqlen_pairs: [[1024, 1024], [2048, 2048], [4096, 4096], [4096, 8192]] - head_dims: [64, 128] + head_dims: [64, 128, [192, 128]] head_pairs: [[16, 16], [16, 4], [16, 1]] # supported mask_names: block_diagonal, sliding_window mask_names: [block_diagonal]