Skip to content

Upload Gated OJA Operator#725

Closed
AwesomeSeq wants to merge 2 commits intofla-org:mainfrom
AwesomeSeq:main
Closed

Upload Gated OJA Operator#725
AwesomeSeq wants to merge 2 commits intofla-org:mainfrom
AwesomeSeq:main

Conversation

@AwesomeSeq
Copy link
Copy Markdown
Contributor

@AwesomeSeq AwesomeSeq commented Jan 22, 2026

Summary by CodeRabbit

Release Notes

  • New Features
    • Added OJA2 attention operations with chunked and recurrent variants, including support for gradient computation and variable-length sequences.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 22, 2026

Walkthrough

This PR introduces a complete chunked OJA2 (contrastive Hebbian-like) recurrent attention implementation across seven files. It provides forward and backward passes via Triton kernels, PyTorch autograd integration, optional L2 normalization, and support for variable-length sequences.

Changes

Cohort / File(s) Summary
Package Exports
fla/ops/oja2/__init__.py
Exposes chunk_oja2 and fused_recurrent_oja2 as public API entry points via __all__ declaration.
High-Level Orchestration
fla/ops/oja2/chunk.py
Implements forward (chunk_oja2_fwd) and backward (chunk_oja2_bwd) passes with PyTorch autograd integration. Includes ChunkOJA2Function autograd wrapper and chunk_oja2 public function handling scale defaults, L2 norm options, cu_seqlens validation, and deprecation warnings.
Core Kernel Implementations
fla/ops/oja2/chunk_h.py, fla/ops/oja2/chunk_kkt.py, fla/ops/oja2/chunk_o.py
Three complementary Triton kernel modules: chunk_h.py for recurrent hidden state updates with optional GV gating and state handling; chunk_kkt.py for scaled dot product attention-like (K×K^T) computation with optional variable-length sequence support; chunk_o.py for output generation and dA computation via forward/backward passes.
Fused Recurrent Path
fla/ops/oja2/fused_recurrent.py
Implements single-pass fused recurrent forward pathway with FusedRecurrentFunction autograd wrapper (backward unimplemented). Includes Triton kernel with support for variable-length sequences, optional initial/final state, per-head vs per-tensor beta broadcasting, and row-wise L2 normalization.
Wy Representation Helpers
fla/ops/oja2/wy_fast.py
Provides recompute_w_u_fwd and prepare_wy_repr_bwd kernels for W/U representation recomputation and backward gradient preparation with optional GV and variable-length sequence support.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant Chunk as chunk_oja2<br/>(High-Level)
    participant ChunkKKT as chunk_scaled_dot_kkt_fwd<br/>(KKT Kernel)
    participant ChunkH as chunk_oja2_fwd_h<br/>(H Kernel)
    participant ChunkO as chunk_oja2_fwd_o<br/>(O Kernel)
    participant Output as Autograd Output

    User->>Chunk: q, k, v, gv, beta, scale<br/>initial_state, cu_seqlens
    Chunk->>Chunk: Validate inputs & defaults
    Chunk->>ChunkKKT: Compute A = β·K·K^T<br/>with optional GK
    ChunkKKT-->>Chunk: A (scaled dot product)
    Chunk->>Chunk: Solve triangular system<br/>extract w, u, vg
    Chunk->>ChunkH: Process via h computation<br/>with w, u, vg, gv
    ChunkH-->>Chunk: h, k_new, final_state
    Chunk->>ChunkO: Compute output o & A<br/>from h and inputs
    ChunkO-->>Chunk: o, A
    Chunk->>Output: Save intermediates for backward
    Chunk-->>User: o, final_state
Loading
sequenceDiagram
    participant User as User Code
    participant Chunk as chunk_oja2_bwd<br/>(Backward)
    participant DH as chunk_oja2_bwd_dhu<br/>(dH Kernel)
    participant DV as chunk_oja2_bwd_dvwg_h<br/>(dV Kernel)
    participant DA as chunk_oja2_bwd_dA<br/>(dA Kernel)
    participant DQK as chunk_oja2_bwd_dqk<br/>(dQK Kernel)

    User->>Chunk: do, dht, saved intermediates
    Chunk->>Chunk: Recompute w, u, vg
    Chunk->>DH: Compute dh, dh0, dk_new<br/>from do and dht
    DH-->>Chunk: dh, dh0, dk_new
    Chunk->>DV: Compute dv, dw, dgv_last<br/>from dh and dk_new
    DV-->>Chunk: dv, dw, dgv_last
    Chunk->>DA: Compute dA<br/>from do and intermediates
    DA-->>Chunk: dA
    Chunk->>DQK: Compute dq, dk<br/>from dA and dh
    DQK-->>Chunk: dq, dk
    Chunk->>Chunk: Aggregate all gradients
    Chunk-->>User: dq, dk, dv, db, dgv, dh0
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

This review requires careful analysis of: (1) four interdependent Triton kernel modules with complex tiling strategies and auto-tuning configurations, (2) forward/backward mathematical correctness across multiple kernels, (3) PyTorch autograd integration patterns, (4) variable-length sequence handling with cu_seqlens, (5) optional parameter propagation (GV, L2 norm, beta broadcasting), and (6) state management and intermediate tensor saving for backpropagation.

Possibly related PRs

Suggested reviewers

  • sustcsonglin

Poem

🐰 Chunked OJA leaps through recurrence streams,
Kernels fuse in Triton's compute dreams,
H and O dance with K and V so keen,
Gradients flow in backward's graceful scene,
Variable-length sequences bend with care—
A complete recurrent attention, beyond compare!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Upload Gated OJA Operator' accurately describes the primary addition: a comprehensive new gated OJA2 operator implementation across multiple modules with forward and backward passes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new Gated OJA Operator, providing both chunked and fused recurrent approaches. The core functionality is built upon custom Triton kernels designed for performance, handling the forward and backward passes for the operator's internal state, output, and key-key-transpose matrix computations. The implementation is robust, supporting variable sequence lengths and including comprehensive gradient computations, though the fused recurrent backward pass is explicitly noted as not yet implemented.

Highlights

  • New Gated OJA Operator: Introduced a new Gated OJA Operator with both chunked and fused recurrent implementations for efficient sequence processing.
  • Optimized Triton Kernels: Implemented highly optimized Triton kernels for the forward and backward passes of various components, including state (h), output (o), and the KKT matrix computation.
  • Comprehensive Gradient Calculation: The chunked implementation includes detailed gradient calculations for all relevant tensors, ensuring full backpropagation support.
  • Variable Sequence Length Support: Both chunked and fused recurrent operators now support variable sequence lengths through cu_seqlens.
  • WY Representation Recomputation: Added functionality for recomputing w and u (WY representation) in both forward and backward passes for the OJA operator.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Gated OJA Operator, including chunked and fused recurrent implementations. The changes involve several new Triton kernels for forward and backward passes, as well as utility functions for handling variable sequence lengths and numerical stability. Overall, the code is well-structured and follows good practices for Triton kernel development, such as using autotune and input_guard.

However, there are several areas that require attention, particularly regarding code clarity, potential correctness issues, and a critical limitation in the fused recurrent operator's backward pass. Please review the specific comments below for detailed feedback.

Comment on lines +207 to +214
@staticmethod
@input_guard
def backward(ctx, do, dht):
raise NotImplementedError(
"Backward pass is not implemented yet and we do not have plans to implement it "
"because we haven't figured out how to compute dg without materializing the full "
"hidden states for all time steps."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The backward method for FusedRecurrentFunction is not implemented, raising a NotImplementedError. This is a critical limitation as it prevents the use of this operator in training scenarios that require backpropagation. This should be clearly documented in the function's docstring and the pull request description, as it significantly impacts the usability of this fused operator.

Comment thread fla/ops/oja2/chunk_h.py
Comment on lines +531 to +650
def chunk_gsa_bwd_k_kernel_dqkvg(
q,
k,
v,
h,
g,
A,
do,
dh,
dq,
dk,
dv,
dg,
dgv,
dA,
cu_seqlens,
chunk_indices,
scale,
T,
B: tl.constexpr,
HQ: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NG: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_hq = i_bh // HQ, i_bh % HQ
i_h = i_hq // NG
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
all = T
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
all = B * T

o_i = tl.arange(0, BT)
o_t = min(i_t * BT + BT, T)
m_s = o_i[:, None] >= o_i[None, :]

p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
b_A = tl.where(m_s, b_A, 0.)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
o_v = i_v * BV + tl.arange(0, BV)
p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
m_v = o_v < V

# [BV,]
b_gn = tl.load(p_gn, mask=m_v, other=0)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_gv = exp(b_gn[None, :] - b_g)
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BV]
b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)

b_dh = b_dh.to(b_k.dtype)
# [BT, BK]
b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
# [BT, BV]
b_dv = tl.dot(b_k, b_dh) * b_gv
# [BV]
b_dg += tl.sum(b_dv * b_v, 0)

if i_k == 0:
b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
else:
b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]

tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# [BT, BT]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BT, BK]
b_dq += tl.dot(b_dA, b_k)
b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)

tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel chunk_gsa_bwd_k_kernel_dqkvg is named gsa but is located in the oja2 directory. This suggests a naming inconsistency or that this kernel might be an artifact from another module. Please ensure that all kernels are appropriately named and belong to their respective modules to maintain code clarity and modularity.

Comment thread fla/ops/oja2/chunk_kkt.py
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_A = tl.sum(b_k * b_kt[None, :] * exp(b_g - b_gk[None, :]), 1)
b_A = tl.where(o_i > j, b_A, 0.)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The line b_A = tl.where(o_i > j, b_A, 0.) might not correctly enforce strict lower triangularity for the matrix b_A within the loop. o_i is a block of indices, and j is a scalar. For proper element-wise comparison across the matrix, it should likely be o_i[:, None] > j or a similar construct to ensure the masking applies correctly to the matrix dimensions. Please verify this logic for correctness.

Comment thread fla/ops/oja2/chunk_o.py
# [BC, BV]
b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
# avoid 0 * inf = inf
b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment avoid 0 * inf = inf highlights a potential numerical stability issue. While tl.where is used to mask the addition, the multiplication b_A[:, None] * b_vg might still produce inf if b_vg contains inf values (e.g., from exp) and b_A is non-zero. This could lead to NaN propagation. Please ensure that inf values are not generated or are handled robustly before this multiplication.

Comment thread fla/ops/oja2/chunk.py
Comment on lines +286 to +323
# === 遍历检查所有梯度,定位具体是哪个 NaN ===
# 将变量名和tensor对应起来
# grad_tensors = {
# 'dq': dq, 'dk': dk, 'dv': dv, 'db': db,
# 'dg': dg, 'dh0': dh0
# }

# for name, t in grad_tensors.items():
# if t is not None and torch.isnan(t).any():
# import os
# import torch.distributed as dist

# # 获取 Rank ID
# # try:
# # rank = dist.get_rank() if dist.is_initialized() else 0
# # except:
# # rank = 0
# rank = 0

# base_dir = "/mnt/moonfs/hujiaxi-m2/oja_nan_12"
# os.makedirs(base_dir, exist_ok=True)

# # 保存路径:nan_dump_rank{卡号}.pt
# save_path = os.path.join(base_dir, f"nan_dump_rank{rank}.pt")

# torch.save({
# "q": q,
# "k": k,
# "v": v,
# "beta": beta,
# "gv": gv,
# "do": do,
# "cu_seqlens": cu_seqlens,
# "error_source": name # 顺便把出错的变量名也存进文件
# }, save_path)

# # 明确报错:指出是哪个变量出的问题
# raise RuntimeError(f"NaN detected in [{name}] on Rank {rank}! Context saved to: {save_path}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out debugging code for NaN detection should be removed before merging to keep the codebase clean and prevent accidental re-enabling.

Comment thread fla/ops/oja2/chunk.py
Comment on lines +346 to +350
if 'head_first' in kwargs:
warnings.warn(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warning message for the deprecated head_first argument is confusing. If head_first is deprecated, suggesting head_first=False "for now instead" implies it might still be used or is a temporary workaround. It would be clearer to either remove the argument entirely if it's no longer supported or provide a clearer migration path if its functionality is replaced.

Comment thread fla/ops/oja2/chunk_h.py
gv: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment SY: remove this argument and force chunk size 64? indicates an unresolved design decision. If the chunk_size is always intended to be 64, this argument should be removed from the function signature to simplify the API and prevent potential misuse or confusion.

Comment thread fla/ops/oja2/chunk_h.py
dht: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to chunk_oja2_fwd_h, the comment SY: remove this argument and force chunk size 64? suggests an unresolved design decision regarding chunk_size. Please clarify if this argument is necessary or if the chunk size should be fixed.

Comment thread fla/ops/oja2/chunk_o.py
Comment on lines +288 to +292
all = T
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
all = B * T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable all is conditionally assigned T or B * T based on IS_VARLEN. Using a generic name like all for a variable that changes its meaning and is used in pointer arithmetic can be confusing and error-prone. Consider renaming it to something more descriptive, like total_sequence_elements or batch_time_elements, to improve clarity and prevent potential bugs.

Comment thread fla/ops/oja2/wy_fast.py

b_dA = tl.where(m_A, -b_dA, 0)

# if USE_GV:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out if USE_GV: suggests that USE_GV might be intended for conditional logic here, but it's currently inactive. Please either remove this commented-out line if it's no longer relevant or properly implement the conditional logic if it's intended to be used.

@AwesomeSeq AwesomeSeq closed this Jan 22, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🤖 Fix all issues with AI agents
In `@fla/ops/oja2/chunk_h.py`:
- Line 1: Re-run the project's pre-commit hooks/formatter on
fla/ops/oja2/chunk_h.py to remove trailing whitespace and apply EOF/autopep8
formatting fixes; specifically ensure the import line "from typing import
Optional, Tuple" and the file ending are formatted per the repo's style (no
trailing spaces, proper newline at EOF) and commit the updated file so CI lint
passes.
- Line 93: The assignment to the unused variable BV should be removed to satisfy
linting; locate the BV = 64 statement in chunk_h.py (symbol BV) and delete that
line (or remove any unused constant/variable declaration named BV) so no unused
BV symbol remains in the module.
- Around line 148-163: The GV gating uses K instead of V, which can skip loads
or use wrong masks; update the gating and masks so comparisons use V (not K):
change the conditional checks if K > 64 / 128 / 192 to if V > 64 / 128 / 192 and
ensure the tl.load mask arguments for o_v2, o_v3, o_v4 use (o_vX < V) (e.g., the
last load currently uses (o_v4 < K) — change it to (o_v4 < V)); update
references around b_h2, b_h3, b_h4 and their corresponding o_v2/o_v3/o_v4 loads
in chunk_h.py.
- Around line 470-485: The code currently hardcodes BT = 64 but calls
prepare_chunk_indices(cu_seqlens, chunk_size), causing mismatch if chunk_size !=
64; fix by enforcing a single source of truth: either require chunk_size == 64
or thread chunk_size through BT. Implementation: add a runtime assertion and/or
normalize BT from the provided chunk_size (e.g., assert chunk_size == 64 or set
BT = chunk_size) and then use that BT consistently for prepare_chunk_indices and
prepare_chunk_offsets; reference symbols: BT, chunk_size, prepare_chunk_indices,
prepare_chunk_offsets.
- Around line 197-207: The function chunk_oja2_fwd_h currently annotates its
return as a 2‑tuple but actually returns (h, k_new, final_state); update the
return type on chunk_oja2_fwd_h to reflect the third, optional tensor (e.g.
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] or Tuple[torch.Tensor,
torch.Tensor, torch.Tensor] if you prefer always returning final_state) so the
signature matches the returned values (h, k_new, final_state) and adjust any
related type hints/call sites if necessary.
- Around line 795-810: The code unconditionally calls torch.empty_like(gv) and
returns dgv_last, but gv is Optional and may be None; add an explicit validation
at the start of the function that raises a clear TypeError (or ValueError) if gv
is None (e.g., "gv must be provided"), or alternatively change the signature to
make gv required; update callers if you choose the latter. Ensure the check
happens before creating dgv_last and before invoking
chunk_oja2_bwd_kernel_dvwg_h so torch.empty_like(gv) is never called with None;
reference gv, dgv_last, and chunk_oja2_bwd_kernel_dvwg_h when making the change.
- Around line 391-445: The gv pointer creation and loads (p_gv and b_gv) must be
moved inside the USE_GV guard because gv can be None; in each of the four blocks
(the V>0, V>64, V>128, V>192 branches) remove or skip tl.make_block_ptr(gv, ...)
and tl.load(...) when outside the if USE_GV, and instead create p_gv and load
b_gv only inside the if USE_GV: use the existing symbols p_gv, b_gv, gv, and
USE_GV and keep the subsequent operations that reference b_gv (e.g., b_gv_last*,
b_dh* *= exp(...), b_do *= exp(b_gv)) inside that guard so no gv access happens
when USE_GV is false.
- Around line 710-769: The code uses b_gn and b_dv unconditionally though they
are only set inside the USE_GV branch, causing crashes when USE_GV is false; fix
by providing safe fallbacks or guarding uses: ensure b_gn and b_dv are
initialized before the loop (e.g., zeros or proper shapes) or move/guard the
computations that reference b_gn and b_dv (the tl.sum using exp(b_gn) in the
loop that updates b_dgv_last, the b_dv-dependent b_dgv_last accumulation and
tl.store(p_dv,...)/tl.store(p_dgv_last,...)) under the same USE_GV condition,
updating references to b_dv, b_gn, and b_dgv_last consistently (symbols: USE_GV,
b_gn, b_dv, b_dgv_last, p_dv, p_dgv_last).
- Around line 470-505: The function allows scale: Optional[float]=None but
passes scale into the kernel chunk_oja2_bwd_kernel_dhu_blockdim64 and uses it in
arithmetic; add a validation at the start of this function to ensure scale is
not None (e.g., assert scale is not None or raise ValueError with context) so
the kernel never receives None, or alternatively assign a safe default (e.g.,
scale = 1.0) before calling prepare_chunk_indices/prepare_chunk_offsets and
before launching chunk_oja2_bwd_kernel_dhu_blockdim64; reference the parameter
name scale and the kernel chunk_oja2_bwd_kernel_dhu_blockdim64 to locate where
to add the check.

In `@fla/ops/oja2/chunk_kkt.py`:
- Around line 384-417: The beta parameter is currently Optional but required by
the native kernels; update the APIs to make beta mandatory instead of Optional
(remove the default None and Optional[...] in the signatures and docs) for
chunk_scaled_dot_kkt_fwd and the corresponding backward wrapper (the one at
lines ~474-495), or alternatively add an early check that raises a clear
ValueError if beta is None before calling the kernels; reference
chunk_scaled_dot_kkt_fwd (and its backward counterpart) when applying the change
so callers and docs are updated consistently.

In `@fla/ops/oja2/chunk_o.py`:
- Around line 7-15: Remove the duplicate/conflicting exp import and the unused
chunk_local_cumsum import in chunk_o.py: delete the line importing exp from
fla.ops.utils.op (or stop reassigning exp = tl.exp) so that only the intended
exp symbol from tl.exp remains, and remove the unused chunk_local_cumsum import;
keep prepare_chunk_indices and the shared-memory checks (BKV_LIST/NUM_WARPS)
intact. Ensure there are no other references to fla.ops.utils.op.exp or
chunk_local_cumsum elsewhere in this module before removing.

In `@fla/ops/oja2/wy_fast.py`:
- Around line 11-13: Remove the unused import symbol chunk_local_cumsum from the
import statement in this module: update the line "from fla.ops.utils import
chunk_local_cumsum, prepare_chunk_indices" so it only imports
prepare_chunk_indices; keep the rest of imports (exp from fla.ops.utils.op and
check_shared_mem) unchanged to avoid affecting other references.
- Around line 240-289: prepare_wy_repr_bwd must guard against gv being None and
derive the chunk size BT from A instead of hardcoding 64; replace the hardcoded
BT=64 with BT = A.shape[-1] (or the appropriate last-dimension of A used for
tiling) and before allocating dgv/db/etc. ensure gv is non-None by doing
something like if gv is None: gv = torch.zeros_like(v) (so torch.empty_like(gv,
dtype=torch.float) is safe). Update uses of BT (NT computation and kernel args)
to use the new BT variable and keep the rest of allocations (dgv, dA, db)
unchanged.
- Around line 199-237: The function recompute_w_u_fwd currently declares gv as
Optional[torch.Tensor] = None but the triton kernel dereferences gv
unconditionally, and the function only annotates returning two Tensors while it
actually returns w, u, vg; update the signature to make gv a required
torch.Tensor (remove Optional and default None) and change the return annotation
to Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; ensure vg is always
allocated (vg = torch.empty_like(v)) and passed/returned consistently to match
the kernel's expectation and the returned triple (symbols: recompute_w_u_fwd,
gv, vg, w, u, recompute_w_u_fwd_kernel).

Comment thread fla/ops/oja2/chunk_h.py
@@ -0,0 +1,821 @@
from typing import Optional, Tuple
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Re-run pre-commit to apply formatting fixes.
CI reports trailing whitespace/EOF/autopep8 changes in this file; please re-run the formatter and commit the result so lint passes.

🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook 'trailing-whitespace'.


[error] 1-1: End-of-file fixer modified the file to ensure proper EOF; re-run pre-commit.


[error] 1-1: autopep8 formatting applied. Please re-run pre-commit to apply changes.

🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` at line 1, Re-run the project's pre-commit
hooks/formatter on fla/ops/oja2/chunk_h.py to remove trailing whitespace and
apply EOF/autopep8 formatting fixes; specifically ensure the import line "from
typing import Optional, Tuple" and the file ending are formatted per the repo's
style (no trailing spaces, proper newline at EOF) and commit the updated file so
CI lint passes.

Comment thread fla/ops/oja2/chunk_h.py
h0 = h0 + i_nh * K*V
if STORE_FINAL_STATE:
ht = ht + i_nh * K*V
BV=64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused BV to clear lint errors.
Line 93 assigns BV but it isn't used and is already failing Ruff/Flake8.

🧹 Proposed fix
-    BV=64
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
BV=64
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 93-93: local variable 'BV' is assigned to but never used

(F841)

🪛 Ruff (0.14.13)

93-93: Local variable BV is assigned to but never used

Remove assignment to unused variable BV

(F841)

🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` at line 93, The assignment to the unused variable BV
should be removed to satisfy linting; locate the BV = 64 statement in chunk_h.py
(symbol BV) and delete that line (or remove any unused constant/variable
declaration named BV) so no unused BV symbol remains in the module.

Comment thread fla/ops/oja2/chunk_h.py
Comment on lines +148 to +163
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.)
b_h1 *= exp(b_gk_last1)[None, :]
if K > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_v4 = 192 + o_v1
b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix GV gating thresholds to use V (not K).
Lines 152–162 gate b_h2/b_h3/b_h4 by K, which can skip gating or reference undefined buffers when K != V; the mask at Line 162 should also be V.

🧭 Proposed fix
-            if K > 64:
+            if V > 64:
                 o_v2 = 64 + o_v1
                 b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
                 b_h2 *= exp(b_gk_last2)[None, :]
-            if K > 128:
+            if V > 128:
                 o_v3 = 128 + o_v1
                 b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
                 b_h3 *= exp(b_gk_last3)[None, :]
-            if K > 192:
+            if V > 192:
                 o_v4 = 192 + o_v1
-                b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
+                b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.)
                 b_h4 *= exp(b_gk_last4)[None, :]
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 148 - 163, The GV gating uses K instead
of V, which can skip loads or use wrong masks; update the gating and masks so
comparisons use V (not K): change the conditional checks if K > 64 / 128 / 192
to if V > 64 / 128 / 192 and ensure the tl.load mask arguments for o_v2, o_v3,
o_v4 use (o_vX < V) (e.g., the last load currently uses (o_v4 < K) — change it
to (o_v4 < V)); update references around b_h2, b_h3, b_h4 and their
corresponding o_v2/o_v3/o_v4 loads in chunk_h.py.

Comment thread fla/ops/oja2/chunk_h.py
Comment on lines +197 to +207
def chunk_oja2_fwd_h(
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
gv: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_key: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -type f -name "chunk_h.py" | head -20

Repository: fla-org/flash-linear-attention

Length of output: 127


🏁 Script executed:

cat -n ./fla/ops/oja2/chunk_h.py | head -250

Repository: fla-org/flash-linear-attention

Length of output: 11723


Return annotation doesn't match the 3‑tuple.
Line 207 declares a 2‑tuple but the function returns (h, k_new, final_state). Update typing to include the optional tensors.

🧩 Proposed fix
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 197 - 207, The function
chunk_oja2_fwd_h currently annotates its return as a 2‑tuple but actually
returns (h, k_new, final_state); update the return type on chunk_oja2_fwd_h to
reflect the third, optional tensor (e.g. Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor, torch.Tensor] if
you prefer always returning final_state) so the signature matches the returned
values (h, k_new, final_state) and adjust any related type hints/call sites if
necessary.

Comment thread fla/ops/oja2/chunk_h.py
Comment on lines +391 to +445
if V > 0:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]

if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let me check the file structure
fd chunk_h.py

Repository: fla-org/flash-linear-attention

Length of output: 123


🏁 Script executed:

# Get the actual code at the specified lines
head -445 fla/ops/oja2/chunk_h.py | tail -55

Repository: fla-org/flash-linear-attention

Length of output: 3544


🏁 Script executed:

# Get more context - search for USE_GV and gv parameter definition
rg -n "USE_GV|def.*chunk_h|gv.*=" fla/ops/oja2/chunk_h.py -A 2 -B 2 | head -100

Repository: fla-org/flash-linear-attention

Length of output: 4394


Move gv loads and pointer creation inside the USE_GV guard.

Lines 396–397 (and the V>64/128/192 blocks at 410–411, 424–425, 436–437) create and load gv outside the USE_GV branch. When gv=None and USE_GV=false, tl.make_block_ptr(gv, ...) will fail. Since b_gv is only used inside the if USE_GV block, move both the pointer creation and load inside the guard.

Apply to all four V blocks
-            p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
-            b_gv = tl.load(p_gv, boundary_check=(0, 1))
             if USE_GV:
+                p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
+                b_gv = tl.load(p_gv, boundary_check=(0, 1))
                 o_v1 = tl.arange(0, 64)
                 b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
                 b_dh1 *= exp(b_gv_last1[None, :])
                 b_do *= exp(b_gv)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if V > 0:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 0:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_GV:
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_GV:
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_GV:
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
if USE_GV:
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 391 - 445, The gv pointer creation and
loads (p_gv and b_gv) must be moved inside the USE_GV guard because gv can be
None; in each of the four blocks (the V>0, V>64, V>128, V>192 branches) remove
or skip tl.make_block_ptr(gv, ...) and tl.load(...) when outside the if USE_GV,
and instead create p_gv and load b_gv only inside the if USE_GV: use the
existing symbols p_gv, b_gv, gv, and USE_GV and keep the subsequent operations
that reference b_gv (e.g., b_gv_last*, b_dh* *= exp(...), b_do *= exp(b_gv))
inside that guard so no gv access happens when USE_GV is false.

Comment thread fla/ops/oja2/chunk_kkt.py
Comment on lines +384 to +417
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32
) -> torch.Tensor:
r"""
Compute beta * K * K^T.

Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`

Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
BT = chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

beta is optional in the signature but required by the kernels.

With the current defaults, calling these wrappers without beta will fail at runtime. Either make beta required or guard against None early.

🐛 Proposed fix
 def chunk_scaled_dot_kkt_fwd(
     k: torch.Tensor,
     g: Optional[torch.Tensor] = None,
     gk: Optional[torch.Tensor] = None,
     beta: Optional[torch.Tensor] = None,
@@
 ) -> torch.Tensor:
+    if beta is None:
+        raise ValueError("beta must be provided for chunk_scaled_dot_kkt_fwd")
@@
 def chunk_scaled_dot_kkt_bwd_gk(
     k: torch.Tensor,
     g: torch.Tensor,
     beta: torch.Tensor,
     dA: torch.Tensor,
@@
 ):
+    if beta is None:
+        raise ValueError("beta must be provided for chunk_scaled_dot_kkt_bwd_gk")

Also applies to: 474-495

🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_kkt.py` around lines 384 - 417, The beta parameter is
currently Optional but required by the native kernels; update the APIs to make
beta mandatory instead of Optional (remove the default None and Optional[...] in
the signatures and docs) for chunk_scaled_dot_kkt_fwd and the corresponding
backward wrapper (the one at lines ~474-495), or alternatively add an early
check that raises a clear ValueError if beta is None before calling the kernels;
reference chunk_scaled_dot_kkt_fwd (and its backward counterpart) when applying
the change so callers and docs are updated consistently.

Comment thread fla/ops/oja2/chunk_o.py
Comment on lines +7 to +15
from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem, is_nvidia_hopper
from fla.ops.utils.cumsum import chunk_local_cumsum

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]

exp = tl.exp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the duplicate exp import and remove the unused import.

exp is imported and then redefined, and chunk_local_cumsum is unused. This is already failing lint (F811/F401).

🧹 Proposed fix
-from fla.ops.utils.op import exp
-from fla.ops.utils.cumsum import chunk_local_cumsum
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused

(F401)


[error] 15-15: redefinition of unused 'exp' from line 8

(F811)

🪛 GitHub Actions: lint

[error] 13-13: Ruff: 'exp' redefined; previous definition exists (F811). Remove duplicate definition.

🪛 Ruff (0.14.13)

15-15: Redefinition of unused exp from line 8: exp redefined here

(F811)

🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_o.py` around lines 7 - 15, Remove the
duplicate/conflicting exp import and the unused chunk_local_cumsum import in
chunk_o.py: delete the line importing exp from fla.ops.utils.op (or stop
reassigning exp = tl.exp) so that only the intended exp symbol from tl.exp
remains, and remove the unused chunk_local_cumsum import; keep
prepare_chunk_indices and the shared-memory checks (BKV_LIST/NUM_WARPS) intact.
Ensure there are no other references to fla.ops.utils.op.exp or
chunk_local_cumsum elsewhere in this module before removing.

Comment thread fla/ops/oja2/wy_fast.py
Comment on lines +11 to +13
from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove the unused import.

chunk_local_cumsum isn’t referenced in this module and is already flagged by lint.

🧹 Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
+from fla.ops.utils import prepare_chunk_indices
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem
from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 11-11: 'fla.ops.utils.chunk_local_cumsum' imported but unused

(F401)

🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 11 - 13, Remove the unused import
symbol chunk_local_cumsum from the import statement in this module: update the
line "from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices" so it
only imports prepare_chunk_indices; keep the rest of imports (exp from
fla.ops.utils.op and check_shared_mem) unchanged to avoid affecting other
references.

Comment thread fla/ops/oja2/wy_fast.py
Comment on lines +199 to +237
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
gv: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 64
BV = 64

chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

w = torch.empty_like(v)
u = torch.empty_like(k)
vg = torch.empty_like(v) if gv is not None else None
recompute_w_u_fwd_kernel[(NT, B*H)](
k=k,
v=v,
vg=vg,
beta=beta,
w=w,
u=u,
A=A,
gv=gv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return w, u, vg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make gv required and fix the return annotation.

gv is dereferenced unconditionally inside the kernel, so the default None is a crash path. Also, the function returns three values but is annotated as two.

🐛 Proposed fix
-def recompute_w_u_fwd(
-    k: torch.Tensor,
-    v: torch.Tensor,
-    beta: torch.Tensor,
-    A: torch.Tensor,
-    gv: Optional[torch.Tensor] = None,
-    cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+def recompute_w_u_fwd(
+    k: torch.Tensor,
+    v: torch.Tensor,
+    beta: torch.Tensor,
+    A: torch.Tensor,
+    gv: torch.Tensor,
+    cu_seqlens: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    if gv is None:
+        raise ValueError("gv must be provided for recompute_w_u_fwd")
@@
-    vg = torch.empty_like(v) if gv is not None else None
+    vg = torch.empty_like(v)
@@
-    return w, u, vg
+    return w, u, vg
🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 199 - 237, The function
recompute_w_u_fwd currently declares gv as Optional[torch.Tensor] = None but the
triton kernel dereferences gv unconditionally, and the function only annotates
returning two Tensors while it actually returns w, u, vg; update the signature
to make gv a required torch.Tensor (remove Optional and default None) and change
the return annotation to Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; ensure
vg is always allocated (vg = torch.empty_like(v)) and passed/returned
consistently to match the kernel's expectation and the returned triple (symbols:
recompute_w_u_fwd, gv, vg, w, u, recompute_w_u_fwd_kernel).

Comment thread fla/ops/oja2/wy_fast.py
Comment on lines +240 to +289
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
gv: torch.Tensor = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = 64
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64 if check_shared_mem() else 32
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

dk = torch.empty_like(k)
dv = torch.empty_like(v, dtype=torch.float)

dgv = torch.empty_like(gv, dtype=torch.float)
dA = torch.empty_like(A, dtype=torch.float)
db = torch.empty_like(beta, dtype=torch.float)

prepare_wy_repr_bwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
gv=gv,
A=A,
dA=dA,
dw=dw,
du=du,
dk=dk,
dv=dv,
db=db,
dgv=dgv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)

return dk, dv, db, dgv, dA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against gv=None and derive BT from A.

gv defaults to None but is used immediately (torch.empty_like(gv) and kernel loads), which will fail at runtime. Also, using A.shape[-1] makes the chunk size robust if it ever changes.

🐛 Proposed fix
-def prepare_wy_repr_bwd(
-    k: torch.Tensor,
-    v: torch.Tensor,
-    beta: torch.Tensor,
-    A: torch.Tensor,
-    dw: torch.Tensor,
-    du: torch.Tensor,
-    gv: torch.Tensor = None,
-    cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+def prepare_wy_repr_bwd(
+    k: torch.Tensor,
+    v: torch.Tensor,
+    beta: torch.Tensor,
+    A: torch.Tensor,
+    dw: torch.Tensor,
+    du: torch.Tensor,
+    gv: torch.Tensor,
+    cu_seqlens: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    if gv is None:
+        raise ValueError("gv must be provided for prepare_wy_repr_bwd")
@@
-    BT = 64
+    BT = A.shape[-1]
🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 240 - 289, prepare_wy_repr_bwd must
guard against gv being None and derive the chunk size BT from A instead of
hardcoding 64; replace the hardcoded BT=64 with BT = A.shape[-1] (or the
appropriate last-dimension of A used for tiling) and before allocating
dgv/db/etc. ensure gv is non-None by doing something like if gv is None: gv =
torch.zeros_like(v) (so torch.empty_like(gv, dtype=torch.float) is safe). Update
uses of BT (NT computation and kernel args) to use the new BT variable and keep
the rest of allocations (dgv, dA, db) unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant