Skip to content

[OJA] Integrate Gated OJA Rule#730

Merged
zhiyuan1i merged 8 commits intofla-org:mainfrom
AwesomeSeq:main
Feb 18, 2026
Merged

[OJA] Integrate Gated OJA Rule#730
zhiyuan1i merged 8 commits intofla-org:mainfrom
AwesomeSeq:main

Conversation

@AwesomeSeq
Copy link
Copy Markdown
Contributor

@AwesomeSeq AwesomeSeq commented Jan 27, 2026

from hujiaxi@moonshot.cn

Summary by CodeRabbit

  • New Features

    • Two optimized gated OJA paths exposed for users: a chunked, chunk-size configurable implementation and a fused recurrent implementation for sequential workloads.
    • Triton-accelerated kernels with variable-length sequence support, optional per-tensor L2 normalization, and initial/final state propagation; autotuning and hardware-aware heuristics included.
  • Public API

    • Package now exposes both chunked and fused gated OJA entry points at the top-level.
  • Tests

    • Extensive validation suite with reference implementations covering forward/backward behavior, gradients, varlen cases, and multiple dtypes/configurations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 27, 2026

Walkthrough

Adds a complete gated OJA-rule implementation: Triton kernels and Python bindings for chunked KKT, hidden-state (h), output (o), WY recompute/prepare, a fused recurrent path, public API exports, and tests covering forward and backward (including varlen sequences).

Changes

Cohort / File(s) Summary
Public API Exports
fla/ops/gated_oja_rule/__init__.py
Exports chunk_gated_oja_rule and fused_recurrent_gated_oja_rule via __all__.
Chunked Python Orchestration
fla/ops/gated_oja_rule/chunk.py
Adds chunked forward/backward implementations (chunk_oja_fwd, chunk_oja_bwd), ChunkOJAFunction autograd wrapper, and top-level chunk_gated_oja_rule with L2-norm, cu_seqlens and input guards.
Triton: hidden-state kernels
fla/ops/gated_oja_rule/chunk_h.py
New Triton kernels/wrappers for hidden-state forward/backward (block-dim64 variants), GV-conditioned scaling, varlen support, state save/restore, and autotuning.
Triton: KKT kernels
fla/ops/gated_oja_rule/chunk_kkt.py
Adds Triton kernels and Python wrappers for chunked scaled-dot KKT forward/backward (supports gated gk, varlen sequences, chunk/grid orchestration).
Triton: output & grads kernels
fla/ops/gated_oja_rule/chunk_o.py
Adds forward/backward Triton kernels and host orchestration for outputs o and gradients (dA, dq/dk, dv, dgv, etc.) with inter/intra-block tiling, varlen support, and autotuning.
Fused recurrent path
fla/ops/gated_oja_rule/fused_recurrent.py
Adds fused_recurrent Triton binding, FusedRecurrentFunction (forward-only), and fused_recurrent_gated_oja_rule API with defaults, validation, and input guards; backward not implemented.
WY recompute / bwd prep
fla/ops/gated_oja_rule/wy_fast.py
Adds recompute_w_u_fwd and prepare_wy_repr_bwd Triton kernels/wrappers for WY representation forward/backward prep (w/u/vg), varlen support, and autotuning.
Tests & references
tests/ops/test_oja.py
Adds reference recurrent/chunk implementations and extensive parameterized tests validating fused and chunk implementations for forward and backward (including varlen), checking outputs and gradients across dtypes.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant ChunkAPI as chunk_gated_oja_rule
    participant ChunkFunc as ChunkOJAFunction
    participant KKT as chunk_scaled_dot_kkt_fwd
    participant WY as recompute_w_u_fwd
    participant Hkern as chunk_oja_fwd_h
    participant Okern as chunk_oja_fwd_o

    User->>ChunkAPI: q,k,v,gv,beta,...
    ChunkAPI->>ChunkFunc: apply forward
    ChunkFunc->>KKT: compute A
    KKT-->>ChunkFunc: A
    ChunkFunc->>WY: recompute w/u (from k,v,A)
    WY-->>ChunkFunc: w,u,vg
    ChunkFunc->>Hkern: compute h (hidden states)
    Hkern-->>ChunkFunc: h,final_state
    ChunkFunc->>Okern: compute o
    Okern-->>ChunkFunc: o
    ChunkFunc-->>User: o, final_state
Loading
sequenceDiagram
    actor User
    participant UserAPI as fused_recurrent_gated_oja_rule
    participant FusedFunc as FusedRecurrentFunction
    participant FusedKernel as fused_recurrent_oja_fwd_kernel

    User->>UserAPI: q,k,v,gv,beta,initial_state,...
    UserAPI->>FusedFunc: forward (prepare, validate)
    FusedFunc->>FusedKernel: per-timestep fused kernel (update h, o)
    FusedKernel-->>FusedFunc: o, final_state
    FusedFunc-->>User: o, final_state
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰 I hopped through kernels, tiled and neat,

Chunks and gates aligning every beat,
Hidden states snug, gradients on the trail,
A carrot of tests ensures we don't fail.
🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[OJA] Integrate Gated OJA Rule' accurately and concisely describes the main change—integration of a gated OJA rule implementation across multiple modules and kernels shown in the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the library's capabilities by introducing a new gated_oja_operator module. It provides both a chunk-based and a fused recurrent implementation of the Gated Oja Rule, designed for efficient processing of sequence data. The new operators are backed by highly optimized Triton kernels and are thoroughly tested to ensure correctness across various scenarios, including variable-length inputs.

Highlights

  • New Gated Oja Rule Operators: Introduced a new module fla/ops/gated_oja_rule containing implementations for the Gated Oja Rule.
  • Chunked Implementation: Added chunk_gated_oja_rule for efficient chunk-wise processing, including forward and backward passes with specialized Triton kernels for hidden states, KKT matrix computations, and output gradients.
  • Fused Recurrent Implementation: Provided fused_recurrent_gated_oja_rule for a fused recurrent approach, though its backward pass is noted as not yet implemented.
  • Triton Kernels: Developed multiple Triton kernels (chunk_oja_fwd_h, chunk_oja_bwd_dhu, chunk_oja_bwd_dvwg_h, chunk_scaled_dot_kkt_fwd, chunk_scaled_dot_kkt_bwd_gk, chunk_oja_fwd_o, chunk_oja_bwd_dA, chunk_oja_bwd_dqk, chunk_oja_bwd_dv_o, recompute_w_u_fwd, prepare_wy_repr_bwd) to optimize performance.
  • Comprehensive Testing: Included extensive unit tests covering both chunked and fused recurrent Oja rules, validating forward and backward passes, and supporting variable-length sequences.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces gated Oja operator implementations, including chunked and fused recurrent versions, along with corresponding tests. The overall structure is well-organized, separating forward and backward passes into distinct functions and Triton kernels. The addition of comprehensive test cases, including variable-length sequences and backward pass checks, is highly commendable. However, several critical bugs related to indexing and conditional variable usage in Triton kernels have been identified, which need immediate attention to ensure correctness and prevent potential runtime errors.

Comment thread fla/ops/gated_oja_rule/chunk_h.py Outdated
Comment thread fla/ops/gated_oja_rule/chunk_h.py
Comment thread fla/ops/gated_oja_rule/chunk_o.py
Comment thread fla/ops/gated_oja_rule/chunk_o.py
Comment on lines +398 to +445
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]

if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the chunk_oja_bwd_kernel_dhu_blockdim64 kernel, the b_dhX variables are multiplied by exp(b_gv_lastX) unconditionally if USE_GV is true. However, b_gv_lastX (for X=2,3,4) are loaded only if V is greater than a certain threshold (e.g., V > 64). If V is smaller, these b_gv_lastX variables might contain uninitialized or garbage values, leading to incorrect calculations. Each multiplication should be guarded by the corresponding if V > ... condition.

Suggested change
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1)[None, :]
b_do *= exp(b_gv)
if V > 64 and USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2)[None, :]
b_do *= exp(b_gv)
if V > 128 and USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3)[None, :]
b_do *= exp(b_gv)
if V > 192 and USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4)[None, :]
b_do *= exp(b_gv)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Line 93: Remove the unused BV variable assignment (BV=64) in chunk_h.py:
delete the BV definition since the kernel uses a hardcoded 64 and BV is never
referenced elsewhere; ensure no other code in the module refers to BV and run
lint/tests to confirm no remaining references (look for the symbol BV in
chunk_h.py to locate the line to remove).
- Around line 148-163: The conditional blocks that apply gv scaling incorrectly
compare against K instead of V; change all occurrences of "if K > 64/128/192" to
"if V > 64/128/192" and update the final load mask from "(o_v4 < K)" to "(o_v4 <
V)" so the gv loads and masks (e.g., in the blocks computing
b_gk_last1..b_gk_last4 and multiplying b_h1..b_h4) correctly use the
value-dimension V rather than the key-dimension K.
- Around line 747-767: The code uses b_dv unconditionally but only defines it
inside the if USE_GV branch; to fix, ensure b_dv is always assigned: when USE_GV
is True compute b_dv = b_dvg * exp(b_gn[None, :] - b_gv) as before, otherwise
initialize b_dv to a zero tensor with the same shape and dtype used later (shape
[BT, BV] matching b_v and p_dv.element_ty) so subsequent operations (b_dgv_last
update, tl.store(p_dv, ...), and interaction with b_v) work correctly; update
the block so b_dv, b_dvg, b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the
referenced symbols.
- Around line 396-403: The code unconditionally creates and loads p_gv and b_gv
(using gv, p_gv, b_gv) inside the V>0 handling even when gv may be None if
USE_GV is False; wrap the creation of p_gv and any tl.load(gv + ...) or
tl.load(p_gv, ...) calls with a guard on USE_GV (same pattern used elsewhere
when gv is offset at line 319) so that all accesses to gv happen only when
USE_GV is True, and apply the same guard pattern to the other V>64, V>128, and
V>192 blocks to prevent null pointer/runtime loads when gv is not provided.
- Around line 531-651: The function chunk_gsa_bwd_k_kernel_dqkvg defined in this
file is dead/duplicated and should be removed: delete the entire
chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.

In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 8-15: Remove the redundant and unused imports: delete the import
of exp from fla.ops.utils.op and the import of chunk_local_cumsum from
fla.ops.utils.cumsum, keeping the intended tl.exp assignment (exp = tl.exp) as
the single definition of exp; ensure no other code depends on
fla.ops.utils.op.exp or chunk_local_cumsum in this file (references to exp
should use the tl-backed exp symbol).

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 1-2: The file header lines have duplicated comment markers ("#
#"), so remove the extra '#' characters in those header comments: change the
leading "# # -*- coding: utf-8 -*-" to "# -*- coding: utf-8 -*-" and similarly
change "# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to "# Copyright (c)
2023-2025, Songlin Yang, Yu Zhang" to restore proper comment syntax.

In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 95-97: The load of gv must apply the mask for partial vector
blocks to avoid OOB reads: when USE_GV is true, change the load of p_gv (symbol
b_gv) to use mask_v (the mask for the last V block) instead of an unconditional
tl.load; keep the subsequent scaling of b_h (symbol b_h *= exp(b_gv[None, :]))
the same so that only valid lanes are loaded and used when V % BV != 0.

In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 199-237: Update the return type annotation of recompute_w_u_fwd to
match the actual returned values (w, u, vg): change the declared return from
Tuple[torch.Tensor, torch.Tensor] to Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] and import Optional if not already present; ensure the
function signature and any callers/types align with the new signature for
recompute_w_u_fwd.
- Around line 247-261: The gv parameter is declared without Optional typing and
the code unconditionally allocates dgv with torch.empty_like(gv), which will
crash if gv is None; update the function signature to annotate gv as
Optional[torch.Tensor] and change the local dgv to be Optional[torch.Tensor] (or
torch.Tensor | None) and only allocate dgv when gv is not None (e.g., after
checking gv) — leave dgv as None otherwise; ensure any later uses of dgv handle
the None case or assert/raise if those code paths require gv to be present.
🧹 Nitpick comments (12)
fla/ops/gated_oja_rule/chunk_kkt.py (1)

131-143: Inconsistent naming convention for block pointer.

The variable b_kt at line 131 is a block pointer (created via tl.make_block_ptr), but follows the b_ prefix convention used for block tensors throughout this file. Consider renaming to p_kt for consistency with other pointers (p_k, p_g, etc.).

♻️ Suggested naming fix
-        b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
+        p_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))

And at line 143:

-        b_kt = tl.load(b_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)
+        b_kt = tl.load(p_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)
fla/ops/gated_oja_rule/wy_fast.py (2)

11-11: Remove unused import.

The static analysis correctly identifies that chunk_local_cumsum is imported but not used in this file.

♻️ Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
+from fla.ops.utils import prepare_chunk_indices

193-193: Remove or complete the commented-out code.

Line 193 contains a commented-out conditional # if USE_GV:. This appears to be either dead code or an incomplete TODO. Please remove it or implement the intended logic.

fla/ops/gated_oja_rule/chunk.py (1)

283-287: Add stacklevel to warnings.warn.

Per best practices, specify stacklevel=2 so the warning points to the caller's location rather than this line.

♻️ Proposed fix
     if 'head_first' in kwargs:
         warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
+            "Please use head_first=False for now instead.",
+            stacklevel=2
         )
fla/ops/gated_oja_rule/fused_recurrent.py (1)

133-133: Document or relax the V <= 128 constraint.

The assertion assert V <= 128 limits the value dimension without explanation. Consider adding a comment explaining why this limit exists, or raising a more informative error.

♻️ Suggested improvement
-    assert V <= 128
+    if V > 128:
+        raise ValueError(
+            f"fused_recurrent_oja_fwd currently supports V <= 128, got V={V}. "
+            "Use chunk_gated_oja_rule for larger value dimensions."
+        )
tests/ops/test_oja.py (4)

4-4: Remove unused imports.

Optional from typing and repeat from einops are imported but never used.

♻️ Proposed fix
-from typing import List, Optional
+from typing import List
-from einops import rearrange, repeat
+from einops import rearrange

82-82: Rename ambiguous variable l.

The variable l at line 82 is flagged by linters as ambiguous (looks like 1). Consider renaming to seq_len or L for clarity.

♻️ Proposed fix
-    b, h, l, d_k = q.shape
+    b, h, seq_len, d_k = q.shape
     d_v = v.shape[-1]
     q = q * scale # B H T D
-    assert l % chunk_size == 0
+    assert seq_len % chunk_size == 0

And update other usages of l (lines 85, 121) to seq_len.


341-341: Remove debug print statement.

The print statement at line 341 appears to be debug output. Consider removing it or using proper logging.

♻️ Proposed fix
-    print('================== Running forward and backward ==================')

412-412: Consider isolating environment variable modification.

Setting os.environ['TRITON_F32_DEFAULT'] at line 412 persists beyond this test and may affect subsequent tests. Consider using a fixture or context manager to ensure cleanup.

♻️ Suggested approach
`@pytest.fixture`(autouse=True)
def set_triton_f32_default():
    old_value = os.environ.get('TRITON_F32_DEFAULT')
    os.environ['TRITON_F32_DEFAULT'] = 'ieee'
    yield
    if old_value is None:
        del os.environ['TRITON_F32_DEFAULT']
    else:
        os.environ['TRITON_F32_DEFAULT'] = old_value
fla/ops/gated_oja_rule/chunk_h.py (2)

480-484: Unused chunk_indices computation.

chunk_indices is computed but never passed to the kernel. Consider removing this unnecessary computation.

Proposed fix
-    chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
     if cu_seqlens is None:
         N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
     else:
-        N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
+        chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
+        N = len(cu_seqlens) - 1
+        NT = chunk_offsets[-1].item()  # or compute directly

386-386: Use ASCII commas in comments for consistency.

The comment contains fullwidth commas (,) which triggers linter warnings. Consider using standard ASCII commas or translating comments to English.

fla/ops/gated_oja_rule/chunk_o.py (1)

453-461: Redundant computation of attention matrix A.

The attention matrix A is computed identically in each i_k block and then summed (line 540), which is wasteful. Since A = dot(q*scale, k.T) is the same regardless of which K-block is being processed, this results in NK redundant computations.

Consider computing A once in a separate kernel or only in the first K-block.

Comment thread fla/ops/gated_oja_rule/chunk_h.py Outdated
Comment thread fla/ops/gated_oja_rule/chunk_h.py
Comment thread fla/ops/gated_oja_rule/chunk_h.py Outdated
Comment thread fla/ops/gated_oja_rule/chunk_h.py Outdated
Comment on lines +747 to +767
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)

p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))

b_dgv_last += tl.sum(b_dv * b_v, axis=0)

# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]

tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: b_dv undefined when USE_GV is False.

When USE_GV is False, b_dv is never assigned (line 748 is inside if USE_GV), but it's used unconditionally at lines 756 and 767. This will cause a runtime error.

Proposed fix
     if USE_GV:
         b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
-    
+    else:
+        b_dv = b_dvg
+
     p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
else:
b_dv = b_dvg
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 747 - 767, The code uses b_dv
unconditionally but only defines it inside the if USE_GV branch; to fix, ensure
b_dv is always assigned: when USE_GV is True compute b_dv = b_dvg *
exp(b_gn[None, :] - b_gv) as before, otherwise initialize b_dv to a zero tensor
with the same shape and dtype used later (shape [BT, BV] matching b_v and
p_dv.element_ty) so subsequent operations (b_dgv_last update, tl.store(p_dv,
...), and interaction with b_v) work correctly; update the block so b_dv, b_dvg,
b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the referenced symbols.

Comment thread fla/ops/gated_oja_rule/chunk_o.py Outdated
Comment on lines +8 to +15
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem, is_nvidia_hopper
from fla.ops.utils.cumsum import chunk_local_cumsum

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]

exp = tl.exp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove redundant import and duplicate exp definition.

exp is imported from fla.ops.utils.op at line 8 but immediately overwritten with tl.exp at line 15. Also, chunk_local_cumsum is imported but never used. Remove the unused imports.

Proposed fix
 from fla.ops.utils import prepare_chunk_indices
-from fla.ops.utils.op import exp
 from fla.utils import check_shared_mem, is_nvidia_hopper
-from fla.ops.utils.cumsum import chunk_local_cumsum
 
 BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
 NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
 
 exp = tl.exp
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused

(F401)


[error] 15-15: redefinition of unused 'exp' from line 8

(F811)

🪛 GitHub Actions: lint

[error] 13-13: Ruff: F811 Redefinition of unused 'exp' from line 7.

🪛 Ruff (0.14.14)

15-15: Redefinition of unused exp from line 8: exp redefined here

(F811)

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 8 - 15, Remove the redundant
and unused imports: delete the import of exp from fla.ops.utils.op and the
import of chunk_local_cumsum from fla.ops.utils.cumsum, keeping the intended
tl.exp assignment (exp = tl.exp) as the single definition of exp; ensure no
other code depends on fla.ops.utils.op.exp or chunk_local_cumsum in this file
(references to exp should use the tl-backed exp symbol).

Comment thread fla/ops/gated_oja_rule/chunk.py Outdated
Comment on lines +1 to +2
# # -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix comment syntax.

Line 1 has doubled comment markers # # which appears to be a typo.

🐛 Proposed fix
-# # -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# # -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: Trailing whitespace detected by pre-commit; file was modified.

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 1 - 2, The file header lines
have duplicated comment markers ("# #"), so remove the extra '#' characters in
those header comments: change the leading "# # -*- coding: utf-8 -*-" to "# -*-
coding: utf-8 -*-" and similarly change "# # Copyright (c) 2023-2025, Songlin
Yang, Yu Zhang" to "# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to
restore proper comment syntax.

Comment thread fla/ops/gated_oja_rule/fused_recurrent.py
Comment thread fla/ops/gated_oja_rule/wy_fast.py
Comment on lines +247 to +261
gv: torch.Tensor = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = 64
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64 if check_shared_mem() else 32
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

dk = torch.empty_like(k)
dv = torch.empty_like(v, dtype=torch.float)

dgv = torch.empty_like(gv, dtype=torch.float)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add type hint and guard for gv parameter.

The gv parameter at line 247 lacks a proper type hint (defaults to None but no Optional annotation). Additionally, dgv is unconditionally allocated at line 261 using torch.empty_like(gv), which will fail if gv is None.

🐛 Proposed fix
 def prepare_wy_repr_bwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     A: torch.Tensor,
     dw: torch.Tensor,
     du: torch.Tensor,
-    gv: torch.Tensor = None,
+    gv: Optional[torch.Tensor] = None,
     cu_seqlens: Optional[torch.LongTensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    if gv is None:
+        raise ValueError("gv is required for prepare_wy_repr_bwd")
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 247 - 261, The gv parameter
is declared without Optional typing and the code unconditionally allocates dgv
with torch.empty_like(gv), which will crash if gv is None; update the function
signature to annotate gv as Optional[torch.Tensor] and change the local dgv to
be Optional[torch.Tensor] (or torch.Tensor | None) and only allocate dgv when gv
is not None (e.g., after checking gv) — leave dgv as None otherwise; ensure any
later uses of dgv handle the None case or assert/raise if those code paths
require gv to be present.

@zhiyuan1i zhiyuan1i changed the title Upload gated_oja_operator and test [OJA] Integrate Gated OJA Rule Feb 4, 2026
@ntumm120
Copy link
Copy Markdown

ntumm120 commented Feb 9, 2026

Cool work @AwesomeSeq! Have you trained any 340M/1.3B models with this recurrence yet? If so, is there a paper I can refer to?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 194-238: The return type annotation of chunk_oja_fwd_h is wrong:
the function actually returns three values (h, k_new, final_state) where k_new
and final_state can be None depending on save_new_key and output_final_state;
update the function signature's return type to reflect three elements (e.g.
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]) and ensure any
callers or tests expecting the old two-tuple are adjusted accordingly; reference
the chunk_oja_fwd_h definition and the variables h, k_new, final_state in your
change.

