diff --git a/.github/workflows/nvidia-4090.yml b/.github/workflows/nvidia-4090.yml
index 7992b01d3e..984ba74f3a 100644
--- a/.github/workflows/nvidia-4090.yml
+++ b/.github/workflows/nvidia-4090.yml
@@ -4,6 +4,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+permissions:
+ contents: read
+ pull-requests: write
+
on:
pull_request:
branches: [ '*' ]
diff --git a/.github/workflows/nvidia-a100.yml b/.github/workflows/nvidia-a100.yml
index e14cb95578..de65dbc898 100644
--- a/.github/workflows/nvidia-a100.yml
+++ b/.github/workflows/nvidia-a100.yml
@@ -4,6 +4,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+permissions:
+ contents: read
+ pull-requests: write
+
on:
pull_request:
branches: [ '*' ]
diff --git a/.github/workflows/nvidia-h100.yml b/.github/workflows/nvidia-h100.yml
index 4e20d11adf..5ee2f2382c 100644
--- a/.github/workflows/nvidia-h100.yml
+++ b/.github/workflows/nvidia-h100.yml
@@ -4,6 +4,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+permissions:
+ contents: read
+ pull-requests: write
+
on:
pull_request:
branches: [ '*' ]
diff --git a/.github/workflows/reusable-ci-benchmarks.yml b/.github/workflows/reusable-ci-benchmarks.yml
index b0c11a7c04..83754e2b32 100644
--- a/.github/workflows/reusable-ci-benchmarks.yml
+++ b/.github/workflows/reusable-ci-benchmarks.yml
@@ -43,6 +43,11 @@ jobs:
runs-on: ${{ inputs.runner }}
env:
FLA_CI_ENV: 1
+ # Longer Triton timing windows + pause between commits reduce clock/thermal variance.
+ FLA_BENCH_WARMUP_MS: 40
+ FLA_BENCH_REP_MS: 200
+ FLA_BENCH_OP_WARMUP_ITERS: 6
+ FLA_BENCH_COOLDOWN_SEC: 3
steps:
- name: Check out repo (full history for cross-commit comparison)
@@ -112,12 +117,14 @@ jobs:
- name: Run benchmark comparison
if: steps.check_skip.outputs.skip_bench == 'false'
+ id: benchmark
shell: bash
run: |
$CONDA_BIN_PATH/python scripts/run_benchmark_compare.py \
--base ${{ inputs.base_ref }} \
--head HEAD \
--threshold ${{ inputs.threshold }} \
+ --no-fail-on-regression \
--output benchmark_results.json
- name: Upload benchmark results
@@ -128,3 +135,107 @@ jobs:
path: benchmark_results.json
if-no-files-found: ignore
retention-days: 30
+
+ - name: Post benchmark results to PR
+ if: steps.check_skip.outputs.skip_bench == 'false' && github.event_name == 'pull_request'
+ uses: actions/github-script@v7
+ with:
+ script: |
+ const fs = require('fs');
+ const path = 'benchmark_results.json';
+
+ if (!fs.existsSync(path)) {
+ console.log('No benchmark results file found');
+ return;
+ }
+
+ const data = JSON.parse(fs.readFileSync(path, 'utf8'));
+ const { base_sha, head_sha, machine_info, regressions, speedups, has_regression, has_speedup } = data;
+
+ const gpuName = machine_info?.gpu_name || 'Unknown';
+ const cudaVersion = machine_info?.cuda_version || 'Unknown';
+ const pytorchVersion = machine_info?.pytorch_version || 'Unknown';
+ const runnerName = '${{ inputs.runner }}';
+
+ // Format the summary table
+ let summary = '';
+ if (has_regression || has_speedup) {
+ summary = '| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Change |\n';
+ summary += '|:---|:---:|---:|---:|---:|---:|---:|---:|---:|\n';
+
+ // Add regressions
+ if (regressions && regressions.length > 0) {
+ for (const r of regressions) {
+ const change = `+${r.change_pct.toFixed(1)}% 🔴`;
+ summary += `| ${r.op} | ${r.mode} | ${r.B} | ${r.T} | ${r.H} | ${r.D} | ${r.base_ms.toFixed(3)} | ${r.head_ms.toFixed(3)} | ${change} |\n`;
+ }
+ }
+
+ // Add speedups
+ if (speedups && speedups.length > 0) {
+ for (const s of speedups) {
+ const change = `${s.change_pct.toFixed(1)}% 🟢`;
+ summary += `| ${s.op} | ${s.mode} | ${s.B} | ${s.T} | ${s.H} | ${s.D} | ${s.base_ms.toFixed(3)} | ${s.head_ms.toFixed(3)} | ${change} |\n`;
+ }
+ }
+ }
+
+ // Build the comment body
+ const statusEmoji = has_regression ? '⚠️' : '✅';
+ const statusText = has_regression
+ ? `${regressions.length} regression(s) detected`
+ : 'No significant performance regressions detected';
+
+ let body = `## ${statusEmoji} Benchmark Results (${runnerName.toUpperCase()})\n\n`;
+ body += `**Status:** ${statusText}\n\n`;
+ body += `| | |\n`;
+ body += `|:---|:---|\n`;
+ body += `| **GPU** | ${gpuName} |\n`;
+ body += `| **CUDA** | ${cudaVersion} |\n`;
+ body += `| **PyTorch** | ${pytorchVersion} |\n`;
+ body += `| **Base** | \`${base_sha}\` |\n`;
+ body += `| **Head** | \`${head_sha}\` |\n`;
+ body += `| **Threshold** | ${{ inputs.threshold }}% |\n\n`;
+
+ if (has_regression || has_speedup) {
+ body += `\n📊 View Details (${(regressions?.length || 0) + (speedups?.length || 0)} significant changes)
\n\n`;
+ body += summary;
+ body += '\n \n';
+ } else {
+ body += '> All benchmarked operations are within the performance threshold.\n';
+ }
+
+ body += '\n---\n';
+ body += '*This comment is automatically updated with the latest benchmark results.*';
+
+ // Find existing comment
+ const { data: comments } = await github.rest.issues.listComments({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: context.issue.number,
+ });
+
+ const botComment = comments.find(comment =>
+ comment.user.type === 'Bot' &&
+ comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`)
+ );
+
+ if (botComment) {
+ // Update existing comment
+ await github.rest.issues.updateComment({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ comment_id: botComment.id,
+ body: body,
+ });
+ console.log(`Updated existing benchmark comment for ${runnerName}`);
+ } else {
+ // Create new comment
+ await github.rest.issues.createComment({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: context.issue.number,
+ body: body,
+ });
+ console.log(`Created new benchmark comment for ${runnerName}`);
+ }
diff --git a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py b/benchmarks/cp/benchmark_chunk_delta_h_kernels.py
deleted file mode 100644
index 8dd302a315..0000000000
--- a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py
+++ /dev/null
@@ -1,467 +0,0 @@
-#!/usr/bin/env python
-"""
-Benchmark script for chunk_delta_h kernels with cp=8 using triton.testing.do_bench.
-
-This script benchmarks all 6 kernels from generate_chunk_delta_h_cache():
-1. pre_process_fwd_kernel_stage1 (FWD)
-2. pre_process_fwd_bwd_kernel_stage2 (FWD)
-3. merge_fwd_bwd_kernel (FWD)
-4. pre_process_bwd_kernel_stage1 (BWD)
-5. pre_process_fwd_bwd_kernel_stage2 (BWD)
-6. merge_fwd_bwd_kernel (BWD)
-
-Usage:
- # Run benchmark with default settings (batch=1, head=32, headdim=128, cp=8, seqlen=32k per rank)
- python benchmark_chunk_delta_h_kernels.py
-
- # Run with custom settings
- python benchmark_chunk_delta_h_kernels.py --batch 1 --heads 32 --headdim 128 --seqlen 32768
-"""
-
-import argparse
-
-import torch
-import triton
-import triton.testing
-
-from fla.ops.cp.chunk_delta_h import (
- merge_fwd_bwd_kernel,
- pre_process_bwd_kernel_merged,
- pre_process_bwd_kernel_stage1,
- pre_process_fwd_bwd_kernel_stage2,
- pre_process_fwd_kernel_merged,
- pre_process_fwd_kernel_stage1,
-)
-from fla.utils import device
-
-DTYPE = torch.bfloat16
-
-
-def get_args():
- parser = argparse.ArgumentParser(
- description="Benchmark chunk_delta_h kernels using triton.testing.do_bench"
- )
- parser.add_argument(
- "--batch", type=int, default=1, help="Batch size (default: 1)"
- )
- parser.add_argument(
- "--heads", type=int, default=32, help="Number of heads (default: 32)"
- )
- parser.add_argument(
- "--headdim", type=int, default=128, help="Head dimension (default: 128)"
- )
- parser.add_argument(
- "--seqlen",
- type=int,
- default=32768,
- help="Sequence length per CP rank (default: 32768 = 32k)",
- )
- return parser.parse_args()
-
-
-def create_tensors(B, T, H, K, V, device, dtype):
- """Create input tensors for benchmarking."""
- # Tensors for forward kernels
- k = torch.randn(B, T, H, K, device=device, dtype=dtype)
- v = torch.randn(B, T, H, V, device=device, dtype=dtype)
- w = torch.randn(B, T, H, K, device=device, dtype=dtype)
- gk = torch.randn(B, T, H, K, device=device, dtype=torch.float32)
-
- # cu_seqlens for 1 chunk
- cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
-
- # hm tensor (zero-initialized)
- hm = torch.zeros(H, K, V + K, device=device, dtype=torch.float32)
-
- # Tensors for backward kernels
- q = k.clone()
- do = torch.randn(B, T, H, V, device=device, dtype=dtype)
- dv = torch.zeros_like(v)
- dhm = torch.zeros_like(hm)
-
- # h_out and dht for merge kernels
- h_out = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
- dht = torch.randn(B, H, K, V, device=device, dtype=torch.float32)
-
- # ag_hm and ag_dhm for merge kernels (simulating cp=8)
- num_ranks = 8 # Simulating cp=8
- stride = H * K * (K + V)
- ag_hm = torch.zeros(num_ranks * stride, device=device, dtype=torch.float32)
- ag_dhm = torch.zeros(num_ranks * stride, device=device, dtype=torch.float32)
-
- return {
- "k": k,
- "v": v,
- "w": w,
- "gk": gk,
- "hm": hm,
- "cu_seqlens": cu_seqlens,
- "q": q,
- "do": do,
- "dv": dv,
- "dhm": dhm,
- "h_out": h_out,
- "dht": dht,
- "ag_hm": ag_hm,
- "ag_dhm": ag_dhm,
- }
-
-
-def benchmark_all_kernels(B, T, H, K, V):
- """Benchmark all 6 kernels from generate_chunk_delta_h_cache()."""
- print(f"\n{'=' * 80}")
- print("Benchmarking chunk_delta_h kernels (using triton.testing.do_bench)")
- print(f"{'=' * 80}")
- print("Configuration:")
- print(f" Batch size: {B}")
- print(f" Heads: {H}")
- print(f" Head dim (K): {K}")
- print(f" Head dim (V): {V}")
- print(f" SeqLen per rank: {T}")
- print(f" Dtype: {DTYPE}")
- print(f" Device: {device}")
- print(f"{'=' * 80}\n")
-
- # Create tensors
- tensors = create_tensors(B, T, H, K, V, device, DTYPE)
- BT = 64
- BK = triton.next_power_of_2(K)
-
- results = {}
- quantiles = [0.5, 0.2, 0.8] # median, 20th percentile, 80th percentile
-
- # ========================================
- # 1. pre_process_fwd_kernel_stage1 (FWD)
- # ========================================
- print("[1/6] Benchmarking pre_process_fwd_kernel_stage1 (FWD)...")
-
- def grid_stage1(meta):
- return (triton.cdiv(V, meta["BV"]), H)
-
- def kernel_fwd_stage1():
- pre_process_fwd_kernel_stage1[grid_stage1](
- k=tensors["k"],
- v=tensors["v"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- hm=tensors["hm"],
- cu_seqlens=tensors["cu_seqlens"],
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- USE_G=False,
- USE_GK=True,
- USE_EXP2=True,
- IS_VARLEN=True,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_stage1, quantiles=quantiles)
- results["pre_process_fwd_kernel_stage1"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # ========================================
- # 2. pre_process_fwd_bwd_kernel_stage2 (FWD)
- # ========================================
- print("[2/6] Benchmarking pre_process_fwd_bwd_kernel_stage2 (FWD)...")
-
- def grid_stage2(meta):
- return (triton.cdiv(K, meta["BK2"]), H)
-
- def kernel_fwd_stage2():
- pre_process_fwd_bwd_kernel_stage2[grid_stage2](
- k=tensors["k"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- hm=tensors["hm"],
- cu_seqlens=tensors["cu_seqlens"],
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- BK1=BK,
- USE_G=False,
- USE_GK=True,
- USE_EXP2=True,
- IS_VARLEN=True,
- FORWARD=True,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_stage2, quantiles=quantiles)
- results["pre_process_fwd_bwd_kernel_stage2 (FWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # ========================================
- # 2b. pre_process_fwd_kernel_merged (FWD)
- # ========================================
- print("[2b/6] Benchmarking pre_process_fwd_kernel_merged (FWD)...")
-
- BLOCK_SIZE = 32 if K <= 64 else 64
- grid_merged = (triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE), H)
-
- def kernel_fwd_merged():
- pre_process_fwd_kernel_merged[grid_merged](
- k=tensors["k"],
- v=tensors["v"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- hm=tensors["hm"],
- cu_seqlens=tensors["cu_seqlens"],
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- BK1=BK,
- USE_G=False,
- USE_GK=True,
- USE_EXP2=True,
- IS_VARLEN=True,
- BLOCK_SIZE=BLOCK_SIZE,
- MULTI_SEQS=False,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_merged, quantiles=quantiles)
- results["pre_process_fwd_kernel_merged (FWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # Speedup comparison
- stage1_time = results["pre_process_fwd_kernel_stage1"]
- stage2_time = results["pre_process_fwd_bwd_kernel_stage2 (FWD)"]
- merged_time = results["pre_process_fwd_kernel_merged (FWD)"]
- original_total = stage1_time + stage2_time
- speedup = original_total / merged_time if merged_time > 0 else 0
- print(f"\n [Speedup] Merged kernel vs Split kernels: {speedup:.2f}x")
- print(f" Original (stage1+stage2): {original_total * 1000:.2f} us")
- print(f" Merged: {merged_time * 1000:.2f} us")
-
- # ========================================
- # 3. merge_fwd_bwd_kernel (FWD)
- # ========================================
- print("[3/6] Benchmarking merge_fwd_bwd_kernel (FWD)...")
-
- def grid_merge(meta):
- return (triton.cdiv(V, meta["BV"]), H)
-
- def kernel_merge_fwd():
- merge_fwd_bwd_kernel[grid_merge](
- h=tensors["h_out"],
- ag_hm=tensors["ag_hm"],
- pre_or_post_num_ranks=1,
- rank=1,
- seq_offsets=None,
- init_offsets=None,
- h0_seq_ids=None,
- h0=None,
- H=H,
- K=K,
- V=V,
- BK=BK,
- FORWARD=True,
- INTRACARD_MODE=False,
- NUM_SEQ_ENTRIES=0,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_fwd, quantiles=quantiles)
- results["merge_fwd_bwd_kernel (FWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # ========================================
- # 4. pre_process_bwd_kernel_stage1 (BWD)
- # ========================================
- print("[4/6] Benchmarking pre_process_bwd_kernel_stage1 (BWD)...")
-
- def kernel_bwd_stage1():
- pre_process_bwd_kernel_stage1[grid_stage1](
- q=tensors["q"],
- k=tensors["k"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- do=tensors["do"],
- dhm=tensors["dhm"],
- dv=tensors["dv"],
- cu_seqlens=tensors["cu_seqlens"],
- scale=1.0,
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- USE_G=False,
- USE_GK=True,
- IS_VARLEN=True,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_stage1, quantiles=quantiles)
- results["pre_process_bwd_kernel_stage1"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # ========================================
- # 5. pre_process_fwd_bwd_kernel_stage2 (BWD)
- # ========================================
- print("[5/6] Benchmarking pre_process_fwd_bwd_kernel_stage2 (BWD)...")
-
- def kernel_bwd_stage2():
- pre_process_fwd_bwd_kernel_stage2[grid_stage2](
- k=tensors["k"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- hm=tensors["dhm"],
- cu_seqlens=tensors["cu_seqlens"],
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- BK1=BK,
- USE_G=False,
- USE_GK=True,
- USE_EXP2=True,
- IS_VARLEN=True,
- FORWARD=False,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_stage2, quantiles=quantiles)
- results["pre_process_fwd_bwd_kernel_stage2 (BWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # ========================================
- # 5b. pre_process_bwd_kernel_merged (BWD)
- # ========================================
- print("[5b/6] Benchmarking pre_process_bwd_kernel_merged (BWD)...")
-
- BLOCK_SIZE_BWD = 32 if K <= 64 else 64
- grid_bwd_merged = (triton.cdiv(V + K, BLOCK_SIZE_BWD), H)
-
- def kernel_bwd_merged():
- pre_process_bwd_kernel_merged[grid_bwd_merged](
- q=tensors["q"],
- k=tensors["k"],
- w=tensors["w"],
- g=None,
- gk=tensors["gk"],
- do=tensors["do"],
- dhm=tensors["dhm"],
- dv=tensors["dv"],
- cu_seqlens=tensors["cu_seqlens"],
- scale=1.0,
- T=T,
- H=H,
- K=K,
- V=V,
- BT=BT,
- BK1=BK,
- USE_G=False,
- USE_GK=True,
- USE_EXP2=True,
- IS_VARLEN=True,
- BLOCK_SIZE=BLOCK_SIZE_BWD,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_bwd_merged, quantiles=quantiles)
- results["pre_process_bwd_kernel_merged (BWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # Speedup comparison for backward
- stage1_bwd_time = results["pre_process_bwd_kernel_stage1"]
- stage2_bwd_time = results["pre_process_fwd_bwd_kernel_stage2 (BWD)"]
- merged_bwd_time = results["pre_process_bwd_kernel_merged (BWD)"]
- original_bwd_total = stage1_bwd_time + stage2_bwd_time
- bwd_speedup = original_bwd_total / merged_bwd_time if merged_bwd_time > 0 else 0
- print(f"\n [Speedup] BWD Merged kernel vs Split kernels: {bwd_speedup:.2f}x")
- print(f" Original (stage1+stage2): {original_bwd_total * 1000:.2f} us")
- print(f" Merged: {merged_bwd_time * 1000:.2f} us")
-
- # ========================================
- # 6. merge_fwd_bwd_kernel (BWD)
- # ========================================
- print("[6/6] Benchmarking merge_fwd_bwd_kernel (BWD)...")
-
- def kernel_merge_bwd():
- merge_fwd_bwd_kernel[grid_merge](
- h=tensors["dht"],
- ag_hm=tensors["ag_dhm"],
- pre_or_post_num_ranks=1,
- rank=1,
- seq_offsets=None,
- init_offsets=None,
- h0_seq_ids=None,
- h0=None,
- H=H,
- K=K,
- V=V,
- BK=BK,
- FORWARD=False,
- INTRACARD_MODE=False,
- NUM_SEQ_ENTRIES=0,
- )
-
- ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_bwd, quantiles=quantiles)
- results["merge_fwd_bwd_kernel (BWD)"] = ms
- print(f" Median: {ms * 1000:.2f} us (min: {min_ms * 1000:.2f} us, max: {max_ms * 1000:.2f} us)")
-
- # Print summary
- print(f"\n{'=' * 80}")
- print("Benchmark Summary")
- print(f"{'=' * 80}")
- print(f"{'Kernel Name':<50} {'Time (us)':<15} {'% of Total':<12}")
- print("-" * 80)
-
- total_time = sum(results.values())
- for name, t in results.items():
- percentage = (t / total_time * 100) if total_time > 0 else 0
- print(f"{name:<50} {t * 1000:<15.2f} {percentage:<12.1f}")
-
- print("-" * 80)
- print(f"{'Total':<50} {total_time * 1000:<15.2f} {'100.0':<12}")
- print(f"{'=' * 80}\n")
-
- # Breakdown by direction (using split kernels for fair comparison)
- fwd_time = (
- results["pre_process_fwd_kernel_stage1"]
- + results["pre_process_fwd_bwd_kernel_stage2 (FWD)"]
- + results["merge_fwd_bwd_kernel (FWD)"]
- )
- bwd_time = (
- results["pre_process_bwd_kernel_stage1"]
- + results["pre_process_fwd_bwd_kernel_stage2 (BWD)"]
- + results["merge_fwd_bwd_kernel (BWD)"]
- )
-
- print(f"Forward pass total: {fwd_time * 1000:.2f} us ({fwd_time/total_time*100:.1f}%)")
- print(f"Backward pass total: {bwd_time * 1000:.2f} us ({bwd_time/total_time*100:.1f}%)")
- print(f"FWD/BWD ratio: {fwd_time/bwd_time:.2f}")
- print()
-
- return results
-
-
-def main():
- args = get_args()
-
- B = args.batch
- H = args.heads
- K = args.headdim
- V = args.headdim # Assuming V = K
- T = args.seqlen
-
- # Print environment info
- print(f"\nPyTorch version: {torch.__version__}")
- print(f"CUDA available: {torch.cuda.is_available()}")
- if torch.cuda.is_available():
- print(f"CUDA device: {torch.cuda.get_device_name(0)}")
- print(f"CUDA capability: {torch.cuda.get_device_capability(0)}")
-
- # Run benchmark
- benchmark_all_kernels(B, T, H, K, V)
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/ops/registry.py b/benchmarks/ops/registry.py
index 69c5a49d76..879afc45da 100644
--- a/benchmarks/ops/registry.py
+++ b/benchmarks/ops/registry.py
@@ -58,6 +58,14 @@ def logsigmoid_clamp(t):
return F.logsigmoid(t).clamp_min(-5)
+RWKV7_W_MIN = -0.6065306597126334
+
+
+def rwkv7_w_transform(t):
+ w = RWKV7_W_MIN * t.sigmoid()
+ return w.clamp(min=RWKV7_W_MIN, max=-1e-6)
+
+
# ---------------------------------------------------------------------------
# TensorSpec: describes how to create one input tensor
# ---------------------------------------------------------------------------
@@ -325,12 +333,13 @@ def _rwkv7_post_init(inputs, B, T, H, D, **kw):
import_path='fla.ops.rwkv7',
inputs={
'r': TensorSpec(shape_BTHD),
- 'w': TensorSpec(shape_BTHD, transform=logsigmoid),
+ 'w': TensorSpec(shape_BTHD, transform=rwkv7_w_transform),
'k': TensorSpec(shape_BTHD),
'v': TensorSpec(shape_BTHD),
'a': TensorSpec(shape_BTHD),
'b': TensorSpec(shape_BTHD),
},
+ extra_kwargs={'safe_gate': True, 'chunk_size': 64},
post_init=_rwkv7_post_init,
category='rwkv',
))
diff --git a/benchmarks/ops/run.py b/benchmarks/ops/run.py
index 60d91a96d2..214fc78b88 100644
--- a/benchmarks/ops/run.py
+++ b/benchmarks/ops/run.py
@@ -8,6 +8,9 @@
# Benchmark one op (uses all default shape configs)
python -m benchmarks.ops.run --op chunk_gla
+ # Ops touched by git diff (same rules as scripts/run_benchmark_compare.py)
+ python -m benchmarks.ops.run --from-diff --diff-base main --diff-head HEAD
+
# Multiple ops
python -m benchmarks.ops.run --op chunk_gla chunk_kda
@@ -104,10 +107,12 @@
Benchmark methodology
=====================
-1. **Warmup**: For each (op, shape), run fwd+bwd 5 times to trigger all
- triton autotuning. All shapes are warmed up before any timing begins.
-2. **Timing**: ``triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])``
- gives median, p20, p80 in milliseconds.
+1. **Warmup**: For each (op, shape), run fwd+bwd several times (default 5;
+ override with ``FLA_BENCH_OP_WARMUP_ITERS``) to trigger triton autotuning.
+ All shapes are warmed up before any timing begins.
+2. **Timing**: ``triton.testing.do_bench`` with quantiles ``[0.5, 0.2, 0.8]``,
+ ``warmup``/``rep`` in milliseconds (defaults 25 / 100; set
+ ``FLA_BENCH_WARMUP_MS`` / ``FLA_BENCH_REP_MS`` for noisier machines / CI).
3. Input tensors (including gate transforms like logsigmoid) are prepared
**before** timing — only the op call itself is measured.
"""
@@ -198,8 +203,22 @@ def _get_machine_info() -> dict:
return info
-def _warmup_autotune(fn, n=5):
+def _warmup_iters() -> int:
+ """Extra per-shape forward+backward iterations before timing (not Triton do_bench warmup)."""
+ return max(1, int(os.environ.get('FLA_BENCH_OP_WARMUP_ITERS', '5')))
+
+
+def _do_bench_kw():
+ """Triton ``do_bench`` uses warmup/rep in *milliseconds* of timed execution (see Triton docs)."""
+ warmup_ms = int(os.environ.get('FLA_BENCH_WARMUP_MS', '25'))
+ rep_ms = int(os.environ.get('FLA_BENCH_REP_MS', '100'))
+ return {'warmup': max(1, warmup_ms), 'rep': max(1, rep_ms)}
+
+
+def _warmup_autotune(fn, n: int | None = None):
"""Run *fn* multiple times so triton autotuning is fully cached."""
+ if n is None:
+ n = _warmup_iters()
for _ in range(n):
fn()
torch.cuda.synchronize()
@@ -299,7 +318,9 @@ def fn(inputs=inputs, do=do):
t.backward(do)
try:
- ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
+ ms = triton.testing.do_bench(
+ fn, quantiles=[0.5, 0.2, 0.8], **_do_bench_kw()
+ )
except Exception as e:
logger.warning(f"Bench failed for {op_name} {mode} @ {shape_name}: {e}")
continue
@@ -534,6 +555,18 @@ def main():
'--list', action='store_true',
help='List all registered ops and exit',
)
+ parser.add_argument(
+ '--from-diff', action='store_true',
+ help='Select ops from git diff (use with --diff-base / --diff-head, not with --op)',
+ )
+ parser.add_argument(
+ '--diff-base', default='main',
+ help='Base ref for --from-diff (default: main)',
+ )
+ parser.add_argument(
+ '--diff-head', default='HEAD',
+ help='Head ref for --from-diff (default: HEAD)',
+ )
args = parser.parse_args()
if args.list:
@@ -544,10 +577,24 @@ def main():
print(f" {name:30s} [{cfg.category}] {cfg.import_path}")
return
- if args.op is None:
- parser.error("--op is required (use --list to see available ops)")
-
- op_names = list_ops() if args.op == ['all'] else args.op
+ if args.from_diff:
+ if args.op is not None:
+ parser.error('--from-diff cannot be used with --op')
+ project_root = _find_project_root()
+ scripts_dir = os.path.join(project_root, 'scripts')
+ if scripts_dir not in sys.path:
+ sys.path.insert(0, scripts_dir)
+ import run_benchmark_compare as _diff
+ changed = _diff.get_changed_files(args.diff_base, args.diff_head)
+ op_names = _diff.find_affected_op_names(changed)
+ if not op_names:
+ print('No affected ops for this diff.', file=sys.stderr)
+ return
+ elif args.op is None:
+ parser.error("--op is required unless --from-diff (use --list to see available ops)")
+
+ if not args.from_diff:
+ op_names = list_ops() if args.op == ['all'] else args.op
shape_configs = json.loads(args.custom_shapes) if args.custom_shapes else SHAPE_CONFIGS
machine_info = _get_machine_info()
diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py
index a574a5653b..3e3edb1b93 100644
--- a/fla/ops/common/chunk_delta_h.py
+++ b/fla/ops/common/chunk_delta_h.py
@@ -46,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
chunk_offsets,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -91,7 +92,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
# calculate offset
h += (boh * H + i_h).to(tl.int64) * K*V
v += (bos * H + i_h).to(tl.int64) * V
- k += (bos * H + i_h).to(tl.int64) * K
+ k += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K
w += (bos * H + i_h).to(tl.int64) * K
if SAVE_NEW_VALUE:
v_new += (bos * H + i_h).to(tl.int64) * V
@@ -263,28 +264,28 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_v = b_v.to(k.dtype.element_ty)
- p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (0, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if TRANSPOSE_STATE:
b_h1 += tl.trans(tl.dot(b_k, b_v))
else:
b_h1 += tl.dot(b_k, b_v)
if K > 64:
- p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (64, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if TRANSPOSE_STATE:
b_h2 += tl.trans(tl.dot(b_k, b_v))
else:
b_h2 += tl.dot(b_k, b_v)
if K > 128:
- p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (128, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if TRANSPOSE_STATE:
b_h3 += tl.trans(tl.dot(b_k, b_v))
else:
b_h3 += tl.dot(b_k, b_v)
if K > 192:
- p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (192, i_t * BT), (64, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
if TRANSPOSE_STATE:
b_h4 += tl.trans(tl.dot(b_k, b_v))
@@ -353,6 +354,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
scale,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -395,8 +397,8 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
# calculate offset
- q += (bos * H + i_h).to(tl.int64) * K
- k += (bos * H + i_h).to(tl.int64) * K
+ q += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K
+ k += (bos * Hq + i_h // (H // Hq)).to(tl.int64) * K
w += (bos * H + i_h).to(tl.int64) * K
do += (bos * H + i_h).to(tl.int64) * V
dv += (bos * H + i_h).to(tl.int64) * V
@@ -480,7 +482,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
@@ -491,7 +493,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
@@ -502,7 +504,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
@@ -513,7 +515,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
@@ -534,7 +536,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh
p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1))
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
@@ -556,7 +558,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
else:
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
@@ -579,7 +581,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
else:
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
@@ -602,7 +604,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
else:
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
@@ -668,7 +670,9 @@ def chunk_gated_delta_rule_fwd_h(
use_exp2: bool = False,
transpose_state_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
- B, T, H, K, V = *k.shape, u.shape[-1]
+ B, T, Hq, K = k.shape
+ V = u.shape[-1]
+ H = u.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
@@ -703,6 +707,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
chunk_offsets=chunk_offsets,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -729,7 +734,9 @@ def chunk_gated_delta_rule_bwd_dhu(
use_exp2: bool = False,
transpose_state_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- B, T, H, K, V = *q.shape, do.shape[-1]
+ B, T, Hq, K = q.shape
+ V = do.shape[-1]
+ H = do.shape[2]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
@@ -766,6 +773,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
scale=scale,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py
index 3e1fdae2e0..161696daea 100644
--- a/fla/ops/common/chunk_o.py
+++ b/fla/ops/common/chunk_o.py
@@ -40,6 +40,7 @@ def chunk_fwd_kernel_o(
scale,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -66,8 +67,8 @@ def chunk_fwd_kernel_o(
bos, eos = i_b * T, i_b * T + T
# offset calculation
- q += (bos * H + i_h) * K
- k += (bos * H + i_h) * K
+ q += (bos * Hq + i_h // (H // Hq)) * K
+ k += (bos * Hq + i_h // (H // Hq)) * K
v += (bos * H + i_h) * V
o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * K*V
@@ -76,8 +77,8 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
- p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
if TRANSPOSE_STATE:
p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0))
else:
@@ -168,6 +169,7 @@ def chunk_bwd_kernel_dqkwg(
B: tl.constexpr,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -200,8 +202,8 @@ def chunk_bwd_kernel_dqkwg(
do += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * K*V
dh += (i_tg * H + i_h).to(tl.int64) * K*V
- q += (bos * H + i_h) * K
- k += (bos * H + i_h) * K
+ q += (bos * Hq + i_h // (H // Hq)) * K
+ k += (bos * Hq + i_h // (H // Hq)) * K
dq += (bos * H + i_h) * K
dk += (bos * H + i_h) * K
@@ -255,8 +257,8 @@ def chunk_bwd_kernel_dqkwg(
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
- p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_q = tl.make_block_ptr(q, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
@@ -362,6 +364,7 @@ def chunk_bwd_kernel_dv(
scale,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -388,16 +391,16 @@ def chunk_bwd_kernel_dv(
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
# offset calculation
- q += (bos * H + i_h) * K
- k += (bos * H + i_h) * K
+ q += (bos * Hq + i_h // (H // Hq)) * K
+ k += (bos * Hq + i_h // (H // Hq)) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
dh += (i_tg * H + i_h).to(tl.int64) * K*V
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, b_q)
@@ -463,6 +466,7 @@ def chunk_bwd_kernel_dv_local(
scale,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -484,8 +488,8 @@ def chunk_bwd_kernel_dv_local(
bos, eos = i_b * T, i_b * T + T
# offset calculation
- q += (bos * H + i_h) * K
- k += (bos * H + i_h) * K
+ q += (bos * Hq + i_h // (H // Hq)) * K
+ k += (bos * Hq + i_h // (H // Hq)) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V
@@ -503,8 +507,8 @@ def chunk_bwd_kernel_dv_local(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
- p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
+ p_k = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_q = tl.make_block_ptr(q, (K, T), (1, Hq*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
@@ -542,7 +546,9 @@ def chunk_fwd_o(
use_exp2: bool = False,
transpose_state_layout: bool = False,
) -> torch.Tensor:
- B, T, H, K, V = *q.shape, v.shape[-1]
+ B, T, Hq, K = q.shape
+ V = v.shape[-1]
+ H = v.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -565,6 +571,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
scale=scale,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -587,7 +594,9 @@ def chunk_bwd_dv(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> torch.Tensor:
- B, T, H, K, V = *k.shape, do.shape[-1]
+ B, T, Hq, K = k.shape
+ V = do.shape[-1]
+ H = do.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -620,6 +629,7 @@ def chunk_bwd_dv(
scale=scale,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -643,7 +653,9 @@ def chunk_bwd_dv_local(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> torch.Tensor:
- B, T, H, K, V = *k.shape, do.shape[-1]
+ B, T, Hq, K = k.shape
+ V = do.shape[-1]
+ H = do.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -673,6 +685,7 @@ def chunk_bwd_dv_local(
scale=scale,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -702,7 +715,9 @@ def chunk_bwd_dqkwg(
transpose_state_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- B, T, H, K, V = *k.shape, v.shape[-1]
+ B, T, Hq, K = k.shape
+ V = v.shape[-1]
+ H = v.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -717,8 +732,8 @@ def chunk_bwd_dqkwg(
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
NK = triton.cdiv(K, BK)
- dq = torch.empty_like(q)
- dk = torch.empty_like(k)
+ dq = q.new_empty(B, T, H, K)
+ dk = k.new_empty(B, T, H, K)
dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
dw = torch.empty_like(w) if w is not None else None
@@ -743,6 +758,7 @@ def chunk_bwd_dqkwg(
B=B,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -752,6 +768,9 @@ def chunk_bwd_dqkwg(
TRANSPOSE_STATE=transpose_state_layout,
)
+ if Hq != H:
+ dq = dq.view(B, T, Hq, H // Hq, K).sum(3)
+ dk = dk.view(B, T, Hq, H // Hq, K).sum(3)
if dg is not None:
dg = dg.sum(0)
return dq, dk, dw, dg
diff --git a/fla/ops/common/chunk_scaled_dot_kkt.py b/fla/ops/common/chunk_scaled_dot_kkt.py
index 2c33a12db4..51713bb64e 100644
--- a/fla/ops/common/chunk_scaled_dot_kkt.py
+++ b/fla/ops/common/chunk_scaled_dot_kkt.py
@@ -34,6 +34,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
chunk_indices,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
@@ -56,7 +57,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, tl.trans(b_k))
@@ -87,9 +88,9 @@ def chunk_scaled_dot_kkt_fwd(
Args:
k (torch.Tensor):
- The key tensor of shape `[B, T, H, K]`.
+ The key tensor of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads.
beta (torch.Tensor):
- The beta tensor of shape `[B, T, H]`.
+ The beta tensor of shape `[B, T, H]` where `H` is the number of value/output heads.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
@@ -104,8 +105,10 @@ def chunk_scaled_dot_kkt_fwd(
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
+ For GQA, Hq < H and H % Hq == 0. For standard attention, Hq == H.
"""
- B, T, H, K = k.shape
+ B, T, Hq, K = k.shape
+ H = beta.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -120,6 +123,7 @@ def chunk_scaled_dot_kkt_fwd(
chunk_indices=chunk_indices,
T=T,
H=H,
+ Hq=Hq,
K=K,
BT=BT,
)
diff --git a/fla/ops/common/intracard_cp.py b/fla/ops/common/intracard_cp.py
index e73a308f96..35d3964939 100644
--- a/fla/ops/common/intracard_cp.py
+++ b/fla/ops/common/intracard_cp.py
@@ -86,7 +86,9 @@ def _raw_chunk_gated_delta_rule_fwd_h(
use_exp2: bool = False,
transpose_state_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
- B, T, H, K, V = *k.shape, u.shape[-1]
+ B, T, Hq, K = k.shape
+ V = u.shape[-1]
+ H = u.shape[2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
@@ -111,7 +113,7 @@ def grid(meta):
k=k, v=u, w=w, v_new=v_new,
g=g, gk=gk, h=h, h0=initial_state, ht=final_state,
cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets,
- T=T, H=H, K=K, V=V, BT=BT, USE_EXP2=use_exp2,
+ T=T, H=H, Hq=Hq, K=K, V=V, BT=BT, USE_EXP2=use_exp2,
TRANSPOSE_STATE=transpose_state_layout,
)
return h, v_new, final_state
@@ -235,13 +237,15 @@ def intracard_pre_scan(
kg: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
- gk: torch.Tensor,
+ g: torch.Tensor | None,
+ gk: torch.Tensor | None,
cu_seqlens_subseq_split: torch.Tensor,
S_split: int,
chunk_size: int = 64,
use_exp2: bool = True,
):
- H, K, V = kg.shape[2], kg.shape[3], u.shape[3]
+ Hq, K, V = kg.shape[2], kg.shape[3], u.shape[3]
+ H = u.shape[2]
BK = triton.next_power_of_2(K)
BLOCK_SIZE = 32 if K <= 64 else 64
@@ -252,12 +256,13 @@ def intracard_pre_scan(
k=kg,
v=u,
w=w,
- g=None,
+ g=g,
gk=gk,
hm=hm,
cu_seqlens=cu_seqlens_subseq_split,
T=0,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=chunk_size,
@@ -430,7 +435,9 @@ def intracard_fwd_h(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
assert cu_seqlens is not None, "intracard_fwd_h requires cu_seqlens"
- _, _, H, K, V = *k.shape, u.shape[-1]
+ _, _, _Hq, K = k.shape
+ V = u.shape[-1]
+ H = u.shape[2]
device = k.device
if cu_seqlens_cpu is None:
@@ -542,7 +549,7 @@ def intracard_fwd_h(
_intracard_cache.popitem(last=False)
hm = intracard_pre_scan(
- kg=k, w=w, u=u, gk=gk,
+ kg=k, w=w, u=u, g=g, gk=gk,
cu_seqlens_subseq_split=cu_seqlens_split_flat,
S_split=S_split_total,
chunk_size=chunk_size,
diff --git a/fla/ops/cp/chunk_delta_h.py b/fla/ops/cp/chunk_delta_h.py
index 6cd2a12718..80acfa9316 100644
--- a/fla/ops/cp/chunk_delta_h.py
+++ b/fla/ops/cp/chunk_delta_h.py
@@ -15,265 +15,6 @@
from fla.ops.cp.context import FLACPContext
-@triton.heuristics({
- 'USE_G': lambda args: args['g'] is not None,
- 'USE_GK': lambda args: args['gk'] is not None,
- 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
-})
-@triton.autotune(
- configs=[
- triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
- for num_warps in [2, 4]
- for num_stages in [2, 3, 4]
- for BV in [32, 64]
- ],
- key=['H', 'K', 'V', 'BT', 'USE_EXP2', "STAGE"],
- use_cuda_graph=USE_CUDA_GRAPH,
- **autotune_cache_kwargs,
-)
-@triton.jit(do_not_specialize=['T'])
-def pre_process_fwd_kernel_stage1(
- k,
- v,
- w,
- g,
- gk,
- hm,
- cu_seqlens,
- T,
- H: tl.constexpr,
- K: tl.constexpr,
- V: tl.constexpr,
- BT: tl.constexpr,
- BV: tl.constexpr,
- USE_G: tl.constexpr,
- USE_GK: tl.constexpr,
- USE_EXP2: tl.constexpr,
- IS_VARLEN: tl.constexpr,
-):
- i_v, i_h = tl.program_id(0), tl.program_id(1)
- i_n = 0
- if IS_VARLEN:
- bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
- T = (eos - bos).to(tl.int32)
- NT = tl.cdiv(T, BT)
- else:
- bos, eos = (i_n * T).to(tl.int64), (i_n * T + T).to(tl.int64)
- NT = tl.cdiv(T, BT)
-
- # calculate offset
- hm += i_h * K * (K + V)
- v += ((bos * H + i_h) * V).to(tl.int64)
- k += ((bos * H + i_h) * K).to(tl.int64)
- w += ((bos * H + i_h) * K).to(tl.int64)
- stride_v = H*V
- stride_k = H*K
-
- b_h1 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 64:
- b_h2 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 128:
- b_h3 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 192:
- b_h4 = tl.zeros([64, BV], dtype=tl.float32)
- # main recurrence
- for i_t in range(NT):
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
- if K > 64:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
- if K > 128:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
- if K > 192:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
- p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
- b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
-
- last_idx = min((i_t + 1) * BT, T) - 1
- if USE_G:
- m_t = (i_t * BT + tl.arange(0, BT)) < T
- b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32)
- p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
- b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
- if USE_EXP2:
- b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None]
- b_g_last = exp2(b_g_last)
- else:
- b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
- b_g_last = exp(b_g_last)
- b_h1 *= b_g_last
- if K > 64:
- b_h2 *= b_g_last
- if K > 128:
- b_h3 *= b_g_last
- if K > 192:
- b_h4 *= b_g_last
-
- if USE_GK:
- o_k1 = tl.arange(0, 64)
- b_gk_last1 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32)
- if USE_EXP2:
- b_h1 *= exp2(b_gk_last1)[:, None]
- else:
- b_h1 *= exp(b_gk_last1)[:, None]
- if K > 64:
- o_k2 = 64 + o_k1
- b_gk_last2 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32)
- if USE_EXP2:
- b_h2 *= exp2(b_gk_last2)[:, None]
- else:
- b_h2 *= exp(b_gk_last2)[:, None]
- if K > 128:
- o_k3 = 128 + o_k1
- b_gk_last3 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32)
- if USE_EXP2:
- b_h3 *= exp2(b_gk_last3)[:, None]
- else:
- b_h3 *= exp(b_gk_last3)[:, None]
- if K > 192:
- o_k4 = 192 + o_k1
- b_gk_last4 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32)
- if USE_EXP2:
- b_h4 *= exp2(b_gk_last4)[:, None]
- else:
- b_h4 *= exp(b_gk_last4)[:, None]
-
- b_v = b_v.to(k.dtype.element_ty)
-
- p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- b_h1 += tl.dot(b_k, b_v)
- if K > 64:
- p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- b_h2 += tl.dot(b_k, b_v)
- if K > 128:
- p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- b_h3 += tl.dot(b_k, b_v)
- if K > 192:
- p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- b_h4 += tl.dot(b_k, b_v)
-
- p_h1 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (0, i_v * BV), (64, BV), (1, 0))
- tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
- if K > 64:
- p_h2 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (64, i_v * BV), (64, BV), (1, 0))
- tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
- if K > 128:
- p_h3 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (128, i_v * BV), (64, BV), (1, 0))
- tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
- if K > 192:
- p_h4 = tl.make_block_ptr(hm, (K, V), (K+V, 1), (192, i_v * BV), (64, BV), (1, 0))
- tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
-
-
-@triton.heuristics({
- 'USE_G': lambda args: args['g'] is not None,
- 'USE_GK': lambda args: args['gk'] is not None,
- 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
-})
-@triton.autotune(
- configs=[
- triton.Config({'BK2': BK2}, num_warps=num_warps, num_stages=num_stages)
- for num_warps in [2, 4]
- for num_stages in [2, 3, 4]
- for BK2 in [32]
- ],
- key=['H', 'BT', 'USE_EXP2', 'FORWARD'],
- use_cuda_graph=USE_CUDA_GRAPH,
- **autotune_cache_kwargs,
-)
-@triton.jit(do_not_specialize=['T'])
-def pre_process_fwd_bwd_kernel_stage2(
- k,
- w,
- g,
- gk,
- hm,
- cu_seqlens,
- T,
- H: tl.constexpr,
- K: tl.constexpr,
- V: tl.constexpr,
- BT: tl.constexpr,
- USE_G: tl.constexpr,
- USE_GK: tl.constexpr,
- USE_EXP2: tl.constexpr,
- IS_VARLEN: tl.constexpr,
- BK1: tl.constexpr,
- BK2: tl.constexpr,
- FORWARD: tl.constexpr = True,
-):
- i_k_col, i_h = tl.program_id(0), tl.program_id(1)
- i_n = 0
- if IS_VARLEN:
- bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
- T = eos - bos
- NT = tl.cdiv(T, BT)
- else:
- bos, eos = i_n * T, i_n * T + T
- NT = tl.cdiv(T, BT)
-
- # calculate offset
- hm += i_h * K * (K + V)
- k += ((bos * H + i_h) * K).to(tl.int64)
- w += ((bos * H + i_h) * K).to(tl.int64)
- stride_k = H*K
-
- row = tl.arange(0, BK1)
- col = tl.arange(0, BK2) + i_k_col * BK2
-
- b_m = tl.where(row[:, None] == col[None, :], 1.0, 0.0)
- for _i_t in range(NT):
- if FORWARD:
- i_t = _i_t
- else:
- i_t = NT - 1 - _i_t
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- last_idx = min((i_t + 1) * BT, T) - 1
- if USE_G:
- m_t = (i_t * BT + tl.arange(0, BT)) < T
- b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32)
- p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
- b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
- if USE_EXP2:
- b_k = b_k * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None]
- b_g_last = exp2(b_g_last)
- else:
- b_k = b_k * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
- b_g_last = exp(b_g_last)
- b_diag = tl.where(row[:, None] == row[None, :], b_g_last, 0.0)
- elif USE_GK:
- b_gk_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + row, mask=(row < K), other=0.).to(tl.float32)
- if USE_EXP2:
- b_gk_last = exp2(b_gk_last)
- else:
- b_gk_last = exp(b_gk_last)
- b_diag = tl.where(row[:, None] == row[None, :], b_gk_last[:, None], 0.0)
- else:
- b_diag = tl.where(row[:, None] == row[None, :], 1., 0.0)
- if FORWARD:
- b_kw = tl.dot(tl.trans(b_k.to(b_w.dtype)), b_w)
- else:
- b_kw = tl.dot(tl.trans(b_w), b_k.to(b_w.dtype))
- b_m_i = b_diag - b_kw
- b_m = tl.dot(b_m_i.to(b_w.dtype), b_m.to(b_w.dtype))
- p_m = tl.make_block_ptr(hm + V, (K, K), (K+V, 1), (0, i_k_col * BK2), (BK1, BK2), (1, 0))
- tl.store(p_m, b_m.to(p_m.dtype.element_ty), boundary_check=(0, 1))
-
-
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'USE_GK': lambda args: args['gk'] is not None,
@@ -300,6 +41,7 @@ def pre_process_fwd_kernel_merged(
cu_seqlens,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -331,9 +73,10 @@ def pre_process_fwd_kernel_merged(
# i_col is in range [0, cdiv(V + K, BLOCK_SIZE))
# Columns [0, V) are for h, columns [V, V+K) are for m
is_h_part = i_col * BLOCK_SIZE < V
- k += ((bos * H + i_h) * K).to(tl.int64)
+ k += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
- stride_k = H * K
+ stride_k = Hq * K
+ stride_w = H * K
if is_h_part:
# ====== Stage 1: Compute h (K x V) ======
@@ -353,19 +96,19 @@ def pre_process_fwd_kernel_merged(
# Main recurrence for h
for i_t in range(NT):
# Compute decayed v
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_decay = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_decay += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_decay += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_decay += tl.dot(b_w, b_h4.to(b_w.dtype))
@@ -480,7 +223,7 @@ def pre_process_fwd_kernel_merged(
# Load k and w with full BK1 rows
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
@@ -672,182 +415,6 @@ def merge_fwd_bwd_kernel(
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
-@triton.heuristics({
- 'USE_G': lambda args: args['g'] is not None,
- 'USE_GK': lambda args: args['gk'] is not None,
- 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
-})
-@triton.autotune(
- configs=[
- triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
- for num_warps in [2, 4]
- for num_stages in ([4, 3, 2] if check_shared_mem('ampere') else [1])
- for BV in [64, 32]
- ],
- key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'],
- use_cuda_graph=USE_CUDA_GRAPH,
- **autotune_cache_kwargs,
-)
-@triton.jit(do_not_specialize=['T'])
-def pre_process_bwd_kernel_stage1(
- q,
- k,
- w,
- g,
- gk,
- do,
- dhm,
- dv,
- cu_seqlens,
- scale,
- T,
- H: tl.constexpr,
- K: tl.constexpr,
- V: tl.constexpr,
- BT: tl.constexpr,
- BV: tl.constexpr,
- USE_G: tl.constexpr,
- USE_GK: tl.constexpr,
- IS_VARLEN: tl.constexpr,
-):
- i_v, i_h = tl.program_id(0), tl.program_id(1)
- i_n = 0
- if IS_VARLEN:
- bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
- T = (eos - bos).to(tl.int32)
- NT = tl.cdiv(T, BT)
- else:
- bos, eos = (i_n * T).to(tl.int64), (i_n * T + T).to(tl.int64)
- NT = tl.cdiv(T, BT)
-
- # [BK, BV]
- b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 64:
- b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 128:
- b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
- if K > 192:
- b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
-
- # calculate offset
- q += ((bos * H + i_h) * K).to(tl.int64)
- k += ((bos * H + i_h) * K).to(tl.int64)
- w += ((bos * H + i_h) * K).to(tl.int64)
- do += ((bos * H + i_h) * V).to(tl.int64)
- dv += ((bos * H + i_h) * V).to(tl.int64)
- dhm += i_h * K * (V + K)
-
- stride_v = H*V
- stride_k = H*K
-
- for i_t in range(NT - 1, -1, -1):
- last_idx = min((i_t + 1) * BT, T) - 1
- if USE_G:
- bg_last = tl.load(g + (bos + last_idx) * H + i_h).to(tl.float32)
- bg_last_exp = exp(bg_last)
- p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
- b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
- b_g_exp = exp(b_g)
-
- p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
- p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
- b_do = tl.load(p_do, boundary_check=(0, 1))
-
- # Update dv
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- if USE_GK:
- o_k1 = tl.arange(0, 64)
- b_gk_last1 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32)
- b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
-
- if K > 64:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- if USE_GK:
- o_k2 = 64 + o_k1
- b_gk_last2 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32)
- b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
-
- if K > 128:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- if USE_GK:
- o_k3 = 128 + o_k1
- b_gk_last3 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32)
- b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
-
- if K > 192:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
- b_k = tl.load(p_k, boundary_check=(0, 1))
- if USE_GK:
- o_k4 = 192 + o_k1
- b_gk_last4 = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32)
- b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
-
- if USE_G:
- m_t = (i_t * BT + tl.arange(0, BT)) < T
- b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None]
- b_dv += tl.load(p_dv, boundary_check=(0, 1))
-
- # Update dh
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- b_q = tl.load(p_q, boundary_check=(0, 1))
- if USE_G:
- b_dh1 *= bg_last_exp
- b_q = b_q * b_g_exp[None, :]
- if USE_GK:
- b_dh1 *= exp(b_gk_last1[:, None])
- b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
- if K > 64:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
- b_q = tl.load(p_q, boundary_check=(0, 1))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- if USE_G:
- b_dh2 *= bg_last_exp
- b_q = b_q * b_g_exp[None, :]
- if USE_GK:
- b_dh2 *= exp(b_gk_last2[:, None])
- b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
- if K > 128:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
- b_q = tl.load(p_q, boundary_check=(0, 1))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- if USE_G:
- b_dh3 *= bg_last_exp
- b_q = b_q * b_g_exp[None, :]
- if USE_GK:
- b_dh3 *= exp(b_gk_last3[:, None])
- b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
- if K > 192:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
- b_q = tl.load(p_q, boundary_check=(0, 1))
- b_w = tl.load(p_w, boundary_check=(0, 1))
- if USE_G:
- b_dh4 *= bg_last_exp
- b_q = b_q * b_g_exp[None, :]
- if USE_GK:
- b_dh4 *= exp(b_gk_last4[:, None])
- b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
-
- p_dh1 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (0, i_v * BV), (64, BV), (1, 0))
- tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1))
- if K > 64:
- p_dh2 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (64, i_v * BV), (64, BV), (1, 0))
- tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1))
- if K > 128:
- p_dh3 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (128, i_v * BV), (64, BV), (1, 0))
- tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))
- if K > 192:
- p_dh4 = tl.make_block_ptr(dhm, (K, V), (V + K, 1), (192, i_v * BV), (64, BV), (1, 0))
- tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
-
-
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'USE_GK': lambda args: args['gk'] is not None,
@@ -877,6 +444,7 @@ def pre_process_bwd_kernel_merged(
scale,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -908,11 +476,12 @@ def pre_process_bwd_kernel_merged(
is_dh_part = i_col * BLOCK_SIZE < V
# Calculate offsets
- q += ((bos * H + i_h) * K).to(tl.int64)
- k += ((bos * H + i_h) * K).to(tl.int64)
+ q += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64)
+ k += ((bos * Hq + i_h // (H // Hq)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
dhm += i_h * K * (V + K)
- stride_k = H * K
+ stride_qk = Hq * K
+ stride_w = H * K
if is_dh_part:
# ====== Stage 1: Compute dh (K x V) ======
@@ -948,7 +517,7 @@ def pre_process_bwd_kernel_merged(
b_do = tl.load(p_do, boundary_check=(0, 1))
# Update dv
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k1 = tl.arange(0, 64)
@@ -961,7 +530,7 @@ def pre_process_bwd_kernel_merged(
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
if K > 64:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k2 = 64 + o_k1
@@ -974,7 +543,7 @@ def pre_process_bwd_kernel_merged(
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
if K > 128:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k3 = 128 + o_k1
@@ -987,7 +556,7 @@ def pre_process_bwd_kernel_merged(
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
if K > 192:
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
if USE_GK:
o_k4 = 192 + o_k1
@@ -1006,8 +575,8 @@ def pre_process_bwd_kernel_merged(
b_dv += tl.load(p_dv, boundary_check=(0, 1))
# Update dh
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
+ p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (0, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
if USE_G:
@@ -1021,8 +590,8 @@ def pre_process_bwd_kernel_merged(
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 64:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (64, i_t * BT), (64, BT), (0, 1))
+ p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (64, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
@@ -1036,8 +605,8 @@ def pre_process_bwd_kernel_merged(
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 128:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (128, i_t * BT), (64, BT), (0, 1))
+ p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (128, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
@@ -1051,8 +620,8 @@ def pre_process_bwd_kernel_merged(
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
if K > 192:
- p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
- p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (192, i_t * BT), (64, BT), (0, 1))
+ p_w = tl.make_block_ptr(w, (K, T), (1, stride_w), (192, i_t * BT), (64, BT), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_G:
@@ -1096,9 +665,9 @@ def pre_process_bwd_kernel_merged(
i_t = NT - 1 - _i_t
# Load k and w with full BK1 rows
- p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
- p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, BK1), (1, 0))
b_w = tl.load(p_w, boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
@@ -1155,7 +724,9 @@ def chunk_gated_delta_rule_fwd_h_pre_process(
assert initial_state is None, "When enable CP, the provided initial_state must be None."
rank = dist.get_rank(group=context.group)
- B, T, H, K, V = *k.shape, u.shape[-1]
+ B, T, Hq, K = k.shape
+ V = u.shape[-1]
+ H = u.shape[2]
BT = chunk_size
BK = triton.next_power_of_2(K)
@@ -1184,6 +755,7 @@ def chunk_gated_delta_rule_fwd_h_pre_process(
cu_seqlens=cu_seqlens[-2:],
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -1237,7 +809,9 @@ def chunk_gated_delta_rule_bwd_dhu_pre_process(
assert dht is None, "When enable CP, the provided dht must be None."
rank = dist.get_rank(context.group)
- B, T, H, K, V = *q.shape, do.shape[-1]
+ B, T, Hq, K = q.shape
+ H = do.shape[2]
+ V = do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
@@ -1270,6 +844,7 @@ def chunk_gated_delta_rule_bwd_dhu_pre_process(
scale=scale,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py
index 9980f5917f..429232e9bd 100644
--- a/fla/ops/gated_delta_rule/chunk.py
+++ b/fla/ops/gated_delta_rule/chunk.py
@@ -322,11 +322,12 @@ def chunk_gated_delta_rule(
r"""
Args:
q (torch.Tensor):
- queries of shape `[B, T, H, K]`.
+ queries of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads.
k (torch.Tensor):
- keys of shape `[B, T, H, K]`.
+ keys of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads.
v (torch.Tensor):
- values of shape `[B, T, H, V]`.
+ values of shape `[B, T, H, V]` where `H` is the number of value/output heads.
+ For standard attention, `Hq == H`. For GQA, `H % Hq == 0`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
@@ -388,6 +389,14 @@ def chunk_gated_delta_rule(
cu_seqlens=cu_seqlens
)
"""
+ # Validate GQA head divisibility
+ Hq, H = q.shape[2], v.shape[2]
+ if H % Hq != 0:
+ raise ValueError(
+ f"For GQA, num_heads (H={H}) must be evenly divisible by "
+ f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
+ )
+
if 'head_first' in kwargs:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
diff --git a/fla/ops/gated_delta_rule/chunk_fwd.py b/fla/ops/gated_delta_rule/chunk_fwd.py
index 7c6c764f77..e86833ff08 100644
--- a/fla/ops/gated_delta_rule/chunk_fwd.py
+++ b/fla/ops/gated_delta_rule/chunk_fwd.py
@@ -38,6 +38,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
chunk_indices,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
@@ -77,7 +78,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
i_tc2 = i_t * BT + 2 * BC
i_tc3 = i_t * BT + 3 * BC
- k += (bos * H + i_h) * K
+ k += (bos * Hq + i_h // (H // Hq)) * K
A += (bos * H + i_h) * BT
o_i = tl.arange(0, BC)
@@ -127,13 +128,13 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
b_A32 = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_k0 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0))
+ p_k0 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0))
b_k0 = tl.load(p_k0, boundary_check=(0, 1))
# diagonal block 0
b_A00 += tl.dot(b_k0, tl.trans(b_k0))
if i_tc1 < T:
- p_k1 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0))
+ p_k1 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0))
b_k1 = tl.load(p_k1, boundary_check=(0, 1))
# diagonal block 1
b_A11 += tl.dot(b_k1, tl.trans(b_k1))
@@ -141,7 +142,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
b_A10 += tl.dot(b_k1, tl.trans(b_k0))
if i_tc2 < T:
- p_k2 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0))
+ p_k2 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0))
b_k2 = tl.load(p_k2, boundary_check=(0, 1))
# diagonal block 2
b_A22 += tl.dot(b_k2, tl.trans(b_k2))
@@ -150,7 +151,7 @@ def chunk_gated_delta_rule_fwd_kkt_solve_kernel(
b_A21 += tl.dot(b_k2, tl.trans(b_k1))
if i_tc3 < T:
- p_k3 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0))
+ p_k3 = tl.make_block_ptr(k, (T, K), (Hq*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0))
b_k3 = tl.load(p_k3, boundary_check=(0, 1))
# diagonal block 3
b_A33 += tl.dot(b_k3, tl.trans(b_k3))
@@ -363,7 +364,8 @@ def chunk_gated_delta_rule_fwd_intra(
u (torch.Tensor): shape `[B, T, H, V]`
A (torch.Tensor): shape `[B, T, H, BT]`, the solved (I+A)^{-1} matrix
"""
- B, T, H, K = k.shape
+ B, T, Hq, K = k.shape
+ H = beta.shape[2]
BT = chunk_size
BC = 16
@@ -382,6 +384,7 @@ def chunk_gated_delta_rule_fwd_intra(
chunk_indices=chunk_indices,
T=T,
H=H,
+ Hq=Hq,
K=K,
BT=BT,
BC=BC,
diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py
index 52b2c92491..ef445f63d1 100644
--- a/fla/ops/gated_delta_rule/wy_fast.py
+++ b/fla/ops/gated_delta_rule/wy_fast.py
@@ -61,6 +61,7 @@ def recompute_w_u_fwd_kernel(
chunk_indices,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -100,7 +101,7 @@ def recompute_w_u_fwd_kernel(
b_g = exp(tl.load(p_g, boundary_check=(0,)))
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_b[:, None]
@@ -140,6 +141,7 @@ def prepare_wy_repr_bwd_kernel(
chunk_indices,
T,
H: tl.constexpr,
+ Hq: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
@@ -177,7 +179,7 @@ def prepare_wy_repr_bwd_kernel(
b_dg = tl.zeros([BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# [BT, BK]
@@ -227,7 +229,7 @@ def prepare_wy_repr_bwd_kernel(
tl.debug_barrier()
for i_k in range(tl.cdiv(K, BK)):
- p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k + (bos*Hq + i_h // (H // Hq)) * K, (T, K), (Hq*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kt = tl.trans(b_k)
@@ -260,7 +262,9 @@ def recompute_w_u_fwd(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
- B, T, H, K, V = *k.shape, v.shape[-1]
+ B, T, Hq, K = k.shape
+ V = v.shape[-1]
+ H = v.shape[2]
BT = A.shape[-1]
BK = 64
BV = 64
@@ -269,7 +273,7 @@ def recompute_w_u_fwd(
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
- w = torch.empty_like(k)
+ w = k.new_empty(B, T, H, K)
u = torch.empty_like(v)
recompute_w_u_fwd_kernel[(NT, B*H)](
k=k,
@@ -283,6 +287,7 @@ def recompute_w_u_fwd(
chunk_indices=chunk_indices,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -305,7 +310,9 @@ def prepare_wy_repr_bwd(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- B, T, H, K, V = *k.shape, v.shape[-1]
+ B, T, Hq, K = k.shape
+ V = v.shape[-1]
+ H = v.shape[2]
BT = 64
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
@@ -314,7 +321,7 @@ def prepare_wy_repr_bwd(
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
- dk = torch.empty_like(k)
+ dk = k.new_empty(B, T, H, K)
dv = torch.empty_like(v)
dg = torch.empty_like(g) if g is not None else None
db = torch.empty_like(beta)
@@ -334,6 +341,7 @@ def prepare_wy_repr_bwd(
chunk_indices=chunk_indices,
T=T,
H=H,
+ Hq=Hq,
K=K,
V=V,
BT=BT,
@@ -341,6 +349,8 @@ def prepare_wy_repr_bwd(
BV=BV,
USE_EXP2=use_exp2,
)
+ if Hq != H:
+ dk = dk.view(B, T, Hq, H // Hq, K).sum(3)
return dk, dv, db, dg
diff --git a/scripts/run_benchmark_compare.py b/scripts/run_benchmark_compare.py
index f258990914..6ae0023a0d 100644
--- a/scripts/run_benchmark_compare.py
+++ b/scripts/run_benchmark_compare.py
@@ -22,6 +22,7 @@
"""
import argparse
+import ast
import json
import os
import re
@@ -29,6 +30,7 @@
import subprocess
import sys
import tempfile
+import time
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
@@ -68,12 +70,35 @@ def get_changed_files(base: str, head: str) -> list[str]:
return [f for f in out.split('\n') if f]
-def find_affected_op_names(changed_files: list[str]) -> list[str]:
- """Map changed file paths to registered op names.
+def _ops_subdirs_from_flapy_imports(relpath: str) -> set[str]:
+ path = PROJECT_ROOT / relpath
+ if not path.is_file():
+ return set()
+ try:
+ text = path.read_text(encoding='utf-8')
+ tree = ast.parse(text, filename=relpath)
+ except (OSError, SyntaxError, UnicodeDecodeError):
+ return set()
+
+ subdirs: set[str] = set()
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ImportFrom):
+ mod = node.module or ''
+ if mod == 'fla.ops' or mod.startswith('fla.ops.'):
+ parts = mod.split('.')
+ if len(parts) >= 3 and parts[0] == 'fla' and parts[1] == 'ops':
+ subdirs.add(parts[2])
+ elif isinstance(node, ast.Import):
+ for alias in node.names:
+ name = alias.name
+ if name.startswith('fla.ops.'):
+ parts = name.split('.')
+ if len(parts) >= 3 and parts[0] == 'fla' and parts[1] == 'ops':
+ subdirs.add(parts[2])
+ return subdirs
- Parses fla/ops/
/ from each path and looks up the registry.
- Changes in fla/ops/common/ or fla/ops/utils/ affect ALL ops.
- """
+
+def find_affected_op_names(changed_files: list[str]) -> list[str]:
try:
sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops'))
from registry import _REGISTRY
@@ -100,6 +125,9 @@ def find_affected_op_names(changed_files: list[str]) -> list[str]:
if m:
affected_dirs.add(m.group(1))
+ if fpath.startswith('fla/') and not m:
+ affected_dirs.update(_ops_subdirs_from_flapy_imports(fpath))
+
if common_changed:
return sorted(_REGISTRY.keys())
@@ -174,10 +202,13 @@ def _truncate(s: str, max_len: int = 8) -> str:
def print_comparison(base_results: list[dict], head_results: list[dict],
base_sha: str, head_sha: str, threshold: float,
- machine_info: dict | None = None):
+ machine_info: dict | None = None) -> tuple[int, list, list]:
"""Print a comparison table with speedup ratios and detect regressions.
Rows are grouped by mode with a blank line between fwd and fwdbwd.
+
+ Returns:
+ tuple of (exit_code, regressions, speedups)
"""
def make_key(r):
@@ -214,6 +245,7 @@ def make_key(r):
f" {'-' * col_w} {'-' * col_w} {'-' * 8} {'-' * 8}")
regressions = []
+ speedups = []
prev_mode = None
for key in all_keys:
op, mode, B, T, H, D = key
@@ -237,11 +269,15 @@ def make_key(r):
if change_pct > threshold:
marker = ' <<< REGRESSION'
elif change_pct < -threshold:
- marker = ' SPEEDUP'
+ marker = ' <<< SPEEDUP'
+ else:
+ marker = ''
print(f"{prefix} {base_ms:>{col_w}.3f} {head_ms:>{col_w}.3f} {speedup:>7.2f}x "
f"{sign}{change_pct:>6.1f}%{marker}")
if change_pct > threshold:
regressions.append((key, base_ms, head_ms, change_pct))
+ elif change_pct < -threshold:
+ speedups.append((key, base_ms, head_ms, change_pct))
elif head_r:
print(f"{prefix} {'-':>{col_w}s} {head_r['median_ms']:>{col_w}.3f} {'':>8s} {'new':>8s}")
elif base_r:
@@ -253,10 +289,16 @@ def make_key(r):
print(f"\n WARNING: {len(regressions)} regression(s) detected (>{threshold}% slower):")
for key, base_ms, head_ms, pct in regressions:
print(f" {key}: {base_ms:.3f} -> {head_ms:.3f} ms (+{pct:.1f}%)")
- return 1
- else:
+
+ if speedups:
+ print(f"\n INFO: {len(speedups)} speedup(s) detected (<-{threshold}% faster):")
+ for key, base_ms, head_ms, pct in speedups:
+ print(f" {key}: {base_ms:.3f} -> {head_ms:.3f} ms ({pct:.1f}%)")
+
+ if not regressions:
print(f"\n No regressions detected (threshold: {threshold}%).")
- return 0
+
+ return (1 if regressions else 0, regressions, speedups)
def main():
@@ -268,6 +310,8 @@ def main():
parser.add_argument("--threshold", type=float, default=5.0,
help="Regression threshold in percent (default: 5.0)")
parser.add_argument("--output", help="Save comparison results to JSON file")
+ parser.add_argument("--no-fail-on-regression", action="store_true",
+ help="Do not exit with non-zero code on performance regression")
args = parser.parse_args()
# Resolve commit SHAs
@@ -325,6 +369,13 @@ def main():
head_results = data.get('results', [])
machine_info = data.get('machine_info')
+ cooldown = int(os.environ.get('FLA_BENCH_COOLDOWN_SEC', '0'))
+ if cooldown > 0:
+ print(
+ f"\n Waiting {cooldown}s before base benchmark "
+ f"(FLA_BENCH_COOLDOWN_SEC; reduces thermal bias between runs).\n")
+ time.sleep(cooldown)
+
# Step 2: BASE
print(f"\n{'=' * 60}")
print(f" Running benchmarks at BASE ({base_sha})")
@@ -349,7 +400,7 @@ def main():
print("Warning: missing results from one or both commits.", file=sys.stderr)
sys.exit(1)
- exit_code = print_comparison(
+ exit_code, regressions, speedups = print_comparison(
base_results, head_results,
base_sha, head_sha, args.threshold,
machine_info=machine_info,
@@ -363,11 +414,32 @@ def main():
'machine_info': machine_info,
'base_results': base_results,
'head_results': head_results,
+ 'regressions': [
+ {
+ 'op': key[0], 'mode': key[1], 'B': key[2], 'T': key[3],
+ 'H': key[4], 'D': key[5],
+ 'base_ms': base_ms, 'head_ms': head_ms, 'change_pct': change_pct
+ }
+ for key, base_ms, head_ms, change_pct in regressions
+ ],
+ 'speedups': [
+ {
+ 'op': key[0], 'mode': key[1], 'B': key[2], 'T': key[3],
+ 'H': key[4], 'D': key[5],
+ 'base_ms': base_ms, 'head_ms': head_ms, 'change_pct': change_pct
+ }
+ for key, base_ms, head_ms, change_pct in speedups
+ ],
+ 'has_regression': len(regressions) > 0,
+ 'has_speedup': len(speedups) > 0,
}
with open(args.output, 'w') as f:
json.dump(comparison, f, indent=2)
print(f"\nFull results saved to {args.output}")
+ if args.no_fail_on_regression:
+ print("\n --no-fail-on-regression set, exiting with code 0 despite regressions.")
+ sys.exit(0)
sys.exit(exit_code)
finally:
diff --git a/tests/context_parallel/test_cp_bwd_gk_offset.py b/tests/context_parallel/test_cp_bwd_gk_offset.py
index 3ec15a19c6..6e0512b9a6 100644
--- a/tests/context_parallel/test_cp_bwd_gk_offset.py
+++ b/tests/context_parallel/test_cp_bwd_gk_offset.py
@@ -109,7 +109,7 @@ def test_stage1_gk_per_head_sensitivity(self, T: int, H: int, K: int, V: int):
dhm_a = torch.zeros(H, K, V + K, dtype=torch.float32, device=device)
pre_process_bwd_kernel_merged[grid](
q=q, k=k, w=w, g=None, gk=gk_zero, do=do, dhm=dhm_a, dv=dv,
- cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, K=K, V=V,
+ cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, Hq=H, K=K, V=V,
BT=BT, BK1=BK, USE_EXP2=True, BLOCK_SIZE=BLOCK_SIZE,
)
@@ -117,7 +117,7 @@ def test_stage1_gk_per_head_sensitivity(self, T: int, H: int, K: int, V: int):
dhm_b = torch.zeros(H, K, V + K, dtype=torch.float32, device=device)
pre_process_bwd_kernel_merged[grid](
q=q, k=k, w=w, g=None, gk=gk_diff, do=do, dhm=dhm_b, dv=dv,
- cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, K=K, V=V,
+ cu_seqlens=cu_seqlens, scale=1.0, T=T, H=H, Hq=H, K=K, V=V,
BT=BT, BK1=BK, USE_EXP2=True, BLOCK_SIZE=BLOCK_SIZE,
)
diff --git a/tests/context_parallel/test_cp_gdn.py b/tests/context_parallel/test_cp_gdn.py
index b536fdb3a3..873fdee9e5 100644
--- a/tests/context_parallel/test_cp_gdn.py
+++ b/tests/context_parallel/test_cp_gdn.py
@@ -100,6 +100,7 @@ def run_cp_gdn_test_worker(
lengths: list[int],
dtype,
transpose_state_layout: bool = False,
+ Hq: int | None = None,
):
"""
Worker function for CP GDN test.
@@ -115,14 +116,15 @@ def run_cp_gdn_test_worker(
if rank == 0:
print(f"\n{'='*60}")
print(f"Test: {test_name}")
- print(f"Config: T={T}, H={H}, D={D}, world_size={world_size}")
+ print(f"Config: T={T}, H={H}, D={D}, Hq={Hq}, world_size={world_size}")
print(f"Sequence lengths: {lengths}")
print(f"{'='*60}")
# Step 1: Prepare Global Data (all generated on rank 0, broadcast to all)
B = 1
- q_global = torch.empty(B, T, H, D, device=device, dtype=dtype)
- k_global = torch.empty(B, T, H, D, device=device, dtype=dtype)
+ Hq_actual = Hq if Hq is not None else H
+ q_global = torch.empty(B, T, Hq_actual, D, device=device, dtype=dtype)
+ k_global = torch.empty(B, T, Hq_actual, D, device=device, dtype=dtype)
v_global = torch.empty(B, T, H, D, device=device, dtype=dtype)
g_global = torch.empty(B, T, H, device=device, dtype=dtype)
beta_global = torch.empty(B, T, H, device=device, dtype=torch.float32)
@@ -130,8 +132,10 @@ def run_cp_gdn_test_worker(
if rank == 0:
torch.manual_seed(42)
- q_global.copy_(F.normalize(torch.randn(B, T, H, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype))
- k_global.copy_(F.normalize(torch.randn(B, T, H, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype))
+ q_global.copy_(F.normalize(torch.randn(B, T, Hq_actual, D, device=device,
+ dtype=torch.float32), p=2, dim=-1).to(dtype))
+ k_global.copy_(F.normalize(torch.randn(B, T, Hq_actual, D, device=device,
+ dtype=torch.float32), p=2, dim=-1).to(dtype))
v_global.copy_(torch.randn(B, T, H, D, device=device, dtype=dtype))
g_global.copy_(F.logsigmoid(torch.randn(B, T, H, device=device, dtype=dtype)))
beta_global.copy_(torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid())
@@ -255,7 +259,7 @@ def run_cp_gdn_test_worker(
try:
for name, ref, cp in tensors_to_verify:
- assert_close(name, ref, cp, ratio=2e-3, warning=False)
+ assert_close(name, ref, cp, ratio=3e-3, warning=False)
print(f"[{test_name}] Test Passed!\n")
except AssertionError as e:
print(f"[{test_name}] Test Failed: {e}\n")
@@ -281,6 +285,7 @@ def run_cp_test_with_spawn(
lengths: list[int],
dtype=torch.bfloat16,
transpose_state_layout: bool = False,
+ Hq: int | None = None,
):
"""
Run CP test using torch.multiprocessing.spawn.
@@ -288,7 +293,7 @@ def run_cp_test_with_spawn(
"""
mp.start_processes(
run_cp_gdn_test_worker,
- args=(world_size, test_name, T, H, D, lengths, dtype, transpose_state_layout),
+ args=(world_size, test_name, T, H, D, lengths, dtype, transpose_state_layout, Hq),
nprocs=world_size,
join=True,
start_method='spawn',
@@ -383,6 +388,34 @@ def test_cp2_many_short_sequences():
)
+def test_cp2_gqa_sequence_cut():
+ """CP2 GQA: sequences cut across rank boundary, Hq < H."""
+ if torch.cuda.device_count() < 2:
+ pytest.skip("At least 2 GPUs required")
+
+ run_cp_test_with_spawn(
+ world_size=2,
+ test_name="CP2_GQA_SequenceCut",
+ T=10240, H=4, D=64, Hq=2,
+ lengths=[3000, 4000, 3240],
+ dtype=torch.bfloat16,
+ )
+
+
+def test_cp2_gqa_single_sequence():
+ """CP2 GQA: single long sequence with Hq < H."""
+ if torch.cuda.device_count() < 2:
+ pytest.skip("At least 2 GPUs required")
+
+ run_cp_test_with_spawn(
+ world_size=2,
+ test_name="CP2_GQA_SingleSequence",
+ T=10240, H=8, D=64, Hq=2,
+ lengths=[10240],
+ dtype=torch.bfloat16,
+ )
+
+
# ============================================================
# Transpose State Layout Tests
# ============================================================
diff --git a/tests/ops/test_gated_delta.py b/tests/ops/test_gated_delta.py
index f9cf9e50c6..526794ddc7 100644
--- a/tests/ops/test_gated_delta.py
+++ b/tests/ops/test_gated_delta.py
@@ -156,6 +156,87 @@ def test_chunk(
assert_close('dh0', ref_dh0, tri_dh0, 0.008)
+@pytest.mark.parametrize(
+ ('B', 'T', 'Hq', 'H', 'D', 'scale', 'gate_logit_normalizer', 'use_qk_l2norm_in_kernel', 'dtype'),
+ [
+ pytest.param(
+ *test,
+ id="B{}-T{}-Hq{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test),
+ )
+ for test in [
+ (2, 256, 2, 4, 64, 1, 1, False, torch.float16),
+ (2, 512, 1, 4, 64, 0.1, 1, False, torch.float16),
+ (2, 512, 2, 8, 64, 1, 0.1, True, torch.float16),
+ (2, 1024, 4, 8, 128, 0.1, 1, False, torch.float16),
+ ]
+ ],
+)
+def test_chunk_gqa(
+ B: int,
+ T: int,
+ Hq: int,
+ H: int,
+ D: int,
+ scale: float,
+ gate_logit_normalizer: float,
+ use_qk_l2norm_in_kernel: bool,
+ dtype: torch.dtype,
+):
+ torch.manual_seed(42)
+ if IS_INTEL_ALCHEMIST and D > 128:
+ pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
+ assert H % Hq == 0
+ G = H // Hq
+
+ q = torch.rand(B, T, Hq, D, dtype=dtype)
+ k = torch.rand(B, T, Hq, D, dtype=dtype)
+ v = torch.rand(B, T, H, D, dtype=dtype)
+ beta = torch.rand(B, T, H, dtype=torch.float).sigmoid()
+ g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32))
+ g = g / gate_logit_normalizer
+ h0 = torch.zeros(B, H, D, D, dtype=torch.float32)
+ q, k, v, beta, g, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, g, h0))
+
+ tri, tri_ht = chunk_gated_delta_rule(
+ q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(),
+ k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(),
+ v=v.clone(),
+ g=g.clone(),
+ beta=beta.clone(),
+ scale=scale,
+ initial_state=h0.clone(),
+ output_final_state=True,
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
+ )
+ do = torch.randn_like(v)
+ dht = torch.randn_like(h0)
+ ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True)
+ tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad
+ q.grad = k.grad = v.grad = beta.grad = g.grad = h0.grad = None
+
+ ref, ref_ht = naive_recurrent_gated_delta_rule(
+ q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=G), p=2, dim=-1),
+ k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=G), p=2, dim=-1),
+ v=v.clone(),
+ beta=beta.clone(),
+ g=g.clone(),
+ scale=scale,
+ output_final_state=True,
+ initial_state=h0.clone(),
+ )
+
+ ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True)
+ ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = q.grad, k.grad, v.grad, beta.grad, g.grad, h0.grad
+ assert_close('o', ref, tri, 0.005)
+ assert_close('ht', ref_ht, tri_ht, 0.005)
+ assert_close('dq', ref_dq, tri_dq, 0.008)
+ assert_close('dk', ref_dk, tri_dk, 0.008)
+ assert_close('dv', ref_dv, tri_dv, 0.008)
+ assert_close('db', ref_dbeta, tri_dbeta, 0.02)
+ assert_close('dg', ref_dg, tri_dg, 0.02)
+ assert_close('dh0', ref_dh0, tri_dh0, 0.008)
+
+
@pytest.mark.parametrize(
('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'),
[
diff --git a/tests/ops/test_intracard_cache.py b/tests/ops/test_intracard_cache.py
index a531b1ba2e..4b106088f6 100644
--- a/tests/ops/test_intracard_cache.py
+++ b/tests/ops/test_intracard_cache.py
@@ -114,3 +114,59 @@ def test_intracard_backend_enabled_when_env_var_is_one(monkeypatch):
monkeypatch.setenv("FLA_INTRACARD_CP", "1")
assert IntraCardCPBackend.is_enabled() is True
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
+def test_chunk_gdn_intracard_gqa(monkeypatch):
+ """E2E: chunk_gated_delta_rule intracard path produces correct results with GQA (Hq < H).
+
+ Uses a long varlen sequence to exercise the intracard split path,
+ with Hq=2 key/query heads and H=4 value/output heads.
+ """
+ import torch.nn.functional as F
+
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
+
+ torch.manual_seed(0)
+ dtype = torch.bfloat16
+
+ # T must be large enough to bypass early_return in intracard_fwd_h.
+ B, T, Hq, H, D = 1, 32768, 2, 4, 64
+
+ q = F.normalize(torch.randn(B, T, Hq, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype)
+ k = F.normalize(torch.randn(B, T, Hq, D, device=device, dtype=torch.float32), p=2, dim=-1).to(dtype)
+ v = torch.randn(B, T, H, D, device=device, dtype=dtype)
+ g = F.logsigmoid(torch.randn(B, T, H, device=device, dtype=torch.float32))
+ beta = torch.randn(B, T, H, device=device, dtype=torch.float32).sigmoid()
+
+ cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
+ cu_seqlens_cpu = cu_seqlens.cpu()
+
+ # Run with intracard path (inference_mode triggers it)
+ with torch.inference_mode():
+ o_intra, ht_intra = chunk_gated_delta_rule(
+ q=q, k=k, v=v, g=g, beta=beta,
+ cu_seqlens=cu_seqlens,
+ cu_seqlens_cpu=cu_seqlens_cpu,
+ output_final_state=True,
+ )
+
+ # Run without intracard: disable the backend temporarily
+ from fla.ops.common.backends import common_registry
+ saved_backends = common_registry._backends.copy()
+ common_registry._backends.clear()
+ try:
+ with torch.inference_mode():
+ o_ref, ht_ref = chunk_gated_delta_rule(
+ q=q, k=k, v=v, g=g, beta=beta,
+ cu_seqlens=cu_seqlens,
+ cu_seqlens_cpu=cu_seqlens_cpu,
+ output_final_state=True,
+ )
+ finally:
+ common_registry._backends = saved_backends
+
+ assert torch.allclose(o_intra, o_ref, atol=1e-2, rtol=1e-2), \
+ f"Output mismatch: max diff={(o_intra - o_ref).abs().max().item()}"
+ assert torch.allclose(ht_intra, ht_ref, atol=1e-2, rtol=1e-2), \
+ f"Final state mismatch: max diff={(ht_intra - ht_ref).abs().max().item()}"