[OJA] Integrate Gated OJA Rule#730
Conversation
WalkthroughAdds a complete gated OJA-rule implementation: Triton kernels and Python bindings for chunked KKT, hidden-state (h), output (o), WY recompute/prepare, a fused recurrent path, public API exports, and tests covering forward and backward (including varlen sequences). Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant ChunkAPI as chunk_gated_oja_rule
participant ChunkFunc as ChunkOJAFunction
participant KKT as chunk_scaled_dot_kkt_fwd
participant WY as recompute_w_u_fwd
participant Hkern as chunk_oja_fwd_h
participant Okern as chunk_oja_fwd_o
User->>ChunkAPI: q,k,v,gv,beta,...
ChunkAPI->>ChunkFunc: apply forward
ChunkFunc->>KKT: compute A
KKT-->>ChunkFunc: A
ChunkFunc->>WY: recompute w/u (from k,v,A)
WY-->>ChunkFunc: w,u,vg
ChunkFunc->>Hkern: compute h (hidden states)
Hkern-->>ChunkFunc: h,final_state
ChunkFunc->>Okern: compute o
Okern-->>ChunkFunc: o
ChunkFunc-->>User: o, final_state
sequenceDiagram
actor User
participant UserAPI as fused_recurrent_gated_oja_rule
participant FusedFunc as FusedRecurrentFunction
participant FusedKernel as fused_recurrent_oja_fwd_kernel
User->>UserAPI: q,k,v,gv,beta,initial_state,...
UserAPI->>FusedFunc: forward (prepare, validate)
FusedFunc->>FusedKernel: per-timestep fused kernel (update h, o)
FusedKernel-->>FusedFunc: o, final_state
FusedFunc-->>User: o, final_state
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the library's capabilities by introducing a new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces gated Oja operator implementations, including chunked and fused recurrent versions, along with corresponding tests. The overall structure is well-organized, separating forward and backward passes into distinct functions and Triton kernels. The addition of comprehensive test cases, including variable-length sequences and backward pass checks, is highly commendable. However, several critical bugs related to indexing and conditional variable usage in Triton kernels have been identified, which need immediate attention to ensure correctness and prevent potential runtime errors.
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_dh1 *= exp(b_gv_last1[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | ||
|
|
||
| if V > 64: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v2 = 64 + o_v1 | ||
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_dh2 *= exp(b_gv_last2[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 128: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v3 = 128 + o_v1 | ||
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_dh3 *= exp(b_gv_last3[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 192: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v4 = 192 + o_v1 | ||
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | ||
| b_dh4 *= exp(b_gv_last4[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) |
There was a problem hiding this comment.
In the chunk_oja_bwd_kernel_dhu_blockdim64 kernel, the b_dhX variables are multiplied by exp(b_gv_lastX) unconditionally if USE_GV is true. However, b_gv_lastX (for X=2,3,4) are loaded only if V is greater than a certain threshold (e.g., V > 64). If V is smaller, these b_gv_lastX variables might contain uninitialized or garbage values, leading to incorrect calculations. Each multiplication should be guarded by the corresponding if V > ... condition.
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | |
| if V > 64: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 128: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 192: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 64 and USE_GV: | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 128 and USE_GV: | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 192 and USE_GV: | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4)[None, :] | |
| b_do *= exp(b_gv) |
There was a problem hiding this comment.
Actionable comments posted: 10
🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Line 93: Remove the unused BV variable assignment (BV=64) in chunk_h.py:
delete the BV definition since the kernel uses a hardcoded 64 and BV is never
referenced elsewhere; ensure no other code in the module refers to BV and run
lint/tests to confirm no remaining references (look for the symbol BV in
chunk_h.py to locate the line to remove).
- Around line 148-163: The conditional blocks that apply gv scaling incorrectly
compare against K instead of V; change all occurrences of "if K > 64/128/192" to
"if V > 64/128/192" and update the final load mask from "(o_v4 < K)" to "(o_v4 <
V)" so the gv loads and masks (e.g., in the blocks computing
b_gk_last1..b_gk_last4 and multiplying b_h1..b_h4) correctly use the
value-dimension V rather than the key-dimension K.
- Around line 747-767: The code uses b_dv unconditionally but only defines it
inside the if USE_GV branch; to fix, ensure b_dv is always assigned: when USE_GV
is True compute b_dv = b_dvg * exp(b_gn[None, :] - b_gv) as before, otherwise
initialize b_dv to a zero tensor with the same shape and dtype used later (shape
[BT, BV] matching b_v and p_dv.element_ty) so subsequent operations (b_dgv_last
update, tl.store(p_dv, ...), and interaction with b_v) work correctly; update
the block so b_dv, b_dvg, b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the
referenced symbols.
- Around line 396-403: The code unconditionally creates and loads p_gv and b_gv
(using gv, p_gv, b_gv) inside the V>0 handling even when gv may be None if
USE_GV is False; wrap the creation of p_gv and any tl.load(gv + ...) or
tl.load(p_gv, ...) calls with a guard on USE_GV (same pattern used elsewhere
when gv is offset at line 319) so that all accesses to gv happen only when
USE_GV is True, and apply the same guard pattern to the other V>64, V>128, and
V>192 blocks to prevent null pointer/runtime loads when gv is not provided.
- Around line 531-651: The function chunk_gsa_bwd_k_kernel_dqkvg defined in this
file is dead/duplicated and should be removed: delete the entire
chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 8-15: Remove the redundant and unused imports: delete the import
of exp from fla.ops.utils.op and the import of chunk_local_cumsum from
fla.ops.utils.cumsum, keeping the intended tl.exp assignment (exp = tl.exp) as
the single definition of exp; ensure no other code depends on
fla.ops.utils.op.exp or chunk_local_cumsum in this file (references to exp
should use the tl-backed exp symbol).
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 1-2: The file header lines have duplicated comment markers ("#
#"), so remove the extra '#' characters in those header comments: change the
leading "# # -*- coding: utf-8 -*-" to "# -*- coding: utf-8 -*-" and similarly
change "# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to "# Copyright (c)
2023-2025, Songlin Yang, Yu Zhang" to restore proper comment syntax.
In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 95-97: The load of gv must apply the mask for partial vector
blocks to avoid OOB reads: when USE_GV is true, change the load of p_gv (symbol
b_gv) to use mask_v (the mask for the last V block) instead of an unconditional
tl.load; keep the subsequent scaling of b_h (symbol b_h *= exp(b_gv[None, :]))
the same so that only valid lanes are loaded and used when V % BV != 0.
In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 199-237: Update the return type annotation of recompute_w_u_fwd to
match the actual returned values (w, u, vg): change the declared return from
Tuple[torch.Tensor, torch.Tensor] to Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] and import Optional if not already present; ensure the
function signature and any callers/types align with the new signature for
recompute_w_u_fwd.
- Around line 247-261: The gv parameter is declared without Optional typing and
the code unconditionally allocates dgv with torch.empty_like(gv), which will
crash if gv is None; update the function signature to annotate gv as
Optional[torch.Tensor] and change the local dgv to be Optional[torch.Tensor] (or
torch.Tensor | None) and only allocate dgv when gv is not None (e.g., after
checking gv) — leave dgv as None otherwise; ensure any later uses of dgv handle
the None case or assert/raise if those code paths require gv to be present.
🧹 Nitpick comments (12)
fla/ops/gated_oja_rule/chunk_kkt.py (1)
131-143: Inconsistent naming convention for block pointer.The variable
b_ktat line 131 is a block pointer (created viatl.make_block_ptr), but follows theb_prefix convention used for block tensors throughout this file. Consider renaming top_ktfor consistency with other pointers (p_k,p_g, etc.).♻️ Suggested naming fix
- b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))And at line 143:
- b_kt = tl.load(b_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk) + b_kt = tl.load(p_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)fla/ops/gated_oja_rule/wy_fast.py (2)
11-11: Remove unused import.The static analysis correctly identifies that
chunk_local_cumsumis imported but not used in this file.♻️ Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices +from fla.ops.utils import prepare_chunk_indices
193-193: Remove or complete the commented-out code.Line 193 contains a commented-out conditional
# if USE_GV:. This appears to be either dead code or an incomplete TODO. Please remove it or implement the intended logic.fla/ops/gated_oja_rule/chunk.py (1)
283-287: Addstackleveltowarnings.warn.Per best practices, specify
stacklevel=2so the warning points to the caller's location rather than this line.♻️ Proposed fix
if 'head_first' in kwargs: warnings.warn( "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." + "Please use head_first=False for now instead.", + stacklevel=2 )fla/ops/gated_oja_rule/fused_recurrent.py (1)
133-133: Document or relax the V <= 128 constraint.The assertion
assert V <= 128limits the value dimension without explanation. Consider adding a comment explaining why this limit exists, or raising a more informative error.♻️ Suggested improvement
- assert V <= 128 + if V > 128: + raise ValueError( + f"fused_recurrent_oja_fwd currently supports V <= 128, got V={V}. " + "Use chunk_gated_oja_rule for larger value dimensions." + )tests/ops/test_oja.py (4)
4-4: Remove unused imports.
Optionalfromtypingandrepeatfromeinopsare imported but never used.♻️ Proposed fix
-from typing import List, Optional +from typing import List-from einops import rearrange, repeat +from einops import rearrange
82-82: Rename ambiguous variablel.The variable
lat line 82 is flagged by linters as ambiguous (looks like1). Consider renaming toseq_lenorLfor clarity.♻️ Proposed fix
- b, h, l, d_k = q.shape + b, h, seq_len, d_k = q.shape d_v = v.shape[-1] q = q * scale # B H T D - assert l % chunk_size == 0 + assert seq_len % chunk_size == 0And update other usages of
l(lines 85, 121) toseq_len.
341-341: Remove debug print statement.The
♻️ Proposed fix
- print('================== Running forward and backward ==================')
412-412: Consider isolating environment variable modification.Setting
os.environ['TRITON_F32_DEFAULT']at line 412 persists beyond this test and may affect subsequent tests. Consider using a fixture or context manager to ensure cleanup.♻️ Suggested approach
`@pytest.fixture`(autouse=True) def set_triton_f32_default(): old_value = os.environ.get('TRITON_F32_DEFAULT') os.environ['TRITON_F32_DEFAULT'] = 'ieee' yield if old_value is None: del os.environ['TRITON_F32_DEFAULT'] else: os.environ['TRITON_F32_DEFAULT'] = old_valuefla/ops/gated_oja_rule/chunk_h.py (2)
480-484: Unusedchunk_indicescomputation.
chunk_indicesis computed but never passed to the kernel. Consider removing this unnecessary computation.Proposed fix
- chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) + N = len(cu_seqlens) - 1 + NT = chunk_offsets[-1].item() # or compute directly
386-386: Use ASCII commas in comments for consistency.The comment contains fullwidth commas (,) which triggers linter warnings. Consider using standard ASCII commas or translating comments to English.
fla/ops/gated_oja_rule/chunk_o.py (1)
453-461: Redundant computation of attention matrixA.The attention matrix
Ais computed identically in eachi_kblock and then summed (line 540), which is wasteful. SinceA = dot(q*scale, k.T)is the same regardless of which K-block is being processed, this results inNKredundant computations.Consider computing
Aonce in a separate kernel or only in the first K-block.
| if USE_GV: | ||
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | ||
|
|
||
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
|
|
||
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | ||
|
|
||
| # 留给GSA2的接口 | ||
| if HAVE_GK: | ||
| dgk += (bos * H + i_h) * V | ||
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | ||
| b_dgv_last = b_dgk + b_dgv_last[None, :] | ||
| else: | ||
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | ||
|
|
||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
Critical bug: b_dv undefined when USE_GV is False.
When USE_GV is False, b_dv is never assigned (line 748 is inside if USE_GV), but it's used unconditionally at lines 756 and 767. This will cause a runtime error.
Proposed fix
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
-
+ else:
+ b_dv = b_dvg
+
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_v = tl.load(p_v, boundary_check=(0, 1)) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| # 留给GSA2的接口 | |
| if HAVE_GK: | |
| dgk += (bos * H + i_h) * V | |
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | |
| b_dgv_last = b_dgk + b_dgv_last[None, :] | |
| else: | |
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | |
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | |
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| else: | |
| b_dv = b_dvg | |
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_v = tl.load(p_v, boundary_check=(0, 1)) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| # 留给GSA2的接口 | |
| if HAVE_GK: | |
| dgk += (bos * H + i_h) * V | |
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | |
| b_dgv_last = b_dgk + b_dgv_last[None, :] | |
| else: | |
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | |
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 747 - 767, The code uses b_dv
unconditionally but only defines it inside the if USE_GV branch; to fix, ensure
b_dv is always assigned: when USE_GV is True compute b_dv = b_dvg *
exp(b_gn[None, :] - b_gv) as before, otherwise initialize b_dv to a zero tensor
with the same shape and dtype used later (shape [BT, BV] matching b_v and
p_dv.element_ty) so subsequent operations (b_dgv_last update, tl.store(p_dv,
...), and interaction with b_v) work correctly; update the block so b_dv, b_dvg,
b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the referenced symbols.
| from fla.ops.utils.op import exp | ||
| from fla.utils import check_shared_mem, is_nvidia_hopper | ||
| from fla.ops.utils.cumsum import chunk_local_cumsum | ||
|
|
||
| BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] | ||
| NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] | ||
|
|
||
| exp = tl.exp |
There was a problem hiding this comment.
Remove redundant import and duplicate exp definition.
exp is imported from fla.ops.utils.op at line 8 but immediately overwritten with tl.exp at line 15. Also, chunk_local_cumsum is imported but never used. Remove the unused imports.
Proposed fix
from fla.ops.utils import prepare_chunk_indices
-from fla.ops.utils.op import exp
from fla.utils import check_shared_mem, is_nvidia_hopper
-from fla.ops.utils.cumsum import chunk_local_cumsum
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
exp = tl.exp🧰 Tools
🪛 Flake8 (7.3.0)
[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused
(F401)
[error] 15-15: redefinition of unused 'exp' from line 8
(F811)
🪛 GitHub Actions: lint
[error] 13-13: Ruff: F811 Redefinition of unused 'exp' from line 7.
🪛 Ruff (0.14.14)
15-15: Redefinition of unused exp from line 8: exp redefined here
(F811)
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 8 - 15, Remove the redundant
and unused imports: delete the import of exp from fla.ops.utils.op and the
import of chunk_local_cumsum from fla.ops.utils.cumsum, keeping the intended
tl.exp assignment (exp = tl.exp) as the single definition of exp; ensure no
other code depends on fla.ops.utils.op.exp or chunk_local_cumsum in this file
(references to exp should use the tl-backed exp symbol).
| # # -*- coding: utf-8 -*- | ||
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
There was a problem hiding this comment.
Fix comment syntax.
Line 1 has doubled comment markers # # which appears to be a typo.
🐛 Proposed fix
-# # -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # # -*- coding: utf-8 -*- | |
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
| # -*- coding: utf-8 -*- | |
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: Trailing whitespace detected by pre-commit; file was modified.
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 1 - 2, The file header lines
have duplicated comment markers ("# #"), so remove the extra '#' characters in
those header comments: change the leading "# # -*- coding: utf-8 -*-" to "# -*-
coding: utf-8 -*-" and similarly change "# # Copyright (c) 2023-2025, Songlin
Yang, Yu Zhang" to "# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to
restore proper comment syntax.
| gv: torch.Tensor = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = 64 | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| CONST_TILING = 64 if check_shared_mem() else 32 | ||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | ||
|
|
||
| dk = torch.empty_like(k) | ||
| dv = torch.empty_like(v, dtype=torch.float) | ||
|
|
||
| dgv = torch.empty_like(gv, dtype=torch.float) |
There was a problem hiding this comment.
Add type hint and guard for gv parameter.
The gv parameter at line 247 lacks a proper type hint (defaults to None but no Optional annotation). Additionally, dgv is unconditionally allocated at line 261 using torch.empty_like(gv), which will fail if gv is None.
🐛 Proposed fix
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
- gv: torch.Tensor = None,
+ gv: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if gv is None:
+ raise ValueError("gv is required for prepare_wy_repr_bwd")🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 247 - 261, The gv parameter
is declared without Optional typing and the code unconditionally allocates dgv
with torch.empty_like(gv), which will crash if gv is None; update the function
signature to annotate gv as Optional[torch.Tensor] and change the local dgv to
be Optional[torch.Tensor] (or torch.Tensor | None) and only allocate dgv when gv
is not None (e.g., after checking gv) — leave dgv as None otherwise; ensure any
later uses of dgv handle the None case or assert/raise if those code paths
require gv to be present.
|
Cool work @AwesomeSeq! Have you trained any 340M/1.3B models with this recurrence yet? If so, is there a paper I can refer to? |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 194-238: The return type annotation of chunk_oja_fwd_h is wrong:
the function actually returns three values (h, k_new, final_state) where k_new
and final_state can be None depending on save_new_key and output_final_state;
update the function signature's return type to reflect three elements (e.g.
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]) and ensure any
callers or tests expecting the old two-tuple are adjusted accordingly; reference
the chunk_oja_fwd_h definition and the variables h, k_new, final_state in your
change.
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 178-241: The function chunk_oja_fwd_o has a return type mismatch:
its annotation declares four tensors but the implementation returns only A and
o; update the function signature's return annotation to match the actual return
(tuple[torch.Tensor, torch.Tensor]) or modify the body to return the additional
tensors if intended; locate chunk_oja_fwd_o and change the annotated return type
to reflect only A and o (or add the missing tensors to the return) and ensure
callers expect the corrected shape.
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 287-297: The error messages use adjacent f-strings that get
implicitly concatenated without a separating space; update the messages in the
cu_seqlens check (where variables q, cu_seqlens, and initial_state are
referenced) to ensure proper spacing — either merge the two f-strings into one
or insert an explicit leading/trailing space or punctuation between them so the
resulting strings read correctly (do the same fix in the analogous checks in
fused_recurrent.py around the initial_state/cu_seqlens validation).
In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 61-79: The kernel unconditionally loads from gv (e.g.,
tl.load(p_gv) and tl.load(gv + ...)) which will crash if gv is None; update the
kernels (recompute_w_u_fwd_kernel and prepare_wy_repr_bwd_kernel) to either
(preferred) add a compile-time/use-time guard like a boolean USE_GV and wrap all
gv loads and STORE_VG-dependent logic (the p_gv/tl.load uses and computing
b_vb/b_vg) behind if USE_GV so the code never dereferences gv when absent, or
alternatively make gv a required parameter in the Python wrappers so callers
cannot pass None; ensure referenced symbols include gv, p_gv, b_gv, b_gn,
STORE_VG, and vg when applying the guard so no tl.load or tl.store touches gv/
vg unless USE_GV is true.
In `@tests/ops/test_oja.py`:
- Around line 404-407: Fix two issues: update the skip reason text and avoid
mutating global env. Change the pytest.skip call (condition using
is_intel_alchemist and D) to use the correct message 'chunk_gated_oja_rule'
instead of 'chunk_gated_delta_rule'; and replace the direct
os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
🧹 Nitpick comments (2)
fla/ops/gated_oja_rule/fused_recurrent.py (1)
117-168: Use explicitT | Nonefor optional parameters.Several parameters use implicit
Optional(PEP 484 violation):scaleat line 123,initial_stateat line 124. This also applies toFusedRecurrentFunction.forward(line 182) andfused_recurrent_gated_oja_rule(lines 221-222).Proposed fix (for the wrapper)
- scale: float = None, - initial_state: torch.Tensor = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None,tests/ops/test_oja.py (1)
337-337: Remove leftover debugLine 337 contains a
print(...)that shouldn't be in committed test code. Also, there are leftover# breakpoint()comments at lines 37 and 369.Proposed fix
- print('================== Running forward and backward ==================')
| if cu_seqlens is not None: | ||
| if q.shape[0] != 1: | ||
| raise ValueError( | ||
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." | ||
| f"Please flatten variable-length inputs before processing." | ||
| ) | ||
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | ||
| raise ValueError( | ||
| f"The number of initial states is expected to be equal to the number of input sequences, " | ||
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | ||
| ) |
There was a problem hiding this comment.
Missing space between concatenated f-strings in error messages.
The two adjacent f-strings at lines 290–291 and 295–296 are implicitly concatenated without a separator, producing messages like "...cu_seqlens.Please flatten...". The same issue exists in fused_recurrent.py` lines 237–239.
🐛 Proposed fix
if q.shape[0] != 1:
raise ValueError(
- f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
- f"The number of initial states is expected to be equal to the number of input sequences, "
+ f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if cu_seqlens is not None: | |
| if q.shape[0] != 1: | |
| raise ValueError( | |
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." | |
| f"Please flatten variable-length inputs before processing." | |
| ) | |
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | |
| raise ValueError( | |
| f"The number of initial states is expected to be equal to the number of input sequences, " | |
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | |
| ) | |
| if cu_seqlens is not None: | |
| if q.shape[0] != 1: | |
| raise ValueError( | |
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. " | |
| f"Please flatten variable-length inputs before processing." | |
| ) | |
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | |
| raise ValueError( | |
| f"The number of initial states is expected to be equal to the number of input sequences, " | |
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | |
| ) |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 289-292: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 294-297: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 287 - 297, The error messages
use adjacent f-strings that get implicitly concatenated without a separating
space; update the messages in the cu_seqlens check (where variables q,
cu_seqlens, and initial_state are referenced) to ensure proper spacing — either
merge the two f-strings into one or insert an explicit leading/trailing space or
punctuation between them so the resulting strings read correctly (do the same
fix in the analogous checks in fused_recurrent.py around the
initial_state/cu_seqlens validation).
| for i_v in range(tl.cdiv(V, BV)): | ||
| p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_w = tl.make_block_ptr(w + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
| b_vb = b_v * b_b[:, None] | ||
|
|
||
| p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| b_vb *= exp(b_gv) | ||
| if STORE_VG: | ||
| last_idx = min(i_t * BT + BT, T) - 1 | ||
|
|
||
| o_v = i_v * BV + tl.arange(0, BV) | ||
| m_v = o_v < V | ||
| b_gn = tl.load(gv + ((bos + last_idx) * H + i_h) * V + o_v, mask=m_v, other=0.) | ||
| b_vg = b_v * exp(b_gn - b_gv) | ||
|
|
||
| p_vg = tl.make_block_ptr(vg + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| tl.store(p_vg, b_vg.to(p_vg.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
gv is loaded unconditionally in the kernel — will crash if gv is None.
Both recompute_w_u_fwd_kernel (line 67) and prepare_wy_repr_bwd_kernel (line 152) load from gv without guarding on whether gv is actually provided. Although callers in chunk.py always pass a valid gv, the Python wrapper signatures allow gv=None. Either add a USE_GV heuristic guard in the kernels or make gv a required parameter in the wrappers to prevent a latent null-pointer crash.
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 61 - 79, The kernel
unconditionally loads from gv (e.g., tl.load(p_gv) and tl.load(gv + ...)) which
will crash if gv is None; update the kernels (recompute_w_u_fwd_kernel and
prepare_wy_repr_bwd_kernel) to either (preferred) add a compile-time/use-time
guard like a boolean USE_GV and wrap all gv loads and STORE_VG-dependent logic
(the p_gv/tl.load uses and computing b_vb/b_vg) behind if USE_GV so the code
never dereferences gv when absent, or alternatively make gv a required parameter
in the Python wrappers so callers cannot pass None; ensure referenced symbols
include gv, p_gv, b_gv, b_gn, STORE_VG, and vg when applying the guard so no
tl.load or tl.store touches gv/ vg unless USE_GV is true.
| if is_intel_alchemist and D > 128: | ||
| pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') | ||
| torch.manual_seed(42) | ||
| os.environ['TRITON_F32_DEFAULT'] = 'ieee' |
There was a problem hiding this comment.
Two issues: incorrect skip message and global env var side-effect.
- Line 405: The skip reason says
chunk_gated_delta_rulebut should saychunk_gated_oja_rule. - Line 407: Setting
os.environ['TRITON_F32_DEFAULT'] = 'ieee'modifies global process state, which can leak into other tests. Consider scoping this withmonkeypatchor restoring the original value in a fixture.
Proposed fix for the skip message
- pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
+ pytest.skip(reason='chunk_gated_oja_rule is not supported on alchemist for D>128')🤖 Prompt for AI Agents
In `@tests/ops/test_oja.py` around lines 404 - 407, Fix two issues: update the
skip reason text and avoid mutating global env. Change the pytest.skip call
(condition using is_intel_alchemist and D) to use the correct message
'chunk_gated_oja_rule' instead of 'chunk_gated_delta_rule'; and replace the
direct os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
|
Please refrain from merging just yet, waiting for further review and full testing |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
fla/ops/gated_oja_rule/chunk.py (1)
279-282:warnings.warnmissingstacklevel=2.Without
stacklevel=2, the warning points to this internal call site rather than the caller's code, making the deprecation message less useful in tracebacks.♻️ Proposed fix
warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", + stacklevel=2, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_oja_rule/chunk.py` around lines 279 - 282, The deprecation warning emitted in chunk.py currently calls warnings.warn without a stacklevel, so update the warnings.warn invocation in the function/method where "head_first is deprecated..." is emitted (the warnings.warn call in chunk.py) to include stacklevel=2; this will make the traceback point to the caller rather than the internal site—add the stacklevel=2 kwarg to that warnings.warn call while preserving the existing message and warning category.fla/ops/gated_oja_rule/chunk_o.py (1)
17-25: Autotunekey=['BT']is too narrow — configs won't re-tune across differentK/Vsizes.
chunk_oja_fwd_inter,chunk_oja_bwd_kernel_dA, andchunk_oja_bwd_kernel_dqkall usekey=['BT']. This means the autotunedBK/BVchoice (and warp/stage config) is shared across all calls with the sameBT, regardless of the actualKandVdimensions. For example,BK=64tuned atK=128is reused atK=60, even thoughBK=32might be optimal there. Other kernels in this PR (e.g.,recompute_w_u_fwd_kernel) correctly include'K'and'V'in the autotune key. Consider expanding the key:♻️ Proposed fix (representative for chunk_oja_fwd_inter)
- key=['BT'] + key=['H', 'K', 'V', 'BT', 'IS_VARLEN']Also applies to: 247-252, 386-392
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 17 - 25, The autotune key is too narrow: update the triton.autotune decorators for chunk_oja_fwd_inter, chunk_oja_bwd_kernel_dA, and chunk_oja_bwd_kernel_dqk (and the other occurrences noted) to include the input dimension identifiers so tuning varies with K and V; specifically add 'K' and 'V' to the key list (e.g., key=['BT','K','V'] or similar) so BK/BV and warp/stage choices are re-tuned when K or V change.tests/ops/test_oja.py (1)
37-37: Remove leftover# breakpoint()debug comments.Lines 37 and 369 contain
# breakpoint()debug artifacts that should be removed before merge.Also applies to: 369-369
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_oja.py` at line 37, Remove the leftover debug comments consisting of the literal "# breakpoint()" (two occurrences) from the test file; locate the occurrences of "# breakpoint()" in test_oja.py and delete those comment lines so no debug artifacts remain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 143-167: The kernel call unconditionally evaluates beta.ndim in
fused_recurrent_oja_fwd_kernel which crashes when beta is None; before invoking
fused_recurrent_oja_fwd_kernel compute the boolean for IS_BETA_HEADWISE safely
(e.g. set is_beta_headwise = (beta is not None) and (beta.ndim != v.ndim) or
mirror the guard used in fused_recurrent_gated_oja_rule) and pass that variable
to the kernel instead of evaluating beta.ndim inline.
- Line 31: The kernel parameter B is declared as a constexpr in
fused_recurrent_oja_fwd_kernel but never used and is also passed from Python
(B=B); remove the unused parameter declaration from the
fused_recurrent_oja_fwd_kernel signature (delete the "B: tl.constexpr" entry)
and remove the corresponding B=B argument at the Python call site that invokes
fused_recurrent_oja_fwd_kernel; ensure any related references/comments or tests
that expect that parameter are updated accordingly.
In `@tests/ops/test_oja.py`:
- Line 78: Rename the ambiguous variable l from the unpacking b, h, l, d_k =
q.shape to a clearer name like seq_len (e.g., b, h, seq_len, d_k = q.shape) and
update every subsequent reference to l in this test (all uses of l at and after
the q.shape line) to seq_len so variable meaning is clear and avoids confusion
with 1/I; ensure any related asserts or shape computations (the references
originally at lines 81, 117, 131–133) are changed consistently.
---
Duplicate comments:
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 731-753: The code uses b_dv unconditionally but only assigns it
inside the USE_GV branch, so when USE_GV is false b_dv is undefined; fix by
adding an else branch alongside the existing if USE_GV that initializes b_dv to
a zero tensor matching the expected shape/type (same shape/dtype as b_dvg / the
block [BT, BV]) before it is used in tl.sum and stored via p_dv, ensuring
b_dv.to(...) and subsequent arithmetic with b_v remain valid; update the same
pattern if any dependent variables (e.g., b_dv usage in tl.store(p_dv, ...))
expect a specific dtype.
- Around line 520-639: The function chunk_gsa_bwd_k_kernel_dqkvg is a dead
duplicate (copied from gsa/chunk.py) and should be removed from this module to
avoid duplicate maintenance; delete the entire chunk_gsa_bwd_k_kernel_dqkvg
definition from this file, verify there are no remaining references to
chunk_gsa_bwd_k_kernel_dqkvg in the module (and remove any imports or helper
symbols that become unused as a result), and run tests/linters to ensure no
unintended breakage.
- Around line 384-440: The kernel chunk_oja_bwd_kernel_dhu_blockdim64 currently
constructs p_gv and calls tl.load into b_gv unconditionally in each V-block
(e.g. the p_gv/b_gv in the V>0, V>64, V>128, V>192 branches) which dereferences
gv when USE_GV is False; to fix, move the p_gv = tl.make_block_ptr(...) and b_gv
= tl.load(...) lines inside the corresponding if USE_GV: block for each branch
so gv is only accessed when USE_GV is True, keeping the surrounding b_do/b_dh
updates unchanged and ensuring any use of b_gv (e.g. b_do *= exp(b_gv)) remains
inside that guard.
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 289-292: The ValueError message uses two adjacent f-strings that
get implicitly concatenated without a space, producing a malformed message;
update the raises in gated_oja_rule.chunk (the ValueError that mentions
q.shape[0] and `cu_seqlens`, and the analogous one referencing q.shape[0] at the
second occurrence) to use a single f-string or explicitly include a
space/newline between the parts so the final message reads correctly (e.g.,
combine into one f-string: f"...when using `cu_seqlens`. Please flatten
variable-length inputs..." or add " " between the two f-strings).
In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 244-258: The parameter gv is annotated/defined with a default None
but the function (and the kernel that loads gv) uses it unconditionally (e.g.,
torch.empty_like(gv, dtype=torch.float)), so change gv to be required (remove
the default None/Optional typing) and update its type to plain torch.Tensor;
then remove any conditional handling around gv in this function (e.g., the
torch.empty_like(gv, dtype=torch.float) and the kernel references that assume gv
exists will be valid) and update any call sites to always pass a gv tensor.
- Around line 61-79: The kernel unconditionally dereferences gv (p_gv/b_gv and
b_vb *= exp(b_gv)) even when callers pass gv=None; either introduce a boolean
heuristic USE_GV (parallel to STORE_VG) and wrap all gv accesses (construction
of p_gv, tl.load of b_gv, and any uses like b_vb *= exp(b_gv) and the STORE_VG
branch) behind if USE_GV, or make gv a required parameter by removing the None
default in recompute_w_u_fwd and callers so gv is always non-null; update
function signature and call sites if you choose the required-parameter route.
In `@tests/ops/test_oja.py`:
- Around line 404-407: The test mutates global env var
os.environ['TRITON_F32_DEFAULT'] without restoring it, leaking into other tests;
change the test (e.g., inside test_chunk_varlen or the surrounding test in
tests/ops/test_oja.py) to set the env value in a test-scoped way by using the
pytest monkeypatch fixture (monkeypatch.setenv('TRITON_F32_DEFAULT', 'ieee')) or
by saving the original value and restoring it after the test completes, ensuring
the modification is confined to this test only.
---
Nitpick comments:
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 17-25: The autotune key is too narrow: update the triton.autotune
decorators for chunk_oja_fwd_inter, chunk_oja_bwd_kernel_dA, and
chunk_oja_bwd_kernel_dqk (and the other occurrences noted) to include the input
dimension identifiers so tuning varies with K and V; specifically add 'K' and
'V' to the key list (e.g., key=['BT','K','V'] or similar) so BK/BV and
warp/stage choices are re-tuned when K or V change.
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 279-282: The deprecation warning emitted in chunk.py currently
calls warnings.warn without a stacklevel, so update the warnings.warn invocation
in the function/method where "head_first is deprecated..." is emitted (the
warnings.warn call in chunk.py) to include stacklevel=2; this will make the
traceback point to the caller rather than the internal site—add the stacklevel=2
kwarg to that warnings.warn call while preserving the existing message and
warning category.
In `@tests/ops/test_oja.py`:
- Line 37: Remove the leftover debug comments consisting of the literal "#
breakpoint()" (two occurrences) from the test file; locate the occurrences of "#
breakpoint()" in test_oja.py and delete those comment lines so no debug
artifacts remain.
| cu_seqlens, | ||
| scale, | ||
| T, | ||
| B: tl.constexpr, |
There was a problem hiding this comment.
Unused B: tl.constexpr in fused_recurrent_oja_fwd_kernel.
B is declared as a kernel constexpr at line 31 but never referenced in the kernel body. It is also passed at the Python call site (line 155: B=B). Ruff (ARG001) flags this. Remove it from both the kernel signature and the call site.
🐛 Proposed fix
- B: tl.constexpr,- B=B,🧰 Tools
🪛 Ruff (0.15.1)
[warning] 31-31: Unused function argument: B
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_oja_rule/fused_recurrent.py` at line 31, The kernel parameter B
is declared as a constexpr in fused_recurrent_oja_fwd_kernel but never used and
is also passed from Python (B=B); remove the unused parameter declaration from
the fused_recurrent_oja_fwd_kernel signature (delete the "B: tl.constexpr"
entry) and remove the corresponding B=B argument at the Python call site that
invokes fused_recurrent_oja_fwd_kernel; ensure any related references/comments
or tests that expect that parameter are updated accordingly.
| fused_recurrent_oja_fwd_kernel[grid]( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| gv=gv, | ||
| beta=beta, | ||
| o=o, | ||
| h0=initial_state, | ||
| ht=final_state, | ||
| cu_seqlens=cu_seqlens, | ||
| scale=scale, | ||
| T=T, | ||
| B=B, | ||
| H=H, | ||
| HV=HV, | ||
| K=K, | ||
| V=V, | ||
| BK=BK, | ||
| BV=BV, | ||
| IS_BETA_HEADWISE=beta.ndim != v.ndim, | ||
| USE_Q_L2NORM=use_q_l2norm, | ||
| USE_K_L2NORM=use_k_l2norm, | ||
| num_warps=num_warps, | ||
| num_stages=num_stages, | ||
| ) |
There was a problem hiding this comment.
fused_recurrent_oja_fwd crashes with AttributeError when beta=None.
Line 162 evaluates beta.ndim != v.ndim unconditionally as part of the kernel-call keyword arguments. Since the function signature declares beta: torch.Tensor | None = None, a direct caller passing beta=None will get AttributeError: 'NoneType' object has no attribute 'ndim' before the kernel is invoked. Guard the expression or fill the default the same way fused_recurrent_gated_oja_rule does:
🐛 Proposed fix
+ if beta is None:
+ beta = torch.ones_like(q[..., 0])
fused_recurrent_oja_fwd_kernel[grid](
...
IS_BETA_HEADWISE=beta.ndim != v.ndim,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| fused_recurrent_oja_fwd_kernel[grid]( | |
| q=q, | |
| k=k, | |
| v=v, | |
| gv=gv, | |
| beta=beta, | |
| o=o, | |
| h0=initial_state, | |
| ht=final_state, | |
| cu_seqlens=cu_seqlens, | |
| scale=scale, | |
| T=T, | |
| B=B, | |
| H=H, | |
| HV=HV, | |
| K=K, | |
| V=V, | |
| BK=BK, | |
| BV=BV, | |
| IS_BETA_HEADWISE=beta.ndim != v.ndim, | |
| USE_Q_L2NORM=use_q_l2norm, | |
| USE_K_L2NORM=use_k_l2norm, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| if beta is None: | |
| beta = torch.ones_like(q[..., 0]) | |
| fused_recurrent_oja_fwd_kernel[grid]( | |
| q=q, | |
| k=k, | |
| v=v, | |
| gv=gv, | |
| beta=beta, | |
| o=o, | |
| h0=initial_state, | |
| ht=final_state, | |
| cu_seqlens=cu_seqlens, | |
| scale=scale, | |
| T=T, | |
| B=B, | |
| H=H, | |
| HV=HV, | |
| K=K, | |
| V=V, | |
| BK=BK, | |
| BV=BV, | |
| IS_BETA_HEADWISE=beta.ndim != v.ndim, | |
| USE_Q_L2NORM=use_q_l2norm, | |
| USE_K_L2NORM=use_k_l2norm, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_oja_rule/fused_recurrent.py` around lines 143 - 167, The kernel
call unconditionally evaluates beta.ndim in fused_recurrent_oja_fwd_kernel which
crashes when beta is None; before invoking fused_recurrent_oja_fwd_kernel
compute the boolean for IS_BETA_HEADWISE safely (e.g. set is_beta_headwise =
(beta is not None) and (beta.ndim != v.ndim) or mirror the guard used in
fused_recurrent_gated_oja_rule) and pass that variable to the kernel instead of
evaluating beta.ndim inline.
| g = F.pad(g, (0, 0, 0, pad_len)) | ||
| q, k, v, beta, g = map(lambda x: x.to(torch.float32), [q, k, v, beta, g]) | ||
| chunk_size = BT | ||
| b, h, l, d_k = q.shape |
There was a problem hiding this comment.
Ambiguous variable name l (Ruff E741).
b, h, l, d_k = q.shape — lowercase l is easily confused with 1 or I. Rename to seq_len or total_len.
🐛 Proposed fix
- b, h, l, d_k = q.shape
+ b, h, seq_len, d_k = q.shapeUpdate all subsequent references to l (lines 81, 117, 131–133) accordingly.
🧰 Tools
🪛 Ruff (0.15.1)
[error] 78-78: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ops/test_oja.py` at line 78, Rename the ambiguous variable l from the
unpacking b, h, l, d_k = q.shape to a clearer name like seq_len (e.g., b, h,
seq_len, d_k = q.shape) and update every subsequent reference to l in this test
(all uses of l at and after the q.shape line) to seq_len so variable meaning is
clear and avoids confusion with 1/I; ensure any related asserts or shape
computations (the references originally at lines 81, 117, 131–133) are changed
consistently.
Thanks to @AwesomeSeq for bringing us another great possibility. This leads to the same question: Would Oja's rule be a better alternative to the Delta rule? To others who want to learn Oja's rule: BackgroundErkki Oja proposed the idea of naturally integrating a forgetting/constraint mechanism into the local update rule. Thus, he introduced the following formula: Simple IntegrationFrom the perspective of online learning, the output is actually With Gate (fast decay with gate + slow decay with oja) |
from hujiaxi@moonshot.cn
Summary by CodeRabbit
New Features
Public API
Tests