In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 178-241: The function chunk_oja_fwd_o has a return type mismatch:
its annotation declares four tensors but the implementation returns only A and
o; update the function signature's return annotation to match the actual return
(tuple[torch.Tensor, torch.Tensor]) or modify the body to return the additional
tensors if intended; locate chunk_oja_fwd_o and change the annotated return type
to reflect only A and o (or add the missing tensors to the return) and ensure
callers expect the corrected shape.

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 287-297: The error messages use adjacent f-strings that get
implicitly concatenated without a separating space; update the messages in the
cu_seqlens check (where variables q, cu_seqlens, and initial_state are
referenced) to ensure proper spacing — either merge the two f-strings into one
or insert an explicit leading/trailing space or punctuation between them so the
resulting strings read correctly (do the same fix in the analogous checks in
fused_recurrent.py around the initial_state/cu_seqlens validation).

In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 61-79: The kernel unconditionally loads from gv (e.g.,
tl.load(p_gv) and tl.load(gv + ...)) which will crash if gv is None; update the
kernels (recompute_w_u_fwd_kernel and prepare_wy_repr_bwd_kernel) to either
(preferred) add a compile-time/use-time guard like a boolean USE_GV and wrap all
gv loads and STORE_VG-dependent logic (the p_gv/tl.load uses and computing
b_vb/b_vg) behind if USE_GV so the code never dereferences gv when absent, or
alternatively make gv a required parameter in the Python wrappers so callers
cannot pass None; ensure referenced symbols include gv, p_gv, b_gv, b_gn,
STORE_VG, and vg when applying the guard so no tl.load or tl.store touches gv/
vg unless USE_GV is true.

In `@tests/ops/test_oja.py`:
- Around line 404-407: Fix two issues: update the skip reason text and avoid
mutating global env. Change the pytest.skip call (condition using
is_intel_alchemist and D) to use the correct message 'chunk_gated_oja_rule'
instead of 'chunk_gated_delta_rule'; and replace the direct
os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
🧹 Nitpick comments (2)
fla/ops/gated_oja_rule/fused_recurrent.py (1)

117-168: Use explicit T | None for optional parameters.

Several parameters use implicit Optional (PEP 484 violation): scale at line 123, initial_state at line 124. This also applies to FusedRecurrentFunction.forward (line 182) and fused_recurrent_gated_oja_rule (lines 221-222).

Proposed fix (for the wrapper)
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: float | None = None,
+    initial_state: torch.Tensor | None = None,
tests/ops/test_oja.py (1)

337-337: Remove leftover debug print statement.

Line 337 contains a print(...) that shouldn't be in committed test code. Also, there are leftover # breakpoint() comments at lines 37 and 369.

Proposed fix
-    print('================== Running forward and backward ==================')

Comment thread fla/ops/gated_oja_rule/chunk_h.py
Comment thread fla/ops/gated_oja_rule/chunk_o.py
Comment on lines +287 to +297
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing space between concatenated f-strings in error messages.

The two adjacent f-strings at lines 290–291 and 295–296 are implicitly concatenated without a separator, producing messages like "...cu_seqlens.Please flatten...". The same issue exists in fused_recurrent.py` lines 237–239.

🐛 Proposed fix
         if q.shape[0] != 1:
             raise ValueError(
-                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
                 f"Please flatten variable-length inputs before processing."
             )
         if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
             raise ValueError(
-                f"The number of initial states is expected to be equal to the number of input sequences, "
+                f"The number of initial states is expected to be equal to the number of input sequences, "
                 f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 289-292: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 294-297: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 287 - 297, The error messages
use adjacent f-strings that get implicitly concatenated without a separating
space; update the messages in the cu_seqlens check (where variables q,
cu_seqlens, and initial_state are referenced) to ensure proper spacing — either
merge the two f-strings into one or insert an explicit leading/trailing space or
punctuation between them so the resulting strings read correctly (do the same
fix in the analogous checks in fused_recurrent.py around the
initial_state/cu_seqlens validation).

Comment on lines +61 to +79
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_w = tl.make_block_ptr(w + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = b_v * b_b[:, None]

p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_gv = tl.load(p_gv, boundary_check=(0, 1))
b_vb *= exp(b_gv)
if STORE_VG:
last_idx = min(i_t * BT + BT, T) - 1

o_v = i_v * BV + tl.arange(0, BV)
m_v = o_v < V
b_gn = tl.load(gv + ((bos + last_idx) * H + i_h) * V + o_v, mask=m_v, other=0.)
b_vg = b_v * exp(b_gn - b_gv)

p_vg = tl.make_block_ptr(vg + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_vg, b_vg.to(p_vg.dtype.element_ty), boundary_check=(0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

gv is loaded unconditionally in the kernel — will crash if gv is None.

Both recompute_w_u_fwd_kernel (line 67) and prepare_wy_repr_bwd_kernel (line 152) load from gv without guarding on whether gv is actually provided. Although callers in chunk.py always pass a valid gv, the Python wrapper signatures allow gv=None. Either add a USE_GV heuristic guard in the kernels or make gv a required parameter in the wrappers to prevent a latent null-pointer crash.

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 61 - 79, The kernel
unconditionally loads from gv (e.g., tl.load(p_gv) and tl.load(gv + ...)) which
will crash if gv is None; update the kernels (recompute_w_u_fwd_kernel and
prepare_wy_repr_bwd_kernel) to either (preferred) add a compile-time/use-time
guard like a boolean USE_GV and wrap all gv loads and STORE_VG-dependent logic
(the p_gv/tl.load uses and computing b_vb/b_vg) behind if USE_GV so the code
never dereferences gv when absent, or alternatively make gv a required parameter
in the Python wrappers so callers cannot pass None; ensure referenced symbols
include gv, p_gv, b_gv, b_gn, STORE_VG, and vg when applying the guard so no
tl.load or tl.store touches gv/ vg unless USE_GV is true.

Comment thread tests/ops/test_oja.py
Comment on lines +404 to +407
if is_intel_alchemist and D > 128:
pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
torch.manual_seed(42)
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Two issues: incorrect skip message and global env var side-effect.

  1. Line 405: The skip reason says chunk_gated_delta_rule but should say chunk_gated_oja_rule.
  2. Line 407: Setting os.environ['TRITON_F32_DEFAULT'] = 'ieee' modifies global process state, which can leak into other tests. Consider scoping this with monkeypatch or restoring the original value in a fixture.
Proposed fix for the skip message
-        pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
+        pytest.skip(reason='chunk_gated_oja_rule is not supported on alchemist for D>128')
🤖 Prompt for AI Agents
In `@tests/ops/test_oja.py` around lines 404 - 407, Fix two issues: update the
skip reason text and avoid mutating global env. Change the pytest.skip call
(condition using is_intel_alchemist and D) to use the correct message
'chunk_gated_oja_rule' instead of 'chunk_gated_delta_rule'; and replace the
direct os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

Please refrain from merging just yet, waiting for further review and full testing

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
fla/ops/gated_oja_rule/chunk.py (1)

279-282: warnings.warn missing stacklevel=2.

Without stacklevel=2, the warning points to this internal call site rather than the caller's code, making the deprecation message less useful in tracebacks.

♻️ Proposed fix
         warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
             "Please use head_first=False for now instead.",
+            stacklevel=2,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_oja_rule/chunk.py` around lines 279 - 282, The deprecation
warning emitted in chunk.py currently calls warnings.warn without a stacklevel,
so update the warnings.warn invocation in the function/method where "head_first
is deprecated..." is emitted (the warnings.warn call in chunk.py) to include
stacklevel=2; this will make the traceback point to the caller rather than the
internal site—add the stacklevel=2 kwarg to that warnings.warn call while
preserving the existing message and warning category.
fla/ops/gated_oja_rule/chunk_o.py (1)

17-25: Autotune key=['BT'] is too narrow — configs won't re-tune across different K/V sizes.

chunk_oja_fwd_inter, chunk_oja_bwd_kernel_dA, and chunk_oja_bwd_kernel_dqk all use key=['BT']. This means the autotuned BK/BV choice (and warp/stage config) is shared across all calls with the same BT, regardless of the actual K and V dimensions. For example, BK=64 tuned at K=128 is reused at K=60, even though BK=32 might be optimal there. Other kernels in this PR (e.g., recompute_w_u_fwd_kernel) correctly include 'K' and 'V' in the autotune key. Consider expanding the key:

♻️ Proposed fix (representative for chunk_oja_fwd_inter)
-    key=['BT']
+    key=['H', 'K', 'V', 'BT', 'IS_VARLEN']

Also applies to: 247-252, 386-392

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 17 - 25, The autotune key is
too narrow: update the triton.autotune decorators for chunk_oja_fwd_inter,
chunk_oja_bwd_kernel_dA, and chunk_oja_bwd_kernel_dqk (and the other occurrences
noted) to include the input dimension identifiers so tuning varies with K and V;
specifically add 'K' and 'V' to the key list (e.g., key=['BT','K','V'] or
similar) so BK/BV and warp/stage choices are re-tuned when K or V change.
tests/ops/test_oja.py (1)

37-37: Remove leftover # breakpoint() debug comments.

Lines 37 and 369 contain # breakpoint() debug artifacts that should be removed before merge.

Also applies to: 369-369

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_oja.py` at line 37, Remove the leftover debug comments
consisting of the literal "# breakpoint()" (two occurrences) from the test file;
locate the occurrences of "# breakpoint()" in test_oja.py and delete those
comment lines so no debug artifacts remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 143-167: The kernel call unconditionally evaluates beta.ndim in
fused_recurrent_oja_fwd_kernel which crashes when beta is None; before invoking
fused_recurrent_oja_fwd_kernel compute the boolean for IS_BETA_HEADWISE safely
(e.g. set is_beta_headwise = (beta is not None) and (beta.ndim != v.ndim) or
mirror the guard used in fused_recurrent_gated_oja_rule) and pass that variable
to the kernel instead of evaluating beta.ndim inline.
- Line 31: The kernel parameter B is declared as a constexpr in
fused_recurrent_oja_fwd_kernel but never used and is also passed from Python
(B=B); remove the unused parameter declaration from the
fused_recurrent_oja_fwd_kernel signature (delete the "B: tl.constexpr" entry)
and remove the corresponding B=B argument at the Python call site that invokes
fused_recurrent_oja_fwd_kernel; ensure any related references/comments or tests
that expect that parameter are updated accordingly.

In `@tests/ops/test_oja.py`:
- Line 78: Rename the ambiguous variable l from the unpacking b, h, l, d_k =
q.shape to a clearer name like seq_len (e.g., b, h, seq_len, d_k = q.shape) and
update every subsequent reference to l in this test (all uses of l at and after
the q.shape line) to seq_len so variable meaning is clear and avoids confusion
with 1/I; ensure any related asserts or shape computations (the references
originally at lines 81, 117, 131–133) are changed consistently.

---

Duplicate comments:
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 731-753: The code uses b_dv unconditionally but only assigns it
inside the USE_GV branch, so when USE_GV is false b_dv is undefined; fix by
adding an else branch alongside the existing if USE_GV that initializes b_dv to
a zero tensor matching the expected shape/type (same shape/dtype as b_dvg / the
block [BT, BV]) before it is used in tl.sum and stored via p_dv, ensuring
b_dv.to(...) and subsequent arithmetic with b_v remain valid; update the same
pattern if any dependent variables (e.g., b_dv usage in tl.store(p_dv, ...))
expect a specific dtype.
- Around line 520-639: The function chunk_gsa_bwd_k_kernel_dqkvg is a dead
duplicate (copied from gsa/chunk.py) and should be removed from this module to
avoid duplicate maintenance; delete the entire chunk_gsa_bwd_k_kernel_dqkvg
definition from this file, verify there are no remaining references to
chunk_gsa_bwd_k_kernel_dqkvg in the module (and remove any imports or helper
symbols that become unused as a result), and run tests/linters to ensure no
unintended breakage.
- Around line 384-440: The kernel chunk_oja_bwd_kernel_dhu_blockdim64 currently
constructs p_gv and calls tl.load into b_gv unconditionally in each V-block
(e.g. the p_gv/b_gv in the V>0, V>64, V>128, V>192 branches) which dereferences
gv when USE_GV is False; to fix, move the p_gv = tl.make_block_ptr(...) and b_gv
= tl.load(...) lines inside the corresponding if USE_GV: block for each branch
so gv is only accessed when USE_GV is True, keeping the surrounding b_do/b_dh
updates unchanged and ensuring any use of b_gv (e.g. b_do *= exp(b_gv)) remains
inside that guard.

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 289-292: The ValueError message uses two adjacent f-strings that
get implicitly concatenated without a space, producing a malformed message;
update the raises in gated_oja_rule.chunk (the ValueError that mentions
q.shape[0] and `cu_seqlens`, and the analogous one referencing q.shape[0] at the
second occurrence) to use a single f-string or explicitly include a
space/newline between the parts so the final message reads correctly (e.g.,
combine into one f-string: f"...when using `cu_seqlens`. Please flatten
variable-length inputs..." or add " " between the two f-strings).

In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 244-258: The parameter gv is annotated/defined with a default None
but the function (and the kernel that loads gv) uses it unconditionally (e.g.,
torch.empty_like(gv, dtype=torch.float)), so change gv to be required (remove
the default None/Optional typing) and update its type to plain torch.Tensor;
then remove any conditional handling around gv in this function (e.g., the
torch.empty_like(gv, dtype=torch.float) and the kernel references that assume gv
exists will be valid) and update any call sites to always pass a gv tensor.
- Around line 61-79: The kernel unconditionally dereferences gv (p_gv/b_gv and
b_vb *= exp(b_gv)) even when callers pass gv=None; either introduce a boolean
heuristic USE_GV (parallel to STORE_VG) and wrap all gv accesses (construction
of p_gv, tl.load of b_gv, and any uses like b_vb *= exp(b_gv) and the STORE_VG
branch) behind if USE_GV, or make gv a required parameter by removing the None
default in recompute_w_u_fwd and callers so gv is always non-null; update
function signature and call sites if you choose the required-parameter route.

In `@tests/ops/test_oja.py`:
- Around line 404-407: The test mutates global env var
os.environ['TRITON_F32_DEFAULT'] without restoring it, leaking into other tests;
change the test (e.g., inside test_chunk_varlen or the surrounding test in
tests/ops/test_oja.py) to set the env value in a test-scoped way by using the
pytest monkeypatch fixture (monkeypatch.setenv('TRITON_F32_DEFAULT', 'ieee')) or
by saving the original value and restoring it after the test completes, ensuring
the modification is confined to this test only.

---

Nitpick comments:
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 17-25: The autotune key is too narrow: update the triton.autotune
decorators for chunk_oja_fwd_inter, chunk_oja_bwd_kernel_dA, and
chunk_oja_bwd_kernel_dqk (and the other occurrences noted) to include the input
dimension identifiers so tuning varies with K and V; specifically add 'K' and
'V' to the key list (e.g., key=['BT','K','V'] or similar) so BK/BV and
warp/stage choices are re-tuned when K or V change.

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 279-282: The deprecation warning emitted in chunk.py currently
calls warnings.warn without a stacklevel, so update the warnings.warn invocation
in the function/method where "head_first is deprecated..." is emitted (the
warnings.warn call in chunk.py) to include stacklevel=2; this will make the
traceback point to the caller rather than the internal site—add the stacklevel=2
kwarg to that warnings.warn call while preserving the existing message and
warning category.

In `@tests/ops/test_oja.py`:
- Line 37: Remove the leftover debug comments consisting of the literal "#
breakpoint()" (two occurrences) from the test file; locate the occurrences of "#
breakpoint()" in test_oja.py and delete those comment lines so no debug
artifacts remain.

cu_seqlens,
scale,
T,
B: tl.constexpr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused B: tl.constexpr in fused_recurrent_oja_fwd_kernel.

B is declared as a kernel constexpr at line 31 but never referenced in the kernel body. It is also passed at the Python call site (line 155: B=B). Ruff (ARG001) flags this. Remove it from both the kernel signature and the call site.

🐛 Proposed fix
-    B: tl.constexpr,
-        B=B,
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 31-31: Unused function argument: B

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_oja_rule/fused_recurrent.py` at line 31, The kernel parameter B
is declared as a constexpr in fused_recurrent_oja_fwd_kernel but never used and
is also passed from Python (B=B); remove the unused parameter declaration from
the fused_recurrent_oja_fwd_kernel signature (delete the "B: tl.constexpr"
entry) and remove the corresponding B=B argument at the Python call site that
invokes fused_recurrent_oja_fwd_kernel; ensure any related references/comments
or tests that expect that parameter are updated accordingly.

Comment on lines +143 to +167
fused_recurrent_oja_fwd_kernel[grid](
q=q,
k=k,
v=v,
gv=gv,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
IS_BETA_HEADWISE=beta.ndim != v.ndim,
USE_Q_L2NORM=use_q_l2norm,
USE_K_L2NORM=use_k_l2norm,
num_warps=num_warps,
num_stages=num_stages,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

fused_recurrent_oja_fwd crashes with AttributeError when beta=None.

Line 162 evaluates beta.ndim != v.ndim unconditionally as part of the kernel-call keyword arguments. Since the function signature declares beta: torch.Tensor | None = None, a direct caller passing beta=None will get AttributeError: 'NoneType' object has no attribute 'ndim' before the kernel is invoked. Guard the expression or fill the default the same way fused_recurrent_gated_oja_rule does:

🐛 Proposed fix
+    if beta is None:
+        beta = torch.ones_like(q[..., 0])
     fused_recurrent_oja_fwd_kernel[grid](
         ...
         IS_BETA_HEADWISE=beta.ndim != v.ndim,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fused_recurrent_oja_fwd_kernel[grid](
q=q,
k=k,
v=v,
gv=gv,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
IS_BETA_HEADWISE=beta.ndim != v.ndim,
USE_Q_L2NORM=use_q_l2norm,
USE_K_L2NORM=use_k_l2norm,
num_warps=num_warps,
num_stages=num_stages,
)
if beta is None:
beta = torch.ones_like(q[..., 0])
fused_recurrent_oja_fwd_kernel[grid](
q=q,
k=k,
v=v,
gv=gv,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
scale=scale,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
IS_BETA_HEADWISE=beta.ndim != v.ndim,
USE_Q_L2NORM=use_q_l2norm,
USE_K_L2NORM=use_k_l2norm,
num_warps=num_warps,
num_stages=num_stages,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_oja_rule/fused_recurrent.py` around lines 143 - 167, The kernel
call unconditionally evaluates beta.ndim in fused_recurrent_oja_fwd_kernel which
crashes when beta is None; before invoking fused_recurrent_oja_fwd_kernel
compute the boolean for IS_BETA_HEADWISE safely (e.g. set is_beta_headwise =
(beta is not None) and (beta.ndim != v.ndim) or mirror the guard used in
fused_recurrent_gated_oja_rule) and pass that variable to the kernel instead of
evaluating beta.ndim inline.

Comment thread tests/ops/test_oja.py
g = F.pad(g, (0, 0, 0, pad_len))
q, k, v, beta, g = map(lambda x: x.to(torch.float32), [q, k, v, beta, g])
chunk_size = BT
b, h, l, d_k = q.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ambiguous variable name l (Ruff E741).

b, h, l, d_k = q.shape — lowercase l is easily confused with 1 or I. Rename to seq_len or total_len.

🐛 Proposed fix
-    b, h, l, d_k = q.shape
+    b, h, seq_len, d_k = q.shape

Update all subsequent references to l (lines 81, 117, 131–133) accordingly.

🧰 Tools
🪛 Ruff (0.15.1)

[error] 78-78: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_oja.py` at line 78, Rename the ambiguous variable l from the
unpacking b, h, l, d_k = q.shape to a clearer name like seq_len (e.g., b, h,
seq_len, d_k = q.shape) and update every subsequent reference to l in this test
(all uses of l at and after the q.shape line) to seq_len so variable meaning is
clear and avoids confusion with 1/I; ensure any related asserts or shape
computations (the references originally at lines 81, 117, 131–133) are changed
consistently.

@zhiyuan1i zhiyuan1i merged commit 7ef6685 into fla-org:main Feb 18, 2026
4 checks passed
@MzeroMiko
Copy link
Copy Markdown

Cool work @AwesomeSeq! Have you trained any 340M/1.3B models with this recurrence yet? If so, is there a paper I can refer to?

Thanks to @AwesomeSeq for bringing us another great possibility. This leads to the same question: Would Oja's rule be a better alternative to the Delta rule?

To others who want to learn Oja's rule:

Background

Erkki Oja proposed the idea of naturally integrating a forgetting/constraint mechanism into the local update rule. Thus, he introduced the following formula:

$$ \Delta w_i = \eta \cdot (y \cdot x_i - y^2 \cdot w_i) $$

Simple Integration

From the perspective of online learning, the output is actually v and the input is k. So we have the equation:

$$ \mathbf{S}_t - \mathbf{S}_{t-1} = \beta_t (\boldsymbol{v}_t \boldsymbol{k}_t^\top - \boldsymbol{v}_t \boldsymbol{v}_t^\top \mathbf{S}_{t-1} ) $$

With Gate (fast decay with gate + slow decay with oja)

$$ \mathbf{S}_t = \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) + \beta_t (\boldsymbol{v}_t \boldsymbol{k}_t^\top - \boldsymbol{v}_t \boldsymbol{v}_t^\top \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) ) $$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants