Conversation
Add `use_exp2` flag throughout the gated delta rule chunk operations, enabling the use of `exp2` (base-2 exponential) instead of `exp` (natural exponential) in Triton kernels. When enabled, gate values are pre-scaled by `RCP_LN2` (1/ln2) so that `exp2(g * RCP_LN2) == exp(g)`, allowing the compiler to emit faster native exp2 instructions. Affected kernels: chunk_o fwd/bwd, chunk_fwd intra (kkt+solve), wy_fast fwd/bwd, fused_recurrent fwd, and the top-level chunk fwd/bwd dispatch. Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
WalkthroughAdds a runtime flag Changes
Sequence Diagram(s)sequenceDiagram
participant PyAPI as Python API
participant Wrapper as Host wrapper
participant TritonK as Triton Kernel (USE_EXP2)
participant Recompute as Recompute / WY / Backward kernels
participant Device as Device tensors
PyAPI->>Wrapper: call(..., use_exp2=bool)
Wrapper->>Device: prepare tensors, flags
Wrapper->>TritonK: launch(kernel, USE_EXP2=use_exp2)
TritonK->>Device: read/write tensors (use exp or exp2)
TritonK->>Recompute: when needed, request recompute (USE_EXP2 passed)
Recompute->>Device: recompute/backward using exp or exp2
Recompute->>Wrapper: return gradients/results
Wrapper->>PyAPI: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance enhancement for gated delta rule chunk operations by allowing the use of base-2 exponential ( Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a use_exp2 flag to various kernel operations across several files, allowing for flexible use of exp or exp2 functions in exponential calculations. Feedback suggests exposing the use_exp2 flag in the chunk_gated_delta_rule function for better configurability and refactoring duplicated code in fused_recurrent.py to improve readability and maintainability.
| cu_seqlens: torch.LongTensor | None = None, | ||
| cp_context: FLACPContext | None = None, | ||
| chunk_indices: torch.LongTensor | None = None, | ||
| use_exp2: bool = True, |
There was a problem hiding this comment.
The use_exp2 flag is introduced here with a default value of True, but it's not configurable from the public-facing chunk_gated_delta_rule function. This means exp2 will always be used for this operation. To make this configurable, you should consider adding use_exp2 to the signature of chunk_gated_delta_rule and ChunkGatedDeltaRuleFunction, and then pass it down through the function calls. This would align its behavior with other functions in this PR like fused_recurrent_gated_delta_rule.
| if USE_EXP2: | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp2(b_gk[None, :]) | ||
| else: | ||
| b_h *= exp2(b_gk[:, None]) | ||
| else: | ||
| b_h *= exp(b_gk[:, None]) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gk[None, :]) | ||
| else: | ||
| b_h *= exp(b_gk[:, None]) |
There was a problem hiding this comment.
The nested if/else statements for USE_EXP2 and TRANSPOSE_STATE lead to code duplication. You can refactor this to improve readability and maintainability without impacting performance by first determining the value to be exponentiated based on TRANSPOSE_STATE.
if TRANSPOSE_STATE:
g_val = b_gk[None, :]
else:
g_val = b_gk[:, None]
if USE_EXP2:
b_h *= exp2(g_val)
else:
b_h *= exp(g_val)| if USE_EXP2: | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp2(b_gv[:, None]) | ||
| else: | ||
| b_h *= exp2(b_gv[None, :]) | ||
| else: | ||
| b_h *= exp(b_gv[None, :]) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gv[:, None]) | ||
| else: | ||
| b_h *= exp(b_gv[None, :]) |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/chunk.py (1)
25-38:⚠️ Potential issue | 🟠 MajorThe new flag never reaches the public autograd path.
chunk_gated_delta_rule_fwd()andchunk_gated_delta_rule_bwd()now takeuse_exp2, butChunkGatedDeltaRuleFunctionandchunk_gated_delta_rule(...)still never accept, save, or forward it. The public API therefore always uses the hard-coded defaults, anduse_exp2=passed via**kwargsis silently ignored.💡 Suggested fix
class ChunkGatedDeltaRuleFunction(torch.autograd.Function): `@staticmethod` `@input_guard` `@autocast_custom_fwd` @@ cu_seqlens_cpu: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = False, cp_context: FLACPContext | None = None, + use_exp2: bool = True, transpose_state_layout: bool = False, ): @@ g, o, A, final_state, initial_state = chunk_gated_delta_rule_fwd( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, cp_context=cp_context, chunk_indices=chunk_indices, + use_exp2=use_exp2, transpose_state_layout=transpose_state_layout, ) @@ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel ctx.cp_context = cp_context + ctx.use_exp2 = use_exp2 ctx.transpose_state_layout = transpose_state_layout return o.to(q.dtype), final_state @@ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( q=q, k=k, v=v, g=g, beta=beta, A=A, scale=ctx.scale, initial_state=initial_state, do=do, dht=dht, cu_seqlens=cu_seqlens, cp_context=ctx.cp_context, chunk_indices=chunk_indices, + use_exp2=ctx.use_exp2, transpose_state_layout=ctx.transpose_state_layout, ) @@ - return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None, Nonedef chunk_gated_delta_rule( @@ cu_seqlens: torch.LongTensor | None = None, cu_seqlens_cpu: torch.LongTensor | None = None, cp_context: FLACPContext | None = None, + use_exp2: bool = True, transpose_state_layout: bool = False, **kwargs, ): @@ o, final_state = ChunkGatedDeltaRuleFunction.apply( q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, cu_seqlens_cpu, use_qk_l2norm_in_kernel, cp_context, + use_exp2, transpose_state_layout, )Also applies to: 103-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 25 - 38, The new use_exp2 flag is accepted by chunk_gated_delta_rule_fwd/bwd but never threaded through the public autograd wrapper; update ChunkGatedDeltaRuleFunction and the top-level chunk_gated_delta_rule to accept use_exp2, store it on the autograd ctx in ChunkGatedDeltaRuleFunction.forward (e.g., ctx.use_exp2 = use_exp2), and read it in ChunkGatedDeltaRuleFunction.backward to pass into chunk_gated_delta_rule_bwd (and ensure chunk_gated_delta_rule forwards use_exp2 into chunk_gated_delta_rule_fwd when calling the autograd function); also include use_exp2 in saved_for_backward or ctx attributes as needed so the backward path receives the same flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 110-118: The code uses locally-built g_gamma (symbol g_gamma) to
compute b_gamma and b_g, but when USE_EXP2 is true b_g must be pre-scaled by
RCP_LN2 (the same normalization used upstream) before calling exp2; update every
kernel that materializes b_g (and b_g_last) from b_gamma—i.e., where b_gamma =
tl.load(g_gamma + i_h) and b_g = b_gamma * (tl.arange(...) + 1)—to multiply
b_gamma (or b_g) by RCP_LN2 (the reciprocal of ln(2)) prior to using exp2, and
keep the USE_EXP2/else branches otherwise unchanged so exp2 sees base-2-scaled
exponents consistently (apply the same change in all occurrences mentioned:
around lines 110-118, 311-319, 421-427, 512-516).
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 111-142: The exp2 path is using raw log-scale variables, changing
decay semantics; when USE_EXP2 is true you must convert stored natural-log terms
to base-2 by multiplying b_g, b_gk, and b_gv by RCP_LN2 before calling exp2 so
behavior matches exp(b_*). Update the three places in this block where exp2 is
used (the branches for USE_G, USE_GK, and USE_GV) to apply RCP_LN2 to the loaded
parameters (b_g, b_gk, b_gv) in the same transpose-aware manner you handle exp,
and apply the same change to the other similar blocks mentioned (the
corresponding USE_EXP2 branches around the other occurrences of these variables)
so exp2(b_*) uses b_* * RCP_LN2.
---
Outside diff comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 25-38: The new use_exp2 flag is accepted by
chunk_gated_delta_rule_fwd/bwd but never threaded through the public autograd
wrapper; update ChunkGatedDeltaRuleFunction and the top-level
chunk_gated_delta_rule to accept use_exp2, store it on the autograd ctx in
ChunkGatedDeltaRuleFunction.forward (e.g., ctx.use_exp2 = use_exp2), and read it
in ChunkGatedDeltaRuleFunction.backward to pass into chunk_gated_delta_rule_bwd
(and ensure chunk_gated_delta_rule forwards use_exp2 into
chunk_gated_delta_rule_fwd when calling the autograd function); also include
use_exp2 in saved_for_backward or ctx attributes as needed so the backward path
receives the same flag.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 25ca7896-b622-4bff-8235-143ec0af0f96
📒 Files selected for processing (6)
fla/ops/common/chunk_o.pyfla/ops/gated_delta_product/chunk.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gated_delta_rule/wy_fast.py
| if USE_G_GAMMA: | ||
| b_gamma = tl.load(g_gamma + i_h) | ||
| b_g = b_gamma * (tl.arange(0, BT) + 1) | ||
| b_o = b_o * exp(b_g)[:, None] | ||
| b_A = b_A * exp(b_g[:, None] - b_g[None, :]) | ||
| if USE_EXP2: | ||
| b_o = b_o * exp2(b_g)[:, None] | ||
| b_A = b_A * exp2(b_g[:, None] - b_g[None, :]) | ||
| else: | ||
| b_o = b_o * exp(b_g)[:, None] | ||
| b_A = b_A * exp(b_g[:, None] - b_g[None, :]) |
There was a problem hiding this comment.
Scale g_gamma before the exp2 branches.
The USE_G path can stay equivalent because fla/ops/gated_delta_rule/chunk.py now pre-scales cumulative g by RCP_LN2 before it reaches this file. g_gamma is generated locally here, though, so the new exp2(...) branches are still consuming raw natural-log decays. With use_exp2=True, every USE_G_GAMMA path here changes the operator in both forward and backward.
💡 Suggested fix
+from fla.ops.utils.constant import RCP_LN2
from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp, exp2 if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
+ if USE_EXP2:
+ b_g = b_g * RCP_LN2 if USE_G_GAMMA:
b_gamma = tl.load(g_gamma + i_h)
b_g = b_gamma * (tl.arange(0, BT) + 1)
b_g_last = b_gamma * min(BT, T - i_t * BT)
+ if USE_EXP2:
+ b_g = b_g * RCP_LN2
+ b_g_last = b_g_last * RCP_LN2Apply the same normalization in each kernel here that materializes b_g / b_g_last from g_gamma.
Also applies to: 311-319, 421-427, 512-516
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/common/chunk_o.py` around lines 110 - 118, The code uses
locally-built g_gamma (symbol g_gamma) to compute b_gamma and b_g, but when
USE_EXP2 is true b_g must be pre-scaled by RCP_LN2 (the same normalization used
upstream) before calling exp2; update every kernel that materializes b_g (and
b_g_last) from b_gamma—i.e., where b_gamma = tl.load(g_gamma + i_h) and b_g =
b_gamma * (tl.arange(...) + 1)—to multiply b_gamma (or b_g) by RCP_LN2 (the
reciprocal of ln(2)) prior to using exp2, and keep the USE_EXP2/else branches
otherwise unchanged so exp2 sees base-2-scaled exponents consistently (apply the
same change in all occurrences mentioned: around lines 110-118, 311-319,
421-427, 512-516).
| if USE_G: | ||
| b_g = tl.load(p_g).to(tl.float32) | ||
| b_h *= exp(b_g) | ||
| if USE_EXP2: | ||
| b_h *= exp2(b_g) | ||
| else: | ||
| b_h *= exp(b_g) | ||
|
|
||
| if USE_GK: | ||
| b_gk = tl.load(p_gk).to(tl.float32) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gk[None, :]) | ||
| if USE_EXP2: | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp2(b_gk[None, :]) | ||
| else: | ||
| b_h *= exp2(b_gk[:, None]) | ||
| else: | ||
| b_h *= exp(b_gk[:, None]) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gk[None, :]) | ||
| else: | ||
| b_h *= exp(b_gk[:, None]) | ||
|
|
||
| if USE_GV: | ||
| b_gv = tl.load(p_gv).to(tl.float32) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gv[:, None]) | ||
| if USE_EXP2: | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp2(b_gv[:, None]) | ||
| else: | ||
| b_h *= exp2(b_gv[None, :]) | ||
| else: | ||
| b_h *= exp(b_gv[None, :]) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gv[:, None]) | ||
| else: | ||
| b_h *= exp(b_gv[None, :]) |
There was a problem hiding this comment.
use_exp2=True changes the recurrent decay math.
Unlike the chunked path, this file never applies RCP_LN2 to g, gk, or gv before switching from exp to exp2. As written, use_exp2=True computes base-2 decays from raw natural-log terms, so the new flag changes both the output and the stored final state instead of preserving the existing semantics.
💡 Suggested fix
+from fla.ops.utils.constant import RCP_LN2
from fla.ops.utils.op import exp, exp2 if USE_G:
b_g = tl.load(p_g).to(tl.float32)
if USE_EXP2:
- b_h *= exp2(b_g)
+ b_h *= exp2(b_g * RCP_LN2)
else:
b_h *= exp(b_g)
if USE_GK:
b_gk = tl.load(p_gk).to(tl.float32)
if USE_EXP2:
if TRANSPOSE_STATE:
- b_h *= exp2(b_gk[None, :])
+ b_h *= exp2(b_gk[None, :] * RCP_LN2)
else:
- b_h *= exp2(b_gk[:, None])
+ b_h *= exp2(b_gk[:, None] * RCP_LN2)
else:
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])
if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
if USE_EXP2:
if TRANSPOSE_STATE:
- b_h *= exp2(b_gv[:, None])
+ b_h *= exp2(b_gv[:, None] * RCP_LN2)
else:
- b_h *= exp2(b_gv[None, :])
+ b_h *= exp2(b_gv[None, :] * RCP_LN2)
else:
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])Also applies to: 174-233, 255-255, 300-300
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 111 - 142, The exp2
path is using raw log-scale variables, changing decay semantics; when USE_EXP2
is true you must convert stored natural-log terms to base-2 by multiplying b_g,
b_gk, and b_gv by RCP_LN2 before calling exp2 so behavior matches exp(b_*).
Update the three places in this block where exp2 is used (the branches for
USE_G, USE_GK, and USE_GV) to apply RCP_LN2 to the loaded parameters (b_g, b_gk,
b_gv) in the same transpose-aware manner you handle exp, and apply the same
change to the other similar blocks mentioned (the corresponding USE_EXP2
branches around the other occurrences of these variables) so exp2(b_*) uses b_*
* RCP_LN2.
Co-Authored-By: Claude (claude-opus-4-6) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
fla/ops/comba/fused_recurrent.py (1)
132-132:use_exp2parameter is not exposed in the public API.The
use_exp2parameter is added tofused_recurrent_comba_fwdbut neitherFusedRecurrentCombaFunction.forward(line 184) nor the publicfused_recurrent_combafunction (line 224) expose or thread this parameter. This meansuse_exp2will always beFalsewhen using the public API.If this is intentional (keeping exp2 as an internal optimization not exposed to users), this is fine. Otherwise, consider threading
use_exp2through the autograd function and public interface for consistency with how it's handled inchunk.py.Also applies to: 172-172
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/comba/fused_recurrent.py` at line 132, The new use_exp2 flag is added to fused_recurrent_comba_fwd but isn’t threaded through the autograd wrapper or public API; update FusedRecurrentCombaFunction.forward to accept and pass use_exp2 into fused_recurrent_comba_fwd, and extend the public fused_recurrent_comba(...) function signature to accept use_exp2 and forward it into FusedRecurrentCombaFunction.apply (or constructor) so the runtime uses the provided flag instead of always defaulting to False; ensure the same change is applied for the corresponding backward path if applicable (reference fused_recurrent_comba_fwd, FusedRecurrentCombaFunction.forward, and fused_recurrent_comba).fla/ops/comba/wy_fast.py (2)
205-208: Inconsistent use oftl.expvs importedexp.Line 208 uses
tl.exp(b_g0)directly, while theUSE_EXP2branch at line 206 uses the importedexp2. For consistency with other kernels in this PR that use the importedexpwrapper, consider usingexp(b_g0)instead.♻️ Suggested fix
if USE_EXP2: b_g0_exp = exp2(b_g0) else: - b_g0_exp = tl.exp(b_g0) + b_g0_exp = exp(b_g0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/comba/wy_fast.py` around lines 205 - 208, The branch computing b_g0_exp is inconsistent: when USE_EXP2 is True it uses the imported exp2 but the else branch calls tl.exp(b_g0); change the else branch to use the imported exp wrapper (exp(b_g0)) so both branches use the same abstraction. Update the expression that assigns b_g0_exp (the else branch referencing tl.exp) to call exp instead, keeping the USE_EXP2 conditional and the variable name b_g0_exp unchanged.
329-332: Inconsistent use oftl.expvs importedexp.Same issue as in
prepare_wy_repr_bwd_kernel: line 332 usestl.expdirectly instead of the importedexpwrapper.♻️ Suggested fix
if USE_EXP2: b_g = exp2(tl.load(p_g, boundary_check=(0,))) else: - b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + b_g = exp(tl.load(p_g, boundary_check=(0,)))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/comba/wy_fast.py` around lines 329 - 332, The branch that computes b_g uses tl.exp directly (b_g = tl.exp(tl.load(p_g, ...))) when USE_EXP2 is false; change it to use the imported exp wrapper instead (i.e., call exp(tl.load(p_g, boundary_check=(0,)))) so behavior matches the other branch and prepare_wy_repr_bwd_kernel; update the b_g assignment in the same conditional that references USE_EXP2 and p_g.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/comba/chunk.py`:
- Around line 381-388: Remove the accidental duplicated lines that re-declare
cu_seqlens and cu_seqlens_cpu and the duplicated "return o, final_state" which
cause syntax errors; locate the end of the function that uses cu_seqlens and
cu_seqlens_cpu (the block that currently returns "o, final_state") and delete
the repeated second occurrence so the function ends with a single "return o,
final_state" and no repeated variable lines.
---
Nitpick comments:
In `@fla/ops/comba/fused_recurrent.py`:
- Line 132: The new use_exp2 flag is added to fused_recurrent_comba_fwd but
isn’t threaded through the autograd wrapper or public API; update
FusedRecurrentCombaFunction.forward to accept and pass use_exp2 into
fused_recurrent_comba_fwd, and extend the public fused_recurrent_comba(...)
function signature to accept use_exp2 and forward it into
FusedRecurrentCombaFunction.apply (or constructor) so the runtime uses the
provided flag instead of always defaulting to False; ensure the same change is
applied for the corresponding backward path if applicable (reference
fused_recurrent_comba_fwd, FusedRecurrentCombaFunction.forward, and
fused_recurrent_comba).
In `@fla/ops/comba/wy_fast.py`:
- Around line 205-208: The branch computing b_g0_exp is inconsistent: when
USE_EXP2 is True it uses the imported exp2 but the else branch calls
tl.exp(b_g0); change the else branch to use the imported exp wrapper (exp(b_g0))
so both branches use the same abstraction. Update the expression that assigns
b_g0_exp (the else branch referencing tl.exp) to call exp instead, keeping the
USE_EXP2 conditional and the variable name b_g0_exp unchanged.
- Around line 329-332: The branch that computes b_g uses tl.exp directly (b_g =
tl.exp(tl.load(p_g, ...))) when USE_EXP2 is false; change it to use the imported
exp wrapper instead (i.e., call exp(tl.load(p_g, boundary_check=(0,)))) so
behavior matches the other branch and prepare_wy_repr_bwd_kernel; update the b_g
assignment in the same conditional that references USE_EXP2 and p_g.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7a32f469-4e65-4554-ab58-9a7dbe9383df
📒 Files selected for processing (4)
fla/ops/comba/chunk.pyfla/ops/comba/fused_recurrent.pyfla/ops/comba/utils.pyfla/ops/comba/wy_fast.py
Add
use_exp2flag throughout the gated delta rule chunk operations, enabling the use ofexp2(base-2 exponential) instead ofexp(natural exponential) in Triton kernels. When enabled, gate values are pre-scaled byRCP_LN2(1/ln2) so thatexp2(g * RCP_LN2) == exp(g), allowing the compiler to emit faster native exp2 instructions.Affected kernels: chunk_o fwd/bwd, chunk_fwd intra (kkt+solve), wy_fast fwd/bwd, fused_recurrent fwd, and the top-level chunk fwd/bwd dispatch.
Summary by CodeRabbit
New Features
Public API
Bug Fixes