Skip to content

[GDN] Add exp2 support across chunk kernels for improved performance#791

Merged
yzhangcs merged 5 commits intomainfrom
exp2
Mar 25, 2026
Merged

[GDN] Add exp2 support across chunk kernels for improved performance#791
yzhangcs merged 5 commits intomainfrom
exp2

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Mar 24, 2026

Add use_exp2 flag throughout the gated delta rule chunk operations, enabling the use of exp2 (base-2 exponential) instead of exp (natural exponential) in Triton kernels. When enabled, gate values are pre-scaled by RCP_LN2 (1/ln2) so that exp2(g * RCP_LN2) == exp(g), allowing the compiler to emit faster native exp2 instructions.

Affected kernels: chunk_o fwd/bwd, chunk_fwd intra (kkt+solve), wy_fast fwd/bwd, fused_recurrent fwd, and the top-level chunk fwd/bwd dispatch.

Summary by CodeRabbit

  • New Features

    • Added an opt-in use_exp2 mode to enable an alternative base-2 exponential path across gated-delta, COMBA, and fused-recurrent operations; preserves prior behavior when disabled.
  • Public API

    • Several public wrappers now accept use_exp2; COMBA cumsum gained an explicit scale parameter and the head_first option was removed.
  • Bug Fixes

    • Fixed cumsum scaling behavior and removed an unreachable duplicated trailing code path.

Add `use_exp2` flag throughout the gated delta rule chunk operations,
enabling the use of `exp2` (base-2 exponential) instead of `exp` (natural
exponential) in Triton kernels. When enabled, gate values are pre-scaled
by `RCP_LN2` (1/ln2) so that `exp2(g * RCP_LN2) == exp(g)`, allowing
the compiler to emit faster native exp2 instructions.

Affected kernels: chunk_o fwd/bwd, chunk_fwd intra (kkt+solve), wy_fast
fwd/bwd, fused_recurrent fwd, and the top-level chunk fwd/bwd dispatch.

Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 83f11e62-756d-41dc-b254-6ca89a93c567

📥 Commits

Reviewing files that changed from the base of the PR and between b16630c and d03a419.

📒 Files selected for processing (1)
  • fla/ops/comba/chunk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/comba/chunk.py

Walkthrough

Adds a runtime flag use_exp2 propagated into many gated-delta, WY, COMBA, and chunked wrappers and a Triton constexpr USE_EXP2 so kernels can compile-time select between exp(...) and exp2(...) in forward and backward kernels.

Changes

Cohort / File(s) Summary
Common chunk ops
fla/ops/common/chunk_o.py
Added use_exp2: bool = False to public chunk functions; pass USE_EXP2 into Triton launches and conditionally use exp2 vs exp.
Gated-delta core & wrappers
fla/ops/gated_delta_rule/chunk.py, fla/ops/gated_delta_rule/chunk_fwd.py
Added use_exp2 parameter to public fwd/bwd APIs and intra wrappers; introduced USE_EXP2 constexpr in fused/KKT kernels; conditional RCP_LN2 scaling for cumsum when using exp2.
Gated-delta fused & WY helpers
fla/ops/gated_delta_rule/fused_recurrent.py, fla/ops/gated_delta_rule/wy_fast.py
Public wrappers accept use_exp2; threaded USE_EXP2 into Triton kernels to switch exp↔exp2 in recurrence and WY recompute/backward kernels; exp2 imported.
Gated-delta product caller
fla/ops/gated_delta_product/chunk.py
Backward call site now explicitly passes use_exp2=False to the gated-delta-rule backward.
COMBA high-level
fla/ops/comba/chunk.py, fla/ops/comba/fused_recurrent.py
Threaded use_exp2=True into multiple gated/wy/comba calls; added use_exp2 parameter to fused_recurrent wrapper and pass-through to kernels.
COMBA utils & WY
fla/ops/comba/utils.py, fla/ops/comba/wy_fast.py
chunk_comba_cumsum_scalar_fwd gains scale param (removed head_first), forward/backward kernels accept scale; WY helpers accept use_exp2 and pass USE_EXP2 into kernels.
Triton kernels (various modules)
fla/ops/... (multiple files)
Multiple Triton kernels extended with USE_EXP2: tl.constexpr and conditional tl.exp/tl.exp2 usage; host wrappers updated to forward USE_EXP2=use_exp2.

Sequence Diagram(s)

sequenceDiagram
  participant PyAPI as Python API
  participant Wrapper as Host wrapper
  participant TritonK as Triton Kernel (USE_EXP2)
  participant Recompute as Recompute / WY / Backward kernels
  participant Device as Device tensors

  PyAPI->>Wrapper: call(..., use_exp2=bool)
  Wrapper->>Device: prepare tensors, flags
  Wrapper->>TritonK: launch(kernel, USE_EXP2=use_exp2)
  TritonK->>Device: read/write tensors (use exp or exp2)
  TritonK->>Recompute: when needed, request recompute (USE_EXP2 passed)
  Recompute->>Device: recompute/backward using exp or exp2
  Recompute->>Wrapper: return gradients/results
  Wrapper->>PyAPI: return outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • zhiyuan1i
  • Nathancgy

Poem

🐇 I hopped through kernels, quiet and spry,
Swapped exp for exp2 with a two-bit sigh,
Flags threaded down from Python to Triton,
Base‑two blooms where the exponents brighten,
A rabbit applauds — tiny, timely, and wry.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.70% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[GDN] Add exp2 support across chunk kernels for improved performance' clearly and concisely describes the main change: adding exp2 support to chunk kernels with a focus on performance improvement, which aligns with the changeset's objective of enabling exp2 as an alternative to exp throughout multiple kernel operations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch exp2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement for gated delta rule chunk operations by allowing the use of base-2 exponential (exp2) functions in Triton kernels. By leveraging hardware-optimized exp2 instructions, the system can achieve faster computations. The change is carefully implemented to maintain mathematical equivalence with the original natural exponential (exp) function through a pre-scaling mechanism, ensuring accuracy while boosting efficiency across several core components of the gated delta rule.

Highlights

  • Performance Optimization: Introduced a use_exp2 flag across various gated delta rule chunk operations to enable the use of exp2 (base-2 exponential) instead of exp (natural exponential) in Triton kernels.
  • Mathematical Equivalence: Implemented pre-scaling of gate values by RCP_LN2 (1/ln2) when use_exp2 is enabled, ensuring that exp2(g * RCP_LN2) is mathematically equivalent to exp(g).
  • Kernel and API Integration: Integrated the use_exp2 flag into the signatures of multiple forward and backward chunk kernels and their corresponding Python API calls, including chunk_o (fwd/bwd), chunk_fwd_intra (kkt+solve), wy_fast (fwd/bwd), and fused_recurrent (fwd).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a use_exp2 flag to various kernel operations across several files, allowing for flexible use of exp or exp2 functions in exponential calculations. Feedback suggests exposing the use_exp2 flag in the chunk_gated_delta_rule function for better configurability and refactoring duplicated code in fused_recurrent.py to improve readability and maintainability.

cu_seqlens: torch.LongTensor | None = None,
cp_context: FLACPContext | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use_exp2 flag is introduced here with a default value of True, but it's not configurable from the public-facing chunk_gated_delta_rule function. This means exp2 will always be used for this operation. To make this configurable, you should consider adding use_exp2 to the signature of chunk_gated_delta_rule and ChunkGatedDeltaRuleFunction, and then pass it down through the function calls. This would align its behavior with other functions in this PR like fused_recurrent_gated_delta_rule.

Comment on lines +120 to +129
if USE_EXP2:
if TRANSPOSE_STATE:
b_h *= exp2(b_gk[None, :])
else:
b_h *= exp2(b_gk[:, None])
else:
b_h *= exp(b_gk[:, None])
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested if/else statements for USE_EXP2 and TRANSPOSE_STATE lead to code duplication. You can refactor this to improve readability and maintainability without impacting performance by first determining the value to be exponentiated based on TRANSPOSE_STATE.

            if TRANSPOSE_STATE:
                g_val = b_gk[None, :]
            else:
                g_val = b_gk[:, None]
            if USE_EXP2:
                b_h *= exp2(g_val)
            else:
                b_h *= exp(g_val)

Comment on lines +133 to +142
if USE_EXP2:
if TRANSPOSE_STATE:
b_h *= exp2(b_gv[:, None])
else:
b_h *= exp2(b_gv[None, :])
else:
b_h *= exp(b_gv[None, :])
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the USE_GK block above, this block can be refactored to reduce code duplication and improve readability.

            if TRANSPOSE_STATE:
                g_val = b_gv[:, None]
            else:
                g_val = b_gv[None, :]
            if USE_EXP2:
                b_h *= exp2(g_val)
            else:
                b_h *= exp(g_val)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/chunk.py (1)

25-38: ⚠️ Potential issue | 🟠 Major

The new flag never reaches the public autograd path.

chunk_gated_delta_rule_fwd() and chunk_gated_delta_rule_bwd() now take use_exp2, but ChunkGatedDeltaRuleFunction and chunk_gated_delta_rule(...) still never accept, save, or forward it. The public API therefore always uses the hard-coded defaults, and use_exp2= passed via **kwargs is silently ignored.

💡 Suggested fix
 class ChunkGatedDeltaRuleFunction(torch.autograd.Function):

     `@staticmethod`
     `@input_guard`
     `@autocast_custom_fwd`
@@
         cu_seqlens_cpu: torch.LongTensor | None = None,
         use_qk_l2norm_in_kernel: bool = False,
         cp_context: FLACPContext | None = None,
+        use_exp2: bool = True,
         transpose_state_layout: bool = False,
     ):
@@
         g, o, A, final_state, initial_state = chunk_gated_delta_rule_fwd(
             q=q,
             k=k,
             v=v,
             g=g,
             beta=beta,
             scale=scale,
             initial_state=initial_state,
             output_final_state=output_final_state,
             cu_seqlens=cu_seqlens,
             cp_context=cp_context,
             chunk_indices=chunk_indices,
+            use_exp2=use_exp2,
             transpose_state_layout=transpose_state_layout,
         )
@@
         ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
         ctx.cp_context = cp_context
+        ctx.use_exp2 = use_exp2
         ctx.transpose_state_layout = transpose_state_layout
         return o.to(q.dtype), final_state
@@
         dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
             q=q,
             k=k,
             v=v,
             g=g,
             beta=beta,
             A=A,
             scale=ctx.scale,
             initial_state=initial_state,
             do=do,
             dht=dht,
             cu_seqlens=cu_seqlens,
             cp_context=ctx.cp_context,
             chunk_indices=chunk_indices,
+            use_exp2=ctx.use_exp2,
             transpose_state_layout=ctx.transpose_state_layout,
         )
@@
-        return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None
+        return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None, None
 def chunk_gated_delta_rule(
@@
     cu_seqlens: torch.LongTensor | None = None,
     cu_seqlens_cpu: torch.LongTensor | None = None,
     cp_context: FLACPContext | None = None,
+    use_exp2: bool = True,
     transpose_state_layout: bool = False,
     **kwargs,
 ):
@@
     o, final_state = ChunkGatedDeltaRuleFunction.apply(
         q,
         k,
         v,
         g,
         beta,
         scale,
         initial_state,
         output_final_state,
         cu_seqlens,
         cu_seqlens_cpu,
         use_qk_l2norm_in_kernel,
         cp_context,
+        use_exp2,
         transpose_state_layout,
     )

Also applies to: 103-118

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 25 - 38, The new use_exp2
flag is accepted by chunk_gated_delta_rule_fwd/bwd but never threaded through
the public autograd wrapper; update ChunkGatedDeltaRuleFunction and the
top-level chunk_gated_delta_rule to accept use_exp2, store it on the autograd
ctx in ChunkGatedDeltaRuleFunction.forward (e.g., ctx.use_exp2 = use_exp2), and
read it in ChunkGatedDeltaRuleFunction.backward to pass into
chunk_gated_delta_rule_bwd (and ensure chunk_gated_delta_rule forwards use_exp2
into chunk_gated_delta_rule_fwd when calling the autograd function); also
include use_exp2 in saved_for_backward or ctx attributes as needed so the
backward path receives the same flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 110-118: The code uses locally-built g_gamma (symbol g_gamma) to
compute b_gamma and b_g, but when USE_EXP2 is true b_g must be pre-scaled by
RCP_LN2 (the same normalization used upstream) before calling exp2; update every
kernel that materializes b_g (and b_g_last) from b_gamma—i.e., where b_gamma =
tl.load(g_gamma + i_h) and b_g = b_gamma * (tl.arange(...) + 1)—to multiply
b_gamma (or b_g) by RCP_LN2 (the reciprocal of ln(2)) prior to using exp2, and
keep the USE_EXP2/else branches otherwise unchanged so exp2 sees base-2-scaled
exponents consistently (apply the same change in all occurrences mentioned:
around lines 110-118, 311-319, 421-427, 512-516).

In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 111-142: The exp2 path is using raw log-scale variables, changing
decay semantics; when USE_EXP2 is true you must convert stored natural-log terms
to base-2 by multiplying b_g, b_gk, and b_gv by RCP_LN2 before calling exp2 so
behavior matches exp(b_*). Update the three places in this block where exp2 is
used (the branches for USE_G, USE_GK, and USE_GV) to apply RCP_LN2 to the loaded
parameters (b_g, b_gk, b_gv) in the same transpose-aware manner you handle exp,
and apply the same change to the other similar blocks mentioned (the
corresponding USE_EXP2 branches around the other occurrences of these variables)
so exp2(b_*) uses b_* * RCP_LN2.

---

Outside diff comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 25-38: The new use_exp2 flag is accepted by
chunk_gated_delta_rule_fwd/bwd but never threaded through the public autograd
wrapper; update ChunkGatedDeltaRuleFunction and the top-level
chunk_gated_delta_rule to accept use_exp2, store it on the autograd ctx in
ChunkGatedDeltaRuleFunction.forward (e.g., ctx.use_exp2 = use_exp2), and read it
in ChunkGatedDeltaRuleFunction.backward to pass into chunk_gated_delta_rule_bwd
(and ensure chunk_gated_delta_rule forwards use_exp2 into
chunk_gated_delta_rule_fwd when calling the autograd function); also include
use_exp2 in saved_for_backward or ctx attributes as needed so the backward path
receives the same flag.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 25ca7896-b622-4bff-8235-143ec0af0f96

📥 Commits

Reviewing files that changed from the base of the PR and between 4225ff9 and 8b287d5.

📒 Files selected for processing (6)
  • fla/ops/common/chunk_o.py
  • fla/ops/gated_delta_product/chunk.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gated_delta_rule/wy_fast.py

Comment thread fla/ops/common/chunk_o.py
Comment on lines 110 to +118
if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
if USE_EXP2:
b_o = b_o * exp2(b_g)[:, None]
b_A = b_A * exp2(b_g[:, None] - b_g[None, :])
else:
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Scale g_gamma before the exp2 branches.

The USE_G path can stay equivalent because fla/ops/gated_delta_rule/chunk.py now pre-scales cumulative g by RCP_LN2 before it reaches this file. g_gamma is generated locally here, though, so the new exp2(...) branches are still consuming raw natural-log decays. With use_exp2=True, every USE_G_GAMMA path here changes the operator in both forward and backward.

💡 Suggested fix
+from fla.ops.utils.constant import RCP_LN2
 from fla.ops.utils import prepare_chunk_indices
 from fla.ops.utils.op import exp, exp2
 if USE_G_GAMMA:
     b_gamma = tl.load(g_gamma + i_h)
     b_g = b_gamma * (tl.arange(0, BT) + 1)
+    if USE_EXP2:
+        b_g = b_g * RCP_LN2
 if USE_G_GAMMA:
     b_gamma = tl.load(g_gamma + i_h)
     b_g = b_gamma * (tl.arange(0, BT) + 1)
     b_g_last = b_gamma * min(BT, T - i_t * BT)
+    if USE_EXP2:
+        b_g = b_g * RCP_LN2
+        b_g_last = b_g_last * RCP_LN2

Apply the same normalization in each kernel here that materializes b_g / b_g_last from g_gamma.

Also applies to: 311-319, 421-427, 512-516

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_o.py` around lines 110 - 118, The code uses
locally-built g_gamma (symbol g_gamma) to compute b_gamma and b_g, but when
USE_EXP2 is true b_g must be pre-scaled by RCP_LN2 (the same normalization used
upstream) before calling exp2; update every kernel that materializes b_g (and
b_g_last) from b_gamma—i.e., where b_gamma = tl.load(g_gamma + i_h) and b_g =
b_gamma * (tl.arange(...) + 1)—to multiply b_gamma (or b_g) by RCP_LN2 (the
reciprocal of ln(2)) prior to using exp2, and keep the USE_EXP2/else branches
otherwise unchanged so exp2 sees base-2-scaled exponents consistently (apply the
same change in all occurrences mentioned: around lines 110-118, 311-319,
421-427, 512-516).

Comment on lines 111 to +142
if USE_G:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
if USE_EXP2:
b_h *= exp2(b_g)
else:
b_h *= exp(b_g)

if USE_GK:
b_gk = tl.load(p_gk).to(tl.float32)
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
if USE_EXP2:
if TRANSPOSE_STATE:
b_h *= exp2(b_gk[None, :])
else:
b_h *= exp2(b_gk[:, None])
else:
b_h *= exp(b_gk[:, None])
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])

if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
if USE_EXP2:
if TRANSPOSE_STATE:
b_h *= exp2(b_gv[:, None])
else:
b_h *= exp2(b_gv[None, :])
else:
b_h *= exp(b_gv[None, :])
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

use_exp2=True changes the recurrent decay math.

Unlike the chunked path, this file never applies RCP_LN2 to g, gk, or gv before switching from exp to exp2. As written, use_exp2=True computes base-2 decays from raw natural-log terms, so the new flag changes both the output and the stored final state instead of preserving the existing semantics.

💡 Suggested fix
+from fla.ops.utils.constant import RCP_LN2
 from fla.ops.utils.op import exp, exp2
         if USE_G:
             b_g = tl.load(p_g).to(tl.float32)
             if USE_EXP2:
-                b_h *= exp2(b_g)
+                b_h *= exp2(b_g * RCP_LN2)
             else:
                 b_h *= exp(b_g)

         if USE_GK:
             b_gk = tl.load(p_gk).to(tl.float32)
             if USE_EXP2:
                 if TRANSPOSE_STATE:
-                    b_h *= exp2(b_gk[None, :])
+                    b_h *= exp2(b_gk[None, :] * RCP_LN2)
                 else:
-                    b_h *= exp2(b_gk[:, None])
+                    b_h *= exp2(b_gk[:, None] * RCP_LN2)
             else:
                 if TRANSPOSE_STATE:
                     b_h *= exp(b_gk[None, :])
                 else:
                     b_h *= exp(b_gk[:, None])

         if USE_GV:
             b_gv = tl.load(p_gv).to(tl.float32)
             if USE_EXP2:
                 if TRANSPOSE_STATE:
-                    b_h *= exp2(b_gv[:, None])
+                    b_h *= exp2(b_gv[:, None] * RCP_LN2)
                 else:
-                    b_h *= exp2(b_gv[None, :])
+                    b_h *= exp2(b_gv[None, :] * RCP_LN2)
             else:
                 if TRANSPOSE_STATE:
                     b_h *= exp(b_gv[:, None])
                 else:
                     b_h *= exp(b_gv[None, :])

Also applies to: 174-233, 255-255, 300-300

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 111 - 142, The exp2
path is using raw log-scale variables, changing decay semantics; when USE_EXP2
is true you must convert stored natural-log terms to base-2 by multiplying b_g,
b_gk, and b_gv by RCP_LN2 before calling exp2 so behavior matches exp(b_*).
Update the three places in this block where exp2 is used (the branches for
USE_G, USE_GK, and USE_GV) to apply RCP_LN2 to the loaded parameters (b_g, b_gk,
b_gv) in the same transpose-aware manner you handle exp, and apply the same
change to the other similar blocks mentioned (the corresponding USE_EXP2
branches around the other occurrences of these variables) so exp2(b_*) uses b_*
* RCP_LN2.

yzhangcs and others added 2 commits March 24, 2026 16:57
Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
fla/ops/comba/fused_recurrent.py (1)

132-132: use_exp2 parameter is not exposed in the public API.

The use_exp2 parameter is added to fused_recurrent_comba_fwd but neither FusedRecurrentCombaFunction.forward (line 184) nor the public fused_recurrent_comba function (line 224) expose or thread this parameter. This means use_exp2 will always be False when using the public API.

If this is intentional (keeping exp2 as an internal optimization not exposed to users), this is fine. Otherwise, consider threading use_exp2 through the autograd function and public interface for consistency with how it's handled in chunk.py.

Also applies to: 172-172

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/comba/fused_recurrent.py` at line 132, The new use_exp2 flag is added
to fused_recurrent_comba_fwd but isn’t threaded through the autograd wrapper or
public API; update FusedRecurrentCombaFunction.forward to accept and pass
use_exp2 into fused_recurrent_comba_fwd, and extend the public
fused_recurrent_comba(...) function signature to accept use_exp2 and forward it
into FusedRecurrentCombaFunction.apply (or constructor) so the runtime uses the
provided flag instead of always defaulting to False; ensure the same change is
applied for the corresponding backward path if applicable (reference
fused_recurrent_comba_fwd, FusedRecurrentCombaFunction.forward, and
fused_recurrent_comba).
fla/ops/comba/wy_fast.py (2)

205-208: Inconsistent use of tl.exp vs imported exp.

Line 208 uses tl.exp(b_g0) directly, while the USE_EXP2 branch at line 206 uses the imported exp2. For consistency with other kernels in this PR that use the imported exp wrapper, consider using exp(b_g0) instead.

♻️ Suggested fix
     if USE_EXP2:
         b_g0_exp = exp2(b_g0)
     else:
-        b_g0_exp = tl.exp(b_g0)
+        b_g0_exp = exp(b_g0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/comba/wy_fast.py` around lines 205 - 208, The branch computing
b_g0_exp is inconsistent: when USE_EXP2 is True it uses the imported exp2 but
the else branch calls tl.exp(b_g0); change the else branch to use the imported
exp wrapper (exp(b_g0)) so both branches use the same abstraction. Update the
expression that assigns b_g0_exp (the else branch referencing tl.exp) to call
exp instead, keeping the USE_EXP2 conditional and the variable name b_g0_exp
unchanged.

329-332: Inconsistent use of tl.exp vs imported exp.

Same issue as in prepare_wy_repr_bwd_kernel: line 332 uses tl.exp directly instead of the imported exp wrapper.

♻️ Suggested fix
     if USE_EXP2:
         b_g = exp2(tl.load(p_g, boundary_check=(0,)))
     else:
-        b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
+        b_g = exp(tl.load(p_g, boundary_check=(0,)))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/comba/wy_fast.py` around lines 329 - 332, The branch that computes
b_g uses tl.exp directly (b_g = tl.exp(tl.load(p_g, ...))) when USE_EXP2 is
false; change it to use the imported exp wrapper instead (i.e., call
exp(tl.load(p_g, boundary_check=(0,)))) so behavior matches the other branch and
prepare_wy_repr_bwd_kernel; update the b_g assignment in the same conditional
that references USE_EXP2 and p_g.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/comba/chunk.py`:
- Around line 381-388: Remove the accidental duplicated lines that re-declare
cu_seqlens and cu_seqlens_cpu and the duplicated "return o, final_state" which
cause syntax errors; locate the end of the function that uses cu_seqlens and
cu_seqlens_cpu (the block that currently returns "o, final_state") and delete
the repeated second occurrence so the function ends with a single "return o,
final_state" and no repeated variable lines.

---

Nitpick comments:
In `@fla/ops/comba/fused_recurrent.py`:
- Line 132: The new use_exp2 flag is added to fused_recurrent_comba_fwd but
isn’t threaded through the autograd wrapper or public API; update
FusedRecurrentCombaFunction.forward to accept and pass use_exp2 into
fused_recurrent_comba_fwd, and extend the public fused_recurrent_comba(...)
function signature to accept use_exp2 and forward it into
FusedRecurrentCombaFunction.apply (or constructor) so the runtime uses the
provided flag instead of always defaulting to False; ensure the same change is
applied for the corresponding backward path if applicable (reference
fused_recurrent_comba_fwd, FusedRecurrentCombaFunction.forward, and
fused_recurrent_comba).

In `@fla/ops/comba/wy_fast.py`:
- Around line 205-208: The branch computing b_g0_exp is inconsistent: when
USE_EXP2 is True it uses the imported exp2 but the else branch calls
tl.exp(b_g0); change the else branch to use the imported exp wrapper (exp(b_g0))
so both branches use the same abstraction. Update the expression that assigns
b_g0_exp (the else branch referencing tl.exp) to call exp instead, keeping the
USE_EXP2 conditional and the variable name b_g0_exp unchanged.
- Around line 329-332: The branch that computes b_g uses tl.exp directly (b_g =
tl.exp(tl.load(p_g, ...))) when USE_EXP2 is false; change it to use the imported
exp wrapper instead (i.e., call exp(tl.load(p_g, boundary_check=(0,)))) so
behavior matches the other branch and prepare_wy_repr_bwd_kernel; update the b_g
assignment in the same conditional that references USE_EXP2 and p_g.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7a32f469-4e65-4554-ab58-9a7dbe9383df

📥 Commits

Reviewing files that changed from the base of the PR and between 0a745db and b16630c.

📒 Files selected for processing (4)
  • fla/ops/comba/chunk.py
  • fla/ops/comba/fused_recurrent.py
  • fla/ops/comba/utils.py
  • fla/ops/comba/wy_fast.py

Comment thread fla/ops/comba/chunk.py Outdated
@yzhangcs yzhangcs merged commit 047b5df into main Mar 25, 2026
3 of 4 checks passed
@yzhangcs yzhangcs deleted the exp2 branch March 25, 2026 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant