Skip to content

[GDN] Add GVA support#799

Merged
zhiyuan1i merged 7 commits intomainfrom
lzy/gdn-different-kvheads
Mar 31, 2026
Merged

[GDN] Add GVA support#799
zhiyuan1i merged 7 commits intomainfrom
lzy/gdn-different-kvheads

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Mar 29, 2026

Add grouped-query attention support to all gated delta rule kernels. When Hq < H, query/key tensors have fewer heads than value/output, with head mapping: i_h // (H // Hq).

Kernel changes:

  • chunk_delta_h: fwd_h and bwd_dhu use Hq offset/stride for k (fwd) and q/k (bwd)
  • chunk_o: all kernels use Hq mapping for q/k reads; dq/dk allocated at [B,T,H,K] and reduced via .view(B,T,Hq,H//Hq,K).sum(3) to avoid write race conditions
  • chunk_scaled_dot_kkt: k uses Hq mapping (read-only)
  • wy_fast: k reads use Hq mapping; dk writes at H heads, reduced post-kernel
  • chunk_fwd: k uses Hq mapping (read-only)
  • cp/chunk_delta_h: fwd/bwd merged pre-process kernels use Hq for q/k
  • intracard_cp: pass scalar g to pre_process_fwd_kernel_merged in pre-scan

Tests:

  • Add test_chunk_gqa with 4 GQA configurations
  • Add CP GQA tests and intracard GQA test
  • All existing tests pass (no regression)

Summary by CodeRabbit

  • New Features

    • Added grouped-query attention (GQA) support: q/k may use a separate head count (Hq) from v/beta/h, with runtime validation, updated shape contracts, and preserved output/head semantics.
  • Tests

    • Added/extended CUDA and unit tests covering GQA, intracard and context-parallel paths; relaxed tolerances and added new GQA scenarios.
  • Chores

    • Removed an old benchmark script; CI workflows updated; benchmark comparison now supports non-failing regression mode and posts PR benchmark summaries.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Adds grouped-query-attention (GQA) support by introducing a compile-time Triton constexpr Hq separate from H, threading Hq through many kernels and Python wrappers, remapping q/k head indexing and pointer strides to use Hq, updating host-side shape interpretation and post-processing, and updating tests and intracard/preprocess code for the Hq/H split.

Changes

Cohort / File(s) Summary
Common chunk kernels & wrappers
fla/ops/common/chunk_delta_h.py, fla/ops/common/chunk_o.py, fla/ops/common/chunk_scaled_dot_kkt.py
Added Hq: tl.constexpr to Triton kernel signatures and launches; treat q/k as (B,T,Hq,K) and derive H from value/output tensors; remapped head index via i_h // (H // Hq) and changed block-pointer strides from H*KHq*K.
Context-parallel / intracard
fla/ops/common/intracard_cp.py, fla/ops/cp/chunk_delta_h.py
Threaded Hq into intracard/preprocess merged kernels and host wrappers; removed some stage kernels in favor of merged kernels; updated pointer/stride math to use Hq for q/k addressing; intracard_pre_scan signature now accepts optional g/gk.
Gated-delta-rule kernels & wrappers
fla/ops/gated_delta_rule/chunk_fwd.py, fla/ops/gated_delta_rule/wy_fast.py
Added Hq constexpr to KKT and WY Triton kernels and launches; remapped k/q indexing and strides to use Hq; host wrappers now allocate outputs with head dim H; added post-processing when Hq != H (reshape + sum) for aggregated dq/dk results.
Tests
tests/context_parallel/test_cp_gdn.py, tests/ops/test_gated_delta.py, tests/ops/test_intracard_cache.py, tests/context_parallel/test_cp_bwd_gk_offset.py
Threaded optional Hq through test workers; added GQA parameterized tests and intracard E2E GQA test; adjusted q/k tensor shapes to use Hq, relaxed some tolerances, and updated merged-kernel test call-sites to pass Hq.
Chunk O backward changes
fla/ops/common/chunk_o.py
Changed dq/dk allocation to (B,T,H,K) and added post-launch reduction when Hq != H to produce (B,T,Hq,K) via reshape+sum.
Docs / API validation
fla/ops/gated_delta_rule/chunk.py
Docstring and runtime validation updated: q/k documented as [B,T,Hq,K], v as [B,T,H,V], and check that H % Hq == 0 with explicit error message.
Benchmarks / CI / tooling
benchmarks/cp/benchmark_chunk_delta_h_kernels.py (deleted), .github/workflows/*.yml, scripts/run_benchmark_compare.py, benchmarks/ops/run.py, benchmarks/ops/registry.py
Removed old CP benchmark script; added workflow permissions and PR-composing benchmark comment logic; extended benchmark comparison script to report speedups and support --no-fail-on-regression; added diff-driven bench selection and env-configurable bench timing.
Misc tests fixup
tests/context_parallel/test_cp_bwd_gk_offset.py
Updated merged pre-process kernel invocation in test to pass Hq=H.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Caller
    participant Host as Python wrapper
    participant Kernel as Triton kernel
    participant Mem as GPU memory / tensors

    Test->>Host: call op(q shaped (B,T,Hq,K), v shaped (B,T,H,V))
    Host->>Mem: prepare pointers, derive H from v, compute Hq/H mapping
    Host->>Kernel: launch kernel with constexpr Hq, H, K, T
    Kernel->>Mem: load q/k using Hq-based mapping (index i_h // (H//Hq))
    Kernel->>Mem: write outputs indexed by H (or temporary accumulators)
    Kernel-->>Host: return outputs (A/w/d*)
    Host->>Host: if Hq != H, reshape & sum grouped heads to (B,T,Hq,K)
    Host->>Test: return final outputs and gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰 Hq hops in, a curious tune,
Small heads grouped beneath the moon,
Pointers shuffle, strides grow lean,
Kernels hum with tidy sheen,
Outputs bloom — GQA's bright boon!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.19% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[GDN] Add GVA support' clearly summarizes the main change: adding grouped-value attention support to the GDN (Gated Delta Rule) kernels, which aligns with the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch lzy/gdn-different-kvheads

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Add grouped-query attention support to all gated delta rule kernels.
When Hq < H, query/key tensors have fewer heads than value/output,
with head mapping: i_h // (H // Hq).

Kernel changes:
- chunk_delta_h: fwd_h and bwd_dhu use Hq offset/stride for k (fwd) and q/k (bwd)
- chunk_o: all kernels use Hq mapping for q/k reads; dq/dk allocated at [B,T,H,K]
  and reduced via .view(B,T,Hq,H//Hq,K).sum(3) to avoid write race conditions
- chunk_scaled_dot_kkt: k uses Hq mapping (read-only)
- wy_fast: k reads use Hq mapping; dk writes at H heads, reduced post-kernel
- chunk_fwd: k uses Hq mapping (read-only)
- cp/chunk_delta_h: fwd/bwd merged pre-process kernels use Hq for q/k
- intracard_cp: pass scalar g to pre_process_fwd_kernel_merged in pre-scan

Tests:
- Add test_chunk_gqa with 4 GQA configurations
- Add CP GQA tests and intracard GQA test
- All existing tests pass (no regression)
@zhiyuan1i zhiyuan1i force-pushed the lzy/gdn-different-kvheads branch from a36331f to 44563eb Compare March 29, 2026 02:39
@zhiyuan1i zhiyuan1i requested a review from yzhangcs March 29, 2026 02:39
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Grouped Query Attention (GQA) support across various Triton-based operations, including gated delta rules and context-parallel backends. By introducing the Hq parameter, the kernels now support configurations where the number of query/key heads differs from value heads, with corresponding updates to offset calculations and gradient aggregation logic. The PR also includes extensive new tests for GQA in both standard and distributed settings. Feedback is provided to improve parameter naming consistency in the intracard_pre_scan function call.


hm = intracard_pre_scan(
kg=k, w=w, u=u, gk=gk,
kg=k, w=w, u=u, g=g, gk=gk,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To match the suggested renaming of kg to k in the intracard_pre_scan function definition for clarity, this call should be updated to pass k as the k argument.

Suggested change
kg=k, w=w, u=u, g=g, gk=gk,
k=k, w=w, u=u, g=g, gk=gk,

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
fla/ops/common/intracard_cp.py (1)

438-440: Unused variable Hq should be prefixed with underscore.

Static analysis correctly identifies that Hq is unpacked but never used in intracard_fwd_h. The Hq handling is delegated to called functions (_raw_chunk_gated_delta_rule_fwd_h and intracard_pre_scan).

🔧 Suggested fix
-    _, _, Hq, K = k.shape
+    _, _, _Hq, K = k.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/intracard_cp.py` around lines 438 - 440, In intracard_fwd_h,
the unpacked variable Hq is unused and should be renamed to indicate that
(prefix with an underscore) to satisfy static analysis; change the unpacking "_,
_, Hq, K = k.shape" to use a prefixed name (e.g., "_Hq") so references remain
correct, leaving logic that delegates Hq handling to
_raw_chunk_gated_delta_rule_fwd_h and intracard_pre_scan unchanged.
tests/ops/test_intracard_cache.py (1)

154-167: Backend disabling approach may be fragile.

Directly manipulating common_registry._backends is reaching into implementation details. The _intracard_cache (cleared by fixture) is separate from the backend registry, so the isolation should work for this test. However, this approach could break if the registry internals change.

Consider using a more robust mechanism like mocking or a dedicated test mode flag if available.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_intracard_cache.py` around lines 154 - 167, The test is
brittle because it directly mutates the private common_registry._backends dict;
instead, temporarily replace the registry via a safe mock/patch so internals
aren't relied on. Update the test to use unittest.mock.patch.object or
pytest.monkeypatch to set common_registry._backends to an empty dict (or call a
public clear method if one exists) around the chunk_gated_delta_rule call (refer
to common_registry._backends and chunk_gated_delta_rule in the diff) so the
original value is automatically restored after the test.
fla/ops/common/chunk_scaled_dot_kkt.py (1)

86-91: Docstring is outdated — update to reflect GQA semantics.

The docstring still states k has shape [B, T, H, K], but the implementation now interprets it as [B, T, Hq, K] where Hq may differ from H.

📝 Suggested docstring update
     Args:
         k (torch.Tensor):
-            The key tensor of shape `[B, T, H, K]`.
+            The key tensor of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads.
         beta (torch.Tensor):
-            The beta tensor of shape `[B, T, H]`.
+            The beta tensor of shape `[B, T, H]` where `H` is the number of value/output heads.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_scaled_dot_kkt.py` around lines 86 - 91, Update the
outdated docstring in chunk_scaled_dot_kkt.py to reflect GQA semantics: state
that the input k has shape [B, T, Hq, K] (where Hq may differ from H used for
queries), describe that the function computes beta * K @ K^T per head across the
sequence producing a tensor of shape [B, Hq, T, T] (or equivalent ordering used
in the implementation), and clarify how Hq relates to any query head dimension
referenced elsewhere so readers understand head-count mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 771-773: Add a validation in the layer constructors to enforce GQA
head divisibility: in the __init__ of the attention-like layers (e.g.,
SimpleGLA, GLA, or any class that defines self.num_heads and self.num_kv_heads)
assert that if self.num_kv_heads is not None then self.num_heads %
self.num_kv_heads == 0, raising a clear message like "num_heads (..) must be
evenly divisible by num_kv_heads (..)"; this prevents the view() in chunk_o.py
that reshapes dq/dk (the H/Hq split where H is num_heads and Hq is num_kv_heads)
from failing later during forward/backward.

In `@fla/ops/common/chunk_scaled_dot_kkt.py`:
- Around line 109-110: The docstring for the function in chunk_scaled_dot_kkt.py
is out-of-date: update the Args section to state that k has shape [B, T, Hq, K]
where Hq is the number of key/query heads (previously documented as [B, T, H,
K]), and add a note that when Hq < H grouped-query attention is used with head
mapping i_h // (H // Hq); also clarify that H is derived from beta.shape[2] (H =
beta.shape[2]) and that for the standard case Hq == H the behavior is unchanged.

---

Nitpick comments:
In `@fla/ops/common/chunk_scaled_dot_kkt.py`:
- Around line 86-91: Update the outdated docstring in chunk_scaled_dot_kkt.py to
reflect GQA semantics: state that the input k has shape [B, T, Hq, K] (where Hq
may differ from H used for queries), describe that the function computes beta *
K @ K^T per head across the sequence producing a tensor of shape [B, Hq, T, T]
(or equivalent ordering used in the implementation), and clarify how Hq relates
to any query head dimension referenced elsewhere so readers understand
head-count mismatches.

In `@fla/ops/common/intracard_cp.py`:
- Around line 438-440: In intracard_fwd_h, the unpacked variable Hq is unused
and should be renamed to indicate that (prefix with an underscore) to satisfy
static analysis; change the unpacking "_, _, Hq, K = k.shape" to use a prefixed
name (e.g., "_Hq") so references remain correct, leaving logic that delegates Hq
handling to _raw_chunk_gated_delta_rule_fwd_h and intracard_pre_scan unchanged.

In `@tests/ops/test_intracard_cache.py`:
- Around line 154-167: The test is brittle because it directly mutates the
private common_registry._backends dict; instead, temporarily replace the
registry via a safe mock/patch so internals aren't relied on. Update the test to
use unittest.mock.patch.object or pytest.monkeypatch to set
common_registry._backends to an empty dict (or call a public clear method if one
exists) around the chunk_gated_delta_rule call (refer to
common_registry._backends and chunk_gated_delta_rule in the diff) so the
original value is automatically restored after the test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c10fb80b-4e1c-438a-8811-57a72d8a24b3

📥 Commits

Reviewing files that changed from the base of the PR and between 8d6f019 and 44563eb.

📒 Files selected for processing (10)
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/chunk_scaled_dot_kkt.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/wy_fast.py
  • tests/context_parallel/test_cp_gdn.py
  • tests/ops/test_gated_delta.py
  • tests/ops/test_intracard_cache.py

Comment thread fla/ops/common/chunk_o.py
Comment thread fla/ops/common/chunk_scaled_dot_kkt.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 727-730: Add a validation that H is divisible by Hq before using
the head mapping formula: assert H % Hq == 0 (or raise a clear ValueError) so
the expression i_h // (H // Hq) cannot produce incorrect or out-of-range
indices; place this check near where H and Hq are computed (after the lines
reading B, T, Hq, K = k.shape and H = u.shape[2]) and include a descriptive
error message referencing H and Hq.
- Around line 812-814: Add the same shape divisibility checks in the backward
wrapper as in the forward pass: after computing B, T, Hq, K = q.shape and H =
do.shape[2], V = do.shape[-1], validate that Hq % K == 0 and that V is divisible
by (Hq // K) (or whatever exact relation the forward check enforces); if the
checks fail, raise a ValueError with a clear message referencing q.shape and
do.shape so the issue is obvious. Ensure these checks are placed at the start of
the backward wrapper (before further computation) and use the same error
text/logic as the forward implementation for consistency.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 84c96282-a505-48b4-a7cc-62ed720db251

📥 Commits

Reviewing files that changed from the base of the PR and between 44563eb and 5ce7930.

📒 Files selected for processing (3)
  • benchmarks/cp/benchmark_chunk_delta_h_kernels.py
  • fla/ops/cp/chunk_delta_h.py
  • tests/context_parallel/test_cp_bwd_gk_offset.py
💤 Files with no reviewable changes (1)
  • benchmarks/cp/benchmark_chunk_delta_h_kernels.py

Comment thread fla/ops/cp/chunk_delta_h.py
Comment thread fla/ops/cp/chunk_delta_h.py
- Add H % Hq == 0 validation at chunk_gated_delta_rule entry
- Update docstrings to reflect GQA semantics (Hq vs H)
- Prefix unused Hq variable with underscore in intracard_fwd_h
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk.py (1)

385-391: Consider validating that k has the same head count as q.

The validation correctly enforces H % Hq == 0, which is essential for the kernel's i_h // (H // Hq) head mapping to work correctly. However, the docstring specifies that both q and k should have shape [B, T, Hq, K], but only q.shape[2] is used to derive Hq. If a caller inadvertently passes k with a different head count, the kernels would silently produce incorrect results.

🛡️ Proposed fix to validate k's head dimension
     # Validate GQA head divisibility
     Hq, H = q.shape[2], v.shape[2]
+    if k.shape[2] != Hq:
+        raise ValueError(
+            f"q and k must have the same number of heads, but got "
+            f"q.shape[2]={Hq} and k.shape[2]={k.shape[2]}"
+        )
     if H % Hq != 0:
         raise ValueError(
             f"For GQA, num_heads (H={H}) must be evenly divisible by "
             f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 385 - 391, Validate that k
has the same head count as q before the GQA divisibility check: read Hq from
q.shape[2] (as currently done) and assert k.shape[2] == Hq, raising a ValueError
that calls out mismatched head counts (mention q, k, and Hq in the message).
Keep this check adjacent to the existing GQA validation (the block using Hq, H
and H % Hq) so the kernel mapping i_h // (H // Hq) is guaranteed correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 385-391: Validate that k has the same head count as q before the
GQA divisibility check: read Hq from q.shape[2] (as currently done) and assert
k.shape[2] == Hq, raising a ValueError that calls out mismatched head counts
(mention q, k, and Hq in the message). Keep this check adjacent to the existing
GQA validation (the block using Hq, H and H % Hq) so the kernel mapping i_h //
(H // Hq) is guaranteed correct.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 88dd2191-b8e6-465c-9a84-1311cd023ebf

📥 Commits

Reviewing files that changed from the base of the PR and between 5ce7930 and 2b2b0e8.

📒 Files selected for processing (3)
  • fla/ops/common/chunk_scaled_dot_kkt.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/gated_delta_rule/chunk.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/ops/common/chunk_scaled_dot_kkt.py
  • fla/ops/common/intracard_cp.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
.github/workflows/reusable-ci-benchmarks.yml (1)

134-136: Declare the write permission this new comment step needs.

If the caller or repo defaults GITHUB_TOKEN to read-only, the new issues.createComment / issues.updateComment calls fail after the benchmark already succeeded. Declaring the requirement here makes the reusable workflow self-describing.

🔐 Suggested change
 jobs:
   benchmark:
     runs-on: ${{ inputs.runner }}
+    permissions:
+      contents: read
+      issues: write
     env:
       FLA_CI_ENV: 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/reusable-ci-benchmarks.yml around lines 134 - 136, The new
"Post benchmark results to PR" step (uses: actions/github-script@v7) needs
explicit write permission for issues to successfully call
issues.createComment/issues.updateComment when GITHUB_TOKEN defaults to
read-only; update the workflow to declare permissions: issues: write (either at
the job or workflow level) so the actions/github-script step can create/update
PR comments.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 207-216: The lookup for an existing benchmark comment uses
github.rest.issues.listComments which only returns the first page (30 comments),
causing duplicates; replace the single-page call with a paginated fetch (e.g.,
use github.paginate or iterate pages of github.rest.issues.listComments with
per_page=100) to retrieve all comments, then search that full list for the bot
comment (the logic that assigns botComment). Keep the same matching criteria
(comment.user.type === 'Bot' && comment.body.includes(`Benchmark Results
(${runnerName.toUpperCase()})`)) and then proceed to update or create the
comment as before.

In `@scripts/run_benchmark_compare.py`:
- Around line 405-407: The current branch unconditionally triggers when
args.no_fail_on_regression is set, printing "despite regressions" and exiting 0
even for clean runs; change the logic around the block that checks
args.no_fail_on_regression to only run when a regression actually exists (e.g.,
regressions is truthy or exit_code != 0). Update the conditional to something
like "if args.no_fail_on_regression and (regressions or exit_code != 0):" so the
message and sys.exit(0) are only used to suppress a failing exit code, leaving
normal successful exits unchanged.

---

Nitpick comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 134-136: The new "Post benchmark results to PR" step (uses:
actions/github-script@v7) needs explicit write permission for issues to
successfully call issues.createComment/issues.updateComment when GITHUB_TOKEN
defaults to read-only; update the workflow to declare permissions: issues: write
(either at the job or workflow level) so the actions/github-script step can
create/update PR comments.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bce5ccb2-1a0a-4daf-8194-61faeef9622a

📥 Commits

Reviewing files that changed from the base of the PR and between 2b2b0e8 and 012c473.

📒 Files selected for processing (5)
  • .github/workflows/nvidia-4090.yml
  • .github/workflows/nvidia-a100.yml
  • .github/workflows/nvidia-h100.yml
  • .github/workflows/reusable-ci-benchmarks.yml
  • scripts/run_benchmark_compare.py
✅ Files skipped from review due to trivial changes (3)
  • .github/workflows/nvidia-4090.yml
  • .github/workflows/nvidia-h100.yml
  • .github/workflows/nvidia-a100.yml

Comment on lines +207 to +216
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});

const botComment = comments.find(comment =>
comment.user.type === 'Bot' &&
comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`)
);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's check if the file exists and read the specific lines
head -n 220 .github/workflows/reusable-ci-benchmarks.yml | tail -n 20

Repository: fla-org/flash-linear-attention

Length of output: 811


🌐 Web query:

GitHub API issues listComments pagination default per_page documentation

💡 Result:

The GitHub REST API endpoint for listing comments on an issue (GET /repos/{owner}/{repo}/issues/{issue_number}/comments), often referred to as "listComments" or "list comments on an issue", supports pagination via the standard per_page and page query parameters. The default value for per_page is 30 (maximum 100). This is explicitly documented in the official GitHub REST API documentation for issue comments endpoints. Pagination uses Link headers in responses for navigating pages, as described in the general "Using pagination in the REST API" guide, which notes that endpoints like issues return 30 items by default (e.g., GET /repos/octocat/Spoon-Knife/issues). Example request: GET https://api.github.com/repos/OWNER/REPO/issues/ISSUE_NUMBER/comments?per_page=100&page=1

Citations:


Paginate the comment lookup before deciding to create a new one.

issues.listComments returns 30 items by default. If the benchmark comment appears beyond the first page, this code fails to find it and creates duplicates instead of updating the existing comment.

💡 Suggested fix
-            const { data: comments } = await github.rest.issues.listComments({
-              owner: context.repo.owner,
-              repo: context.repo.repo,
-              issue_number: context.issue.number,
-            });
+            const comments = await github.paginate(github.rest.issues.listComments, {
+              owner: context.repo.owner,
+              repo: context.repo.repo,
+              issue_number: context.issue.number,
+              per_page: 100,
+            });
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/reusable-ci-benchmarks.yml around lines 207 - 216, The
lookup for an existing benchmark comment uses github.rest.issues.listComments
which only returns the first page (30 comments), causing duplicates; replace the
single-page call with a paginated fetch (e.g., use github.paginate or iterate
pages of github.rest.issues.listComments with per_page=100) to retrieve all
comments, then search that full list for the bot comment (the logic that assigns
botComment). Keep the same matching criteria (comment.user.type === 'Bot' &&
comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`)) and
then proceed to update or create the comment as before.

Comment on lines +405 to +407
if args.no_fail_on_regression:
print("\n --no-fail-on-regression set, exiting with code 0 despite regressions.")
sys.exit(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Only suppress the exit code when a regression actually exists.

This branch fires whenever --no-fail-on-regression is set, so clean runs still log “despite regressions.” Gating it on regressions (or exit_code != 0) keeps the logs accurate and avoids masking future non-regression exit codes.

💡 Suggested fix
-        if args.no_fail_on_regression:
+        if args.no_fail_on_regression and regressions:
             print("\n  --no-fail-on-regression set, exiting with code 0 despite regressions.")
             sys.exit(0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/run_benchmark_compare.py` around lines 405 - 407, The current branch
unconditionally triggers when args.no_fail_on_regression is set, printing
"despite regressions" and exiting 0 even for clean runs; change the logic around
the block that checks args.no_fail_on_regression to only run when a regression
actually exists (e.g., regressions is truthy or exit_code != 0). Update the
conditional to something like "if args.no_fail_on_regression and (regressions or
exit_code != 0):" so the message and sys.exit(0) are only used to suppress a
failing exit code, leaving normal successful exits unchanged.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 30, 2026

⚠️ Benchmark Results (NVIDIA-H100-PT2-7)

Status: 1 regression(s) detected

GPU NVIDIA H200
CUDA 12.8
PyTorch 2.7.1+cu128
Base 46c5f6be5
Head aba2408c4
Threshold 5.0%
📊 View Details (1 significant changes)
Op Mode B T H D Base (ms) Head (ms) Change
chunk_delta_rule fwdbwd 4 2048 16 128 1.158 1.372 +18.5% 🔴

This comment is automatically updated with the latest benchmark results.

@zhiyuan1i zhiyuan1i changed the title [GDN] Add GQA support (Hq != H) [GDN] Add GVA support (Hq != H) Mar 30, 2026
@yzhangcs yzhangcs changed the title [GDN] Add GVA support (Hq != H) [GDN] Add GVA support Mar 30, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
.github/workflows/reusable-ci-benchmarks.yml (1)

212-221: ⚠️ Potential issue | 🟡 Minor

Paginate the benchmark-comment lookup.

issues.listComments only fetches the first page, so once a PR has more than 30 comments this step can miss the existing benchmark comment and create a duplicate instead of updating it.

Does github.rest.issues.listComments paginate by default, and what is the default page size for issue comments in the GitHub REST API?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/reusable-ci-benchmarks.yml around lines 212 - 221, The
current lookup uses github.rest.issues.listComments which only returns the first
page; replace it with Octokit pagination to fetch all comments (e.g., use
github.paginate(github.rest.issues.listComments, { owner: context.repo.owner,
repo: context.repo.repo, issue_number: context.issue.number }) ), then run the
existing find for botComment (checking comment.user.type === 'Bot' and
comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`));
ensure you remove or replace the single-page call to
github.rest.issues.listComments and use the paginated array for
deduplication/updating.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/ops/run.py`:
- Around line 580-592: The current --from-diff branch treats an empty
_diff.find_affected_op_names(changed) as “no changes” and returns early, but
find_affected_op_names only maps fla/ops/** and misses changes to benchmark
registry or runner code; after computing changed = _diff.get_changed_files(...)
check changed for registry/runner updates (e.g. 'benchmarks/ops/registry.py',
'benchmarks/ops/run.py', and 'scripts/run_benchmark_compare.py') and if any are
present do not return early—either continue normal execution or treat it as “all
ops affected” (so the registry changes are picked up) instead of printing "No
affected ops..." and returning; update the logic around variables changed and
op_names to implement this fallback.

In `@fla/ops/common/chunk_delta_h.py`:
- Around line 673-675: Before launching the remapped-head kernels, validate the
head counts: ensure Hq <= H and H % Hq == 0 (the same invariant enforced in
fla/ops/gated_delta_rule/chunk.py:390-399). Specifically, where B, T, Hq, K =
k.shape and H = u.shape[2] and where the mapping uses i_h // (H // Hq), add a
guard that raises/returns a clear error (or assert) if Hq > H or H % Hq != 0 to
prevent division by zero and out-of-bounds base pointer walks; apply the same
check in both entry points that set H/Hq (the occurrences around the B,T,Hq,K
and V,u/H assignments).

---

Duplicate comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 212-221: The current lookup uses github.rest.issues.listComments
which only returns the first page; replace it with Octokit pagination to fetch
all comments (e.g., use github.paginate(github.rest.issues.listComments, {
owner: context.repo.owner, repo: context.repo.repo, issue_number:
context.issue.number }) ), then run the existing find for botComment (checking
comment.user.type === 'Bot' and comment.body.includes(`Benchmark Results
(${runnerName.toUpperCase()})`)); ensure you remove or replace the single-page
call to github.rest.issues.listComments and use the paginated array for
deduplication/updating.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4d7f043e-4dfd-4aaf-8ac5-6c2a842241f6

📥 Commits

Reviewing files that changed from the base of the PR and between 012c473 and d412c84.

📒 Files selected for processing (6)
  • .github/workflows/reusable-ci-benchmarks.yml
  • benchmarks/ops/registry.py
  • benchmarks/ops/run.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • scripts/run_benchmark_compare.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/ops/gated_delta_rule/chunk.py
  • scripts/run_benchmark_compare.py

Comment thread benchmarks/ops/run.py
Comment on lines +580 to +592
if args.from_diff:
if args.op is not None:
parser.error('--from-diff cannot be used with --op')
project_root = _find_project_root()
scripts_dir = os.path.join(project_root, 'scripts')
if scripts_dir not in sys.path:
sys.path.insert(0, scripts_dir)
import run_benchmark_compare as _diff
changed = _diff.get_changed_files(args.diff_base, args.diff_head)
op_names = _diff.find_affected_op_names(changed)
if not op_names:
print('No affected ops for this diff.', file=sys.stderr)
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--from-diff misses benchmark-registry changes.

This path treats an empty find_affected_op_names() result as “nothing changed”, but that helper only maps fla/ops/** paths. A diff that only touches benchmarks/ops/registry.py, benchmarks/ops/run.py, or scripts/run_benchmark_compare.py will therefore exit early even though benchmark behavior changed; the chunk_rwkv7 registry update in this PR is one concrete example.

🛠️ Suggested fallback
         changed = _diff.get_changed_files(args.diff_base, args.diff_head)
-        op_names = _diff.find_affected_op_names(changed)
+        benchmark_meta_files = {
+            'benchmarks/ops/registry.py',
+            'benchmarks/ops/run.py',
+            'scripts/run_benchmark_compare.py',
+        }
+        op_names = (
+            list_ops()
+            if any(path in benchmark_meta_files for path in changed)
+            else _diff.find_affected_op_names(changed)
+        )
         if not op_names:
             print('No affected ops for this diff.', file=sys.stderr)
             return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/ops/run.py` around lines 580 - 592, The current --from-diff branch
treats an empty _diff.find_affected_op_names(changed) as “no changes” and
returns early, but find_affected_op_names only maps fla/ops/** and misses
changes to benchmark registry or runner code; after computing changed =
_diff.get_changed_files(...) check changed for registry/runner updates (e.g.
'benchmarks/ops/registry.py', 'benchmarks/ops/run.py', and
'scripts/run_benchmark_compare.py') and if any are present do not return
early—either continue normal execution or treat it as “all ops affected” (so the
registry changes are picked up) instead of printing "No affected ops..." and
returning; update the logic around variables changed and op_names to implement
this fallback.

Comment on lines +673 to +675
B, T, Hq, K = k.shape
V = u.shape[-1]
H = u.shape[2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate H/Hq before launching the remapped-head kernels.

Line 95 and Lines 400-401 use i_h // (H // Hq), but these wrappers never enforce the GQA invariant. If Hq > H, the divisor becomes zero; if H % Hq != 0, the mapped head can reach Hq, so the q/k base pointers walk past the (B, T, Hq, K) buffers. fla/ops/gated_delta_rule/chunk.py:390-399 already guards this, so these lower-level entry points should do the same.

🛡️ Proposed fix
 def chunk_gated_delta_rule_fwd_h(
     k: torch.Tensor,
     w: torch.Tensor,
     u: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
     B, T, Hq, K = k.shape
     V = u.shape[-1]
     H = u.shape[2]
+    if H % Hq != 0:
+        raise RuntimeError(
+            f"H (num_heads={H}) must be divisible by "
+            f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
+        )
     BT = chunk_size
@@
 def chunk_gated_delta_rule_bwd_dhu(
     q: torch.Tensor,
     k: torch.Tensor,
     w: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     B, T, Hq, K = q.shape
     V = do.shape[-1]
     H = do.shape[2]
+    if H % Hq != 0:
+        raise RuntimeError(
+            f"H (num_heads={H}) must be divisible by "
+            f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
+        )
     # N: the actual number of sequences in the batch with either equal or variable lengths

Also applies to: 737-739

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_delta_h.py` around lines 673 - 675, Before launching the
remapped-head kernels, validate the head counts: ensure Hq <= H and H % Hq == 0
(the same invariant enforced in fla/ops/gated_delta_rule/chunk.py:390-399).
Specifically, where B, T, Hq, K = k.shape and H = u.shape[2] and where the
mapping uses i_h // (H // Hq), add a guard that raises/returns a clear error (or
assert) if Hq > H or H % Hq != 0 to prevent division by zero and out-of-bounds
base pointer walks; apply the same check in both entry points that set H/Hq (the
occurrences around the B,T,Hq,K and V,u/H assignments).

@zhiyuan1i zhiyuan1i merged commit ca3905f into main Mar 31, 2026
5 checks passed
@zhiyuan1i zhiyuan1i deleted the lzy/gdn-different-kvheads branch March 31, 2026 08:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants