[GDN] Fuse kkt + solve_tril kernel & unified benchmark infrastructure#789
[GDN] Fuse kkt + solve_tril kernel & unified benchmark infrastructure#789
Conversation
Reduce the GDN forward WY representation from 3 kernel launches to 2
by fusing chunk_scaled_dot_kkt and solve_tril into a single Triton kernel,
eliminating the HBM round-trip for the intermediate A matrix.
The fused kernel computes all 10 lower-triangular [16,16] blocks of
beta * K @ K^T in registers, performs forward substitution on diagonal
blocks in-register (extracting rows via tl.sum/tl.where), then does the
block merge to produce (I+A)^{-1} — all without writing intermediate
results to HBM.
Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the Gated Delta Rule's forward pass for WY representation by consolidating two previously separate kernel launches into a single, more efficient Triton kernel. This change not only reduces the overall number of kernel calls but also minimizes memory latency by keeping intermediate computations within GPU registers, leading to a faster and more streamlined execution flow for the GDN model. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
WalkthroughReplaces a three-step intra-chunk WY construction with a single fused Triton kernel and Python entrypoint ( Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Chunk forward caller
participant Py as chunk_gated_delta_rule_fwd_intra
participant Triton as fused Triton kernel
participant Recomp as recompute_w_u_fwd
Caller->>Py: call(k, v, g, beta, cu_seqlens, ...)
Py->>Triton: launch chunk_gated_delta_rule_fwd_kkt_solve_kernel (per-chunk)
Triton-->>Py: return solved inverse blocks A
Py->>Recomp: recompute_w_u_fwd(k, v, beta, g, A, chunk_indices)
Recomp-->>Py: return w, u
Py-->>Caller: return w, u, A
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant optimization for the Gated Delta Rule's forward pass by fusing the chunk_scaled_dot_kkt and solve_tril operations into a single Triton kernel. This is a great improvement as it reduces kernel launch overhead and eliminates an HBM round-trip for the intermediate A matrix. The new kernel chunk_gated_delta_rule_fwd_kkt_solve_kernel appears to correctly implement the fused logic. My only feedback is on a code duplication within the new kernel that could be refactored for better maintainability.
| for i in range(2, min(BC, T - i_tc0)): | ||
| b_a00 = tl.sum(tl.where((o_i == i)[:, None], -b_A00, 0.), 0) | ||
| b_a00 = tl.where(o_i < i, b_a00, 0.) | ||
| b_a00 = b_a00 + tl.sum(b_a00[:, None] * b_Ai00, 0) | ||
| b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) | ||
| for i in range(2, min(BC, T - i_tc1)): | ||
| b_a11 = tl.sum(tl.where((o_i == i)[:, None], -b_A11, 0.), 0) | ||
| b_a11 = tl.where(o_i < i, b_a11, 0.) | ||
| b_a11 = b_a11 + tl.sum(b_a11[:, None] * b_Ai11, 0) | ||
| b_Ai11 = tl.where((o_i == i)[:, None], b_a11, b_Ai11) | ||
| for i in range(2, min(BC, T - i_tc2)): | ||
| b_a22 = tl.sum(tl.where((o_i == i)[:, None], -b_A22, 0.), 0) | ||
| b_a22 = tl.where(o_i < i, b_a22, 0.) | ||
| b_a22 = b_a22 + tl.sum(b_a22[:, None] * b_Ai22, 0) | ||
| b_Ai22 = tl.where((o_i == i)[:, None], b_a22, b_Ai22) | ||
| for i in range(2, min(BC, T - i_tc3)): | ||
| b_a33 = tl.sum(tl.where((o_i == i)[:, None], -b_A33, 0.), 0) | ||
| b_a33 = tl.where(o_i < i, b_a33, 0.) | ||
| b_a33 = b_a33 + tl.sum(b_a33[:, None] * b_Ai33, 0) | ||
| b_Ai33 = tl.where((o_i == i)[:, None], b_a33, b_Ai33) |
There was a problem hiding this comment.
These four for loops for forward substitution on the diagonal blocks are nearly identical. This code duplication harms readability and maintainability.
Consider refactoring this logic into a helper triton.jit function defined above this kernel. For example:
@triton.jit
def _solve_tril_diag_block(b_A, b_Ai, T_sub, BC):
o_i = tl.arange(0, BC)
for i in range(2, T_sub):
b_a = tl.sum(tl.where((o_i == i)[:, None], -b_A, 0.), 0)
b_a = tl.where(o_i < i, b_a, 0.)
b_a = b_a + tl.sum(b_a[:, None] * b_Ai, 0)
b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai)
return b_AiThen, you could replace the duplicated loops (lines 210-229) with clearer calls to this helper:
b_Ai00 = _solve_tril_diag_block(b_A00, b_Ai00, min(BC, T - i_tc0), BC)
b_Ai11 = _solve_tril_diag_block(b_A11, b_Ai11, min(BC, T - i_tc1), BC)
b_Ai22 = _solve_tril_diag_block(b_A22, b_Ai22, min(BC, T - i_tc2), BC)
b_Ai33 = _solve_tril_diag_block(b_A33, b_Ai33, min(BC, T - i_tc3), BC)There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk_fwd.py (1)
179-195: Good implicit boundary handling via beta.The off-diagonal blocks rely on the fact that
betais loaded withboundary_check=(0,), so out-of-bounds rows (where the sub-chunk extends past T) will haveb_b* = 0, effectively zeroing those rows. This is correct but implicit—consider adding a brief comment noting this invariant for clarity.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 179 - 195, Add a short clarifying comment above the off-diagonal scaling to state the invariant that beta was loaded with boundary_check=(0,), so any out-of-bounds rows in the sub-chunk produce b_b0/b_b1/b_b2/b_b3 == 0 and therefore the multiplications on b_A10, b_A20, b_A21, b_A30, b_A31, b_A32 implicitly zero those rows; reference b_b0..b_b3 and the off-diagonal blocks (b_A10, b_A20, b_A21, b_A30, b_A31, b_A32) so future readers understand the boundary behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 344-366: The code hardcodes BC=16 which assumes BT (chunk_size) ==
64 and will break for other chunk sizes; update the BC calculation and add
validation: compute BC = chunk_size // 4 and assert chunk_size % 4 == 0 and BC
>= 16 (or alternatively assert chunk_size == 64 if that constraint is intended)
before launching chunk_gated_delta_rule_fwd_kkt_solve_kernel; adjust references
to BC, BT, and the kernel launch accordingly so the kernel receives the derived
BC and validation prevents invalid chunk_size values.
---
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 179-195: Add a short clarifying comment above the off-diagonal
scaling to state the invariant that beta was loaded with boundary_check=(0,), so
any out-of-bounds rows in the sub-chunk produce b_b0/b_b1/b_b2/b_b3 == 0 and
therefore the multiplications on b_A10, b_A20, b_A21, b_A30, b_A31, b_A32
implicitly zero those rows; reference b_b0..b_b3 and the off-diagonal blocks
(b_A10, b_A20, b_A21, b_A30, b_A31, b_A32) so future readers understand the
boundary behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: cb5101f1-ad48-42d5-8f3d-4edc19af3638
📒 Files selected for processing (2)
fla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.py
| B, T, H, K = k.shape | ||
| BT = chunk_size | ||
| BC = 16 | ||
|
|
||
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
||
| # Step 1: fused kkt + solve_tril | ||
| A = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) | ||
| chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)]( | ||
| k=k, | ||
| g=g, | ||
| beta=beta, | ||
| A=A, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| BT=BT, | ||
| BC=BC, | ||
| ) |
There was a problem hiding this comment.
BC=16 hardcoded assumes chunk_size=64.
The kernel processes 4 sub-chunks of size BC=16 each, totaling BT=64. If chunk_size is changed to something other than 64, the relationship BT = 4 * BC breaks, leading to incorrect results.
Consider either:
- Adding an assertion
assert chunk_size == 64with a comment explaining the constraint, or - Computing
BC = chunk_size // 4with a check thatchunk_size % 4 == 0 and chunk_size // 4 >= 16
🛡️ Proposed fix to add assertion
B, T, H, K = k.shape
BT = chunk_size
BC = 16
+ assert BT == 4 * BC, f"chunk_size must be 64 (4 * BC), got {chunk_size}"
if chunk_indices is None and cu_seqlens is not None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 344 - 366, The code
hardcodes BC=16 which assumes BT (chunk_size) == 64 and will break for other
chunk sizes; update the BC calculation and add validation: compute BC =
chunk_size // 4 and assert chunk_size % 4 == 0 and BC >= 16 (or alternatively
assert chunk_size == 64 if that constraint is intended) before launching
chunk_gated_delta_rule_fwd_kkt_solve_kernel; adjust references to BC, BT, and
the kernel launch accordingly so the kernel receives the derived BC and
validation prevents invalid chunk_size values.
commit 7978c0b Author: yzhangcs <yzhang.cs@outlook.com> Date: Mon Mar 23 10:17:00 2026 +0000 Update README.md for clarity
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@README.md`:
- Line 45: The link text "here" is non-descriptive and triggers markdownlint
MD059; update the README.md entry "**[2025-02]** 🐳 Add NSA implementations to
`fla`. See kernels [here](fla/ops/nsa)." to use descriptive link text such as
"NSA kernels" or "NSA implementations" that points to the same target
(fla/ops/nsa), e.g., replace the "[here](fla/ops/nsa)" portion with "[NSA
kernels](fla/ops/nsa)" so the link is accessible and lint-friendly.
- Line 57: Replace the generic link text "here" in the README bullet that
mentions `flame` and `fla` (the line containing the URL
https://github.com/fla-org/flame) with a descriptive label such as "flame
repository" or "flame framework on GitHub" so the link reads like "Check out the
details in the flame repository" instead of "Check out the details here"; update
the link text only (leave the URL untouched) to satisfy accessibility/MD059
guidance.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
| - **[2025-04]** 🎉 Add DeltaProduct implementation to `fla` ([paper](https://arxiv.org/abs/2502.10297)). | ||
| - **[2025-04]** 🎉 Add FoX implementation to `fla` ([paper](https://arxiv.org/abs/2503.02130)). | ||
| - **[2025-03]** ~~We have changed the default `initializer_range` to the magic 🐳 0.006~~ The `initializer_range` was rolled back to the default value of 0.02. For actual training, we recommend trying both. | ||
| - **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa). |
There was a problem hiding this comment.
Use descriptive link text instead of “here” (Line 45).
This triggers markdownlint MD059 and reduces accessibility/clarity in rendered docs.
Suggested diff
-- **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa).
+- **[2025-02]** 🐳 Add NSA implementations to `fla`. See [NSA kernels](fla/ops/nsa).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa). | |
| - **[2025-02]** 🐳 Add NSA implementations to `fla`. See [NSA kernels](fla/ops/nsa). |
🧰 Tools
🪛 markdownlint-cli2 (0.21.0)
[warning] 45-45: Link text should be descriptive
(MD059, descriptive-link-text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@README.md` at line 45, The link text "here" is non-descriptive and triggers
markdownlint MD059; update the README.md entry "**[2025-02]** 🐳 Add NSA
implementations to `fla`. See kernels [here](fla/ops/nsa)." to use descriptive
link text such as "NSA kernels" or "NSA implementations" that points to the same
target (fla/ops/nsa), e.g., replace the "[here](fla/ops/nsa)" portion with "[NSA
kernels](fla/ops/nsa)" so the link is accessible and lint-friendly.
| - **[2024-12]** 🚀 `fla` now officially supports kernels with variable-length inputs. | ||
| - **[2024-11]** The inputs are now switched from head-first to seq-first format. | ||
| - **[2024-11]** 💥 `fla` now provides a flexible way for training hybrid models. | ||
| - **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame). |
There was a problem hiding this comment.
Replace generic “here” link text with a descriptive label (Line 57).
Same MD059 issue; descriptive text is clearer for readers and assistive tooling.
Suggested diff
-- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame).
+- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the [flame repository](https://github.com/fla-org/flame).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame). | |
| - **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the [flame repository](https://github.com/fla-org/flame). |
🧰 Tools
🪛 markdownlint-cli2 (0.21.0)
[warning] 57-57: Link text should be descriptive
(MD059, descriptive-link-text)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@README.md` at line 57, Replace the generic link text "here" in the README
bullet that mentions `flame` and `fla` (the line containing the URL
https://github.com/fla-org/flame) with a descriptive label such as "flame
repository" or "flame framework on GitHub" so the link reads like "Check out the
details in the flame repository" instead of "Check out the details here"; update
the link text only (leave the URL untouched) to satisfy accessibility/MD059
guidance.
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In @.github/workflows/nvidia-a100.yml:
- Around line 32-35: The workflow currently sets base_ref using
github.event.pull_request.base.ref which can change if the base branch advances;
change it to use the immutable pull request base SHA by setting the base_ref
input to github.event.pull_request.base.sha and ensure the job that runs
scripts/run_benchmark_compare.py uses that base_ref value for the --base
argument (look for the base_ref input key and the invocation of
scripts/run_benchmark_compare.py --base in the job steps).
In @.github/workflows/nvidia-h100.yml:
- Around line 34-37: Replace the moving branch ref input (base_ref currently set
to github.event.pull_request.base.ref) with the immutable PR base SHA
(github.event.pull_request.base.sha) so the job’s baseline is fixed; update the
workflow input named base_ref to use github.event.pull_request.base.sha and
ensure that value is passed unchanged into scripts/run_benchmark_compare.py
--base (the runner and conda_env_name settings stay the same).
In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 53-63: The check in step id "check_skip" currently gets the latest
non-merge commit via `git log --no-merges -1` which can miss the PR head;
replace that logic to explicitly read the PR head SHA and then get its subject.
Concretely, set COMMIT_SHA="${{ github.event.pull_request.head.sha }}" (or use
the expression inline) and use `git show -s --format=%s "$COMMIT_SHA"` (assign
to COMMIT_MSG) instead of `git log --no-merges -1 --pretty=%s`, then keep the
existing grep/echo logic to set skip_bench.
In `@benchmarks/ops/registry.py`:
- Around line 323-327: _rwkv7_post_init currently uses torch.randn_like(...)*0.1
which yields values centered on zero and therefore negative about half the time;
change the initialization for inputs['a'] and inputs['b'] to guarantee small
positive values (e.g., use torch.abs(torch.randn_like(...))*0.1 or
torch.rand_like(...)*0.1 + a small positive offset) and keep
requires_grad_(True) so the tensors remain trainable; update the assignments in
_rwkv7_post_init for inputs['a'] and inputs['b'] accordingly.
In `@benchmarks/ops/run.py`:
- Around line 136-137: The warmup currently always runs _fwdbwd_fn regardless of
the effective modes or per-op skip_backward flags; update the warmup logic to
honor the computed modes list and each op's skip_backward attribute: when
building modes earlier (respecting config.skip_backward and 'fwdbwd'), use that
same modes variable to decide which warmup function to call (call _fwd_fn only
if 'fwd' in modes, call _fwdbwd_fn only if 'fwdbwd' in modes) and filter out ops
with op.skip_backward=True before scheduling any 'fwdbwd' warmup. Apply the same
change to the later warmup block at the 172-182 region so both warmup locations
use modes and op.skip_backward consistently.
In `@scripts/run_benchmark_compare.py`:
- Around line 80-85: The current try/except around inserting PROJECT_ROOT into
sys.path and "from registry import _REGISTRY" swallows ImportError and returns
an empty list, which masks a broken registry and makes CI pass; change the
handler so the failure surfaces: in the except block log or print the
ImportError with details (include the exception message) and then abort with a
non-zero exit (e.g., re-raise the ImportError or call sys.exit(1)) instead of
returning []; keep the import attempt using sys.path.insert(0, str(PROJECT_ROOT
/ 'benchmarks' / 'ops')) and the "from registry import _REGISTRY" lines to
locate where to modify.
- Around line 151-159: The checkout_and_install function currently ignores pip
install failures which can cause benchmarking the wrong revision; modify
checkout_and_install to detect failure of the subprocess.run call that runs pip
(e.g., by using subprocess.run(..., check=True) or checking result.returncode)
and abort/raise an exception if installation fails, logging the captured
stdout/stderr from the failed run so callers know why the editable install of
the package (invoked in checkout_and_install) failed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 15b909dd-0084-4253-bd4a-1361d146d785
📒 Files selected for processing (24)
.github/workflows/nvidia-a100.yml.github/workflows/nvidia-h100.yml.github/workflows/reusable-ci-benchmarks.ymlbenchmarks/ops/benchmark.pybenchmarks/ops/benchmark_abc.pybenchmarks/ops/benchmark_based.pybenchmarks/ops/benchmark_delta_rule.pybenchmarks/ops/benchmark_fla.pybenchmarks/ops/benchmark_gla.pybenchmarks/ops/benchmark_gsa.pybenchmarks/ops/benchmark_hgrn.pybenchmarks/ops/benchmark_kda.pybenchmarks/ops/benchmark_nsa.pybenchmarks/ops/benchmark_retention.pybenchmarks/ops/benchmark_rwkv.pybenchmarks/ops/benchmark_rwkv7_fused_addcmul.pybenchmarks/ops/benchmark_rwkv7_k_update.pybenchmarks/ops/benchmark_simple_gla_vs_mamba2.pybenchmarks/ops/benchmark_solv_tril.pybenchmarks/ops/benchmark_titans.pybenchmarks/ops/benchmark_ttt.pybenchmarks/ops/registry.pybenchmarks/ops/run.pyscripts/run_benchmark_compare.py
💤 Files with no reviewable changes (18)
- benchmarks/ops/benchmark_hgrn.py
- benchmarks/ops/benchmark_kda.py
- benchmarks/ops/benchmark_nsa.py
- benchmarks/ops/benchmark_based.py
- benchmarks/ops/benchmark_abc.py
- benchmarks/ops/benchmark_retention.py
- benchmarks/ops/benchmark_gla.py
- benchmarks/ops/benchmark_simple_gla_vs_mamba2.py
- benchmarks/ops/benchmark_solv_tril.py
- benchmarks/ops/benchmark_delta_rule.py
- benchmarks/ops/benchmark_gsa.py
- benchmarks/ops/benchmark_ttt.py
- benchmarks/ops/benchmark_fla.py
- benchmarks/ops/benchmark_rwkv7_k_update.py
- benchmarks/ops/benchmark.py
- benchmarks/ops/benchmark_rwkv.py
- benchmarks/ops/benchmark_rwkv7_fused_addcmul.py
- benchmarks/ops/benchmark_titans.py
| with: | ||
| runner: 'nvidia-a100' | ||
| conda_env_name: 'pytorch_2_7' | ||
| base_ref: ${{ github.event.pull_request.base.ref }} |
There was a problem hiding this comment.
Pin the benchmark baseline to the PR base SHA.
Using github.event.pull_request.base.ref lets the comparison move underneath the PR if the base branch advances before this job runs. scripts/run_benchmark_compare.py --base should get the event's immutable base commit instead.
Suggested change
- base_ref: ${{ github.event.pull_request.base.ref }}
+ base_ref: ${{ github.event.pull_request.base.sha }}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| with: | |
| runner: 'nvidia-a100' | |
| conda_env_name: 'pytorch_2_7' | |
| base_ref: ${{ github.event.pull_request.base.ref }} | |
| with: | |
| runner: 'nvidia-a100' | |
| conda_env_name: 'pytorch_2_7' | |
| base_ref: ${{ github.event.pull_request.base.sha }} |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In @.github/workflows/nvidia-a100.yml around lines 32 - 35, The workflow
currently sets base_ref using github.event.pull_request.base.ref which can
change if the base branch advances; change it to use the immutable pull request
base SHA by setting the base_ref input to github.event.pull_request.base.sha and
ensure the job that runs scripts/run_benchmark_compare.py uses that base_ref
value for the --base argument (look for the base_ref input key and the
invocation of scripts/run_benchmark_compare.py --base in the job steps).
| with: | ||
| runner: 'nvidia-h100-pt2-7' | ||
| conda_env_name: 'pytorch_2_7' | ||
| base_ref: ${{ github.event.pull_request.base.ref }} |
There was a problem hiding this comment.
Use the PR base SHA here, not the base branch name.
github.event.pull_request.base.ref is a moving ref, so this job can end up benchmarking against newer main commits if the queue is backed up. The reusable workflow passes this straight to scripts/run_benchmark_compare.py --base, so the baseline should be the immutable PR event SHA.
Suggested change
- base_ref: ${{ github.event.pull_request.base.ref }}
+ base_ref: ${{ github.event.pull_request.base.sha }}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| with: | |
| runner: 'nvidia-h100-pt2-7' | |
| conda_env_name: 'pytorch_2_7' | |
| base_ref: ${{ github.event.pull_request.base.ref }} | |
| with: | |
| runner: 'nvidia-h100-pt2-7' | |
| conda_env_name: 'pytorch_2_7' | |
| base_ref: ${{ github.event.pull_request.base.sha }} |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In @.github/workflows/nvidia-h100.yml around lines 34 - 37, Replace the moving
branch ref input (base_ref currently set to github.event.pull_request.base.ref)
with the immutable PR base SHA (github.event.pull_request.base.sha) so the job’s
baseline is fixed; update the workflow input named base_ref to use
github.event.pull_request.base.sha and ensure that value is passed unchanged
into scripts/run_benchmark_compare.py --base (the runner and conda_env_name
settings stay the same).
| - name: Check skip keyword in LATEST commit | ||
| id: check_skip | ||
| run: | | ||
| COMMIT_MSG=$(git log --no-merges -1 --pretty=%s) | ||
| echo "Latest commit message: $COMMIT_MSG" | ||
| if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then | ||
| echo "::notice::Benchmarks skipped by commit message" | ||
| echo "skip_bench=true" >> $GITHUB_OUTPUT | ||
| else | ||
| echo "skip_bench=false" >> $GITHUB_OUTPUT | ||
| fi |
There was a problem hiding this comment.
Read [skip bench] from the PR head commit explicitly.
This currently inspects the latest non-merge commit reachable from the checked-out HEAD, which is indirect and can pick up the wrong message when HEAD is not the PR head commit. Since the intent is PR-scoped, query github.event.pull_request.head.sha directly.
Suggested change
- COMMIT_MSG=$(git log --no-merges -1 --pretty=%s)
+ COMMIT_MSG=$(git show -s --format=%s ${{ github.event.pull_request.head.sha }})📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - name: Check skip keyword in LATEST commit | |
| id: check_skip | |
| run: | | |
| COMMIT_MSG=$(git log --no-merges -1 --pretty=%s) | |
| echo "Latest commit message: $COMMIT_MSG" | |
| if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then | |
| echo "::notice::Benchmarks skipped by commit message" | |
| echo "skip_bench=true" >> $GITHUB_OUTPUT | |
| else | |
| echo "skip_bench=false" >> $GITHUB_OUTPUT | |
| fi | |
| - name: Check skip keyword in LATEST commit | |
| id: check_skip | |
| run: | | |
| COMMIT_MSG=$(git show -s --format=%s ${{ github.event.pull_request.head.sha }}) | |
| echo "Latest commit message: $COMMIT_MSG" | |
| if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then | |
| echo "::notice::Benchmarks skipped by commit message" | |
| echo "skip_bench=true" >> $GITHUB_OUTPUT | |
| else | |
| echo "skip_bench=false" >> $GITHUB_OUTPUT | |
| fi |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In @.github/workflows/reusable-ci-benchmarks.yml around lines 53 - 63, The check
in step id "check_skip" currently gets the latest non-merge commit via `git log
--no-merges -1` which can miss the PR head; replace that logic to explicitly
read the PR head SHA and then get its subject. Concretely, set COMMIT_SHA="${{
github.event.pull_request.head.sha }}" (or use the expression inline) and use
`git show -s --format=%s "$COMMIT_SHA"` (assign to COMMIT_MSG) instead of `git
log --no-merges -1 --pretty=%s`, then keep the existing grep/echo logic to set
skip_bench.
| def _rwkv7_post_init(inputs, B, T, H, D, **kw): | ||
| """RWKV7 needs a/b to be initialized as small positive values.""" | ||
| with torch.no_grad(): | ||
| inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True) | ||
| inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True) |
There was a problem hiding this comment.
_rwkv7_post_init is not actually generating positive a/b.
torch.randn_like(...) * 0.1 is centered at zero, so this produces negative values about half the time even though the benchmark says RWKV7 needs small positive inputs. That can put the benchmark in the wrong regime.
Suggested change
def _rwkv7_post_init(inputs, B, T, H, D, **kw):
"""RWKV7 needs a/b to be initialized as small positive values."""
with torch.no_grad():
- inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True)
- inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True)
+ inputs['a'] = (torch.rand_like(inputs['a']) * 0.1).requires_grad_(True)
+ inputs['b'] = (torch.rand_like(inputs['b']) * 0.1).requires_grad_(True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _rwkv7_post_init(inputs, B, T, H, D, **kw): | |
| """RWKV7 needs a/b to be initialized as small positive values.""" | |
| with torch.no_grad(): | |
| inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True) | |
| inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True) | |
| def _rwkv7_post_init(inputs, B, T, H, D, **kw): | |
| """RWKV7 needs a/b to be initialized as small positive values.""" | |
| with torch.no_grad(): | |
| inputs['a'] = (torch.rand_like(inputs['a']) * 0.1).requires_grad_(True) | |
| inputs['b'] = (torch.rand_like(inputs['b']) * 0.1).requires_grad_(True) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/ops/registry.py` around lines 323 - 327, _rwkv7_post_init
currently uses torch.randn_like(...)*0.1 which yields values centered on zero
and therefore negative about half the time; change the initialization for
inputs['a'] and inputs['b'] to guarantee small positive values (e.g., use
torch.abs(torch.randn_like(...))*0.1 or torch.rand_like(...)*0.1 + a small
positive offset) and keep requires_grad_(True) so the tensors remain trainable;
update the assignments in _rwkv7_post_init for inputs['a'] and inputs['b']
accordingly.
| try: | ||
| sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops')) | ||
| from registry import _REGISTRY | ||
| except ImportError: | ||
| print("Warning: could not import registry", file=sys.stderr) | ||
| return [] |
There was a problem hiding this comment.
Don't turn a broken registry import into a green benchmark job.
If benchmarks/ops/registry.py stops importing, this returns [], and main() later exits 0 with "No affected ops found". That silently disables the benchmark gate instead of failing the PR.
Suggested change
- except ImportError:
- print("Warning: could not import registry", file=sys.stderr)
- return []
+ except ImportError as e:
+ print(f"Error: could not import benchmark registry: {e}", file=sys.stderr)
+ sys.exit(1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/run_benchmark_compare.py` around lines 80 - 85, The current
try/except around inserting PROJECT_ROOT into sys.path and "from registry import
_REGISTRY" swallows ImportError and returns an empty list, which masks a broken
registry and makes CI pass; change the handler so the failure surfaces: in the
except block log or print the ImportError with details (include the exception
message) and then abort with a non-zero exit (e.g., re-raise the ImportError or
call sys.exit(1)) instead of returning []; keep the import attempt using
sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops')) and the "from
registry import _REGISTRY" lines to locate where to modify.
| def checkout_and_install(ref: str, clear_cache: bool = True): | ||
| """Checkout a ref and reinstall the package.""" | ||
| print(f"\n Checking out {ref}...") | ||
| run_cmd(["git", "checkout", ref]) | ||
| print(" Installing package...") | ||
| subprocess.run( | ||
| [sys.executable, "-m", "pip", "install", "-e", ".", "-q"], | ||
| cwd=str(PROJECT_ROOT), capture_output=True, | ||
| ) |
There was a problem hiding this comment.
Abort when the editable install fails.
The temp runner imports fla from the current environment. If pip install -e . fails here and the script keeps going, you can end up benchmarking the previously installed revision instead of the ref you just checked out.
Suggested change
def checkout_and_install(ref: str, clear_cache: bool = True):
"""Checkout a ref and reinstall the package."""
print(f"\n Checking out {ref}...")
run_cmd(["git", "checkout", ref])
print(" Installing package...")
- subprocess.run(
+ result = subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", ".", "-q"],
- cwd=str(PROJECT_ROOT), capture_output=True,
+ cwd=str(PROJECT_ROOT), capture_output=True, text=True,
)
+ if result.returncode != 0:
+ print(f" Install failed for {ref}: {result.stderr}", file=sys.stderr)
+ sys.exit(1)
if clear_cache:
clear_triton_cache()🧰 Tools
🪛 Ruff (0.15.6)
[error] 156-156: subprocess call: check for execution of untrusted input
(S603)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/run_benchmark_compare.py` around lines 151 - 159, The
checkout_and_install function currently ignores pip install failures which can
cause benchmarking the wrong revision; modify checkout_and_install to detect
failure of the subprocess.run call that runs pip (e.g., by using
subprocess.run(..., check=True) or checking result.returncode) and abort/raise
an exception if installation fails, logging the captured stdout/stderr from the
failed run so callers know why the editable install of the package (invoked in
checkout_and_install) failed.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
scripts/run_benchmark_compare.py (2)
283-286: Two separate temp directories are created but only one is documented.
tmpdir(line 283) stores JSON results whilerunner_tmpdir(line 284) stores the runner scripts. Both are cleaned up infinally, but this split is non-obvious. Consider consolidating or adding a brief comment.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/run_benchmark_compare.py` around lines 283 - 286, The code creates two temp dirs (tmpdir and runner_tmpdir via copy_runner_to_temp()) but only one is documented; update the code to either consolidate storage into a single temp dir (e.g., use tmpdir as the base and make runner files a subdirectory) or add a concise comment above the creation explaining the split responsibilities (tmpdir holds JSON results, runner_tmpdir holds runner scripts) and referencing copy_runner_to_temp() and the cleanup in the finally block so future readers understand both temporary locations.
131-134: Benchmark failure should surface more clearly in the comparison flow.When
run_unified_benchmarkreturnsFalse, the caller continues and may produce incomplete or empty results. At lines 298 and 314, failures are silently ignored until the final check at line 327-329. Consider logging a clearer warning at the call site or accumulating failure info for the final report.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/run_benchmark_compare.py` around lines 131 - 134, Call sites of run_unified_benchmark are ignoring a False return which lets the comparison continue with incomplete data; update the callers (where run_unified_benchmark(...) is invoked in the main flow) to detect a False result, immediately log a clear warning that includes identifying context (benchmark name/params) and/or append a structured failure entry to a failures list, and then ensure the final report/final-check uses that failures list to fail or summarize all benchmark errors instead of silently proceeding; reference run_unified_benchmark and the main invocation loop so the fix is applied where the function is called.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@scripts/run_benchmark_compare.py`:
- Around line 283-286: The code creates two temp dirs (tmpdir and runner_tmpdir
via copy_runner_to_temp()) but only one is documented; update the code to either
consolidate storage into a single temp dir (e.g., use tmpdir as the base and
make runner files a subdirectory) or add a concise comment above the creation
explaining the split responsibilities (tmpdir holds JSON results, runner_tmpdir
holds runner scripts) and referencing copy_runner_to_temp() and the cleanup in
the finally block so future readers understand both temporary locations.
- Around line 131-134: Call sites of run_unified_benchmark are ignoring a False
return which lets the comparison continue with incomplete data; update the
callers (where run_unified_benchmark(...) is invoked in the main flow) to detect
a False result, immediately log a clear warning that includes identifying
context (benchmark name/params) and/or append a structured failure entry to a
failures list, and then ensure the final report/final-check uses that failures
list to fail or summarize all benchmark errors instead of silently proceeding;
reference run_unified_benchmark and the main invocation loop so the fix is
applied where the function is called.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 02db21fb-c770-4854-b6ae-5e4ecdd7958b
📒 Files selected for processing (4)
.github/workflows/reusable-ci-benchmarks.ymlbenchmarks/ops/registry.pybenchmarks/ops/run.pyscripts/run_benchmark_compare.py
🚧 Files skipped from review as they are similar to previous changes (1)
- .github/workflows/reusable-ci-benchmarks.yml
Reduce the GDN forward WY representation from 3 kernel launches to 2 by fusing chunk_scaled_dot_kkt and solve_tril into a single Triton kernel, eliminating the HBM round-trip for the intermediate A matrix.
The fused kernel computes all 10 lower-triangular [16,16] blocks of beta * K @ K^T in registers, performs forward substitution on diagonal blocks in-register (extracting rows via tl.sum/tl.where), then does the block merge to produce (I+A)^{-1} — all without writing intermediate results to HBM.
Summary by CodeRabbit
New Features
Performance Improvements
Chores