diff --git a/.github/workflows/nvidia-4090.yml b/.github/workflows/nvidia-4090.yml index 7992b01d3e..984ba74f3a 100644 --- a/.github/workflows/nvidia-4090.yml +++ b/.github/workflows/nvidia-4090.yml @@ -4,6 +4,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.event_name == 'pull_request' }} +permissions: + contents: read + pull-requests: write + on: pull_request: branches: [ '*' ] diff --git a/.github/workflows/nvidia-a100.yml b/.github/workflows/nvidia-a100.yml index e14cb95578..de65dbc898 100644 --- a/.github/workflows/nvidia-a100.yml +++ b/.github/workflows/nvidia-a100.yml @@ -4,6 +4,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.event_name == 'pull_request' }} +permissions: + contents: read + pull-requests: write + on: pull_request: branches: [ '*' ] diff --git a/.github/workflows/nvidia-h100.yml b/.github/workflows/nvidia-h100.yml index 4e20d11adf..5ee2f2382c 100644 --- a/.github/workflows/nvidia-h100.yml +++ b/.github/workflows/nvidia-h100.yml @@ -4,6 +4,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: ${{ github.event_name == 'pull_request' }} +permissions: + contents: read + pull-requests: write + on: pull_request: branches: [ '*' ] diff --git a/.github/workflows/reusable-ci-benchmarks.yml b/.github/workflows/reusable-ci-benchmarks.yml index b0c11a7c04..83754e2b32 100644 --- a/.github/workflows/reusable-ci-benchmarks.yml +++ b/.github/workflows/reusable-ci-benchmarks.yml @@ -43,6 +43,11 @@ jobs: runs-on: ${{ inputs.runner }} env: FLA_CI_ENV: 1 + # Longer Triton timing windows + pause between commits reduce clock/thermal variance. + FLA_BENCH_WARMUP_MS: 40 + FLA_BENCH_REP_MS: 200 + FLA_BENCH_OP_WARMUP_ITERS: 6 + FLA_BENCH_COOLDOWN_SEC: 3 steps: - name: Check out repo (full history for cross-commit comparison) @@ -112,12 +117,14 @@ jobs: - name: Run benchmark comparison if: steps.check_skip.outputs.skip_bench == 'false' + id: benchmark shell: bash run: | $CONDA_BIN_PATH/python scripts/run_benchmark_compare.py \ --base ${{ inputs.base_ref }} \ --head HEAD \ --threshold ${{ inputs.threshold }} \ + --no-fail-on-regression \ --output benchmark_results.json - name: Upload benchmark results @@ -128,3 +135,107 @@ jobs: path: benchmark_results.json if-no-files-found: ignore retention-days: 30 + + - name: Post benchmark results to PR + if: steps.check_skip.outputs.skip_bench == 'false' && github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const path = 'benchmark_results.json'; + + if (!fs.existsSync(path)) { + console.log('No benchmark results file found'); + return; + } + + const data = JSON.parse(fs.readFileSync(path, 'utf8')); + const { base_sha, head_sha, machine_info, regressions, speedups, has_regression, has_speedup } = data; + + const gpuName = machine_info?.gpu_name || 'Unknown'; + const cudaVersion = machine_info?.cuda_version || 'Unknown'; + const pytorchVersion = machine_info?.pytorch_version || 'Unknown'; + const runnerName = '${{ inputs.runner }}'; + + // Format the summary table + let summary = ''; + if (has_regression || has_speedup) { + summary = '| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Change |\n'; + summary += '|:---|:---:|---:|---:|---:|---:|---:|---:|---:|\n'; + + // Add regressions + if (regressions && regressions.length > 0) { + for (const r of regressions) { + const change = `+${r.change_pct.toFixed(1)}% 🔴`; + summary += `| ${r.op} | ${r.mode} | ${r.B} | ${r.T} | ${r.H} | ${r.D} | ${r.base_ms.toFixed(3)} | ${r.head_ms.toFixed(3)} | ${change} |\n`; + } + } + + // Add speedups + if (speedups && speedups.length > 0) { + for (const s of speedups) { + const change = `${s.change_pct.toFixed(1)}% 🟢`; + summary += `| ${s.op} | ${s.mode} | ${s.B} | ${s.T} | ${s.H} | ${s.D} | ${s.base_ms.toFixed(3)} | ${s.head_ms.toFixed(3)} | ${change} |\n`; + } + } + } + + // Build the comment body + const statusEmoji = has_regression ? '⚠️' : '✅'; + const statusText = has_regression + ? `${regressions.length} regression(s) detected` + : 'No significant performance regressions detected'; + + let body = `## ${statusEmoji} Benchmark Results (${runnerName.toUpperCase()})\n\n`; + body += `**Status:** ${statusText}\n\n`; + body += `| | |\n`; + body += `|:---|:---|\n`; + body += `| **GPU** | ${gpuName} |\n`; + body += `| **CUDA** | ${cudaVersion} |\n`; + body += `| **PyTorch** | ${pytorchVersion} |\n`; + body += `| **Base** | \`${base_sha}\` |\n`; + body += `| **Head** | \`${head_sha}\` |\n`; + body += `| **Threshold** | ${{ inputs.threshold }}% |\n\n`; + + if (has_regression || has_speedup) { + body += `
\n📊 View Details (${(regressions?.length || 0) + (speedups?.length || 0)} significant changes)\n\n`; + body += summary; + body += '\n
\n'; + } else { + body += '> All benchmarked operations are within the performance threshold.\n'; + } + + body += '\n---\n'; + body += '*This comment is automatically updated with the latest benchmark results.*'; + + // Find existing comment + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`) + ); + + if (botComment) { + // Update existing comment + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: body, + }); + console.log(`Updated existing benchmark comment for ${runnerName}`); + } else { + // Create new comment + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: body, + }); + console.log(`Created new benchmark comment for ${runnerName}`); + } diff --git a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py b/benchmarks/cp/benchmark_chunk_delta_h_kernels.py deleted file mode 100644 index 8dd302a315..0000000000 --- a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py +++ /dev/null @@ -1,467 +0,0 @@ -#!/usr/bin/env python -""" -Benchmark script for chunk_delta_h kernels with cp=8 using triton.testing.do_bench. - -This script benchmarks all 6 kernels from generate_chunk_delta_h_cache(): -1. pre_process_fwd_kernel_stage1 (FWD) -2. pre_process_fwd_bwd_kernel_stage2 (FWD) -3. merge_fwd_bwd_kernel (FWD) -4. pre_process_bwd_kernel_stage1 (BWD) -5. pre_process_fwd_bwd_kernel_stage2 (BWD) -6. merge_fwd_bwd_kernel (BWD) - -Usage: - # Run benchmark with default settings (batch=1, head=32, headdim=128, cp=8, seqlen=32k per rank) - python benchmark_chunk_delta_h_kernels.py - - # Run with custom settings - python benchmark_chunk_delta_h_kernels.py --batch 1 --heads 32 --headdim 128 --seqlen 32768 -""" - -import argparse - -import torch -import triton -import triton.testing - -from fla.ops.cp.chunk_delta_h import ( - merge_fwd_bwd_kernel, - pre_process_bwd_kernel_merged, - pre_process_bwd_kernel_stage1, - pre_process_fwd_bwd_kernel_stage2, - pre_process_fwd_kernel_merged, - pre_process_fwd_kernel_stage1, -) -from fla.utils import device - -DTYPE = torch.bfloat16 - - -def get_args(): - parser = argparse.ArgumentParser( - description="Benchmark chunk_delta_h kernels using triton.testing.do_bench" - ) - parser.add_argument( - "--batch", type=int, default=1, help="Batch size (default: 1)" - ) - parser.add_argument( - "--heads", type=int, default=32, help="Number of heads (default: 32)" - ) - parser.add_argument( - "--headdim", type=int, default=128, help="Head dimension (default: 128)" - ) - parser.add_argument( - "--seqlen", - type=int, - default=32768, - help="Sequence length per CP rank (default: 32768 = 32k)", - ) - return parser.parse_args() - - -def create_tensors(B, T, H, K, V, device, dtype): - """Create input tensors for benchmarking.""" - # Tensors for forward kernels - k = torch.randn(B, T, H, K, device=device, dtype=dtype) - v = torch.randn(B, T, H, V, device=device, dtype=dtype) - w = torch.randn(B, T, H, K, device=device, dtype=dtype) - gk = torch.randn(B, T, H, K, device=device, dtype=torch.float32) - - # cu_seqlens for 1 chunk - cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) - - # hm tensor (zero-initialized) - hm = torch.zeros(H, K, V + K, device=device, dtype=torch.float32) - - # Tensors for backward kernels - q = k.clone() - do = torch.randn(B, T, H, V, device=device, dtype=dtype) - dv = torch.zeros_like(v) - dhm = torch.zeros_like(hm) - - # h_out and dht for merge kernels - h_out = torch.zeros(B, H, K, V, device=device, dtype=torch.float32) - dht = torch.randn(B, H, K, V, device=device, dtype=torch.float32) - - # ag_hm and ag_dhm for merge kernels (simulating cp=8) - num_ranks = 8 # Simulating cp=8 - stride = H * K * (K + V) - ag_hm = torch.zeros(num_ranks * stride, device=device, dtype=torch.float32) - ag_dhm = torch.zeros(num_ranks * stride, device=device, dtype=torch.float32) - - return { - "k": k, - "v": v, - "w": w, - "gk": gk, - "hm": hm, - "cu_seqlens": cu_seqlens, - "q": q, - "do": do, - "dv": dv, - "dhm": dhm, - "h_out": h_out, - "dht": dht, - "ag_hm": ag_hm, - "ag_dhm": ag_dhm, - } - - -def benchmark_all_kernels(B, T, H, K, V): - """Benchmark all 6 kernels from generate_chunk_delta_h_cache().""" - print(f"\n{'=' * 80}") - print("Benchmarking chunk_delta_h kernels (using triton.testing.do_bench)") - print(f"{'=' * 80}") - print("Configuration:") - print(f" Batch size: {B}") - print(f" Heads: {H}") - print(f" Head dim (K): {K}") - print(f" Head dim (V): {V}") - print(f" SeqLen per rank: {T}") - print(f" Dtype: {DTYPE}") - print(f" Device: {device}") - print(f"{'=' * 80}\n") - - # Create tensors - tensors = create_tensors(B, T, H, K, V, device, DTYPE) - BT = 64 - BK = triton.next_power_of_2(K) - - results = {} - quantiles = [0.5, 0.2, 0.8] # median, 20th percentile, 80th percentile - - # ======================================== - # 1. pre_process_fwd_kernel_stage1 (FWD) - # ======================================== - print("[1/6] Benchmarking pre_process_fwd_kernel_stage1 (FWD)...") - - def grid_stage1(meta): - return (triton.cdiv(V, meta["BV"]), H) - - def kernel_fwd_stage1(): - pre_process_fwd_kernel_stage1[grid_stage1]( - k=tensors["k"], - v=tensors["v"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - hm=tensors["hm"], - cu_seqlens=tensors["cu_seqlens"], - T=T, - H=H, - K=K, - V=V, - BT=BT, - USE_G=False, - USE_GK=True, - USE_EXP2=True, - IS_VARLEN=True, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_stage1, quantiles=quantiles) - results["pre_process_fwd_kernel_stage1"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # ======================================== - # 2. pre_process_fwd_bwd_kernel_stage2 (FWD) - # ======================================== - print("[2/6] Benchmarking pre_process_fwd_bwd_kernel_stage2 (FWD)...") - - def grid_stage2(meta): - return (triton.cdiv(K, meta["BK2"]), H) - - def kernel_fwd_stage2(): - pre_process_fwd_bwd_kernel_stage2[grid_stage2]( - k=tensors["k"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - hm=tensors["hm"], - cu_seqlens=tensors["cu_seqlens"], - T=T, - H=H, - K=K, - V=V, - BT=BT, - BK1=BK, - USE_G=False, - USE_GK=True, - USE_EXP2=True, - IS_VARLEN=True, - FORWARD=True, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_stage2, quantiles=quantiles) - results["pre_process_fwd_bwd_kernel_stage2 (FWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # ======================================== - # 2b. pre_process_fwd_kernel_merged (FWD) - # ======================================== - print("[2b/6] Benchmarking pre_process_fwd_kernel_merged (FWD)...") - - BLOCK_SIZE = 32 if K <= 64 else 64 - grid_merged = (triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE), H) - - def kernel_fwd_merged(): - pre_process_fwd_kernel_merged[grid_merged]( - k=tensors["k"], - v=tensors["v"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - hm=tensors["hm"], - cu_seqlens=tensors["cu_seqlens"], - T=T, - H=H, - K=K, - V=V, - BT=BT, - BK1=BK, - USE_G=False, - USE_GK=True, - USE_EXP2=True, - IS_VARLEN=True, - BLOCK_SIZE=BLOCK_SIZE, - MULTI_SEQS=False, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_merged, quantiles=quantiles) - results["pre_process_fwd_kernel_merged (FWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # Speedup comparison - stage1_time = results["pre_process_fwd_kernel_stage1"] - stage2_time = results["pre_process_fwd_bwd_kernel_stage2 (FWD)"] - merged_time = results["pre_process_fwd_kernel_merged (FWD)"] - original_total = stage1_time + stage2_time - speedup = original_total / merged_time if merged_time > 0 else 0 - print(f"\n [Speedup] Merged kernel vs Split kernels: {speedup:.2f}x") - print(f" Original (stage1+stage2): {original_total * 1000:.2f} us") - print(f" Merged: {merged_time * 1000:.2f} us") - - # ======================================== - # 3. merge_fwd_bwd_kernel (FWD) - # ======================================== - print("[3/6] Benchmarking merge_fwd_bwd_kernel (FWD)...") - - def grid_merge(meta): - return (triton.cdiv(V, meta["BV"]), H) - - def kernel_merge_fwd(): - merge_fwd_bwd_kernel[grid_merge]( - h=tensors["h_out"], - ag_hm=tensors["ag_hm"], - pre_or_post_num_ranks=1, - rank=1, - seq_offsets=None, - init_offsets=None, - h0_seq_ids=None, - h0=None, - H=H, - K=K, - V=V, - BK=BK, - FORWARD=True, - INTRACARD_MODE=False, - NUM_SEQ_ENTRIES=0, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_fwd, quantiles=quantiles) - results["merge_fwd_bwd_kernel (FWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # ======================================== - # 4. pre_process_bwd_kernel_stage1 (BWD) - # ======================================== - print("[4/6] Benchmarking pre_process_bwd_kernel_stage1 (BWD)...") - - def kernel_bwd_stage1(): - pre_process_bwd_kernel_stage1[grid_stage1]( - q=tensors["q"], - k=tensors["k"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - do=tensors["do"], - dhm=tensors["dhm"], - dv=tensors["dv"], - cu_seqlens=tensors["cu_seqlens"], - scale=1.0, - T=T, - H=H, - K=K, - V=V, - BT=BT, - USE_G=False, - USE_GK=True, - IS_VARLEN=True, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_stage1, quantiles=quantiles) - results["pre_process_bwd_kernel_stage1"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # ======================================== - # 5. pre_process_fwd_bwd_kernel_stage2 (BWD) - # ======================================== - print("[5/6] Benchmarking pre_process_fwd_bwd_kernel_stage2 (BWD)...") - - def kernel_bwd_stage2(): - pre_process_fwd_bwd_kernel_stage2[grid_stage2]( - k=tensors["k"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - hm=tensors["dhm"], - cu_seqlens=tensors["cu_seqlens"], - T=T, - H=H, - K=K, - V=V, - BT=BT, - BK1=BK, - USE_G=False, - USE_GK=True, - USE_EXP2=True, - IS_VARLEN=True, - FORWARD=False, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_stage2, quantiles=quantiles) - results["pre_process_fwd_bwd_kernel_stage2 (BWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # ======================================== - # 5b. pre_process_bwd_kernel_merged (BWD) - # ======================================== - print("[5b/6] Benchmarking pre_process_bwd_kernel_merged (BWD)...") - - BLOCK_SIZE_BWD = 32 if K <= 64 else 64 - grid_bwd_merged = (triton.cdiv(V + K, BLOCK_SIZE_BWD), H) - - def kernel_bwd_merged(): - pre_process_bwd_kernel_merged[grid_bwd_merged]( - q=tensors["q"], - k=tensors["k"], - w=tensors["w"], - g=None, - gk=tensors["gk"], - do=tensors["do"], - dhm=tensors["dhm"], - dv=tensors["dv"], - cu_seqlens=tensors["cu_seqlens"], - scale=1.0, - T=T, - H=H, - K=K, - V=V, - BT=BT, - BK1=BK, - USE_G=False, - USE_GK=True, - USE_EXP2=True, - IS_VARLEN=True, - BLOCK_SIZE=BLOCK_SIZE_BWD, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_merged, quantiles=quantiles) - results["pre_process_bwd_kernel_merged (BWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # Speedup comparison for backward - stage1_bwd_time = results["pre_process_bwd_kernel_stage1"] - stage2_bwd_time = results["pre_process_fwd_bwd_kernel_stage2 (BWD)"] - merged_bwd_time = results["pre_process_bwd_kernel_merged (BWD)"] - original_bwd_total = stage1_bwd_time + stage2_bwd_time - bwd_speedup = original_bwd_total / merged_bwd_time if merged_bwd_time > 0 else 0 - print(f"\n [Speedup] BWD Merged kernel vs Split kernels: {bwd_speedup:.2f}x") - print(f" Original (stage1+stage2): {original_bwd_total * 1000:.2f} us") - print(f" Merged: {merged_bwd_time * 1000:.2f} us") - - # ======================================== - # 6. merge_fwd_bwd_kernel (BWD) - # ======================================== - print("[6/6] Benchmarking merge_fwd_bwd_kernel (BWD)...") - - def kernel_merge_bwd(): - merge_fwd_bwd_kernel[grid_merge]( - h=tensors["dht"], - ag_hm=tensors["ag_dhm"], - pre_or_post_num_ranks=1, - rank=1, - seq_offsets=None, - init_offsets=None, - h0_seq_ids=None, - h0=None, - H=H, - K=K, - V=V, - BK=BK, - FORWARD=False, - INTRACARD_MODE=False, - NUM_SEQ_ENTRIES=0, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_bwd, quantiles=quantiles) - results["merge_fwd_bwd_kernel (BWD)"] = ms - print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)") - - # Print summary - print(f"\n{'=' * 80}") - print("Benchmark Summary") - print(f"{'=' * 80}") - print(f"{'Kernel Name':<50} {'Time (us)':<15} {'% of Total':<12}") - print("-" * 80) - - total_time = sum(results.values()) - for name, t in results.items(): - percentage = (t / total_time * 100) if total_time > 0 else 0 - print(f"{name:<50} {t * 1000:<15.2f} {percentage:<12.1f}") - - print("-" * 80) - print(f"{'Total':<50} {total_time * 1000:<15.2f} {'100.0':<12}") - print(f"{'=' * 80}\n") - - # Breakdown by direction (using split kernels for fair comparison) - fwd_time = ( - results["pre_process_fwd_kernel_stage1"] - + results["pre_process_fwd_bwd_kernel_stage2 (FWD)"] - + results["merge_fwd_bwd_kernel (FWD)"] - ) - bwd_time = ( - results["pre_process_bwd_kernel_stage1"] - + results["pre_process_fwd_bwd_kernel_stage2 (BWD)"] - + results["merge_fwd_bwd_kernel (BWD)"] - ) - - print(f"Forward pass total: {fwd_time * 1000:.2f} us ({fwd_time/total_time*100:.1f}%)") - print(f"Backward pass total: {bwd_time * 1000:.2f} us ({bwd_time/total_time*100:.1f}%)") - print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}") - print() - - return results - - -def main(): - args = get_args() - - B = args.batch - H = args.heads - K = args.headdim - V = args.headdim # Assuming V = K - T = args.seqlen - - # Print environment info - print(f"\nPyTorch version: {torch.__version__}") - print(f"CUDA available: {torch.cuda.is_available()}") - if torch.cuda.is_available(): - print(f"CUDA device: {torch.cuda.get_device_name(0)}") - print(f"CUDA capability: {torch.cuda.get_device_capability(0)}") - - # Run benchmark - benchmark_all_kernels(B, T, H, K, V) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/ops/registry.py b/benchmarks/ops/registry.py index 69c5a49d76..879afc45da 100644 --- a/benchmarks/ops/registry.py +++ b/benchmarks/ops/registry.py @@ -58,6 +58,14 @@ def logsigmoid_clamp(t): return F.logsigmoid(t).clamp_min(-5) +RWKV7_W_MIN = -0.6065306597126334 + + +def rwkv7_w_transform(t): + w = RWKV7_W_MIN * t.sigmoid() + return w.clamp(min=RWKV7_W_MIN, max=-1e-6) + + # --------------------------------------------------------------------------- # TensorSpec: describes how to create one input tensor # --------------------------------------------------------------------------- @@ -325,12 +333,13 @@ def _rwkv7_post_init(inputs, B, T, H, D, **kw): import_path='fla.ops.rwkv7', inputs={ 'r': TensorSpec(shape_BTHD), - 'w': TensorSpec(shape_BTHD, transform=logsigmoid), + 'w': TensorSpec(shape_BTHD, transform=rwkv7_w_transform), 'k': TensorSpec(shape_BTHD), 'v': TensorSpec(shape_BTHD), 'a': TensorSpec(shape_BTHD), 'b': TensorSpec(shape_BTHD), }, + extra_kwargs={'safe_gate': True, 'chunk_size': 64}, post_init=_rwkv7_post_init, category='rwkv', )) diff --git a/benchmarks/ops/run.py b/benchmarks/ops/run.py index 60d91a96d2..214fc78b88 100644 --- a/benchmarks/ops/run.py +++ b/benchmarks/ops/run.py @@ -8,6 +8,9 @@ # Benchmark one op (uses all default shape configs) python -m benchmarks.ops.run --op chunk_gla + # Ops touched by git diff (same rules as scripts/run_benchmark_compare.py) + python -m benchmarks.ops.run --from-diff --diff-base main --diff-head HEAD + # Multiple ops python -m benchmarks.ops.run --op chunk_gla chunk_kda @@ -104,10 +107,12 @@ Benchmark methodology ===================== -1. **Warmup**: For each (op, shape), run fwd+bwd 5 times to trigger all - triton autotuning. All shapes are warmed up before any timing begins. -2. **Timing**: ``triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])`` - gives median, p20, p80 in milliseconds. +1. **Warmup**: For each (op, shape), run fwd+bwd several times (default 5; + override with ``FLA_BENCH_OP_WARMUP_ITERS``) to trigger triton autotuning. + All shapes are warmed up before any timing begins. +2. **Timing**: ``triton.testing.do_bench`` with quantiles ``[0.5, 0.2, 0.8]``, + ``warmup``/``rep`` in milliseconds (defaults 25 / 100; set + ``FLA_BENCH_WARMUP_MS`` / ``FLA_BENCH_REP_MS`` for noisier machines / CI). 3. Input tensors (including gate transforms like logsigmoid) are prepared **before** timing — only the op call itself is measured. """ @@ -198,8 +203,22 @@ def _get_machine_info() -> dict: return info -def _warmup_autotune(fn, n=5): +def _warmup_iters() -> int: + """Extra per-shape forward+backward iterations before timing (not Triton do_bench warmup).""" + return max(1, int(os.environ.get('FLA_BENCH_OP_WARMUP_ITERS', '5'))) + + +def _do_bench_kw(): + """Triton ``do_bench`` uses warmup/rep in *milliseconds* of timed execution (see Triton docs).""" + warmup_ms = int(os.environ.get('FLA_BENCH_WARMUP_MS', '25')) + rep_ms = int(os.environ.get('FLA_BENCH_REP_MS', '100')) + return {'warmup': max(1, warmup_ms), 'rep': max(1, rep_ms)} + + +def _warmup_autotune(fn, n: int | None = None): """Run *fn* multiple times so triton autotuning is fully cached.""" + if n is None: + n = _warmup_iters() for _ in range(n): fn() torch.cuda.synchronize() @@ -299,7 +318,9 @@ def fn(inputs=inputs, do=do): t.backward(do) try: - ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + ms = triton.testing.do_bench( + fn, quantiles=[0.5, 0.2, 0.8], **_do_bench_kw() + ) except Exception as e: logger.warning(f"Bench failed for {op_name} {mode} @ {shape_name}: {e}") continue @@ -534,6 +555,18 @@ def main(): '--list', action='store_true', help='List all registered ops and exit', ) + parser.add_argument( + '--from-diff', action='store_true', + help='Select ops from git diff (use with --diff-base / --diff-head, not with --op)', + ) + parser.add_argument( + '--diff-base', default='main', + help='Base ref for --from-diff (default: main)', + ) + parser.add_argument( + '--diff-head', default='HEAD', + help='Head ref for --from-diff (default: HEAD)', + ) args = parser.parse_args() if args.list: @@ -544,10 +577,24 @@ def main(): print(f" {name:30s} [{cfg.category}] {cfg.import_path}") return - if args.op is None: - parser.error("--op is required (use --list to see available ops)") - - op_names = list_ops() if args.op == ['all'] else args.op + if args.from_diff: + if args.op is not None: + parser.error('--from-diff cannot be used with --op') + project_root = _find_project_root() + scripts_dir = os.path.join(project_root, 'scripts') + if scripts_dir not in sys.path: + sys.path.insert(0, scripts_dir) + import run_benchmark_compare as _diff + changed = _diff.get_changed_files(args.diff_base, args.diff_head) + op_names = _diff.find_affected_op_names(changed) + if not op_names: + print('No affected ops for this diff.', file=sys.stderr) + return + elif args.op is None: + parser.error("--op is required unless --from-diff (use --list to see available ops)") + + if not args.from_diff: + op_names = list_ops() if args.op == ['all'] else args.op shape_configs = json.loads(args.custom_shapes) if args.custom_shapes else SHAPE_CONFIGS machine_info = _get_machine_info() diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index a574a5653b..3e3edb1b93 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -46,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( chunk_offsets, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -91,7 +92,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # calculate offset h += (boh * H + i_h).to(tl.int64) * K*V v += (bos * H + i_h).to(tl.int64) * V - k += (bos * H + i_h).to(tl.int64) * K + k += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K w += (bos * H + i_h).to(tl.int64) * K if SAVE_NEW_VALUE: v_new += (bos * H + i_h).to(tl.int64) * V @@ -263,28 +264,28 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_v = b_v.to(k.dtype.element_ty) - p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (0, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if TRANSPOSE_STATE: b_h1 += tl.trans(tl.dot(b_k, b_v)) else: b_h1 += tl.dot(b_k, b_v) if K > 64: - p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (64, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if TRANSPOSE_STATE: b_h2 += tl.trans(tl.dot(b_k, b_v)) else: b_h2 += tl.dot(b_k, b_v) if K > 128: - p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (128, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if TRANSPOSE_STATE: b_h3 += tl.trans(tl.dot(b_k, b_v)) else: b_h3 += tl.dot(b_k, b_v) if K > 192: - p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (192, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) if TRANSPOSE_STATE: b_h4 += tl.trans(tl.dot(b_k, b_v)) @@ -353,6 +354,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( scale, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -395,8 +397,8 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( b_dh4 = tl.zeros([64, BV], dtype=tl.float32) # calculate offset - q += (bos * H + i_h).to(tl.int64) * K - k += (bos * H + i_h).to(tl.int64) * K + q += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K + k += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K w += (bos * H + i_h).to(tl.int64) * K do += (bos * H + i_h).to(tl.int64) * V dv += (bos * H + i_h).to(tl.int64) * V @@ -480,7 +482,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( b_do = tl.load(p_do, boundary_check=(0, 1)) # Update dv - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k1 = tl.arange(0, 64) @@ -491,7 +493,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) if K > 64: - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k2 = 64 + o_k1 @@ -502,7 +504,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) if K > 128: - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k3 = 128 + o_k1 @@ -513,7 +515,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) if K > 192: - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k4 = 192 + o_k1 @@ -534,7 +536,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) # Update dh p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (0, i_t * BT), (64, BT), (0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) if USE_G: @@ -556,7 +558,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( else: b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 64: - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (64, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) @@ -579,7 +581,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( else: b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 128: - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (128, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) @@ -602,7 +604,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( else: b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 192: - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (192, i_t * BT), (64, BT), (0, 1)) p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) @@ -668,7 +670,9 @@ def chunk_gated_delta_rule_fwd_h( use_exp2: bool = False, transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - B, T, H, K, V = *k.shape, u.shape[-1] + B, T, Hq, K = k.shape + V = u.shape[-1] + H = u.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: @@ -703,6 +707,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) chunk_offsets=chunk_offsets, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -729,7 +734,9 @@ def chunk_gated_delta_rule_bwd_dhu( use_exp2: bool = False, transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, T, H, K, V = *q.shape, do.shape[-1] + B, T, Hq, K = q.shape + V = do.shape[-1] + H = do.shape[2] # N: the actual number of sequences in the batch with either equal or variable lengths BT = 64 assert K <= 256, "current kernel does not support head dimension being larger than 256." @@ -766,6 +773,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) scale=scale, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py index 3e1fdae2e0..161696daea 100644 --- a/fla/ops/common/chunk_o.py +++ b/fla/ops/common/chunk_o.py @@ -40,6 +40,7 @@ def chunk_fwd_kernel_o( scale, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -66,8 +67,8 @@ def chunk_fwd_kernel_o( bos, eos = i_b * T, i_b * T + T # offset calculation - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K + q += (bos * Hq + i_h // (H // Hq)) * K + k += (bos * Hq + i_h // (H // Hq)) * K v += (bos * H + i_h) * V o += (bos * H + i_h) * V h += (i_tg * H + i_h).to(tl.int64) * K*V @@ -76,8 +77,8 @@ def chunk_fwd_kernel_o( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) if TRANSPOSE_STATE: p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) else: @@ -168,6 +169,7 @@ def chunk_bwd_kernel_dqkwg( B: tl.constexpr, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -200,8 +202,8 @@ def chunk_bwd_kernel_dqkwg( do += (bos * H + i_h) * V h += (i_tg * H + i_h).to(tl.int64) * K*V dh += (i_tg * H + i_h).to(tl.int64) * K*V - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K + q += (bos * Hq + i_h // (H // Hq)) * K + k += (bos * Hq + i_h // (H // Hq)) * K dq += (bos * H + i_h) * K dk += (bos * H + i_h) * K @@ -255,8 +257,8 @@ def chunk_bwd_kernel_dqkwg( tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) tl.debug_barrier() - p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) @@ -362,6 +364,7 @@ def chunk_bwd_kernel_dv( scale, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -388,16 +391,16 @@ def chunk_bwd_kernel_dv( b_dv = tl.zeros([BT, BV], dtype=tl.float32) # offset calculation - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K + q += (bos * Hq + i_h // (H // Hq)) * K + k += (bos * Hq + i_h // (H // Hq)) * K do += (bos * H + i_h) * V dv += (bos * H + i_h) * V dh += (i_tg * H + i_h).to(tl.int64) * K*V b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_A += tl.dot(b_k, b_q) @@ -463,6 +466,7 @@ def chunk_bwd_kernel_dv_local( scale, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -484,8 +488,8 @@ def chunk_bwd_kernel_dv_local( bos, eos = i_b * T, i_b * T + T # offset calculation - q += (bos * H + i_h) * K - k += (bos * H + i_h) * K + q += (bos * Hq + i_h // (H // Hq)) * K + k += (bos * Hq + i_h // (H // Hq)) * K do += (bos * H + i_h) * V dv += (bos * H + i_h) * V @@ -503,8 +507,8 @@ def chunk_bwd_kernel_dv_local( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) @@ -542,7 +546,9 @@ def chunk_fwd_o( use_exp2: bool = False, transpose_state_layout: bool = False, ) -> torch.Tensor: - B, T, H, K, V = *q.shape, v.shape[-1] + B, T, Hq, K = q.shape + V = v.shape[-1] + H = v.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -565,6 +571,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) scale=scale, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -587,7 +594,9 @@ def chunk_bwd_dv( chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> torch.Tensor: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, Hq, K = k.shape + V = do.shape[-1] + H = do.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -620,6 +629,7 @@ def chunk_bwd_dv( scale=scale, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -643,7 +653,9 @@ def chunk_bwd_dv_local( chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> torch.Tensor: - B, T, H, K, V = *k.shape, do.shape[-1] + B, T, Hq, K = k.shape + V = do.shape[-1] + H = do.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -673,6 +685,7 @@ def chunk_bwd_dv_local( scale=scale, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -702,7 +715,9 @@ def chunk_bwd_dqkwg( transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, Hq, K = k.shape + V = v.shape[-1] + H = v.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -717,8 +732,8 @@ def chunk_bwd_dqkwg( BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) NK = triton.cdiv(K, BK) - dq = torch.empty_like(q) - dk = torch.empty_like(k) + dq = q.new_empty(B, T, H, K) + dk = k.new_empty(B, T, H, K) dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None dw = torch.empty_like(w) if w is not None else None @@ -743,6 +758,7 @@ def chunk_bwd_dqkwg( B=B, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -752,6 +768,9 @@ def chunk_bwd_dqkwg( TRANSPOSE_STATE=transpose_state_layout, ) + if Hq != H: + dq = dq.view(B, T, Hq, H // Hq, K).sum(3) + dk = dk.view(B, T, Hq, H // Hq, K).sum(3) if dg is not None: dg = dg.sum(0) return dq, dk, dw, dg diff --git a/fla/ops/common/chunk_scaled_dot_kkt.py b/fla/ops/common/chunk_scaled_dot_kkt.py index 2c33a12db4..51713bb64e 100644 --- a/fla/ops/common/chunk_scaled_dot_kkt.py +++ b/fla/ops/common/chunk_scaled_dot_kkt.py @@ -34,6 +34,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( chunk_indices, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, @@ -56,7 +57,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_A += tl.dot(b_k, tl.trans(b_k)) @@ -87,9 +88,9 @@ def chunk_scaled_dot_kkt_fwd( Args: k (torch.Tensor): - The key tensor of shape `[B, T, H, K]`. + The key tensor of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads. beta (torch.Tensor): - The beta tensor of shape `[B, T, H]`. + The beta tensor of shape `[B, T, H]` where `H` is the number of value/output heads. g (torch.Tensor): The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. gk (torch.Tensor): @@ -104,8 +105,10 @@ def chunk_scaled_dot_kkt_fwd( Returns: beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + For GQA, Hq < H and H % Hq == 0. For standard attention, Hq == H. """ - B, T, H, K = k.shape + B, T, Hq, K = k.shape + H = beta.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -120,6 +123,7 @@ def chunk_scaled_dot_kkt_fwd( chunk_indices=chunk_indices, T=T, H=H, + Hq=Hq, K=K, BT=BT, ) diff --git a/fla/ops/common/intracard_cp.py b/fla/ops/common/intracard_cp.py index e73a308f96..35d3964939 100644 --- a/fla/ops/common/intracard_cp.py +++ b/fla/ops/common/intracard_cp.py @@ -86,7 +86,9 @@ def _raw_chunk_gated_delta_rule_fwd_h( use_exp2: bool = False, transpose_state_layout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - B, T, H, K, V = *k.shape, u.shape[-1] + B, T, Hq, K = k.shape + V = u.shape[-1] + H = u.shape[2] BT = chunk_size if chunk_indices is None and cu_seqlens is not None: @@ -111,7 +113,7 @@ def grid(meta): k=k, v=u, w=w, v_new=v_new, g=g, gk=gk, h=h, h0=initial_state, ht=final_state, cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, - T=T, H=H, K=K, V=V, BT=BT, USE_EXP2=use_exp2, + T=T, H=H, Hq=Hq, K=K, V=V, BT=BT, USE_EXP2=use_exp2, TRANSPOSE_STATE=transpose_state_layout, ) return h, v_new, final_state @@ -235,13 +237,15 @@ def intracard_pre_scan( kg: torch.Tensor, w: torch.Tensor, u: torch.Tensor, - gk: torch.Tensor, + g: torch.Tensor | None, + gk: torch.Tensor | None, cu_seqlens_subseq_split: torch.Tensor, S_split: int, chunk_size: int = 64, use_exp2: bool = True, ): - H, K, V = kg.shape[2], kg.shape[3], u.shape[3] + Hq, K, V = kg.shape[2], kg.shape[3], u.shape[3] + H = u.shape[2] BK = triton.next_power_of_2(K) BLOCK_SIZE = 32 if K <= 64 else 64 @@ -252,12 +256,13 @@ def intracard_pre_scan( k=kg, v=u, w=w, - g=None, + g=g, gk=gk, hm=hm, cu_seqlens=cu_seqlens_subseq_split, T=0, H=H, + Hq=Hq, K=K, V=V, BT=chunk_size, @@ -430,7 +435,9 @@ def intracard_fwd_h( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: assert cu_seqlens is not None, "intracard_fwd_h requires cu_seqlens" - _, _, H, K, V = *k.shape, u.shape[-1] + _, _, _Hq, K = k.shape + V = u.shape[-1] + H = u.shape[2] device = k.device if cu_seqlens_cpu is None: @@ -542,7 +549,7 @@ def intracard_fwd_h( _intracard_cache.popitem(last=False) hm = intracard_pre_scan( - kg=k, w=w, u=u, gk=gk, + kg=k, w=w, u=u, g=g, gk=gk, cu_seqlens_subseq_split=cu_seqlens_split_flat, S_split=S_split_total, chunk_size=chunk_size, diff --git a/fla/ops/cp/chunk_delta_h.py b/fla/ops/cp/chunk_delta_h.py index 6cd2a12718..80acfa9316 100644 --- a/fla/ops/cp/chunk_delta_h.py +++ b/fla/ops/cp/chunk_delta_h.py @@ -15,265 +15,6 @@ from fla.ops.cp.context import FLACPContext -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'USE_GK': lambda args: args['gk'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) -@triton.autotune( - configs=[ - triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in [2, 3, 4] - for BV in [32, 64] - ], - key=['H', 'K', 'V', 'BT', 'USE_EXP2', "STAGE"], - use_cuda_graph=USE_CUDA_GRAPH, - **autotune_cache_kwargs, -) -@triton.jit(do_not_specialize=['T']) -def pre_process_fwd_kernel_stage1( - k, - v, - w, - g, - gk, - hm, - cu_seqlens, - T, - H: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - BV: tl.constexpr, - USE_G: tl.constexpr, - USE_GK: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_v, i_h = tl.program_id(0), tl.program_id(1) - i_n = 0 - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) - T = (eos - bos).to(tl.int32) - NT = tl.cdiv(T, BT) - else: - bos, eos = (i_n * T).to(tl.int64), (i_n * T + T).to(tl.int64) - NT = tl.cdiv(T, BT) - - # calculate offset - hm += i_h * K * (K + V) - v += ((bos * H + i_h) * V).to(tl.int64) - k += ((bos * H + i_h) * K).to(tl.int64) - w += ((bos * H + i_h) * K).to(tl.int64) - stride_v = H*V - stride_k = H*K - - b_h1 = tl.zeros([64, BV], dtype=tl.float32) - if K > 64: - b_h2 = tl.zeros([64, BV], dtype=tl.float32) - if K > 128: - b_h3 = tl.zeros([64, BV], dtype=tl.float32) - if K > 192: - b_h4 = tl.zeros([64, BV], dtype=tl.float32) - # main recurrence - for i_t in range(NT): - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) - if K > 64: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) - if K > 128: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) - if K > 192: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) - p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v - - last_idx = min((i_t + 1) * BT, T) - 1 - if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T - b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) - if USE_EXP2: - b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] - b_g_last = exp2(b_g_last) - else: - b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] - b_g_last = exp(b_g_last) - b_h1 *= b_g_last - if K > 64: - b_h2 *= b_g_last - if K > 128: - b_h3 *= b_g_last - if K > 192: - b_h4 *= b_g_last - - if USE_GK: - o_k1 = tl.arange(0, 64) - b_gk_last1 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) - if USE_EXP2: - b_h1 *= exp2(b_gk_last1)[:, None] - else: - b_h1 *= exp(b_gk_last1)[:, None] - if K > 64: - o_k2 = 64 + o_k1 - b_gk_last2 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) - if USE_EXP2: - b_h2 *= exp2(b_gk_last2)[:, None] - else: - b_h2 *= exp(b_gk_last2)[:, None] - if K > 128: - o_k3 = 128 + o_k1 - b_gk_last3 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) - if USE_EXP2: - b_h3 *= exp2(b_gk_last3)[:, None] - else: - b_h3 *= exp(b_gk_last3)[:, None] - if K > 192: - o_k4 = 192 + o_k1 - b_gk_last4 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) - if USE_EXP2: - b_h4 *= exp2(b_gk_last4)[:, None] - else: - b_h4 *= exp(b_gk_last4)[:, None] - - b_v = b_v.to(k.dtype.element_ty) - - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h1 += tl.dot(b_k, b_v) - if K > 64: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h2 += tl.dot(b_k, b_v) - if K > 128: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h3 += tl.dot(b_k, b_v) - if K > 192: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h4 += tl.dot(b_k, b_v) - - p_h1 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (0, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) - if K > 64: - p_h2 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) - if K > 128: - p_h3 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) - if K > 192: - p_h4 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - - -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'USE_GK': lambda args: args['gk'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) -@triton.autotune( - configs=[ - triton.Config({'BK2': BK2}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in [2, 3, 4] - for BK2 in [32] - ], - key=['H', 'BT', 'USE_EXP2', 'FORWARD'], - use_cuda_graph=USE_CUDA_GRAPH, - **autotune_cache_kwargs, -) -@triton.jit(do_not_specialize=['T']) -def pre_process_fwd_bwd_kernel_stage2( - k, - w, - g, - gk, - hm, - cu_seqlens, - T, - H: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - USE_G: tl.constexpr, - USE_GK: tl.constexpr, - USE_EXP2: tl.constexpr, - IS_VARLEN: tl.constexpr, - BK1: tl.constexpr, - BK2: tl.constexpr, - FORWARD: tl.constexpr = True, -): - i_k_col, i_h = tl.program_id(0), tl.program_id(1) - i_n = 0 - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - NT = tl.cdiv(T, BT) - else: - bos, eos = i_n * T, i_n * T + T - NT = tl.cdiv(T, BT) - - # calculate offset - hm += i_h * K * (K + V) - k += ((bos * H + i_h) * K).to(tl.int64) - w += ((bos * H + i_h) * K).to(tl.int64) - stride_k = H*K - - row = tl.arange(0, BK1) - col = tl.arange(0, BK2) + i_k_col * BK2 - - b_m = tl.where(row[:, None] == col[None, :], 1.0, 0.0) - for _i_t in range(NT): - if FORWARD: - i_t = _i_t - else: - i_t = NT - 1 - _i_t - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - last_idx = min((i_t + 1) * BT, T) - 1 - if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T - b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) - if USE_EXP2: - b_k = b_k * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] - b_g_last = exp2(b_g_last) - else: - b_k = b_k * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] - b_g_last = exp(b_g_last) - b_diag = tl.where(row[:, None] == row[None, :], b_g_last, 0.0) - elif USE_GK: - b_gk_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + row, mask=(row < K), other=0.).to(tl.float32) - if USE_EXP2: - b_gk_last = exp2(b_gk_last) - else: - b_gk_last = exp(b_gk_last) - b_diag = tl.where(row[:, None] == row[None, :], b_gk_last[:, None], 0.0) - else: - b_diag = tl.where(row[:, None] == row[None, :], 1., 0.0) - if FORWARD: - b_kw = tl.dot(tl.trans(b_k.to(b_w.dtype)), b_w) - else: - b_kw = tl.dot(tl.trans(b_w), b_k.to(b_w.dtype)) - b_m_i = b_diag - b_kw - b_m = tl.dot(b_m_i.to(b_w.dtype), b_m.to(b_w.dtype)) - p_m = tl.make_block_ptr(hm + V, (K, K), (K+V, 1), (0, i_k_col * BK2), (BK1, BK2), (1, 0)) - tl.store(p_m, b_m.to(p_m.dtype.element_ty), boundary_check=(0, 1)) - - @triton.heuristics({ 'USE_G': lambda args: args['g'] is not None, 'USE_GK': lambda args: args['gk'] is not None, @@ -300,6 +41,7 @@ def pre_process_fwd_kernel_merged( cu_seqlens, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -331,9 +73,10 @@ def pre_process_fwd_kernel_merged( # i_col is in range [0, cdiv(V + K, BLOCK_SIZE)) # Columns [0, V) are for h, columns [V, V+K) are for m is_h_part = i_col * BLOCK_SIZE < V - k += ((bos * H + i_h) * K).to(tl.int64) + k += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64) w += ((bos * H + i_h) * K).to(tl.int64) - stride_k = H * K + stride_k = Hq * K + stride_w = H * K if is_h_part: # ====== Stage 1: Compute h (K x V) ====== @@ -353,19 +96,19 @@ def pre_process_fwd_kernel_merged( # Main recurrence for h for i_t in range(NT): # Compute decayed v - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_decay = tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_decay += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_decay += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_decay += tl.dot(b_w, b_h4.to(b_w.dtype)) @@ -480,7 +223,7 @@ def pre_process_fwd_kernel_merged( # Load k and w with full BK1 rows p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) last_idx = min((i_t + 1) * BT, T) - 1 @@ -672,182 +415,6 @@ def merge_fwd_bwd_kernel( tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'USE_GK': lambda args: args['gk'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) -@triton.autotune( - configs=[ - triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in ([4, 3, 2] if check_shared_mem('ampere') else [1]) - for BV in [64, 32] - ], - key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'], - use_cuda_graph=USE_CUDA_GRAPH, - **autotune_cache_kwargs, -) -@triton.jit(do_not_specialize=['T']) -def pre_process_bwd_kernel_stage1( - q, - k, - w, - g, - gk, - do, - dhm, - dv, - cu_seqlens, - scale, - T, - H: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - BV: tl.constexpr, - USE_G: tl.constexpr, - USE_GK: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_v, i_h = tl.program_id(0), tl.program_id(1) - i_n = 0 - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) - T = (eos - bos).to(tl.int32) - NT = tl.cdiv(T, BT) - else: - bos, eos = (i_n * T).to(tl.int64), (i_n * T + T).to(tl.int64) - NT = tl.cdiv(T, BT) - - # [BK, BV] - b_dh1 = tl.zeros([64, BV], dtype=tl.float32) - if K > 64: - b_dh2 = tl.zeros([64, BV], dtype=tl.float32) - if K > 128: - b_dh3 = tl.zeros([64, BV], dtype=tl.float32) - if K > 192: - b_dh4 = tl.zeros([64, BV], dtype=tl.float32) - - # calculate offset - q += ((bos * H + i_h) * K).to(tl.int64) - k += ((bos * H + i_h) * K).to(tl.int64) - w += ((bos * H + i_h) * K).to(tl.int64) - do += ((bos * H + i_h) * V).to(tl.int64) - dv += ((bos * H + i_h) * V).to(tl.int64) - dhm += i_h * K * (V + K) - - stride_v = H*V - stride_k = H*K - - for i_t in range(NT - 1, -1, -1): - last_idx = min((i_t + 1) * BT, T) - 1 - if USE_G: - bg_last = tl.load(g + (bos + last_idx) * H + i_h).to(tl.float32) - bg_last_exp = exp(bg_last) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) - b_g_exp = exp(b_g) - - p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - - # Update dv - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - if USE_GK: - o_k1 = tl.arange(0, 64) - b_gk_last1 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) - b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) - - if K > 64: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - if USE_GK: - o_k2 = 64 + o_k1 - b_gk_last2 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) - b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) - - if K > 128: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - if USE_GK: - o_k3 = 128 + o_k1 - b_gk_last3 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) - b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) - - if K > 192: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - if USE_GK: - o_k4 = 192 + o_k1 - b_gk_last4 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) - b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) - - if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T - b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None] - b_dv += tl.load(p_dv, boundary_check=(0, 1)) - - # Update dh - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - if USE_G: - b_dh1 *= bg_last_exp - b_q = b_q * b_g_exp[None, :] - if USE_GK: - b_dh1 *= exp(b_gk_last1[:, None]) - b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) - if K > 64: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - if USE_G: - b_dh2 *= bg_last_exp - b_q = b_q * b_g_exp[None, :] - if USE_GK: - b_dh2 *= exp(b_gk_last2[:, None]) - b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) - if K > 128: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - if USE_G: - b_dh3 *= bg_last_exp - b_q = b_q * b_g_exp[None, :] - if USE_GK: - b_dh3 *= exp(b_gk_last3[:, None]) - b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) - if K > 192: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - if USE_G: - b_dh4 *= bg_last_exp - b_q = b_q * b_g_exp[None, :] - if USE_GK: - b_dh4 *= exp(b_gk_last4[:, None]) - b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) - - p_dh1 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (0, i_v * BV), (64, BV), (1, 0)) - tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) - if K > 64: - p_dh2 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) - if K > 128: - p_dh3 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) - if K > 192: - p_dh4 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) - - @triton.heuristics({ 'USE_G': lambda args: args['g'] is not None, 'USE_GK': lambda args: args['gk'] is not None, @@ -877,6 +444,7 @@ def pre_process_bwd_kernel_merged( scale, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -908,11 +476,12 @@ def pre_process_bwd_kernel_merged( is_dh_part = i_col * BLOCK_SIZE < V # Calculate offsets - q += ((bos * H + i_h) * K).to(tl.int64) - k += ((bos * H + i_h) * K).to(tl.int64) + q += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64) + k += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64) w += ((bos * H + i_h) * K).to(tl.int64) dhm += i_h * K * (V + K) - stride_k = H * K + stride_qk = Hq * K + stride_w = H * K if is_dh_part: # ====== Stage 1: Compute dh (K x V) ====== @@ -948,7 +517,7 @@ def pre_process_bwd_kernel_merged( b_do = tl.load(p_do, boundary_check=(0, 1)) # Update dv - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k1 = tl.arange(0, 64) @@ -961,7 +530,7 @@ def pre_process_bwd_kernel_merged( b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) if K > 64: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k2 = 64 + o_k1 @@ -974,7 +543,7 @@ def pre_process_bwd_kernel_merged( b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) if K > 128: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k3 = 128 + o_k1 @@ -987,7 +556,7 @@ def pre_process_bwd_kernel_merged( b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) if K > 192: - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) if USE_GK: o_k4 = 192 + o_k1 @@ -1006,8 +575,8 @@ def pre_process_bwd_kernel_merged( b_dv += tl.load(p_dv, boundary_check=(0, 1)) # Update dh - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (0, i_t * BT), (64, BT), (0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) if USE_G: @@ -1021,8 +590,8 @@ def pre_process_bwd_kernel_merged( b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 64: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (64, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: @@ -1036,8 +605,8 @@ def pre_process_bwd_kernel_merged( b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 128: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (128, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: @@ -1051,8 +620,8 @@ def pre_process_bwd_kernel_merged( b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) if K > 192: - p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) - p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (192, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1)) if USE_G: @@ -1096,9 +665,9 @@ def pre_process_bwd_kernel_merged( i_t = NT - 1 - _i_t # Load k and w with full BK1 rows - p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) - p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, BK1), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) last_idx = min((i_t + 1) * BT, T) - 1 @@ -1155,7 +724,9 @@ def chunk_gated_delta_rule_fwd_h_pre_process( assert initial_state is None, "When enable CP, the provided initial_state must be None." rank = dist.get_rank(group=context.group) - B, T, H, K, V = *k.shape, u.shape[-1] + B, T, Hq, K = k.shape + V = u.shape[-1] + H = u.shape[2] BT = chunk_size BK = triton.next_power_of_2(K) @@ -1184,6 +755,7 @@ def chunk_gated_delta_rule_fwd_h_pre_process( cu_seqlens=cu_seqlens[-2:], T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -1237,7 +809,9 @@ def chunk_gated_delta_rule_bwd_dhu_pre_process( assert dht is None, "When enable CP, the provided dht must be None." rank = dist.get_rank(context.group) - B, T, H, K, V = *q.shape, do.shape[-1] + B, T, Hq, K = q.shape + H = do.shape[2] + V = do.shape[-1] # N: the actual number of sequences in the batch with either equal or variable lengths BT = 64 assert K <= 256, "current kernel does not support head dimension being larger than 256." @@ -1270,6 +844,7 @@ def chunk_gated_delta_rule_bwd_dhu_pre_process( scale=scale, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py index 9980f5917f..429232e9bd 100644 --- a/fla/ops/gated_delta_rule/chunk.py +++ b/fla/ops/gated_delta_rule/chunk.py @@ -322,11 +322,12 @@ def chunk_gated_delta_rule( r""" Args: q (torch.Tensor): - queries of shape `[B, T, H, K]`. + queries of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads. k (torch.Tensor): - keys of shape `[B, T, H, K]`. + keys of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads. v (torch.Tensor): - values of shape `[B, T, H, V]`. + values of shape `[B, T, H, V]` where `H` is the number of value/output heads. + For standard attention, `Hq == H`. For GQA, `H % Hq == 0`. g (torch.Tensor): (forget) gating tensor (in log space!) of shape `[B, T, H]`. beta (torch.Tensor): @@ -388,6 +389,14 @@ def chunk_gated_delta_rule( cu_seqlens=cu_seqlens ) """ + # Validate GQA head divisibility + Hq, H = q.shape[2], v.shape[2] + if H % Hq != 0: + raise ValueError( + f"For GQA, num_heads (H={H}) must be evenly divisible by " + f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}" + ) + if 'head_first' in kwargs: warnings.warn( "head_first is deprecated and will be removed in a future version. " diff --git a/fla/ops/gated_delta_rule/chunk_fwd.py b/fla/ops/gated_delta_rule/chunk_fwd.py index 7c6c764f77..e86833ff08 100644 --- a/fla/ops/gated_delta_rule/chunk_fwd.py +++ b/fla/ops/gated_delta_rule/chunk_fwd.py @@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel( chunk_indices, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, @@ -77,7 +78,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel( i_tc2 = i_t * BT + 2 * BC i_tc3 = i_t * BT + 3 * BC - k += (bos * H + i_h) * K + k += (bos * Hq + i_h // (H // Hq)) * K A += (bos * H + i_h) * BT o_i = tl.arange(0, BC) @@ -127,13 +128,13 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel( b_A32 = tl.zeros([BC, BC], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k0 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + p_k0 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) b_k0 = tl.load(p_k0, boundary_check=(0, 1)) # diagonal block 0 b_A00 += tl.dot(b_k0, tl.trans(b_k0)) if i_tc1 < T: - p_k1 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_k1 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) b_k1 = tl.load(p_k1, boundary_check=(0, 1)) # diagonal block 1 b_A11 += tl.dot(b_k1, tl.trans(b_k1)) @@ -141,7 +142,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel( b_A10 += tl.dot(b_k1, tl.trans(b_k0)) if i_tc2 < T: - p_k2 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) b_k2 = tl.load(p_k2, boundary_check=(0, 1)) # diagonal block 2 b_A22 += tl.dot(b_k2, tl.trans(b_k2)) @@ -150,7 +151,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel( b_A21 += tl.dot(b_k2, tl.trans(b_k1)) if i_tc3 < T: - p_k3 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_k3 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) b_k3 = tl.load(p_k3, boundary_check=(0, 1)) # diagonal block 3 b_A33 += tl.dot(b_k3, tl.trans(b_k3)) @@ -363,7 +364,8 @@ def chunk_gated_delta_rule_fwd_intra( u (torch.Tensor): shape `[B, T, H, V]` A (torch.Tensor): shape `[B, T, H, BT]`, the solved (I+A)^{-1} matrix """ - B, T, H, K = k.shape + B, T, Hq, K = k.shape + H = beta.shape[2] BT = chunk_size BC = 16 @@ -382,6 +384,7 @@ def chunk_gated_delta_rule_fwd_intra( chunk_indices=chunk_indices, T=T, H=H, + Hq=Hq, K=K, BT=BT, BC=BC, diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py index 52b2c92491..ef445f63d1 100644 --- a/fla/ops/gated_delta_rule/wy_fast.py +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -61,6 +61,7 @@ def recompute_w_u_fwd_kernel( chunk_indices, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -100,7 +101,7 @@ def recompute_w_u_fwd_kernel( b_g = exp(tl.load(p_g, boundary_check=(0,))) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = b_k * b_b[:, None] @@ -140,6 +141,7 @@ def prepare_wy_repr_bwd_kernel( chunk_indices, T, H: tl.constexpr, + Hq: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, @@ -177,7 +179,7 @@ def prepare_wy_repr_bwd_kernel( b_dg = tl.zeros([BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) # [BT, BK] @@ -227,7 +229,7 @@ def prepare_wy_repr_bwd_kernel( tl.debug_barrier() for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kt = tl.trans(b_k) @@ -260,7 +262,9 @@ def recompute_w_u_fwd( chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, Hq, K = k.shape + V = v.shape[-1] + H = v.shape[2] BT = A.shape[-1] BK = 64 BV = 64 @@ -269,7 +273,7 @@ def recompute_w_u_fwd( chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - w = torch.empty_like(k) + w = k.new_empty(B, T, H, K) u = torch.empty_like(v) recompute_w_u_fwd_kernel[(NT, B*H)]( k=k, @@ -283,6 +287,7 @@ def recompute_w_u_fwd( chunk_indices=chunk_indices, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -305,7 +310,9 @@ def prepare_wy_repr_bwd( chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, v.shape[-1] + B, T, Hq, K = k.shape + V = v.shape[-1] + H = v.shape[2] BT = 64 if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) @@ -314,7 +321,7 @@ def prepare_wy_repr_bwd( BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) - dk = torch.empty_like(k) + dk = k.new_empty(B, T, H, K) dv = torch.empty_like(v) dg = torch.empty_like(g) if g is not None else None db = torch.empty_like(beta) @@ -334,6 +341,7 @@ def prepare_wy_repr_bwd( chunk_indices=chunk_indices, T=T, H=H, + Hq=Hq, K=K, V=V, BT=BT, @@ -341,6 +349,8 @@ def prepare_wy_repr_bwd( BV=BV, USE_EXP2=use_exp2, ) + if Hq != H: + dk = dk.view(B, T, Hq, H // Hq, K).sum(3) return dk, dv, db, dg diff --git a/scripts/run_benchmark_compare.py b/scripts/run_benchmark_compare.py index f258990914..6ae0023a0d 100644 --- a/scripts/run_benchmark_compare.py +++ b/scripts/run_benchmark_compare.py @@ -22,6 +22,7 @@ """ import argparse +import ast import json import os import re @@ -29,6 +30,7 @@ import subprocess import sys import tempfile +import time from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() @@ -68,12 +70,35 @@ def get_changed_files(base: str, head: str) -> list[str]: return [f for f in out.split('\n') if f] -def find_affected_op_names(changed_files: list[str]) -> list[str]: - """Map changed file paths to registered op names. +def _ops_subdirs_from_flapy_imports(relpath: str) -> set[str]: + path = PROJECT_ROOT / relpath + if not path.is_file(): + return set() + try: + text = path.read_text(encoding='utf-8') + tree = ast.parse(text, filename=relpath) + except (OSError, SyntaxError, UnicodeDecodeError): + return set() + + subdirs: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + mod = node.module or '' + if mod == 'fla.ops' or mod.startswith('fla.ops.'): + parts = mod.split('.') + if len(parts) >= 3 and parts[0] == 'fla' and parts[1] == 'ops': + subdirs.add(parts[2]) + elif isinstance(node, ast.Import): + for alias in node.names: + name = alias.name + if name.startswith('fla.ops.'): + parts = name.split('.') + if len(parts) >= 3 and parts[0] == 'fla' and parts[1] == 'ops': + subdirs.add(parts[2]) + return subdirs - Parses fla/ops// from each path and looks up the registry. - Changes in fla/ops/common/ or fla/ops/utils/ affect ALL ops. - """ + +def find_affected_op_names(changed_files: list[str]) -> list[str]: try: sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops')) from registry import _REGISTRY @@ -100,6 +125,9 @@ def find_affected_op_names(changed_files: list[str]) -> list[str]: if m: affected_dirs.add(m.group(1)) + if fpath.startswith('fla/') and not m: + affected_dirs.update(_ops_subdirs_from_flapy_imports(fpath)) + if common_changed: return sorted(_REGISTRY.keys()) @@ -174,10 +202,13 @@ def _truncate(s: str, max_len: int = 8) -> str: def print_comparison(base_results: list[dict], head_results: list[dict], base_sha: str, head_sha: str, threshold: float, - machine_info: dict | None = None): + machine_info: dict | None = None) -> tuple[int, list, list]: """Print a comparison table with speedup ratios and detect regressions. Rows are grouped by mode with a blank line between fwd and fwdbwd. + + Returns: + tuple of (exit_code, regressions, speedups) """ def make_key(r): @@ -214,6 +245,7 @@ def make_key(r): f" {'-' * col_w} {'-' * col_w} {'-' * 8} {'-' * 8}") regressions = [] + speedups = [] prev_mode = None for key in all_keys: op, mode, B, T, H, D = key @@ -237,11 +269,15 @@ def make_key(r): if change_pct > threshold: marker = ' <<< REGRESSION' elif change_pct < -threshold: - marker = ' SPEEDUP' + marker = ' <<< SPEEDUP' + else: + marker = '' print(f"{prefix} {base_ms:>{col_w}.3f} {head_ms:>{col_w}.3f} {speedup:>7.2f}x " f"{sign}{change_pct:>6.1f}%{marker}") if change_pct > threshold: regressions.append((key, base_ms, head_ms, change_pct)) + elif change_pct < -threshold: + speedups.append((key, base_ms, head_ms, change_pct)) elif head_r: print(f"{prefix} {'-':>{col_w}s} {head_r['median_ms']:>{col_w}.3f} {'':>8s} {'new':>8s}") elif base_r: @@ -253,10 +289,16 @@ def make_key(r): print(f"\n WARNING: {len(regressions)} regression(s) detected (>{threshold}% slower):") for key, base_ms, head_ms, pct in regressions: print(f" {key}: {base_ms:.3f} -> {head_ms:.3f} ms (+{pct:.1f}%)") - return 1 - else: + + if speedups: + print(f"\n INFO: {len(speedups)} speedup(s) detected (<-{threshold}% faster):") + for key, base_ms, head_ms, pct in speedups: + print(f" {key}: {base_ms:.3f} -> {head_ms:.3f} ms ({pct:.1f}%)") + + if not regressions: print(f"\n No regressions detected (threshold: {threshold}%).") - return 0 + + return (1 if regressions else 0, regressions, speedups) def main(): @@ -268,6 +310,8 @@ def main(): parser.add_argument("--threshold", type=float, default=5.0, help="Regression threshold in percent (default: 5.0)") parser.add_argument("--output", help="Save comparison results to JSON file") + parser.add_argument("--no-fail-on-regression", action="store_true", + help="Do not exit with non-zero code on performance regression") args = parser.parse_args() # Resolve commit SHAs @@ -325,6 +369,13 @@ def main(): head_results = data.get('results', []) machine_info = data.get('machine_info') + cooldown = int(os.environ.get('FLA_BENCH_COOLDOWN_SEC', '0')) + if cooldown > 0: + print( + f"\n Waiting {cooldown}s before base benchmark " + f"(FLA_BENCH_COOLDOWN_SEC; reduces thermal bias between runs).\n") + time.sleep(cooldown) + # Step 2: BASE print(f"\n{'=' * 60}") print(f" Running benchmarks at BASE ({base_sha})") @@ -349,7 +400,7 @@ def main(): print("Warning: missing results from one or both commits.", file=sys.stderr) sys.exit(1) - exit_code = print_comparison( + exit_code, regressions, speedups = print_comparison( base_results, head_results, base_sha, head_sha, args.threshold, machine_info=machine_info, @@ -363,11 +414,32 @@ def main(): 'machine_info': machine_info, 'base_results': base_results, 'head_results': head_results, + 'regressions': [ + { + 'op': key[0], 'mode': key[1], 'B': key[2], 'T': key[3], + 'H': key[4], 'D': key[5], + 'base_ms': base_ms, 'head_ms': head_ms, 'change_pct': change_pct + } + for key, base_ms, head_ms, change_pct in regressions + ], + 'speedups': [ + { + 'op': key[0], 'mode': key[1], 'B': key[2], 'T': key[3], + 'H': key[4], 'D': key[5], + 'base_ms': base_ms, 'head_ms': head_ms, 'change_pct': change_pct + } + for key, base_ms, head_ms, change_pct in speedups + ], + 'has_regression': len(regressions) > 0, + 'has_speedup': len(speedups) > 0, } with open(args.output, 'w') as f: json.dump(comparison, f, indent=2) print(f"\nFull results saved to {args.output}") + if args.no_fail_on_regression: + print("\n --no-fail-on-regression set, exiting with code 0 despite regressions.") + sys.exit(0) sys.exit(exit_code) finally: diff --git a/tests/context_parallel/test_cp_bwd_gk_offset.py b/tests/context_parallel/test_cp_bwd_gk_offset.py index 3ec15a19c6..6e0512b9a6 100644 --- a/tests/context_parallel/test_cp_bwd_gk_offset.py +++ b/tests/context_parallel/test_cp_bwd_gk_offset.py @@ -109,7 +109,7 @@ def test_stage1_gk_per_head_sensitivity(self, T: int, H: int, K: int, V: int): dhm_a = torch.zeros(H, K, V + K, dtype=torch.float32, device=device) pre_process_bwd_kernel_merged[grid]( q=q, k=k, w=w, g=None, gk=gk_zero, do=do, dhm=dhm_a, dv=dv, - cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, K=K, V=V, + cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, Hq=H, K=K, V=V, BT=BT, BK1=BK, USE_EXP2=True, BLOCK_SIZE=BLOCK_SIZE, ) @@ -117,7 +117,7 @@ def test_stage1_gk_per_head_sensitivity(self, T: int, H: int, K: int, V: int): dhm_b = torch.zeros(H, K, V + K, dtype=torch.float32, device=device) pre_process_bwd_kernel_merged[grid]( q=q, k=k, w=w, g=None, gk=gk_diff, do=do, dhm=dhm_b, dv=dv, - cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, K=K, V=V, + cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, Hq=H, K=K, V=V, BT=BT, BK1=BK, USE_EXP2=True, BLOCK_SIZE=BLOCK_SIZE, ) diff --git a/tests/context_parallel/test_cp_gdn.py b/tests/context_parallel/test_cp_gdn.py index b536fdb3a3..873fdee9e5 100644 --- a/tests/context_parallel/test_cp_gdn.py +++ b/tests/context_parallel/test_cp_gdn.py @@ -100,6 +100,7 @@ def run_cp_gdn_test_worker( lengths: list[int], dtype, transpose_state_layout: bool = False, + Hq: int | None = None, ): """ Worker function for CP GDN test. @@ -115,14 +116,15 @@ def run_cp_gdn_test_worker( if rank == 0: print(f"\n{'='*60}") print(f"Test: {test_name}") - print(f"Config: T={T}, H={H}, D={D}, world_size={world_size}") + print(f"Config: T={T}, H={H}, D={D}, Hq={Hq}, world_size={world_size}") print(f"Sequence lengths: {lengths}") print(f"{'='*60}") # Step 1: Prepare Global Data (all generated on rank 0, broadcast to all) B = 1 - q_global = torch.empty(B, T, H, D, device=device, dtype=dtype) - k_global = torch.empty(B, T, H, D, device=device, dtype=dtype) + Hq_actual = Hq if Hq is not None else H + q_global = torch.empty(B, T, Hq_actual, D, device=device, dtype=dtype) + k_global = torch.empty(B, T, Hq_actual, D, device=device, dtype=dtype) v_global = torch.empty(B, T, H, D, device=device, dtype=dtype) g_global = torch.empty(B, T, H, device=device, dtype=dtype) beta_global = torch.empty(B, T, H, device=device, dtype=torch.float32) @@ -130,8 +132,10 @@ def run_cp_gdn_test_worker( if rank == 0: torch.manual_seed(42) - q_global.copy_(F.normalize(torch.randn(B, T, H, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype)) - k_global.copy_(F.normalize(torch.randn(B, T, H, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype)) + q_global.copy_(F.normalize(torch.randn(B, T, Hq_actual, D, device=device, + dtype=torch.float32), p=2, dim=-1).to(dtype)) + k_global.copy_(F.normalize(torch.randn(B, T, Hq_actual, D, device=device, + dtype=torch.float32), p=2, dim=-1).to(dtype)) v_global.copy_(torch.randn(B, T, H, D, device=device, dtype=dtype)) g_global.copy_(F.logsigmoid(torch.randn(B, T, H, device=device, dtype=dtype))) beta_global.copy_(torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid()) @@ -255,7 +259,7 @@ def run_cp_gdn_test_worker( try: for name, ref, cp in tensors_to_verify: - assert_close(name, ref, cp, ratio=2e-3, warning=False) + assert_close(name, ref, cp, ratio=3e-3, warning=False) print(f"[{test_name}] Test Passed!\n") except AssertionError as e: print(f"[{test_name}] Test Failed: {e}\n") @@ -281,6 +285,7 @@ def run_cp_test_with_spawn( lengths: list[int], dtype=torch.bfloat16, transpose_state_layout: bool = False, + Hq: int | None = None, ): """ Run CP test using torch.multiprocessing.spawn. @@ -288,7 +293,7 @@ def run_cp_test_with_spawn( """ mp.start_processes( run_cp_gdn_test_worker, - args=(world_size, test_name, T, H, D, lengths, dtype, transpose_state_layout), + args=(world_size, test_name, T, H, D, lengths, dtype, transpose_state_layout, Hq), nprocs=world_size, join=True, start_method='spawn', @@ -383,6 +388,34 @@ def test_cp2_many_short_sequences(): ) +def test_cp2_gqa_sequence_cut(): + """CP2 GQA: sequences cut across rank boundary, Hq < H.""" + if torch.cuda.device_count() < 2: + pytest.skip("At least 2 GPUs required") + + run_cp_test_with_spawn( + world_size=2, + test_name="CP2_GQA_SequenceCut", + T=10240, H=4, D=64, Hq=2, + lengths=[3000, 4000, 3240], + dtype=torch.bfloat16, + ) + + +def test_cp2_gqa_single_sequence(): + """CP2 GQA: single long sequence with Hq < H.""" + if torch.cuda.device_count() < 2: + pytest.skip("At least 2 GPUs required") + + run_cp_test_with_spawn( + world_size=2, + test_name="CP2_GQA_SingleSequence", + T=10240, H=8, D=64, Hq=2, + lengths=[10240], + dtype=torch.bfloat16, + ) + + # ============================================================ # Transpose State Layout Tests # ============================================================ diff --git a/tests/ops/test_gated_delta.py b/tests/ops/test_gated_delta.py index f9cf9e50c6..526794ddc7 100644 --- a/tests/ops/test_gated_delta.py +++ b/tests/ops/test_gated_delta.py @@ -156,6 +156,87 @@ def test_chunk( assert_close('dh0', ref_dh0, tri_dh0, 0.008) +@pytest.mark.parametrize( + ('B', 'T', 'Hq', 'H', 'D', 'scale', 'gate_logit_normalizer', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-Hq{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (2, 256, 2, 4, 64, 1, 1, False, torch.float16), + (2, 512, 1, 4, 64, 0.1, 1, False, torch.float16), + (2, 512, 2, 8, 64, 1, 0.1, True, torch.float16), + (2, 1024, 4, 8, 128, 0.1, 1, False, torch.float16), + ] + ], +) +def test_chunk_gqa( + B: int, + T: int, + Hq: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + dtype: torch.dtype, +): + torch.manual_seed(42) + if IS_INTEL_ALCHEMIST and D > 128: + pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') + assert H % Hq == 0 + G = H // Hq + + q = torch.rand(B, T, Hq, D, dtype=dtype) + k = torch.rand(B, T, Hq, D, dtype=dtype) + v = torch.rand(B, T, H, D, dtype=dtype) + beta = torch.rand(B, T, H, dtype=torch.float).sigmoid() + g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32)) + g = g / gate_logit_normalizer + h0 = torch.zeros(B, H, D, D, dtype=torch.float32) + q, k, v, beta, g, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, g, h0)) + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + do = torch.randn_like(v) + dht = torch.randn_like(h0) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad + q.grad = k.grad = v.grad = beta.grad = g.grad = h0.grad = None + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=G), p=2, dim=-1), + k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=G), p=2, dim=-1), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + output_final_state=True, + initial_state=h0.clone(), + ) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + + @pytest.mark.parametrize( ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), [ diff --git a/tests/ops/test_intracard_cache.py b/tests/ops/test_intracard_cache.py index a531b1ba2e..4b106088f6 100644 --- a/tests/ops/test_intracard_cache.py +++ b/tests/ops/test_intracard_cache.py @@ -114,3 +114,59 @@ def test_intracard_backend_enabled_when_env_var_is_one(monkeypatch): monkeypatch.setenv("FLA_INTRACARD_CP", "1") assert IntraCardCPBackend.is_enabled() is True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_chunk_gdn_intracard_gqa(monkeypatch): + """E2E: chunk_gated_delta_rule intracard path produces correct results with GQA (Hq < H). + + Uses a long varlen sequence to exercise the intracard split path, + with Hq=2 key/query heads and H=4 value/output heads. + """ + import torch.nn.functional as F + + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + torch.manual_seed(0) + dtype = torch.bfloat16 + + # T must be large enough to bypass early_return in intracard_fwd_h. + B, T, Hq, H, D = 1, 32768, 2, 4, 64 + + q = F.normalize(torch.randn(B, T, Hq, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype) + k = F.normalize(torch.randn(B, T, Hq, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn(B, T, H, D, device=device, dtype=dtype) + g = F.logsigmoid(torch.randn(B, T, H, device=device, dtype=torch.float32)) + beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid() + + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) + cu_seqlens_cpu = cu_seqlens.cpu() + + # Run with intracard path (inference_mode triggers it) + with torch.inference_mode(): + o_intra, ht_intra = chunk_gated_delta_rule( + q=q, k=k, v=v, g=g, beta=beta, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + output_final_state=True, + ) + + # Run without intracard: disable the backend temporarily + from fla.ops.common.backends import common_registry + saved_backends = common_registry._backends.copy() + common_registry._backends.clear() + try: + with torch.inference_mode(): + o_ref, ht_ref = chunk_gated_delta_rule( + q=q, k=k, v=v, g=g, beta=beta, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + output_final_state=True, + ) + finally: + common_registry._backends = saved_backends + + assert torch.allclose(o_intra, o_ref, atol=1e-2, rtol=1e-2), \ + f"Output mismatch: max diff={(o_intra - o_ref).abs().max().item()}" + assert torch.allclose(ht_intra, ht_ref, atol=1e-2, rtol=1e-2), \ + f"Final state mismatch: max diff={(ht_intra - ht_ref).abs().max().item()}"