Skip to content

[GDN] Fuse kkt + solve_tril kernel & unified benchmark infrastructure#789

Merged
yzhangcs merged 17 commits intomainfrom
fuse-gdn-kkt-solve
Mar 24, 2026
Merged

[GDN] Fuse kkt + solve_tril kernel & unified benchmark infrastructure#789
yzhangcs merged 17 commits intomainfrom
fuse-gdn-kkt-solve

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Mar 22, 2026

Reduce the GDN forward WY representation from 3 kernel launches to 2 by fusing chunk_scaled_dot_kkt and solve_tril into a single Triton kernel, eliminating the HBM round-trip for the intermediate A matrix.

The fused kernel computes all 10 lower-triangular [16,16] blocks of beta * K @ K^T in registers, performs forward substitution on diagonal blocks in-register (extracting rows via tl.sum/tl.where), then does the block merge to produce (I+A)^{-1} — all without writing intermediate results to HBM.

Summary by CodeRabbit

  • New Features

    • Added a fused forward kernel for the gated delta-rule that returns solved inverse blocks alongside final outputs.
    • Added a centralized benchmark registry, a unified benchmark runner CLI, and a benchmark-compare script.
    • Added CI benchmark workflows for A100 and H100.
  • Performance Improvements

    • Fused execution reduces intermediate memory traffic and improves throughput.
  • Chores

    • Removed many legacy standalone benchmark scripts in favor of the new registry/runner.

Reduce the GDN forward WY representation from 3 kernel launches to 2
by fusing chunk_scaled_dot_kkt and solve_tril into a single Triton kernel,
eliminating the HBM round-trip for the intermediate A matrix.

The fused kernel computes all 10 lower-triangular [16,16] blocks of
beta * K @ K^T in registers, performs forward substitution on diagonal
blocks in-register (extracting rows via tl.sum/tl.where), then does the
block merge to produce (I+A)^{-1} — all without writing intermediate
results to HBM.

Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the Gated Delta Rule's forward pass for WY representation by consolidating two previously separate kernel launches into a single, more efficient Triton kernel. This change not only reduces the overall number of kernel calls but also minimizes memory latency by keeping intermediate computations within GPU registers, leading to a faster and more streamlined execution flow for the GDN model.

Highlights

  • Kernel Fusion: The chunk_scaled_dot_kkt_fwd and solve_tril operations for the Gated Delta Rule (GDN) forward pass in WY representation have been fused into a single Triton kernel.
  • Performance Optimization: This fusion reduces the number of kernel launches required for the GDN forward WY representation from three to two, improving computational efficiency.
  • Memory Access Reduction: The intermediate A matrix no longer requires a round-trip to High Bandwidth Memory (HBM), as its computation and subsequent operations are now performed in registers.
  • In-Register Computation: The new fused kernel computes all lower-triangular blocks of beta * K @ K^T, performs forward substitution on diagonal blocks, and executes the block merge to produce (I+A)^{-1} entirely within GPU registers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Replaces a three-step intra-chunk WY construction with a single fused Triton kernel and Python entrypoint (chunk_gated_delta_rule_fwd_intra) that produces solved inverse blocks A and returns (w, u, A); also adds a new benchmark registry/runner and converts CI to reusable benchmark workflows while deleting many ad-hoc benchmark scripts.

Changes

Cohort / File(s) Summary
Gated delta-rule caller
fla/ops/gated_delta_rule/chunk.py
Replaced explicit calls to chunk_scaled_dot_kkt_fwd + solve_tril + recompute_w_u_fwd with a single call to chunk_gated_delta_rule_fwd_intra; removed unused imports.
Fused Triton kernel & fused-forward entrypoint
fla/ops/gated_delta_rule/chunk_fwd.py
Added chunk_gated_delta_rule_fwd_kkt_solve_kernel (fused KKT compute + forward substitution) and new public chunk_gated_delta_rule_fwd_intra(...) that allocates/returns A and invokes recompute_w_u_fwd to produce (w,u,A).
CI: reusable benchmark workflows
.github/workflows/reusable-ci-benchmarks.yml, .github/workflows/nvidia-a100.yml, .github/workflows/nvidia-h100.yml
Added a reusable CI benchmark workflow and two runner-specific workflow callers (benchmark-a100, benchmark-h100) that invoke it with runner/conda inputs.
Removed ad-hoc benchmarks
benchmarks/ops/benchmark_*.py
benchmarks/ops/benchmark.py, benchmark_abc.py, benchmark_based.py, benchmark_delta_rule.py, benchmark_fla.py, benchmark_gla.py, benchmark_gsa.py, benchmark_hgrn.py, benchmark_kda.py, benchmark_nsa.py, benchmark_retention.py, benchmark_rwkv.py, benchmark_rwkv7_fused_addcmul.py, benchmark_rwkv7_k_update.py, benchmark_simple_gla_vs_mamba2.py, benchmark_solv_tril.py, benchmark_titans.py, benchmark_ttt.py
Removed many standalone Triton benchmark scripts and their perf_report entrypoints.
Benchmark infra added
benchmarks/ops/registry.py, benchmarks/ops/run.py, scripts/run_benchmark_compare.py
Added benchmark registry (registry.py) with op/tensor specs, unified CLI runner/warmup/timing (run.py), and a compare script (scripts/run_benchmark_compare.py) to run base/head benchmarks and report regressions.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Chunk forward caller
    participant Py as chunk_gated_delta_rule_fwd_intra
    participant Triton as fused Triton kernel
    participant Recomp as recompute_w_u_fwd

    Caller->>Py: call(k, v, g, beta, cu_seqlens, ...)
    Py->>Triton: launch chunk_gated_delta_rule_fwd_kkt_solve_kernel (per-chunk)
    Triton-->>Py: return solved inverse blocks A
    Py->>Recomp: recompute_w_u_fwd(k, v, beta, g, A, chunk_indices)
    Recomp-->>Py: return w, u
    Py-->>Caller: return w, u, A
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • zhiyuan1i

Poem

"I hopped through kernels, stitched three to one,
solved blocks in-register beneath the sun.
W and U snug, A gleams in place,
benchmarks reborn — a rabbit's happy race.
🐇"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly describes the main change: fusing KKT + solve_tril kernels and introducing unified benchmark infrastructure. It matches the primary objective of reducing kernel launches and HBM access.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fuse-gdn-kkt-solve

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for the Gated Delta Rule's forward pass by fusing the chunk_scaled_dot_kkt and solve_tril operations into a single Triton kernel. This is a great improvement as it reduces kernel launch overhead and eliminates an HBM round-trip for the intermediate A matrix. The new kernel chunk_gated_delta_rule_fwd_kkt_solve_kernel appears to correctly implement the fused logic. My only feedback is on a code duplication within the new kernel that could be refactored for better maintainability.

Comment on lines +210 to +229
for i in range(2, min(BC, T - i_tc0)):
b_a00 = tl.sum(tl.where((o_i == i)[:, None], -b_A00, 0.), 0)
b_a00 = tl.where(o_i < i, b_a00, 0.)
b_a00 = b_a00 + tl.sum(b_a00[:, None] * b_Ai00, 0)
b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00)
for i in range(2, min(BC, T - i_tc1)):
b_a11 = tl.sum(tl.where((o_i == i)[:, None], -b_A11, 0.), 0)
b_a11 = tl.where(o_i < i, b_a11, 0.)
b_a11 = b_a11 + tl.sum(b_a11[:, None] * b_Ai11, 0)
b_Ai11 = tl.where((o_i == i)[:, None], b_a11, b_Ai11)
for i in range(2, min(BC, T - i_tc2)):
b_a22 = tl.sum(tl.where((o_i == i)[:, None], -b_A22, 0.), 0)
b_a22 = tl.where(o_i < i, b_a22, 0.)
b_a22 = b_a22 + tl.sum(b_a22[:, None] * b_Ai22, 0)
b_Ai22 = tl.where((o_i == i)[:, None], b_a22, b_Ai22)
for i in range(2, min(BC, T - i_tc3)):
b_a33 = tl.sum(tl.where((o_i == i)[:, None], -b_A33, 0.), 0)
b_a33 = tl.where(o_i < i, b_a33, 0.)
b_a33 = b_a33 + tl.sum(b_a33[:, None] * b_Ai33, 0)
b_Ai33 = tl.where((o_i == i)[:, None], b_a33, b_Ai33)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These four for loops for forward substitution on the diagonal blocks are nearly identical. This code duplication harms readability and maintainability.

Consider refactoring this logic into a helper triton.jit function defined above this kernel. For example:

@triton.jit
def _solve_tril_diag_block(b_A, b_Ai, T_sub, BC):
    o_i = tl.arange(0, BC)
    for i in range(2, T_sub):
        b_a = tl.sum(tl.where((o_i == i)[:, None], -b_A, 0.), 0)
        b_a = tl.where(o_i < i, b_a, 0.)
        b_a = b_a + tl.sum(b_a[:, None] * b_Ai, 0)
        b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai)
    return b_Ai

Then, you could replace the duplicated loops (lines 210-229) with clearer calls to this helper:

b_Ai00 = _solve_tril_diag_block(b_A00, b_Ai00, min(BC, T - i_tc0), BC)
b_Ai11 = _solve_tril_diag_block(b_A11, b_Ai11, min(BC, T - i_tc1), BC)
b_Ai22 = _solve_tril_diag_block(b_A22, b_Ai22, min(BC, T - i_tc2), BC)
b_Ai33 = _solve_tril_diag_block(b_A33, b_Ai33, min(BC, T - i_tc3), BC)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
fla/ops/gated_delta_rule/chunk_fwd.py (1)

179-195: Good implicit boundary handling via beta.

The off-diagonal blocks rely on the fact that beta is loaded with boundary_check=(0,), so out-of-bounds rows (where the sub-chunk extends past T) will have b_b* = 0, effectively zeroing those rows. This is correct but implicit—consider adding a brief comment noting this invariant for clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 179 - 195, Add a short
clarifying comment above the off-diagonal scaling to state the invariant that
beta was loaded with boundary_check=(0,), so any out-of-bounds rows in the
sub-chunk produce b_b0/b_b1/b_b2/b_b3 == 0 and therefore the multiplications on
b_A10, b_A20, b_A21, b_A30, b_A31, b_A32 implicitly zero those rows; reference
b_b0..b_b3 and the off-diagonal blocks (b_A10, b_A20, b_A21, b_A30, b_A31,
b_A32) so future readers understand the boundary behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 344-366: The code hardcodes BC=16 which assumes BT (chunk_size) ==
64 and will break for other chunk sizes; update the BC calculation and add
validation: compute BC = chunk_size // 4 and assert chunk_size % 4 == 0 and BC
>= 16 (or alternatively assert chunk_size == 64 if that constraint is intended)
before launching chunk_gated_delta_rule_fwd_kkt_solve_kernel; adjust references
to BC, BT, and the kernel launch accordingly so the kernel receives the derived
BC and validation prevents invalid chunk_size values.

---

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk_fwd.py`:
- Around line 179-195: Add a short clarifying comment above the off-diagonal
scaling to state the invariant that beta was loaded with boundary_check=(0,), so
any out-of-bounds rows in the sub-chunk produce b_b0/b_b1/b_b2/b_b3 == 0 and
therefore the multiplications on b_A10, b_A20, b_A21, b_A30, b_A31, b_A32
implicitly zero those rows; reference b_b0..b_b3 and the off-diagonal blocks
(b_A10, b_A20, b_A21, b_A30, b_A31, b_A32) so future readers understand the
boundary behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: cb5101f1-ad48-42d5-8f3d-4edc19af3638

📥 Commits

Reviewing files that changed from the base of the PR and between 04f9c17 and 4274513.

📒 Files selected for processing (2)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py

Comment on lines +344 to +366
B, T, H, K = k.shape
BT = chunk_size
BC = 16

if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

# Step 1: fused kkt + solve_tril
A = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype)
chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * H)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

BC=16 hardcoded assumes chunk_size=64.

The kernel processes 4 sub-chunks of size BC=16 each, totaling BT=64. If chunk_size is changed to something other than 64, the relationship BT = 4 * BC breaks, leading to incorrect results.

Consider either:

  1. Adding an assertion assert chunk_size == 64 with a comment explaining the constraint, or
  2. Computing BC = chunk_size // 4 with a check that chunk_size % 4 == 0 and chunk_size // 4 >= 16
🛡️ Proposed fix to add assertion
     B, T, H, K = k.shape
     BT = chunk_size
     BC = 16
+    assert BT == 4 * BC, f"chunk_size must be 64 (4 * BC), got {chunk_size}"

     if chunk_indices is None and cu_seqlens is not None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk_fwd.py` around lines 344 - 366, The code
hardcodes BC=16 which assumes BT (chunk_size) == 64 and will break for other
chunk sizes; update the BC calculation and add validation: compute BC =
chunk_size // 4 and assert chunk_size % 4 == 0 and BC >= 16 (or alternatively
assert chunk_size == 64 if that constraint is intended) before launching
chunk_gated_delta_rule_fwd_kkt_solve_kernel; adjust references to BC, BT, and
the kernel launch accordingly so the kernel receives the derived BC and
validation prevents invalid chunk_size values.

commit 7978c0b
Author: yzhangcs <yzhang.cs@outlook.com>
Date:   Mon Mar 23 10:17:00 2026 +0000

    Update README.md for clarity
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@README.md`:
- Line 45: The link text "here" is non-descriptive and triggers markdownlint
MD059; update the README.md entry "**[2025-02]** 🐳 Add NSA implementations to
`fla`. See kernels [here](fla/ops/nsa)." to use descriptive link text such as
"NSA kernels" or "NSA implementations" that points to the same target
(fla/ops/nsa), e.g., replace the "[here](fla/ops/nsa)" portion with "[NSA
kernels](fla/ops/nsa)" so the link is accessible and lint-friendly.
- Line 57: Replace the generic link text "here" in the README bullet that
mentions `flame` and `fla` (the line containing the URL
https://github.com/fla-org/flame) with a descriptive label such as "flame
repository" or "flame framework on GitHub" so the link reads like "Check out the
details in the flame repository" instead of "Check out the details here"; update
the link text only (leave the URL untouched) to satisfy accessibility/MD059
guidance.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a43adb81-aa20-4da3-afcf-d3d939f0b9c4

📥 Commits

Reviewing files that changed from the base of the PR and between 4274513 and 25ae389.

📒 Files selected for processing (1)
  • README.md

Comment thread README.md
- **[2025-04]** 🎉 Add DeltaProduct implementation to `fla` ([paper](https://arxiv.org/abs/2502.10297)).
- **[2025-04]** 🎉 Add FoX implementation to `fla` ([paper](https://arxiv.org/abs/2503.02130)).
- **[2025-03]** ~~We have changed the default `initializer_range` to the magic 🐳 0.006~~ The `initializer_range` was rolled back to the default value of 0.02. For actual training, we recommend trying both.
- **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use descriptive link text instead of “here” (Line 45).

This triggers markdownlint MD059 and reduces accessibility/clarity in rendered docs.

Suggested diff
-- **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa).
+- **[2025-02]** 🐳 Add NSA implementations to `fla`. See [NSA kernels](fla/ops/nsa).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- **[2025-02]** 🐳 Add NSA implementations to `fla`. See kernels [here](fla/ops/nsa).
- **[2025-02]** 🐳 Add NSA implementations to `fla`. See [NSA kernels](fla/ops/nsa).
🧰 Tools
🪛 markdownlint-cli2 (0.21.0)

[warning] 45-45: Link text should be descriptive

(MD059, descriptive-link-text)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` at line 45, The link text "here" is non-descriptive and triggers
markdownlint MD059; update the README.md entry "**[2025-02]** 🐳 Add NSA
implementations to `fla`. See kernels [here](fla/ops/nsa)." to use descriptive
link text such as "NSA kernels" or "NSA implementations" that points to the same
target (fla/ops/nsa), e.g., replace the "[here](fla/ops/nsa)" portion with "[NSA
kernels](fla/ops/nsa)" so the link is accessible and lint-friendly.

Comment thread README.md
- **[2024-12]** 🚀 `fla` now officially supports kernels with variable-length inputs.
- **[2024-11]** The inputs are now switched from head-first to seq-first format.
- **[2024-11]** 💥 `fla` now provides a flexible way for training hybrid models.
- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace generic “here” link text with a descriptive label (Line 57).

Same MD059 issue; descriptive text is clearer for readers and assistive tooling.

Suggested diff
-- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame).
+- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the [flame repository](https://github.com/fla-org/flame).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the details [here](https://github.com/fla-org/flame).
- **[2024-10]** 🔥 Announcing `flame`, a minimal and scalable framework for training `fla` models. Check out the [flame repository](https://github.com/fla-org/flame).
🧰 Tools
🪛 markdownlint-cli2 (0.21.0)

[warning] 57-57: Link text should be descriptive

(MD059, descriptive-link-text)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@README.md` at line 57, Replace the generic link text "here" in the README
bullet that mentions `flame` and `fla` (the line containing the URL
https://github.com/fla-org/flame) with a descriptive label such as "flame
repository" or "flame framework on GitHub" so the link reads like "Check out the
details in the flame repository" instead of "Check out the details here"; update
the link text only (leave the URL untouched) to satisfy accessibility/MD059
guidance.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In @.github/workflows/nvidia-a100.yml:
- Around line 32-35: The workflow currently sets base_ref using
github.event.pull_request.base.ref which can change if the base branch advances;
change it to use the immutable pull request base SHA by setting the base_ref
input to github.event.pull_request.base.sha and ensure the job that runs
scripts/run_benchmark_compare.py uses that base_ref value for the --base
argument (look for the base_ref input key and the invocation of
scripts/run_benchmark_compare.py --base in the job steps).

In @.github/workflows/nvidia-h100.yml:
- Around line 34-37: Replace the moving branch ref input (base_ref currently set
to github.event.pull_request.base.ref) with the immutable PR base SHA
(github.event.pull_request.base.sha) so the job’s baseline is fixed; update the
workflow input named base_ref to use github.event.pull_request.base.sha and
ensure that value is passed unchanged into scripts/run_benchmark_compare.py
--base (the runner and conda_env_name settings stay the same).

In @.github/workflows/reusable-ci-benchmarks.yml:
- Around line 53-63: The check in step id "check_skip" currently gets the latest
non-merge commit via `git log --no-merges -1` which can miss the PR head;
replace that logic to explicitly read the PR head SHA and then get its subject.
Concretely, set COMMIT_SHA="${{ github.event.pull_request.head.sha }}" (or use
the expression inline) and use `git show -s --format=%s "$COMMIT_SHA"` (assign
to COMMIT_MSG) instead of `git log --no-merges -1 --pretty=%s`, then keep the
existing grep/echo logic to set skip_bench.

In `@benchmarks/ops/registry.py`:
- Around line 323-327: _rwkv7_post_init currently uses torch.randn_like(...)*0.1
which yields values centered on zero and therefore negative about half the time;
change the initialization for inputs['a'] and inputs['b'] to guarantee small
positive values (e.g., use torch.abs(torch.randn_like(...))*0.1 or
torch.rand_like(...)*0.1 + a small positive offset) and keep
requires_grad_(True) so the tensors remain trainable; update the assignments in
_rwkv7_post_init for inputs['a'] and inputs['b'] accordingly.

In `@benchmarks/ops/run.py`:
- Around line 136-137: The warmup currently always runs _fwdbwd_fn regardless of
the effective modes or per-op skip_backward flags; update the warmup logic to
honor the computed modes list and each op's skip_backward attribute: when
building modes earlier (respecting config.skip_backward and 'fwdbwd'), use that
same modes variable to decide which warmup function to call (call _fwd_fn only
if 'fwd' in modes, call _fwdbwd_fn only if 'fwdbwd' in modes) and filter out ops
with op.skip_backward=True before scheduling any 'fwdbwd' warmup. Apply the same
change to the later warmup block at the 172-182 region so both warmup locations
use modes and op.skip_backward consistently.

In `@scripts/run_benchmark_compare.py`:
- Around line 80-85: The current try/except around inserting PROJECT_ROOT into
sys.path and "from registry import _REGISTRY" swallows ImportError and returns
an empty list, which masks a broken registry and makes CI pass; change the
handler so the failure surfaces: in the except block log or print the
ImportError with details (include the exception message) and then abort with a
non-zero exit (e.g., re-raise the ImportError or call sys.exit(1)) instead of
returning []; keep the import attempt using sys.path.insert(0, str(PROJECT_ROOT
/ 'benchmarks' / 'ops')) and the "from registry import _REGISTRY" lines to
locate where to modify.
- Around line 151-159: The checkout_and_install function currently ignores pip
install failures which can cause benchmarking the wrong revision; modify
checkout_and_install to detect failure of the subprocess.run call that runs pip
(e.g., by using subprocess.run(..., check=True) or checking result.returncode)
and abort/raise an exception if installation fails, logging the captured
stdout/stderr from the failed run so callers know why the editable install of
the package (invoked in checkout_and_install) failed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 15b909dd-0084-4253-bd4a-1361d146d785

📥 Commits

Reviewing files that changed from the base of the PR and between 9fe7d6e and 046595a.

📒 Files selected for processing (24)
  • .github/workflows/nvidia-a100.yml
  • .github/workflows/nvidia-h100.yml
  • .github/workflows/reusable-ci-benchmarks.yml
  • benchmarks/ops/benchmark.py
  • benchmarks/ops/benchmark_abc.py
  • benchmarks/ops/benchmark_based.py
  • benchmarks/ops/benchmark_delta_rule.py
  • benchmarks/ops/benchmark_fla.py
  • benchmarks/ops/benchmark_gla.py
  • benchmarks/ops/benchmark_gsa.py
  • benchmarks/ops/benchmark_hgrn.py
  • benchmarks/ops/benchmark_kda.py
  • benchmarks/ops/benchmark_nsa.py
  • benchmarks/ops/benchmark_retention.py
  • benchmarks/ops/benchmark_rwkv.py
  • benchmarks/ops/benchmark_rwkv7_fused_addcmul.py
  • benchmarks/ops/benchmark_rwkv7_k_update.py
  • benchmarks/ops/benchmark_simple_gla_vs_mamba2.py
  • benchmarks/ops/benchmark_solv_tril.py
  • benchmarks/ops/benchmark_titans.py
  • benchmarks/ops/benchmark_ttt.py
  • benchmarks/ops/registry.py
  • benchmarks/ops/run.py
  • scripts/run_benchmark_compare.py
💤 Files with no reviewable changes (18)
  • benchmarks/ops/benchmark_hgrn.py
  • benchmarks/ops/benchmark_kda.py
  • benchmarks/ops/benchmark_nsa.py
  • benchmarks/ops/benchmark_based.py
  • benchmarks/ops/benchmark_abc.py
  • benchmarks/ops/benchmark_retention.py
  • benchmarks/ops/benchmark_gla.py
  • benchmarks/ops/benchmark_simple_gla_vs_mamba2.py
  • benchmarks/ops/benchmark_solv_tril.py
  • benchmarks/ops/benchmark_delta_rule.py
  • benchmarks/ops/benchmark_gsa.py
  • benchmarks/ops/benchmark_ttt.py
  • benchmarks/ops/benchmark_fla.py
  • benchmarks/ops/benchmark_rwkv7_k_update.py
  • benchmarks/ops/benchmark.py
  • benchmarks/ops/benchmark_rwkv.py
  • benchmarks/ops/benchmark_rwkv7_fused_addcmul.py
  • benchmarks/ops/benchmark_titans.py

Comment thread .github/workflows/nvidia-a100.yml Outdated
Comment on lines +32 to +35
with:
runner: 'nvidia-a100'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.ref }}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pin the benchmark baseline to the PR base SHA.

Using github.event.pull_request.base.ref lets the comparison move underneath the PR if the base branch advances before this job runs. scripts/run_benchmark_compare.py --base should get the event's immutable base commit instead.

Suggested change
-      base_ref: ${{ github.event.pull_request.base.ref }}
+      base_ref: ${{ github.event.pull_request.base.sha }}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with:
runner: 'nvidia-a100'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.ref }}
with:
runner: 'nvidia-a100'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.sha }}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/nvidia-a100.yml around lines 32 - 35, The workflow
currently sets base_ref using github.event.pull_request.base.ref which can
change if the base branch advances; change it to use the immutable pull request
base SHA by setting the base_ref input to github.event.pull_request.base.sha and
ensure the job that runs scripts/run_benchmark_compare.py uses that base_ref
value for the --base argument (look for the base_ref input key and the
invocation of scripts/run_benchmark_compare.py --base in the job steps).

Comment thread .github/workflows/nvidia-h100.yml Outdated
Comment on lines +34 to +37
with:
runner: 'nvidia-h100-pt2-7'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.ref }}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the PR base SHA here, not the base branch name.

github.event.pull_request.base.ref is a moving ref, so this job can end up benchmarking against newer main commits if the queue is backed up. The reusable workflow passes this straight to scripts/run_benchmark_compare.py --base, so the baseline should be the immutable PR event SHA.

Suggested change
-      base_ref: ${{ github.event.pull_request.base.ref }}
+      base_ref: ${{ github.event.pull_request.base.sha }}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with:
runner: 'nvidia-h100-pt2-7'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.ref }}
with:
runner: 'nvidia-h100-pt2-7'
conda_env_name: 'pytorch_2_7'
base_ref: ${{ github.event.pull_request.base.sha }}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/nvidia-h100.yml around lines 34 - 37, Replace the moving
branch ref input (base_ref currently set to github.event.pull_request.base.ref)
with the immutable PR base SHA (github.event.pull_request.base.sha) so the job’s
baseline is fixed; update the workflow input named base_ref to use
github.event.pull_request.base.sha and ensure that value is passed unchanged
into scripts/run_benchmark_compare.py --base (the runner and conda_env_name
settings stay the same).

Comment on lines +53 to +63
- name: Check skip keyword in LATEST commit
id: check_skip
run: |
COMMIT_MSG=$(git log --no-merges -1 --pretty=%s)
echo "Latest commit message: $COMMIT_MSG"
if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then
echo "::notice::Benchmarks skipped by commit message"
echo "skip_bench=true" >> $GITHUB_OUTPUT
else
echo "skip_bench=false" >> $GITHUB_OUTPUT
fi
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Read [skip bench] from the PR head commit explicitly.

This currently inspects the latest non-merge commit reachable from the checked-out HEAD, which is indirect and can pick up the wrong message when HEAD is not the PR head commit. Since the intent is PR-scoped, query github.event.pull_request.head.sha directly.

Suggested change
-          COMMIT_MSG=$(git log --no-merges -1 --pretty=%s)
+          COMMIT_MSG=$(git show -s --format=%s ${{ github.event.pull_request.head.sha }})
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- name: Check skip keyword in LATEST commit
id: check_skip
run: |
COMMIT_MSG=$(git log --no-merges -1 --pretty=%s)
echo "Latest commit message: $COMMIT_MSG"
if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then
echo "::notice::Benchmarks skipped by commit message"
echo "skip_bench=true" >> $GITHUB_OUTPUT
else
echo "skip_bench=false" >> $GITHUB_OUTPUT
fi
- name: Check skip keyword in LATEST commit
id: check_skip
run: |
COMMIT_MSG=$(git show -s --format=%s ${{ github.event.pull_request.head.sha }})
echo "Latest commit message: $COMMIT_MSG"
if echo "$COMMIT_MSG" | grep -qF "[skip bench]"; then
echo "::notice::Benchmarks skipped by commit message"
echo "skip_bench=true" >> $GITHUB_OUTPUT
else
echo "skip_bench=false" >> $GITHUB_OUTPUT
fi
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In @.github/workflows/reusable-ci-benchmarks.yml around lines 53 - 63, The check
in step id "check_skip" currently gets the latest non-merge commit via `git log
--no-merges -1` which can miss the PR head; replace that logic to explicitly
read the PR head SHA and then get its subject. Concretely, set COMMIT_SHA="${{
github.event.pull_request.head.sha }}" (or use the expression inline) and use
`git show -s --format=%s "$COMMIT_SHA"` (assign to COMMIT_MSG) instead of `git
log --no-merges -1 --pretty=%s`, then keep the existing grep/echo logic to set
skip_bench.

Comment on lines +323 to +327
def _rwkv7_post_init(inputs, B, T, H, D, **kw):
"""RWKV7 needs a/b to be initialized as small positive values."""
with torch.no_grad():
inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True)
inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

_rwkv7_post_init is not actually generating positive a/b.

torch.randn_like(...) * 0.1 is centered at zero, so this produces negative values about half the time even though the benchmark says RWKV7 needs small positive inputs. That can put the benchmark in the wrong regime.

Suggested change
 def _rwkv7_post_init(inputs, B, T, H, D, **kw):
     """RWKV7 needs a/b to be initialized as small positive values."""
     with torch.no_grad():
-        inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True)
-        inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True)
+        inputs['a'] = (torch.rand_like(inputs['a']) * 0.1).requires_grad_(True)
+        inputs['b'] = (torch.rand_like(inputs['b']) * 0.1).requires_grad_(True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _rwkv7_post_init(inputs, B, T, H, D, **kw):
"""RWKV7 needs a/b to be initialized as small positive values."""
with torch.no_grad():
inputs['a'] = (torch.randn_like(inputs['a']) * 0.1).requires_grad_(True)
inputs['b'] = (torch.randn_like(inputs['b']) * 0.1).requires_grad_(True)
def _rwkv7_post_init(inputs, B, T, H, D, **kw):
"""RWKV7 needs a/b to be initialized as small positive values."""
with torch.no_grad():
inputs['a'] = (torch.rand_like(inputs['a']) * 0.1).requires_grad_(True)
inputs['b'] = (torch.rand_like(inputs['b']) * 0.1).requires_grad_(True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/ops/registry.py` around lines 323 - 327, _rwkv7_post_init
currently uses torch.randn_like(...)*0.1 which yields values centered on zero
and therefore negative about half the time; change the initialization for
inputs['a'] and inputs['b'] to guarantee small positive values (e.g., use
torch.abs(torch.randn_like(...))*0.1 or torch.rand_like(...)*0.1 + a small
positive offset) and keep requires_grad_(True) so the tensors remain trainable;
update the assignments in _rwkv7_post_init for inputs['a'] and inputs['b']
accordingly.

Comment thread benchmarks/ops/run.py
Comment on lines +80 to +85
try:
sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops'))
from registry import _REGISTRY
except ImportError:
print("Warning: could not import registry", file=sys.stderr)
return []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't turn a broken registry import into a green benchmark job.

If benchmarks/ops/registry.py stops importing, this returns [], and main() later exits 0 with "No affected ops found". That silently disables the benchmark gate instead of failing the PR.

Suggested change
-    except ImportError:
-        print("Warning: could not import registry", file=sys.stderr)
-        return []
+    except ImportError as e:
+        print(f"Error: could not import benchmark registry: {e}", file=sys.stderr)
+        sys.exit(1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/run_benchmark_compare.py` around lines 80 - 85, The current
try/except around inserting PROJECT_ROOT into sys.path and "from registry import
_REGISTRY" swallows ImportError and returns an empty list, which masks a broken
registry and makes CI pass; change the handler so the failure surfaces: in the
except block log or print the ImportError with details (include the exception
message) and then abort with a non-zero exit (e.g., re-raise the ImportError or
call sys.exit(1)) instead of returning []; keep the import attempt using
sys.path.insert(0, str(PROJECT_ROOT / 'benchmarks' / 'ops')) and the "from
registry import _REGISTRY" lines to locate where to modify.

Comment on lines +151 to +159
def checkout_and_install(ref: str, clear_cache: bool = True):
"""Checkout a ref and reinstall the package."""
print(f"\n Checking out {ref}...")
run_cmd(["git", "checkout", ref])
print(" Installing package...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", ".", "-q"],
cwd=str(PROJECT_ROOT), capture_output=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Abort when the editable install fails.

The temp runner imports fla from the current environment. If pip install -e . fails here and the script keeps going, you can end up benchmarking the previously installed revision instead of the ref you just checked out.

Suggested change
 def checkout_and_install(ref: str, clear_cache: bool = True):
     """Checkout a ref and reinstall the package."""
     print(f"\n  Checking out {ref}...")
     run_cmd(["git", "checkout", ref])
     print("  Installing package...")
-    subprocess.run(
+    result = subprocess.run(
         [sys.executable, "-m", "pip", "install", "-e", ".", "-q"],
-        cwd=str(PROJECT_ROOT), capture_output=True,
+        cwd=str(PROJECT_ROOT), capture_output=True, text=True,
     )
+    if result.returncode != 0:
+        print(f"  Install failed for {ref}: {result.stderr}", file=sys.stderr)
+        sys.exit(1)
     if clear_cache:
         clear_triton_cache()
🧰 Tools
🪛 Ruff (0.15.6)

[error] 156-156: subprocess call: check for execution of untrusted input

(S603)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/run_benchmark_compare.py` around lines 151 - 159, The
checkout_and_install function currently ignores pip install failures which can
cause benchmarking the wrong revision; modify checkout_and_install to detect
failure of the subprocess.run call that runs pip (e.g., by using
subprocess.run(..., check=True) or checking result.returncode) and abort/raise
an exception if installation fails, logging the captured stdout/stderr from the
failed run so callers know why the editable install of the package (invoked in
checkout_and_install) failed.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
scripts/run_benchmark_compare.py (2)

283-286: Two separate temp directories are created but only one is documented.

tmpdir (line 283) stores JSON results while runner_tmpdir (line 284) stores the runner scripts. Both are cleaned up in finally, but this split is non-obvious. Consider consolidating or adding a brief comment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/run_benchmark_compare.py` around lines 283 - 286, The code creates
two temp dirs (tmpdir and runner_tmpdir via copy_runner_to_temp()) but only one
is documented; update the code to either consolidate storage into a single temp
dir (e.g., use tmpdir as the base and make runner files a subdirectory) or add a
concise comment above the creation explaining the split responsibilities (tmpdir
holds JSON results, runner_tmpdir holds runner scripts) and referencing
copy_runner_to_temp() and the cleanup in the finally block so future readers
understand both temporary locations.

131-134: Benchmark failure should surface more clearly in the comparison flow.

When run_unified_benchmark returns False, the caller continues and may produce incomplete or empty results. At lines 298 and 314, failures are silently ignored until the final check at line 327-329. Consider logging a clearer warning at the call site or accumulating failure info for the final report.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/run_benchmark_compare.py` around lines 131 - 134, Call sites of
run_unified_benchmark are ignoring a False return which lets the comparison
continue with incomplete data; update the callers (where
run_unified_benchmark(...) is invoked in the main flow) to detect a False
result, immediately log a clear warning that includes identifying context
(benchmark name/params) and/or append a structured failure entry to a failures
list, and then ensure the final report/final-check uses that failures list to
fail or summarize all benchmark errors instead of silently proceeding; reference
run_unified_benchmark and the main invocation loop so the fix is applied where
the function is called.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@scripts/run_benchmark_compare.py`:
- Around line 283-286: The code creates two temp dirs (tmpdir and runner_tmpdir
via copy_runner_to_temp()) but only one is documented; update the code to either
consolidate storage into a single temp dir (e.g., use tmpdir as the base and
make runner files a subdirectory) or add a concise comment above the creation
explaining the split responsibilities (tmpdir holds JSON results, runner_tmpdir
holds runner scripts) and referencing copy_runner_to_temp() and the cleanup in
the finally block so future readers understand both temporary locations.
- Around line 131-134: Call sites of run_unified_benchmark are ignoring a False
return which lets the comparison continue with incomplete data; update the
callers (where run_unified_benchmark(...) is invoked in the main flow) to detect
a False result, immediately log a clear warning that includes identifying
context (benchmark name/params) and/or append a structured failure entry to a
failures list, and then ensure the final report/final-check uses that failures
list to fail or summarize all benchmark errors instead of silently proceeding;
reference run_unified_benchmark and the main invocation loop so the fix is
applied where the function is called.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 02db21fb-c770-4854-b6ae-5e4ecdd7958b

📥 Commits

Reviewing files that changed from the base of the PR and between 046595a and dceb0c6.

📒 Files selected for processing (4)
  • .github/workflows/reusable-ci-benchmarks.yml
  • benchmarks/ops/registry.py
  • benchmarks/ops/run.py
  • scripts/run_benchmark_compare.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • .github/workflows/reusable-ci-benchmarks.yml

@yzhangcs yzhangcs changed the title [GDN] Fuse kkt + solve_tril into single kernel for WY representation [GDN] Fuse kkt + solve_tril kernel & unified benchmark infrastructure Mar 24, 2026
@yzhangcs yzhangcs merged commit baf57d5 into main Mar 24, 2026
6 checks passed
@yzhangcs yzhangcs deleted the fuse-gdn-kkt-solve branch March 24, 2026 13:18
@coderabbitai coderabbitai bot mentioned this pull request Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants