Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
WalkthroughAdds grouped-query-attention (GQA) support by introducing a compile-time Triton constexpr Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Caller
participant Host as Python wrapper
participant Kernel as Triton kernel
participant Mem as GPU memory / tensors
Test->>Host: call op(q shaped (B,T,Hq,K), v shaped (B,T,H,V))
Host->>Mem: prepare pointers, derive H from v, compute Hq/H mapping
Host->>Kernel: launch kernel with constexpr Hq, H, K, T
Kernel->>Mem: load q/k using Hq-based mapping (index i_h // (H//Hq))
Kernel->>Mem: write outputs indexed by H (or temporary accumulators)
Kernel-->>Host: return outputs (A/w/d*)
Host->>Host: if Hq != H, reshape & sum grouped heads to (B,T,Hq,K)
Host->>Test: return final outputs and gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Add grouped-query attention support to all gated delta rule kernels. When Hq < H, query/key tensors have fewer heads than value/output, with head mapping: i_h // (H // Hq). Kernel changes: - chunk_delta_h: fwd_h and bwd_dhu use Hq offset/stride for k (fwd) and q/k (bwd) - chunk_o: all kernels use Hq mapping for q/k reads; dq/dk allocated at [B,T,H,K] and reduced via .view(B,T,Hq,H//Hq,K).sum(3) to avoid write race conditions - chunk_scaled_dot_kkt: k uses Hq mapping (read-only) - wy_fast: k reads use Hq mapping; dk writes at H heads, reduced post-kernel - chunk_fwd: k uses Hq mapping (read-only) - cp/chunk_delta_h: fwd/bwd merged pre-process kernels use Hq for q/k - intracard_cp: pass scalar g to pre_process_fwd_kernel_merged in pre-scan Tests: - Add test_chunk_gqa with 4 GQA configurations - Add CP GQA tests and intracard GQA test - All existing tests pass (no regression)
a36331f to
44563eb
Compare
There was a problem hiding this comment.
Code Review
This pull request implements Grouped Query Attention (GQA) support across various Triton-based operations, including gated delta rules and context-parallel backends. By introducing the Hq parameter, the kernels now support configurations where the number of query/key heads differs from value heads, with corresponding updates to offset calculations and gradient aggregation logic. The PR also includes extensive new tests for GQA in both standard and distributed settings. Feedback is provided to improve parameter naming consistency in the intracard_pre_scan function call.
|
|
||
| hm = intracard_pre_scan( | ||
| kg=k, w=w, u=u, gk=gk, | ||
| kg=k, w=w, u=u, g=g, gk=gk, |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
fla/ops/common/intracard_cp.py (1)
438-440: Unused variableHqshould be prefixed with underscore.Static analysis correctly identifies that
Hqis unpacked but never used inintracard_fwd_h. TheHqhandling is delegated to called functions (_raw_chunk_gated_delta_rule_fwd_handintracard_pre_scan).🔧 Suggested fix
- _, _, Hq, K = k.shape + _, _, _Hq, K = k.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/common/intracard_cp.py` around lines 438 - 440, In intracard_fwd_h, the unpacked variable Hq is unused and should be renamed to indicate that (prefix with an underscore) to satisfy static analysis; change the unpacking "_, _, Hq, K = k.shape" to use a prefixed name (e.g., "_Hq") so references remain correct, leaving logic that delegates Hq handling to _raw_chunk_gated_delta_rule_fwd_h and intracard_pre_scan unchanged.tests/ops/test_intracard_cache.py (1)
154-167: Backend disabling approach may be fragile.Directly manipulating
common_registry._backendsis reaching into implementation details. The_intracard_cache(cleared by fixture) is separate from the backend registry, so the isolation should work for this test. However, this approach could break if the registry internals change.Consider using a more robust mechanism like mocking or a dedicated test mode flag if available.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_intracard_cache.py` around lines 154 - 167, The test is brittle because it directly mutates the private common_registry._backends dict; instead, temporarily replace the registry via a safe mock/patch so internals aren't relied on. Update the test to use unittest.mock.patch.object or pytest.monkeypatch to set common_registry._backends to an empty dict (or call a public clear method if one exists) around the chunk_gated_delta_rule call (refer to common_registry._backends and chunk_gated_delta_rule in the diff) so the original value is automatically restored after the test.fla/ops/common/chunk_scaled_dot_kkt.py (1)
86-91: Docstring is outdated — update to reflect GQA semantics.The docstring still states
khas shape[B, T, H, K], but the implementation now interprets it as[B, T, Hq, K]whereHqmay differ fromH.📝 Suggested docstring update
Args: k (torch.Tensor): - The key tensor of shape `[B, T, H, K]`. + The key tensor of shape `[B, T, Hq, K]` where `Hq` is the number of query/key heads. beta (torch.Tensor): - The beta tensor of shape `[B, T, H]`. + The beta tensor of shape `[B, T, H]` where `H` is the number of value/output heads.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/common/chunk_scaled_dot_kkt.py` around lines 86 - 91, Update the outdated docstring in chunk_scaled_dot_kkt.py to reflect GQA semantics: state that the input k has shape [B, T, Hq, K] (where Hq may differ from H used for queries), describe that the function computes beta * K @ K^T per head across the sequence producing a tensor of shape [B, Hq, T, T] (or equivalent ordering used in the implementation), and clarify how Hq relates to any query head dimension referenced elsewhere so readers understand head-count mismatches.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 771-773: Add a validation in the layer constructors to enforce GQA
head divisibility: in the __init__ of the attention-like layers (e.g.,
SimpleGLA, GLA, or any class that defines self.num_heads and self.num_kv_heads)
assert that if self.num_kv_heads is not None then self.num_heads %
self.num_kv_heads == 0, raising a clear message like "num_heads (..) must be
evenly divisible by num_kv_heads (..)"; this prevents the view() in chunk_o.py
that reshapes dq/dk (the H/Hq split where H is num_heads and Hq is num_kv_heads)
from failing later during forward/backward.
In `@fla/ops/common/chunk_scaled_dot_kkt.py`:
- Around line 109-110: The docstring for the function in chunk_scaled_dot_kkt.py
is out-of-date: update the Args section to state that k has shape [B, T, Hq, K]
where Hq is the number of key/query heads (previously documented as [B, T, H,
K]), and add a note that when Hq < H grouped-query attention is used with head
mapping i_h // (H // Hq); also clarify that H is derived from beta.shape[2] (H =
beta.shape[2]) and that for the standard case Hq == H the behavior is unchanged.
---
Nitpick comments:
In `@fla/ops/common/chunk_scaled_dot_kkt.py`:
- Around line 86-91: Update the outdated docstring in chunk_scaled_dot_kkt.py to
reflect GQA semantics: state that the input k has shape [B, T, Hq, K] (where Hq
may differ from H used for queries), describe that the function computes beta *
K @ K^T per head across the sequence producing a tensor of shape [B, Hq, T, T]
(or equivalent ordering used in the implementation), and clarify how Hq relates
to any query head dimension referenced elsewhere so readers understand
head-count mismatches.
In `@fla/ops/common/intracard_cp.py`:
- Around line 438-440: In intracard_fwd_h, the unpacked variable Hq is unused
and should be renamed to indicate that (prefix with an underscore) to satisfy
static analysis; change the unpacking "_, _, Hq, K = k.shape" to use a prefixed
name (e.g., "_Hq") so references remain correct, leaving logic that delegates Hq
handling to _raw_chunk_gated_delta_rule_fwd_h and intracard_pre_scan unchanged.
In `@tests/ops/test_intracard_cache.py`:
- Around line 154-167: The test is brittle because it directly mutates the
private common_registry._backends dict; instead, temporarily replace the
registry via a safe mock/patch so internals aren't relied on. Update the test to
use unittest.mock.patch.object or pytest.monkeypatch to set
common_registry._backends to an empty dict (or call a public clear method if one
exists) around the chunk_gated_delta_rule call (refer to
common_registry._backends and chunk_gated_delta_rule in the diff) so the
original value is automatically restored after the test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c10fb80b-4e1c-438a-8811-57a72d8a24b3
📒 Files selected for processing (10)
fla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/chunk_scaled_dot_kkt.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/wy_fast.pytests/context_parallel/test_cp_gdn.pytests/ops/test_gated_delta.pytests/ops/test_intracard_cache.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 727-730: Add a validation that H is divisible by Hq before using
the head mapping formula: assert H % Hq == 0 (or raise a clear ValueError) so
the expression i_h // (H // Hq) cannot produce incorrect or out-of-range
indices; place this check near where H and Hq are computed (after the lines
reading B, T, Hq, K = k.shape and H = u.shape[2]) and include a descriptive
error message referencing H and Hq.
- Around line 812-814: Add the same shape divisibility checks in the backward
wrapper as in the forward pass: after computing B, T, Hq, K = q.shape and H =
do.shape[2], V = do.shape[-1], validate that Hq % K == 0 and that V is divisible
by (Hq // K) (or whatever exact relation the forward check enforces); if the
checks fail, raise a ValueError with a clear message referencing q.shape and
do.shape so the issue is obvious. Ensure these checks are placed at the start of
the backward wrapper (before further computation) and use the same error
text/logic as the forward implementation for consistency.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 84c96282-a505-48b4-a7cc-62ed720db251
📒 Files selected for processing (3)
benchmarks/cp/benchmark_chunk_delta_h_kernels.pyfla/ops/cp/chunk_delta_h.pytests/context_parallel/test_cp_bwd_gk_offset.py
💤 Files with no reviewable changes (1)
- benchmarks/cp/benchmark_chunk_delta_h_kernels.py
- Add H % Hq == 0 validation at chunk_gated_delta_rule entry - Update docstrings to reflect GQA semantics (Hq vs H) - Prefix unused Hq variable with underscore in intracard_fwd_h
There was a problem hiding this comment.
🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk.py (1)
385-391: Consider validating thatkhas the same head count asq.The validation correctly enforces
H % Hq == 0, which is essential for the kernel'si_h // (H // Hq)head mapping to work correctly. However, the docstring specifies that bothqandkshould have shape[B, T, Hq, K], but onlyq.shape[2]is used to deriveHq. If a caller inadvertently passeskwith a different head count, the kernels would silently produce incorrect results.🛡️ Proposed fix to validate k's head dimension
# Validate GQA head divisibility Hq, H = q.shape[2], v.shape[2] + if k.shape[2] != Hq: + raise ValueError( + f"q and k must have the same number of heads, but got " + f"q.shape[2]={Hq} and k.shape[2]={k.shape[2]}" + ) if H % Hq != 0: raise ValueError( f"For GQA, num_heads (H={H}) must be evenly divisible by " f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 385 - 391, Validate that k has the same head count as q before the GQA divisibility check: read Hq from q.shape[2] (as currently done) and assert k.shape[2] == Hq, raising a ValueError that calls out mismatched head counts (mention q, k, and Hq in the message). Keep this check adjacent to the existing GQA validation (the block using Hq, H and H % Hq) so the kernel mapping i_h // (H // Hq) is guaranteed correct.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 385-391: Validate that k has the same head count as q before the
GQA divisibility check: read Hq from q.shape[2] (as currently done) and assert
k.shape[2] == Hq, raising a ValueError that calls out mismatched head counts
(mention q, k, and Hq in the message). Keep this check adjacent to the existing
GQA validation (the block using Hq, H and H % Hq) so the kernel mapping i_h //
(H // Hq) is guaranteed correct.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 88dd2191-b8e6-465c-9a84-1311cd023ebf
📒 Files selected for processing (3)
fla/ops/common/chunk_scaled_dot_kkt.pyfla/ops/common/intracard_cp.pyfla/ops/gated_delta_rule/chunk.py
🚧 Files skipped from review as they are similar to previous changes (2)
- fla/ops/common/chunk_scaled_dot_kkt.py
- fla/ops/common/intracard_cp.py
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
.github/workflows/reusable-ci-benchmarks.yml (1)
134-136: Declare the write permission this new comment step needs.If the caller or repo defaults
GITHUB_TOKENto read-only, the newissues.createComment/issues.updateCommentcalls fail after the benchmark already succeeded. Declaring the requirement here makes the reusable workflow self-describing.🔐 Suggested change
jobs: benchmark: runs-on: ${{ inputs.runner }} + permissions: + contents: read + issues: write env: FLA_CI_ENV: 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In @.github/workflows/reusable-ci-benchmarks.yml around lines 134 - 136, The new "Post benchmark results to PR" step (uses: actions/github-script@v7) needs explicit write permission for issues to successfully call issues.createComment/issues.updateComment when GITHUB_TOKEN defaults to read-only; update the workflow to declare permissions: issues: write (either at the job or workflow level) so the actions/github-script step can create/update PR comments.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 207-216: The lookup for an existing benchmark comment uses
github.rest.issues.listComments which only returns the first page (30 comments),
causing duplicates; replace the single-page call with a paginated fetch (e.g.,
use github.paginate or iterate pages of github.rest.issues.listComments with
per_page=100) to retrieve all comments, then search that full list for the bot
comment (the logic that assigns botComment). Keep the same matching criteria
(comment.user.type === 'Bot' && comment.body.includes(`Benchmark Results
(${runnerName.toUpperCase()})`)) and then proceed to update or create the
comment as before.
In `@scripts/run_benchmark_compare.py`:
- Around line 405-407: The current branch unconditionally triggers when
args.no_fail_on_regression is set, printing "despite regressions" and exiting 0
even for clean runs; change the logic around the block that checks
args.no_fail_on_regression to only run when a regression actually exists (e.g.,
regressions is truthy or exit_code != 0). Update the conditional to something
like "if args.no_fail_on_regression and (regressions or exit_code != 0):" so the
message and sys.exit(0) are only used to suppress a failing exit code, leaving
normal successful exits unchanged.
---
Nitpick comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 134-136: The new "Post benchmark results to PR" step (uses:
actions/github-script@v7) needs explicit write permission for issues to
successfully call issues.createComment/issues.updateComment when GITHUB_TOKEN
defaults to read-only; update the workflow to declare permissions: issues: write
(either at the job or workflow level) so the actions/github-script step can
create/update PR comments.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bce5ccb2-1a0a-4daf-8194-61faeef9622a
📒 Files selected for processing (5)
.github/workflows/nvidia-4090.yml.github/workflows/nvidia-a100.yml.github/workflows/nvidia-h100.yml.github/workflows/reusable-ci-benchmarks.ymlscripts/run_benchmark_compare.py
✅ Files skipped from review due to trivial changes (3)
- .github/workflows/nvidia-4090.yml
- .github/workflows/nvidia-h100.yml
- .github/workflows/nvidia-a100.yml
| const { data: comments } = await github.rest.issues.listComments({ | ||
| owner: context.repo.owner, | ||
| repo: context.repo.repo, | ||
| issue_number: context.issue.number, | ||
| }); | ||
|
|
||
| const botComment = comments.find(comment => | ||
| comment.user.type === 'Bot' && | ||
| comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`) | ||
| ); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check if the file exists and read the specific lines
head -n 220 .github/workflows/reusable-ci-benchmarks.yml | tail -n 20Repository: fla-org/flash-linear-attention
Length of output: 811
🌐 Web query:
GitHub API issues listComments pagination default per_page documentation
💡 Result:
The GitHub REST API endpoint for listing comments on an issue (GET /repos/{owner}/{repo}/issues/{issue_number}/comments), often referred to as "listComments" or "list comments on an issue", supports pagination via the standard per_page and page query parameters. The default value for per_page is 30 (maximum 100). This is explicitly documented in the official GitHub REST API documentation for issue comments endpoints. Pagination uses Link headers in responses for navigating pages, as described in the general "Using pagination in the REST API" guide, which notes that endpoints like issues return 30 items by default (e.g., GET /repos/octocat/Spoon-Knife/issues). Example request: GET https://api.github.com/repos/OWNER/REPO/issues/ISSUE_NUMBER/comments?per_page=100&page=1
Citations:
- 1: https://docs.github.com/rest/using-the-rest-api/using-pagination-in-the-rest-api
- 2: https://docs.github.com/en/rest/issues/comments
- 3: https://docs.github.com/rest/issues/comments
- 4: https://www.skapi.com
- 5: https://docs.github.com/rest/issues/issues?apiVersion=2022-11-28
Paginate the comment lookup before deciding to create a new one.
issues.listComments returns 30 items by default. If the benchmark comment appears beyond the first page, this code fails to find it and creates duplicates instead of updating the existing comment.
💡 Suggested fix
- const { data: comments } = await github.rest.issues.listComments({
- owner: context.repo.owner,
- repo: context.repo.repo,
- issue_number: context.issue.number,
- });
+ const comments = await github.paginate(github.rest.issues.listComments, {
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: context.issue.number,
+ per_page: 100,
+ });🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In @.github/workflows/reusable-ci-benchmarks.yml around lines 207 - 216, The
lookup for an existing benchmark comment uses github.rest.issues.listComments
which only returns the first page (30 comments), causing duplicates; replace the
single-page call with a paginated fetch (e.g., use github.paginate or iterate
pages of github.rest.issues.listComments with per_page=100) to retrieve all
comments, then search that full list for the bot comment (the logic that assigns
botComment). Keep the same matching criteria (comment.user.type === 'Bot' &&
comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`)) and
then proceed to update or create the comment as before.
| if args.no_fail_on_regression: | ||
| print("\n --no-fail-on-regression set, exiting with code 0 despite regressions.") | ||
| sys.exit(0) |
There was a problem hiding this comment.
Only suppress the exit code when a regression actually exists.
This branch fires whenever --no-fail-on-regression is set, so clean runs still log “despite regressions.” Gating it on regressions (or exit_code != 0) keeps the logs accurate and avoids masking future non-regression exit codes.
💡 Suggested fix
- if args.no_fail_on_regression:
+ if args.no_fail_on_regression and regressions:
print("\n --no-fail-on-regression set, exiting with code 0 despite regressions.")
sys.exit(0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/run_benchmark_compare.py` around lines 405 - 407, The current branch
unconditionally triggers when args.no_fail_on_regression is set, printing
"despite regressions" and exiting 0 even for clean runs; change the logic around
the block that checks args.no_fail_on_regression to only run when a regression
actually exists (e.g., regressions is truthy or exit_code != 0). Update the
conditional to something like "if args.no_fail_on_regression and (regressions or
exit_code != 0):" so the message and sys.exit(0) are only used to suppress a
failing exit code, leaving normal successful exits unchanged.
|
| GPU | NVIDIA H200 |
| CUDA | 12.8 |
| PyTorch | 2.7.1+cu128 |
| Base | 46c5f6be5 |
| Head | aba2408c4 |
| Threshold | 5.0% |
📊 View Details (1 significant changes)
| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Change |
|---|---|---|---|---|---|---|---|---|
| chunk_delta_rule | fwdbwd | 4 | 2048 | 16 | 128 | 1.158 | 1.372 | +18.5% 🔴 |
This comment is automatically updated with the latest benchmark results.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
.github/workflows/reusable-ci-benchmarks.yml (1)
212-221:⚠️ Potential issue | 🟡 MinorPaginate the benchmark-comment lookup.
issues.listCommentsonly fetches the first page, so once a PR has more than 30 comments this step can miss the existing benchmark comment and create a duplicate instead of updating it.Does github.rest.issues.listComments paginate by default, and what is the default page size for issue comments in the GitHub REST API?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In @.github/workflows/reusable-ci-benchmarks.yml around lines 212 - 221, The current lookup uses github.rest.issues.listComments which only returns the first page; replace it with Octokit pagination to fetch all comments (e.g., use github.paginate(github.rest.issues.listComments, { owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number }) ), then run the existing find for botComment (checking comment.user.type === 'Bot' and comment.body.includes(`Benchmark Results (${runnerName.toUpperCase()})`)); ensure you remove or replace the single-page call to github.rest.issues.listComments and use the paginated array for deduplication/updating.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/ops/run.py`:
- Around line 580-592: The current --from-diff branch treats an empty
_diff.find_affected_op_names(changed) as “no changes” and returns early, but
find_affected_op_names only maps fla/ops/** and misses changes to benchmark
registry or runner code; after computing changed = _diff.get_changed_files(...)
check changed for registry/runner updates (e.g. 'benchmarks/ops/registry.py',
'benchmarks/ops/run.py', and 'scripts/run_benchmark_compare.py') and if any are
present do not return early—either continue normal execution or treat it as “all
ops affected” (so the registry changes are picked up) instead of printing "No
affected ops..." and returning; update the logic around variables changed and
op_names to implement this fallback.
In `@fla/ops/common/chunk_delta_h.py`:
- Around line 673-675: Before launching the remapped-head kernels, validate the
head counts: ensure Hq <= H and H % Hq == 0 (the same invariant enforced in
fla/ops/gated_delta_rule/chunk.py:390-399). Specifically, where B, T, Hq, K =
k.shape and H = u.shape[2] and where the mapping uses i_h // (H // Hq), add a
guard that raises/returns a clear error (or assert) if Hq > H or H % Hq != 0 to
prevent division by zero and out-of-bounds base pointer walks; apply the same
check in both entry points that set H/Hq (the occurrences around the B,T,Hq,K
and V,u/H assignments).
---
Duplicate comments:
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 212-221: The current lookup uses github.rest.issues.listComments
which only returns the first page; replace it with Octokit pagination to fetch
all comments (e.g., use github.paginate(github.rest.issues.listComments, {
owner: context.repo.owner, repo: context.repo.repo, issue_number:
context.issue.number }) ), then run the existing find for botComment (checking
comment.user.type === 'Bot' and comment.body.includes(`Benchmark Results
(${runnerName.toUpperCase()})`)); ensure you remove or replace the single-page
call to github.rest.issues.listComments and use the paginated array for
deduplication/updating.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4d7f043e-4dfd-4aaf-8ac5-6c2a842241f6
📒 Files selected for processing (6)
.github/workflows/reusable-ci-benchmarks.ymlbenchmarks/ops/registry.pybenchmarks/ops/run.pyfla/ops/common/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyscripts/run_benchmark_compare.py
🚧 Files skipped from review as they are similar to previous changes (2)
- fla/ops/gated_delta_rule/chunk.py
- scripts/run_benchmark_compare.py
| if args.from_diff: | ||
| if args.op is not None: | ||
| parser.error('--from-diff cannot be used with --op') | ||
| project_root = _find_project_root() | ||
| scripts_dir = os.path.join(project_root, 'scripts') | ||
| if scripts_dir not in sys.path: | ||
| sys.path.insert(0, scripts_dir) | ||
| import run_benchmark_compare as _diff | ||
| changed = _diff.get_changed_files(args.diff_base, args.diff_head) | ||
| op_names = _diff.find_affected_op_names(changed) | ||
| if not op_names: | ||
| print('No affected ops for this diff.', file=sys.stderr) | ||
| return |
There was a problem hiding this comment.
--from-diff misses benchmark-registry changes.
This path treats an empty find_affected_op_names() result as “nothing changed”, but that helper only maps fla/ops/** paths. A diff that only touches benchmarks/ops/registry.py, benchmarks/ops/run.py, or scripts/run_benchmark_compare.py will therefore exit early even though benchmark behavior changed; the chunk_rwkv7 registry update in this PR is one concrete example.
🛠️ Suggested fallback
changed = _diff.get_changed_files(args.diff_base, args.diff_head)
- op_names = _diff.find_affected_op_names(changed)
+ benchmark_meta_files = {
+ 'benchmarks/ops/registry.py',
+ 'benchmarks/ops/run.py',
+ 'scripts/run_benchmark_compare.py',
+ }
+ op_names = (
+ list_ops()
+ if any(path in benchmark_meta_files for path in changed)
+ else _diff.find_affected_op_names(changed)
+ )
if not op_names:
print('No affected ops for this diff.', file=sys.stderr)
return🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/ops/run.py` around lines 580 - 592, The current --from-diff branch
treats an empty _diff.find_affected_op_names(changed) as “no changes” and
returns early, but find_affected_op_names only maps fla/ops/** and misses
changes to benchmark registry or runner code; after computing changed =
_diff.get_changed_files(...) check changed for registry/runner updates (e.g.
'benchmarks/ops/registry.py', 'benchmarks/ops/run.py', and
'scripts/run_benchmark_compare.py') and if any are present do not return
early—either continue normal execution or treat it as “all ops affected” (so the
registry changes are picked up) instead of printing "No affected ops..." and
returning; update the logic around variables changed and op_names to implement
this fallback.
| B, T, Hq, K = k.shape | ||
| V = u.shape[-1] | ||
| H = u.shape[2] |
There was a problem hiding this comment.
Validate H/Hq before launching the remapped-head kernels.
Line 95 and Lines 400-401 use i_h // (H // Hq), but these wrappers never enforce the GQA invariant. If Hq > H, the divisor becomes zero; if H % Hq != 0, the mapped head can reach Hq, so the q/k base pointers walk past the (B, T, Hq, K) buffers. fla/ops/gated_delta_rule/chunk.py:390-399 already guards this, so these lower-level entry points should do the same.
🛡️ Proposed fix
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
@@
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
B, T, Hq, K = k.shape
V = u.shape[-1]
H = u.shape[2]
+ if H % Hq != 0:
+ raise RuntimeError(
+ f"H (num_heads={H}) must be divisible by "
+ f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
+ )
BT = chunk_size
@@
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
@@
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, Hq, K = q.shape
V = do.shape[-1]
H = do.shape[2]
+ if H % Hq != 0:
+ raise RuntimeError(
+ f"H (num_heads={H}) must be divisible by "
+ f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
+ )
# N: the actual number of sequences in the batch with either equal or variable lengthsAlso applies to: 737-739
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/common/chunk_delta_h.py` around lines 673 - 675, Before launching the
remapped-head kernels, validate the head counts: ensure Hq <= H and H % Hq == 0
(the same invariant enforced in fla/ops/gated_delta_rule/chunk.py:390-399).
Specifically, where B, T, Hq, K = k.shape and H = u.shape[2] and where the
mapping uses i_h // (H // Hq), add a guard that raises/returns a clear error (or
assert) if Hq > H or H % Hq != 0 to prevent division by zero and out-of-bounds
base pointer walks; apply the same check in both entry points that set H/Hq (the
occurrences around the B,T,Hq,K and V,u/H assignments).
Add grouped-query attention support to all gated delta rule kernels. When Hq < H, query/key tensors have fewer heads than value/output, with head mapping: i_h // (H // Hq).
Kernel changes:
Tests:
Summary by CodeRabbit
New Features
Tests
Chores