Conversation
WalkthroughThis PR introduces a complete chunked OJA2 (contrastive Hebbian-like) recurrent attention implementation across seven files. It provides forward and backward passes via Triton kernels, PyTorch autograd integration, optional L2 normalization, and support for variable-length sequences. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant Chunk as chunk_oja2<br/>(High-Level)
participant ChunkKKT as chunk_scaled_dot_kkt_fwd<br/>(KKT Kernel)
participant ChunkH as chunk_oja2_fwd_h<br/>(H Kernel)
participant ChunkO as chunk_oja2_fwd_o<br/>(O Kernel)
participant Output as Autograd Output
User->>Chunk: q, k, v, gv, beta, scale<br/>initial_state, cu_seqlens
Chunk->>Chunk: Validate inputs & defaults
Chunk->>ChunkKKT: Compute A = β·K·K^T<br/>with optional GK
ChunkKKT-->>Chunk: A (scaled dot product)
Chunk->>Chunk: Solve triangular system<br/>extract w, u, vg
Chunk->>ChunkH: Process via h computation<br/>with w, u, vg, gv
ChunkH-->>Chunk: h, k_new, final_state
Chunk->>ChunkO: Compute output o & A<br/>from h and inputs
ChunkO-->>Chunk: o, A
Chunk->>Output: Save intermediates for backward
Chunk-->>User: o, final_state
sequenceDiagram
participant User as User Code
participant Chunk as chunk_oja2_bwd<br/>(Backward)
participant DH as chunk_oja2_bwd_dhu<br/>(dH Kernel)
participant DV as chunk_oja2_bwd_dvwg_h<br/>(dV Kernel)
participant DA as chunk_oja2_bwd_dA<br/>(dA Kernel)
participant DQK as chunk_oja2_bwd_dqk<br/>(dQK Kernel)
User->>Chunk: do, dht, saved intermediates
Chunk->>Chunk: Recompute w, u, vg
Chunk->>DH: Compute dh, dh0, dk_new<br/>from do and dht
DH-->>Chunk: dh, dh0, dk_new
Chunk->>DV: Compute dv, dw, dgv_last<br/>from dh and dk_new
DV-->>Chunk: dv, dw, dgv_last
Chunk->>DA: Compute dA<br/>from do and intermediates
DA-->>Chunk: dA
Chunk->>DQK: Compute dq, dk<br/>from dA and dh
DQK-->>Chunk: dq, dk
Chunk->>Chunk: Aggregate all gradients
Chunk-->>User: dq, dk, dv, db, dgv, dh0
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes This review requires careful analysis of: (1) four interdependent Triton kernel modules with complex tiling strategies and auto-tuning configurations, (2) forward/backward mathematical correctness across multiple kernels, (3) PyTorch autograd integration patterns, (4) variable-length sequence handling with cu_seqlens, (5) optional parameter propagation (GV, L2 norm, beta broadcasting), and (6) state management and intermediate tensor saving for backpropagation. Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new Gated OJA Operator, providing both chunked and fused recurrent approaches. The core functionality is built upon custom Triton kernels designed for performance, handling the forward and backward passes for the operator's internal state, output, and key-key-transpose matrix computations. The implementation is robust, supporting variable sequence lengths and including comprehensive gradient computations, though the fused recurrent backward pass is explicitly noted as not yet implemented. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the Gated OJA Operator, including chunked and fused recurrent implementations. The changes involve several new Triton kernels for forward and backward passes, as well as utility functions for handling variable sequence lengths and numerical stability. Overall, the code is well-structured and follows good practices for Triton kernel development, such as using autotune and input_guard.
However, there are several areas that require attention, particularly regarding code clarity, potential correctness issues, and a critical limitation in the fused recurrent operator's backward pass. Please review the specific comments below for detailed feedback.
| @staticmethod | ||
| @input_guard | ||
| def backward(ctx, do, dht): | ||
| raise NotImplementedError( | ||
| "Backward pass is not implemented yet and we do not have plans to implement it " | ||
| "because we haven't figured out how to compute dg without materializing the full " | ||
| "hidden states for all time steps." | ||
| ) |
There was a problem hiding this comment.
The backward method for FusedRecurrentFunction is not implemented, raising a NotImplementedError. This is a critical limitation as it prevents the use of this operator in training scenarios that require backpropagation. This should be clearly documented in the function's docstring and the pull request description, as it significantly impacts the usability of this fused operator.
| def chunk_gsa_bwd_k_kernel_dqkvg( | ||
| q, | ||
| k, | ||
| v, | ||
| h, | ||
| g, | ||
| A, | ||
| do, | ||
| dh, | ||
| dq, | ||
| dk, | ||
| dv, | ||
| dg, | ||
| dgv, | ||
| dA, | ||
| cu_seqlens, | ||
| chunk_indices, | ||
| scale, | ||
| T, | ||
| B: tl.constexpr, | ||
| HQ: tl.constexpr, | ||
| H: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BT: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| NG: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| ): | ||
| i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
| i_b, i_hq = i_bh // HQ, i_bh % HQ | ||
| i_h = i_hq // NG | ||
| if IS_VARLEN: | ||
| i_tg = i_t | ||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||
| all = T | ||
| T = eos - bos | ||
| NT = tl.cdiv(T, BT) | ||
| else: | ||
| NT = tl.cdiv(T, BT) | ||
| i_tg = i_b * NT + i_t | ||
| bos, eos = i_b * T, i_b * T + T | ||
| all = B * T | ||
|
|
||
| o_i = tl.arange(0, BT) | ||
| o_t = min(i_t * BT + BT, T) | ||
| m_s = o_i[:, None] >= o_i[None, :] | ||
|
|
||
| p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||
|
|
||
| # [BT, BK] | ||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| # [BT, BT] | ||
| b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k)) | ||
| b_A = tl.where(m_s, b_A, 0.) | ||
| tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||
| b_dk = tl.zeros([BT, BK], dtype=tl.float32) | ||
| for i_v in range(tl.cdiv(V, BV)): | ||
| o_v = i_v * BV + tl.arange(0, BV) | ||
| p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v | ||
| p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||
| p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) | ||
| m_v = o_v < V | ||
|
|
||
| # [BV,] | ||
| b_gn = tl.load(p_gn, mask=m_v, other=0) | ||
| # [BT, BV] | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||
| b_gv = exp(b_gn[None, :] - b_g) | ||
| # [BV, BK] | ||
| b_h = tl.load(p_h, boundary_check=(0, 1)) | ||
| # [BT, BV] | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| b_do = (b_do * exp(b_g) * scale).to(b_do.dtype) | ||
| # [BK, BV] | ||
| b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||
| # [BV] | ||
| b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn) | ||
|
|
||
| b_dh = b_dh.to(b_k.dtype) | ||
| # [BT, BK] | ||
| b_dq += tl.dot(b_do, b_h.to(b_k.dtype)) | ||
| b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh)) | ||
| # [BT, BV] | ||
| b_dv = tl.dot(b_k, b_dh) * b_gv | ||
| # [BV] | ||
| b_dg += tl.sum(b_dv * b_v, 0) | ||
|
|
||
| if i_k == 0: | ||
| b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :] | ||
| else: | ||
| b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :] | ||
|
|
||
| tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||
| p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||
| p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| # [BT, BT] | ||
| b_dA = tl.load(p_dA, boundary_check=(0, 1)) | ||
| # [BT, BK] | ||
| b_dq += tl.dot(b_dA, b_k) | ||
| b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q) | ||
|
|
||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
The kernel chunk_gsa_bwd_k_kernel_dqkvg is named gsa but is located in the oja2 directory. This suggests a naming inconsistency or that this kernel might be an artifact from another module. Please ensure that all kernels are appropriately named and belong to their respective modules to maintain code clarity and modularity.
| b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32) | ||
| b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) | ||
| b_A = tl.sum(b_k * b_kt[None, :] * exp(b_g - b_gk[None, :]), 1) | ||
| b_A = tl.where(o_i > j, b_A, 0.) |
There was a problem hiding this comment.
The line b_A = tl.where(o_i > j, b_A, 0.) might not correctly enforce strict lower triangularity for the matrix b_A within the loop. o_i is a block of indices, and j is a scalar. For proper element-wise comparison across the matrix, it should likely be o_i[:, None] > j or a similar construct to ensure the masking applies correctly to the matrix dimensions. Please verify this logic for correctness.
| # [BC, BV] | ||
| b_vg = b_v[None, :] * exp(b_g - b_gv[None, :]) | ||
| # avoid 0 * inf = inf | ||
| b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.) |
There was a problem hiding this comment.
The comment avoid 0 * inf = inf highlights a potential numerical stability issue. While tl.where is used to mask the addition, the multiplication b_A[:, None] * b_vg might still produce inf if b_vg contains inf values (e.g., from exp) and b_A is non-zero. This could lead to NaN propagation. Please ensure that inf values are not generated or are handled robustly before this multiplication.
| # === 遍历检查所有梯度,定位具体是哪个 NaN === | ||
| # 将变量名和tensor对应起来 | ||
| # grad_tensors = { | ||
| # 'dq': dq, 'dk': dk, 'dv': dv, 'db': db, | ||
| # 'dg': dg, 'dh0': dh0 | ||
| # } | ||
|
|
||
| # for name, t in grad_tensors.items(): | ||
| # if t is not None and torch.isnan(t).any(): | ||
| # import os | ||
| # import torch.distributed as dist | ||
|
|
||
| # # 获取 Rank ID | ||
| # # try: | ||
| # # rank = dist.get_rank() if dist.is_initialized() else 0 | ||
| # # except: | ||
| # # rank = 0 | ||
| # rank = 0 | ||
|
|
||
| # base_dir = "/mnt/moonfs/hujiaxi-m2/oja_nan_12" | ||
| # os.makedirs(base_dir, exist_ok=True) | ||
|
|
||
| # # 保存路径:nan_dump_rank{卡号}.pt | ||
| # save_path = os.path.join(base_dir, f"nan_dump_rank{rank}.pt") | ||
|
|
||
| # torch.save({ | ||
| # "q": q, | ||
| # "k": k, | ||
| # "v": v, | ||
| # "beta": beta, | ||
| # "gv": gv, | ||
| # "do": do, | ||
| # "cu_seqlens": cu_seqlens, | ||
| # "error_source": name # 顺便把出错的变量名也存进文件 | ||
| # }, save_path) | ||
|
|
||
| # # 明确报错:指出是哪个变量出的问题 | ||
| # raise RuntimeError(f"NaN detected in [{name}] on Rank {rank}! Context saved to: {save_path}") |
| if 'head_first' in kwargs: | ||
| warnings.warn( | ||
| "head_first is deprecated and will be removed in a future version. " | ||
| "Please use head_first=False for now instead." | ||
| ) |
There was a problem hiding this comment.
The warning message for the deprecated head_first argument is confusing. If head_first is deprecated, suggesting head_first=False "for now instead" implies it might still be used or is a temporary workaround. It would be clearer to either remove the argument entirely if it's no longer supported or provide a clearer migration path if its functionality is replaced.
| gv: Optional[torch.Tensor] = None, | ||
| initial_state: Optional[torch.Tensor] = None, | ||
| output_final_state: bool = False, | ||
| chunk_size: int = 64, # SY: remove this argument and force chunk size 64? |
There was a problem hiding this comment.
| dht: Optional[torch.Tensor] = None, | ||
| scale: Optional[float] = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| chunk_size: int = 64, # SY: remove this argument and force chunk size 64? |
| all = T | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_b * T, i_b * T + T | ||
| all = B * T |
There was a problem hiding this comment.
The variable all is conditionally assigned T or B * T based on IS_VARLEN. Using a generic name like all for a variable that changes its meaning and is used in pointer arithmetic can be confusing and error-prone. Consider renaming it to something more descriptive, like total_sequence_elements or batch_time_elements, to improve clarity and prevent potential bugs.
|
|
||
| b_dA = tl.where(m_A, -b_dA, 0) | ||
|
|
||
| # if USE_GV: |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 14
🤖 Fix all issues with AI agents
In `@fla/ops/oja2/chunk_h.py`:
- Line 1: Re-run the project's pre-commit hooks/formatter on
fla/ops/oja2/chunk_h.py to remove trailing whitespace and apply EOF/autopep8
formatting fixes; specifically ensure the import line "from typing import
Optional, Tuple" and the file ending are formatted per the repo's style (no
trailing spaces, proper newline at EOF) and commit the updated file so CI lint
passes.
- Line 93: The assignment to the unused variable BV should be removed to satisfy
linting; locate the BV = 64 statement in chunk_h.py (symbol BV) and delete that
line (or remove any unused constant/variable declaration named BV) so no unused
BV symbol remains in the module.
- Around line 148-163: The GV gating uses K instead of V, which can skip loads
or use wrong masks; update the gating and masks so comparisons use V (not K):
change the conditional checks if K > 64 / 128 / 192 to if V > 64 / 128 / 192 and
ensure the tl.load mask arguments for o_v2, o_v3, o_v4 use (o_vX < V) (e.g., the
last load currently uses (o_v4 < K) — change it to (o_v4 < V)); update
references around b_h2, b_h3, b_h4 and their corresponding o_v2/o_v3/o_v4 loads
in chunk_h.py.
- Around line 470-485: The code currently hardcodes BT = 64 but calls
prepare_chunk_indices(cu_seqlens, chunk_size), causing mismatch if chunk_size !=
64; fix by enforcing a single source of truth: either require chunk_size == 64
or thread chunk_size through BT. Implementation: add a runtime assertion and/or
normalize BT from the provided chunk_size (e.g., assert chunk_size == 64 or set
BT = chunk_size) and then use that BT consistently for prepare_chunk_indices and
prepare_chunk_offsets; reference symbols: BT, chunk_size, prepare_chunk_indices,
prepare_chunk_offsets.
- Around line 197-207: The function chunk_oja2_fwd_h currently annotates its
return as a 2‑tuple but actually returns (h, k_new, final_state); update the
return type on chunk_oja2_fwd_h to reflect the third, optional tensor (e.g.
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] or Tuple[torch.Tensor,
torch.Tensor, torch.Tensor] if you prefer always returning final_state) so the
signature matches the returned values (h, k_new, final_state) and adjust any
related type hints/call sites if necessary.
- Around line 795-810: The code unconditionally calls torch.empty_like(gv) and
returns dgv_last, but gv is Optional and may be None; add an explicit validation
at the start of the function that raises a clear TypeError (or ValueError) if gv
is None (e.g., "gv must be provided"), or alternatively change the signature to
make gv required; update callers if you choose the latter. Ensure the check
happens before creating dgv_last and before invoking
chunk_oja2_bwd_kernel_dvwg_h so torch.empty_like(gv) is never called with None;
reference gv, dgv_last, and chunk_oja2_bwd_kernel_dvwg_h when making the change.
- Around line 391-445: The gv pointer creation and loads (p_gv and b_gv) must be
moved inside the USE_GV guard because gv can be None; in each of the four blocks
(the V>0, V>64, V>128, V>192 branches) remove or skip tl.make_block_ptr(gv, ...)
and tl.load(...) when outside the if USE_GV, and instead create p_gv and load
b_gv only inside the if USE_GV: use the existing symbols p_gv, b_gv, gv, and
USE_GV and keep the subsequent operations that reference b_gv (e.g., b_gv_last*,
b_dh* *= exp(...), b_do *= exp(b_gv)) inside that guard so no gv access happens
when USE_GV is false.
- Around line 710-769: The code uses b_gn and b_dv unconditionally though they
are only set inside the USE_GV branch, causing crashes when USE_GV is false; fix
by providing safe fallbacks or guarding uses: ensure b_gn and b_dv are
initialized before the loop (e.g., zeros or proper shapes) or move/guard the
computations that reference b_gn and b_dv (the tl.sum using exp(b_gn) in the
loop that updates b_dgv_last, the b_dv-dependent b_dgv_last accumulation and
tl.store(p_dv,...)/tl.store(p_dgv_last,...)) under the same USE_GV condition,
updating references to b_dv, b_gn, and b_dgv_last consistently (symbols: USE_GV,
b_gn, b_dv, b_dgv_last, p_dv, p_dgv_last).
- Around line 470-505: The function allows scale: Optional[float]=None but
passes scale into the kernel chunk_oja2_bwd_kernel_dhu_blockdim64 and uses it in
arithmetic; add a validation at the start of this function to ensure scale is
not None (e.g., assert scale is not None or raise ValueError with context) so
the kernel never receives None, or alternatively assign a safe default (e.g.,
scale = 1.0) before calling prepare_chunk_indices/prepare_chunk_offsets and
before launching chunk_oja2_bwd_kernel_dhu_blockdim64; reference the parameter
name scale and the kernel chunk_oja2_bwd_kernel_dhu_blockdim64 to locate where
to add the check.
In `@fla/ops/oja2/chunk_kkt.py`:
- Around line 384-417: The beta parameter is currently Optional but required by
the native kernels; update the APIs to make beta mandatory instead of Optional
(remove the default None and Optional[...] in the signatures and docs) for
chunk_scaled_dot_kkt_fwd and the corresponding backward wrapper (the one at
lines ~474-495), or alternatively add an early check that raises a clear
ValueError if beta is None before calling the kernels; reference
chunk_scaled_dot_kkt_fwd (and its backward counterpart) when applying the change
so callers and docs are updated consistently.
In `@fla/ops/oja2/chunk_o.py`:
- Around line 7-15: Remove the duplicate/conflicting exp import and the unused
chunk_local_cumsum import in chunk_o.py: delete the line importing exp from
fla.ops.utils.op (or stop reassigning exp = tl.exp) so that only the intended
exp symbol from tl.exp remains, and remove the unused chunk_local_cumsum import;
keep prepare_chunk_indices and the shared-memory checks (BKV_LIST/NUM_WARPS)
intact. Ensure there are no other references to fla.ops.utils.op.exp or
chunk_local_cumsum elsewhere in this module before removing.
In `@fla/ops/oja2/wy_fast.py`:
- Around line 11-13: Remove the unused import symbol chunk_local_cumsum from the
import statement in this module: update the line "from fla.ops.utils import
chunk_local_cumsum, prepare_chunk_indices" so it only imports
prepare_chunk_indices; keep the rest of imports (exp from fla.ops.utils.op and
check_shared_mem) unchanged to avoid affecting other references.
- Around line 240-289: prepare_wy_repr_bwd must guard against gv being None and
derive the chunk size BT from A instead of hardcoding 64; replace the hardcoded
BT=64 with BT = A.shape[-1] (or the appropriate last-dimension of A used for
tiling) and before allocating dgv/db/etc. ensure gv is non-None by doing
something like if gv is None: gv = torch.zeros_like(v) (so torch.empty_like(gv,
dtype=torch.float) is safe). Update uses of BT (NT computation and kernel args)
to use the new BT variable and keep the rest of allocations (dgv, dA, db)
unchanged.
- Around line 199-237: The function recompute_w_u_fwd currently declares gv as
Optional[torch.Tensor] = None but the triton kernel dereferences gv
unconditionally, and the function only annotates returning two Tensors while it
actually returns w, u, vg; update the signature to make gv a required
torch.Tensor (remove Optional and default None) and change the return annotation
to Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; ensure vg is always
allocated (vg = torch.empty_like(v)) and passed/returned consistently to match
the kernel's expectation and the returned triple (symbols: recompute_w_u_fwd,
gv, vg, w, u, recompute_w_u_fwd_kernel).
| @@ -0,0 +1,821 @@ | |||
| from typing import Optional, Tuple | |||
There was a problem hiding this comment.
Re-run pre-commit to apply formatting fixes.
CI reports trailing whitespace/EOF/autopep8 changes in this file; please re-run the formatter and commit the result so lint passes.
🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook 'trailing-whitespace'.
[error] 1-1: End-of-file fixer modified the file to ensure proper EOF; re-run pre-commit.
[error] 1-1: autopep8 formatting applied. Please re-run pre-commit to apply changes.
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` at line 1, Re-run the project's pre-commit
hooks/formatter on fla/ops/oja2/chunk_h.py to remove trailing whitespace and
apply EOF/autopep8 formatting fixes; specifically ensure the import line "from
typing import Optional, Tuple" and the file ending are formatted per the repo's
style (no trailing spaces, proper newline at EOF) and commit the updated file so
CI lint passes.
| h0 = h0 + i_nh * K*V | ||
| if STORE_FINAL_STATE: | ||
| ht = ht + i_nh * K*V | ||
| BV=64 |
There was a problem hiding this comment.
Remove unused BV to clear lint errors.
Line 93 assigns BV but it isn't used and is already failing Ruff/Flake8.
🧹 Proposed fix
- BV=64📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| BV=64 |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 93-93: local variable 'BV' is assigned to but never used
(F841)
🪛 Ruff (0.14.13)
93-93: Local variable BV is assigned to but never used
Remove assignment to unused variable BV
(F841)
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` at line 93, The assignment to the unused variable BV
should be removed to satisfy linting; locate the BV = 64 statement in chunk_h.py
(symbol BV) and delete that line (or remove any unused constant/variable
declaration named BV) so no unused BV symbol remains in the module.
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_h1 *= exp(b_gk_last1)[None, :] | ||
| if K > 64: | ||
| o_v2 = 64 + o_v1 | ||
| b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_h2 *= exp(b_gk_last2)[None, :] | ||
| if K > 128: | ||
| o_v3 = 128 + o_v1 | ||
| b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_h3 *= exp(b_gk_last3)[None, :] | ||
| if K > 192: | ||
| o_v4 = 192 + o_v1 | ||
| b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.) | ||
| b_h4 *= exp(b_gk_last4)[None, :] |
There was a problem hiding this comment.
Fix GV gating thresholds to use V (not K).
Lines 152–162 gate b_h2/b_h3/b_h4 by K, which can skip gating or reference undefined buffers when K != V; the mask at Line 162 should also be V.
🧭 Proposed fix
- if K > 64:
+ if V > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
- if K > 128:
+ if V > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
- if K > 192:
+ if V > 192:
o_v4 = 192 + o_v1
- b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
+ b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 148 - 163, The GV gating uses K instead
of V, which can skip loads or use wrong masks; update the gating and masks so
comparisons use V (not K): change the conditional checks if K > 64 / 128 / 192
to if V > 64 / 128 / 192 and ensure the tl.load mask arguments for o_v2, o_v3,
o_v4 use (o_vX < V) (e.g., the last load currently uses (o_v4 < K) — change it
to (o_v4 < V)); update references around b_h2, b_h3, b_h4 and their
corresponding o_v2/o_v3/o_v4 loads in chunk_h.py.
| def chunk_oja2_fwd_h( | ||
| v: torch.Tensor, | ||
| w: torch.Tensor, | ||
| u: torch.Tensor, | ||
| gv: Optional[torch.Tensor] = None, | ||
| initial_state: Optional[torch.Tensor] = None, | ||
| output_final_state: bool = False, | ||
| chunk_size: int = 64, # SY: remove this argument and force chunk size 64? | ||
| save_new_key: bool = True, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -type f -name "chunk_h.py" | head -20Repository: fla-org/flash-linear-attention
Length of output: 127
🏁 Script executed:
cat -n ./fla/ops/oja2/chunk_h.py | head -250Repository: fla-org/flash-linear-attention
Length of output: 11723
Return annotation doesn't match the 3‑tuple.
Line 207 declares a 2‑tuple but the function returns (h, k_new, final_state). Update typing to include the optional tensors.
🧩 Proposed fix
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 197 - 207, The function
chunk_oja2_fwd_h currently annotates its return as a 2‑tuple but actually
returns (h, k_new, final_state); update the return type on chunk_oja2_fwd_h to
reflect the third, optional tensor (e.g. Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor, torch.Tensor] if
you prefer always returning final_state) so the signature matches the returned
values (h, k_new, final_state) and adjust any related type hints/call sites if
necessary.
| if V > 0: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_dh1 *= exp(b_gv_last1[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | ||
|
|
||
| if V > 64: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v2 = 64 + o_v1 | ||
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_dh2 *= exp(b_gv_last2[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 128: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v3 = 128 + o_v1 | ||
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_dh3 *= exp(b_gv_last3[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 192: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v4 = 192 + o_v1 | ||
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | ||
| b_dh4 *= exp(b_gv_last4[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let me check the file structure
fd chunk_h.pyRepository: fla-org/flash-linear-attention
Length of output: 123
🏁 Script executed:
# Get the actual code at the specified lines
head -445 fla/ops/oja2/chunk_h.py | tail -55Repository: fla-org/flash-linear-attention
Length of output: 3544
🏁 Script executed:
# Get more context - search for USE_GV and gv parameter definition
rg -n "USE_GV|def.*chunk_h|gv.*=" fla/ops/oja2/chunk_h.py -A 2 -B 2 | head -100Repository: fla-org/flash-linear-attention
Length of output: 4394
Move gv loads and pointer creation inside the USE_GV guard.
Lines 396–397 (and the V>64/128/192 blocks at 410–411, 424–425, 436–437) create and load gv outside the USE_GV branch. When gv=None and USE_GV=false, tl.make_block_ptr(gv, ...) will fail. Since b_gv is only used inside the if USE_GV block, move both the pointer creation and load inside the guard.
Apply to all four V blocks
- p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
- b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
+ p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if V > 0: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | |
| if V > 64: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 128: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 192: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 0: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| if USE_GV: | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | |
| if V > 64: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| if USE_GV: | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 128: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| if USE_GV: | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 192: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| if USE_GV: | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) |
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_h.py` around lines 391 - 445, The gv pointer creation and
loads (p_gv and b_gv) must be moved inside the USE_GV guard because gv can be
None; in each of the four blocks (the V>0, V>64, V>128, V>192 branches) remove
or skip tl.make_block_ptr(gv, ...) and tl.load(...) when outside the if USE_GV,
and instead create p_gv and load b_gv only inside the if USE_GV: use the
existing symbols p_gv, b_gv, gv, and USE_GV and keep the subsequent operations
that reference b_gv (e.g., b_gv_last*, b_dh* *= exp(...), b_do *= exp(b_gv))
inside that guard so no gv access happens when USE_GV is false.
| def chunk_scaled_dot_kkt_fwd( | ||
| k: torch.Tensor, | ||
| g: Optional[torch.Tensor] = None, | ||
| gk: Optional[torch.Tensor] = None, | ||
| beta: Optional[torch.Tensor] = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| chunk_size: int = 64, | ||
| output_dtype: torch.dtype = torch.float32 | ||
| ) -> torch.Tensor: | ||
| r""" | ||
| Compute beta * K * K^T. | ||
|
|
||
| Args: | ||
| k (torch.Tensor): | ||
| The key tensor of shape `[B, T, H, K]`. | ||
| beta (torch.Tensor): | ||
| The beta tensor of shape `[B, T, H]`. | ||
| g (torch.Tensor): | ||
| The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. | ||
| gk (torch.Tensor): | ||
| The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. | ||
| cu_seqlens (torch.LongTensor): | ||
| The cumulative sequence lengths of the input tensor. | ||
| Default: None | ||
| chunk_size (int): | ||
| The chunk size. Default: 64. | ||
| output_dtype (torch.dtype): | ||
| The dtype of the output tensor. Default: `torch.float32` | ||
|
|
||
| Returns: | ||
| beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. | ||
| """ | ||
| B, T, H, K = k.shape | ||
| BT = chunk_size |
There was a problem hiding this comment.
beta is optional in the signature but required by the kernels.
With the current defaults, calling these wrappers without beta will fail at runtime. Either make beta required or guard against None early.
🐛 Proposed fix
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: Optional[torch.Tensor] = None,
gk: Optional[torch.Tensor] = None,
beta: Optional[torch.Tensor] = None,
@@
) -> torch.Tensor:
+ if beta is None:
+ raise ValueError("beta must be provided for chunk_scaled_dot_kkt_fwd")
@@
def chunk_scaled_dot_kkt_bwd_gk(
k: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
dA: torch.Tensor,
@@
):
+ if beta is None:
+ raise ValueError("beta must be provided for chunk_scaled_dot_kkt_bwd_gk")Also applies to: 474-495
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_kkt.py` around lines 384 - 417, The beta parameter is
currently Optional but required by the native kernels; update the APIs to make
beta mandatory instead of Optional (remove the default None and Optional[...] in
the signatures and docs) for chunk_scaled_dot_kkt_fwd and the corresponding
backward wrapper (the one at lines ~474-495), or alternatively add an early
check that raises a clear ValueError if beta is None before calling the kernels;
reference chunk_scaled_dot_kkt_fwd (and its backward counterpart) when applying
the change so callers and docs are updated consistently.
| from fla.ops.utils import prepare_chunk_indices | ||
| from fla.ops.utils.op import exp | ||
| from fla.utils import check_shared_mem, is_nvidia_hopper | ||
| from fla.ops.utils.cumsum import chunk_local_cumsum | ||
|
|
||
| BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] | ||
| NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] | ||
|
|
||
| exp = tl.exp |
There was a problem hiding this comment.
Fix the duplicate exp import and remove the unused import.
exp is imported and then redefined, and chunk_local_cumsum is unused. This is already failing lint (F811/F401).
🧹 Proposed fix
-from fla.ops.utils.op import exp
-from fla.ops.utils.cumsum import chunk_local_cumsum🧰 Tools
🪛 Flake8 (7.3.0)
[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused
(F401)
[error] 15-15: redefinition of unused 'exp' from line 8
(F811)
🪛 GitHub Actions: lint
[error] 13-13: Ruff: 'exp' redefined; previous definition exists (F811). Remove duplicate definition.
🪛 Ruff (0.14.13)
15-15: Redefinition of unused exp from line 8: exp redefined here
(F811)
🤖 Prompt for AI Agents
In `@fla/ops/oja2/chunk_o.py` around lines 7 - 15, Remove the
duplicate/conflicting exp import and the unused chunk_local_cumsum import in
chunk_o.py: delete the line importing exp from fla.ops.utils.op (or stop
reassigning exp = tl.exp) so that only the intended exp symbol from tl.exp
remains, and remove the unused chunk_local_cumsum import; keep
prepare_chunk_indices and the shared-memory checks (BKV_LIST/NUM_WARPS) intact.
Ensure there are no other references to fla.ops.utils.op.exp or
chunk_local_cumsum elsewhere in this module before removing.
| from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices | ||
| from fla.ops.utils.op import exp | ||
| from fla.utils import check_shared_mem |
There was a problem hiding this comment.
Remove the unused import.
chunk_local_cumsum isn’t referenced in this module and is already flagged by lint.
🧹 Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
+from fla.ops.utils import prepare_chunk_indices📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices | |
| from fla.ops.utils.op import exp | |
| from fla.utils import check_shared_mem | |
| from fla.ops.utils import prepare_chunk_indices | |
| from fla.ops.utils.op import exp | |
| from fla.utils import check_shared_mem |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 11-11: 'fla.ops.utils.chunk_local_cumsum' imported but unused
(F401)
🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 11 - 13, Remove the unused import
symbol chunk_local_cumsum from the import statement in this module: update the
line "from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices" so it
only imports prepare_chunk_indices; keep the rest of imports (exp from
fla.ops.utils.op and check_shared_mem) unchanged to avoid affecting other
references.
| def recompute_w_u_fwd( | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| gv: Optional[torch.Tensor] = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = A.shape[-1] | ||
| BK = 64 | ||
| BV = 64 | ||
|
|
||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
||
| w = torch.empty_like(v) | ||
| u = torch.empty_like(k) | ||
| vg = torch.empty_like(v) if gv is not None else None | ||
| recompute_w_u_fwd_kernel[(NT, B*H)]( | ||
| k=k, | ||
| v=v, | ||
| vg=vg, | ||
| beta=beta, | ||
| w=w, | ||
| u=u, | ||
| A=A, | ||
| gv=gv, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT, | ||
| BK=BK, | ||
| BV=BV, | ||
| ) | ||
| return w, u, vg |
There was a problem hiding this comment.
Make gv required and fix the return annotation.
gv is dereferenced unconditionally inside the kernel, so the default None is a crash path. Also, the function returns three values but is annotated as two.
🐛 Proposed fix
-def recompute_w_u_fwd(
- k: torch.Tensor,
- v: torch.Tensor,
- beta: torch.Tensor,
- A: torch.Tensor,
- gv: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+def recompute_w_u_fwd(
+ k: torch.Tensor,
+ v: torch.Tensor,
+ beta: torch.Tensor,
+ A: torch.Tensor,
+ gv: torch.Tensor,
+ cu_seqlens: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if gv is None:
+ raise ValueError("gv must be provided for recompute_w_u_fwd")
@@
- vg = torch.empty_like(v) if gv is not None else None
+ vg = torch.empty_like(v)
@@
- return w, u, vg
+ return w, u, vg🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 199 - 237, The function
recompute_w_u_fwd currently declares gv as Optional[torch.Tensor] = None but the
triton kernel dereferences gv unconditionally, and the function only annotates
returning two Tensors while it actually returns w, u, vg; update the signature
to make gv a required torch.Tensor (remove Optional and default None) and change
the return annotation to Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; ensure
vg is always allocated (vg = torch.empty_like(v)) and passed/returned
consistently to match the kernel's expectation and the returned triple (symbols:
recompute_w_u_fwd, gv, vg, w, u, recompute_w_u_fwd_kernel).
| def prepare_wy_repr_bwd( | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| dw: torch.Tensor, | ||
| du: torch.Tensor, | ||
| gv: torch.Tensor = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = 64 | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| CONST_TILING = 64 if check_shared_mem() else 32 | ||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | ||
|
|
||
| dk = torch.empty_like(k) | ||
| dv = torch.empty_like(v, dtype=torch.float) | ||
|
|
||
| dgv = torch.empty_like(gv, dtype=torch.float) | ||
| dA = torch.empty_like(A, dtype=torch.float) | ||
| db = torch.empty_like(beta, dtype=torch.float) | ||
|
|
||
| prepare_wy_repr_bwd_kernel[(NT, B * H)]( | ||
| k=k, | ||
| v=v, | ||
| beta=beta, | ||
| gv=gv, | ||
| A=A, | ||
| dA=dA, | ||
| dw=dw, | ||
| du=du, | ||
| dk=dk, | ||
| dv=dv, | ||
| db=db, | ||
| dgv=dgv, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT, | ||
| BK=BK, | ||
| BV=BV, | ||
| ) | ||
|
|
||
| return dk, dv, db, dgv, dA |
There was a problem hiding this comment.
Guard against gv=None and derive BT from A.
gv defaults to None but is used immediately (torch.empty_like(gv) and kernel loads), which will fail at runtime. Also, using A.shape[-1] makes the chunk size robust if it ever changes.
🐛 Proposed fix
-def prepare_wy_repr_bwd(
- k: torch.Tensor,
- v: torch.Tensor,
- beta: torch.Tensor,
- A: torch.Tensor,
- dw: torch.Tensor,
- du: torch.Tensor,
- gv: torch.Tensor = None,
- cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+def prepare_wy_repr_bwd(
+ k: torch.Tensor,
+ v: torch.Tensor,
+ beta: torch.Tensor,
+ A: torch.Tensor,
+ dw: torch.Tensor,
+ du: torch.Tensor,
+ gv: torch.Tensor,
+ cu_seqlens: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if gv is None:
+ raise ValueError("gv must be provided for prepare_wy_repr_bwd")
@@
- BT = 64
+ BT = A.shape[-1]🤖 Prompt for AI Agents
In `@fla/ops/oja2/wy_fast.py` around lines 240 - 289, prepare_wy_repr_bwd must
guard against gv being None and derive the chunk size BT from A instead of
hardcoding 64; replace the hardcoded BT=64 with BT = A.shape[-1] (or the
appropriate last-dimension of A used for tiling) and before allocating
dgv/db/etc. ensure gv is non-None by doing something like if gv is None: gv =
torch.zeros_like(v) (so torch.empty_like(gv, dtype=torch.float) is safe). Update
uses of BT (NT computation and kernel args) to use the new BT variable and keep
the rest of allocations (dgv, dA, db) unchanged.
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.