Skip to content

[Ops] Fix int32 overflow in pointer arithmetic across all Triton kernels#818

Open
tmct wants to merge 9 commits intofla-org:mainfrom
tmct:fix/int32-overflow-triton-kernels
Open

[Ops] Fix int32 overflow in pointer arithmetic across all Triton kernels#818
tmct wants to merge 9 commits intofla-org:mainfrom
tmct:fix/int32-overflow-triton-kernels

Conversation

@tmct
Copy link
Copy Markdown
Contributor

@tmct tmct commented Apr 8, 2026

Summary

  • Fix int32 overflow in Triton kernel pointer arithmetic across all FLA
    operator families (~83 files). When tensor element counts exceed
    INT32_MAX (2^31 ~ 2.15 billion), tl.program_id() (which returns int32)
    multiplied by large strides overflows, causing illegal CUDA memory
    accesses or silently wrong results.
  • Add AGENTS.md with int64 casting guidelines so that agent code reviewers
    (Gemini, etc.) can flag this class of bug in future PRs, keeping the
    codebase scale-friendly.

This is the same class of bug fixed for the conv kernels in #783 and #803,
now applied comprehensively across all remaining operator families.

Root Cause

tl.program_id() returns int32. In expressions like bos = i_b * T
followed by ptr + (bos * H + i_h) * K, the intermediate products
overflow when B * T * H * K > 2^31.

Example: B=4096, T=576, H=8, K=128 gives (bos * H + i_h) * K = 18,869,767 * 128 = 2,415,330,176 which overflows int32.

Fix

Cast batch/sequence indices to int64 before any stride multiplication:

# Before: int32 overflow
bos, eos = i_b * T, i_b * T + T
# After: int64 arithmetic
bos, eos = tl.cast(i_b, tl.int64) * T, tl.cast(i_b, tl.int64) * T + T

Same pattern applied to varlen cu_seqlens loads (.to(tl.int64) instead
of .to(tl.int32)), direct pointer arithmetic, stride offsets, and
compound indices. After casting bos/eos to int64, T = eos - bos should
be cast back to int32 via T = (eos - bos).to(tl.int32) because
tl.make_block_ptr requires 32-bit shape/offset arguments.

Testing

All 10 kernel families tested with B=4096, T=576, H=8, K=128, V=128
(BTH*K = 2.4B > INT32_MAX) with CUDA_LAUNCH_BLOCKING=1 on an 80GB
A100. Every family crashes unfixed and passes fixed. Triton autotuning
cache (~/.triton/cache/) cleared between runs.

Families tested: gated_delta_rule (chunk + fused_recurrent), delta_rule,
gla, hgrn, rwkv6, comba (chunk + fused_recurrent), simple_gla (chunk +
parallel).

Summary by CodeRabbit

  • Bug Fixes

    • Improved robustness across many kernels by using wider integer arithmetic and explicit casts for sequence/offset computations to reduce overflow/truncation risk for large sequences/batches.
  • Tests

    • Added a gated GPU test placeholder that runs only on machines with large GPU memory to enable future overflow validations.
  • Documentation

    • Added an AI review guide with rules and a checklist for spotting integer-width and casting issues in kernel index arithmetic.

Cast batch/sequence indices to int64 before stride multiplication in ~83
Triton kernel files. tl.program_id() returns int32; when multiplied by
large strides the intermediate products overflow INT32_MAX, causing
illegal CUDA memory accesses or silently wrong results.

Also adds AGENTS.md with review guidelines to help agent code reviewers
keep future Triton kernels scale-friendly.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

Walkthrough

Widespread Triton integer-width hardening: loads from cu_seqlens/offsets promoted to tl.int64, batch/sequence multipliers cast to tl.int64 before multiplication, and per-sequence lengths cast back to tl.int32 for loop/block math; new AGENTS.md documents review rules for these patterns.

Changes

Cohort / File(s) Summary
Documentation
AGENTS.md
Add reviewer guidelines for promoting tl.program_id()/tl.load(...cu_seqlens...) to tl.int64 before multiplications and for casting T back to tl.int32; lists concrete review flags.
Attention / parallel kernels
fla/ops/attn/decoding.py, fla/ops/attn/parallel.py, fla/ops/nsa/parallel.py
Promote bos/eos/program-id-derived operands to tl.int64; compute T from int64 then cast to tl.int32; explicit tl.cast(i_n/i_b, tl.int64) in fixed-length paths.
Chunked kernels (common families)
fla/ops/*/chunk*.py, fla/ops/*/chunk_*/*.py, fla/ops/*/chunk_*.py
Consistent pattern: cu_seqlens/chunk_offsets/split_offsets loads to tl.int64, T = (eos - bos).to(tl.int32), and batch/seq indices cast to tl.int64 before multiplication.
Fused / recurrent / WY-fast
fla/ops/*/fused_recurrent.py, fla/ops/*/wy_fast.py, fla/ops/*/fused_chunk.py
Widen offset loads to int64, downcast per-sequence T to tl.int32, and use tl.cast(..., tl.int64) * T in non-varlen branches.
Gated / delta / generalized-delta families
fla/ops/gated_*/*, fla/ops/delta_rule/*, fla/ops/generalized_delta_rule/*
Systematic promotion of offset/index arithmetic to int64 and explicit cast-back of T to int32; boh/chunk_offsets widened where applicable.
Algorithm families (KDA/NSA/Path-attn/Log-linear/RWKV/TTT/Mesa/HGRN/GLA/GSA)
fla/ops/kda/*, fla/ops/nsa/*, fla/ops/path_attn/*, fla/ops/log_linear_attn/*, fla/ops/rwkv*/*, fla/ops/ttt/*, fla/ops/mesa_net/*, fla/ops/hgrn/*, fla/ops/gla/*, fla/ops/gsa/*
Broad application of the same cast/widening fixes to avoid 32-bit index overflow in varlen and fixed-length code paths.
Utilities / helpers / matmul/pack/pooling/solve_tril
fla/ops/utils/cumsum.py, fla/ops/utils/index.py, fla/ops/utils/matmul.py, fla/ops/utils/pack.py, fla/ops/utils/pooling.py, fla/ops/utils/solve_tril.py
Pointer/batch-stride arithmetic updated to cast i_b/i_n to tl.int64 before multiplications; cu_seqlens loads to tl.int64; T cast to tl.int32 for tiling/loop bounds.
Tests
tests/ops/test_int32_overflow.py
New skipped large-GPU test scaffold that checks CUDA allocation and provides a placeholder test for future int32-overflow validation.
Miscellaneous kernels
...
Many other Triton kernels updated with the same int64-promotion / int32 cast-back pattern; no public APIs changed.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • zhiyuan1i
  • yzhangcs

"🐰 I hop through kernels wide and small,
casting indices large to guard them all.
Int64 for journeys, then T back to two-byte song,
no overflow to trample — code hops along! 🥕"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.13% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: fixing int32 overflow in pointer arithmetic across Triton kernels, which aligns with the extensive modifications across ~83 kernel files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a set of guidelines for preventing integer overflows in Triton kernels and applies these changes across numerous files. The changes ensure that index arithmetic involving tl.program_id() or cu_seqlens is performed using 64-bit integers to avoid potential overflows with large tensor dimensions. I have reviewed the proposed changes and they correctly address the integer overflow issues identified in the new guidelines.

Comment thread fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py Outdated
Comment thread fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py Outdated
Comment thread fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
fla/ops/kda/chunk_intra_token_parallel.py (1)

70-72: ⚠️ Potential issue | 🟠 Major

Missing int64 cast in non-varlen path could cause overflow.

The non-varlen path computes bos = (i_tg // T) * T using int32 arithmetic throughout. When B * T is large, this multiplication can overflow. For consistency with other files in this PR and to prevent overflow, cast to int64 before the multiplication.

🐛 Proposed fix
     else:
-        bos = (i_tg // T) * T
+        bos = tl.cast(i_tg // T, tl.int64) * T
         i_t = i_tg % T
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk_intra_token_parallel.py` around lines 70 - 72, The
non-varlen branch computes bos = (i_tg // T) * T using int32 math which can
overflow for large B*T; update the calculation in chunk_intra_token_parallel.py
so that the multiplication operates in int64 by casting the divisor or operands
to int64 (e.g., cast (i_tg // T) or T to int64) before doing the * T, and keep
i_t = i_tg % T unchanged; ensure the cast mirrors the int64 usage pattern used
elsewhere in this PR for consistency.
fla/ops/rwkv7/fused_addcmul.py (1)

300-300: ⚠️ Potential issue | 🟡 Minor

Duplicate return statement.

Line 300 is an unreachable duplicate of line 299.

🧹 Proposed fix
     else:
         return oxr, oxw, oxk, oxv, oxa, None
-        return oxr, oxw, oxk, oxv, oxa, None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/rwkv7/fused_addcmul.py` at line 300, The function currently contains
a duplicated return of the same tuple (oxr, oxw, oxk, oxv, oxa, None) which is
unreachable and should be removed; locate the redundant return that returns oxr,
oxw, oxk, oxv, oxa, None (duplicate of the previous return) and delete the extra
statement so the function only returns that tuple once.
fla/ops/gla/chunk.py (1)

336-345: ⚠️ Potential issue | 🟠 Major

This forward path still widens i_b * NT / i_b * T too late.

Line 344 and Line 345 cast after the multiply, so the fixed-length path can still wrap before either value ever becomes tl.int64.

Proposed fix
-        i_tg = (i_b * NT + i_t).to(tl.int64)
-        bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)
+        i_tg = tl.cast(i_b, tl.int64) * NT + i_t
+        bos, eos = tl.cast(i_b, tl.int64) * T, tl.cast(i_b, tl.int64) * T + T
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gla/chunk.py` around lines 336 - 345, The else branch widens i_b * NT
and i_b * T only after the multiplication which can overflow; cast operands to
tl.int64 before multiplying so i_tg, bos and eos are computed using 64-bit math:
convert i_b (and NT/T as needed) to tl.int64 prior to the multiplications that
assign i_tg and compute bos/eos (the calculations in the else branch that create
i_tg, bos, eos using i_b, NT, and T).
fla/ops/common/chunk_h_split.py (1)

64-75: ⚠️ Potential issue | 🟠 Major

The split-state indices are still 32-bit.

These changes widen bos/eos, but the split buffers are still addressed from int32 i_sh / i_n * NS values. Lines 84, 130, 191, 192, 294, 336, 399, and 400 later multiply those indices by K * V, so the split/reduction paths can still overflow when split_size is small enough (for example, S == BT == 64 with the PR's large-batch stress shape).

Also applies to: 178-186, 274-285, 387-395

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_h_split.py` around lines 64 - 75, The split-state index
variables (i_sh, i_ss, i_n, i_s and any derived offsets like i_nh used to index
split buffers) are still 32-bit and can overflow when later multiplied by K*V;
change these to 64-bit (use tl.int64) or explicitly cast them to tl.int64 before
any arithmetic/loads/stores that compute buffer offsets (e.g., reading
split_indices, computing bos/eos, NS-based division, and any subsequent offset
multiplications used in the split/reduction paths). Update uses of i_sh, i_ss,
i_n, i_s, i_nh, and any expressions that multiply by K or V to use tl.int64 so
buffer addressing and multiplications cannot overflow. Ensure the same
widening/casts are applied consistently in the other code blocks that mirror
this logic (the later branches/functions that compute offsets for split buffers
and reductions).
🟡 Minor comments (2)
fla/ops/gla/chunk.py-206-213 (1)

206-213: ⚠️ Potential issue | 🟡 Minor

Rename all to clear Ruff A001.

Line 209 and Line 281 shadow the builtin all, which Ruff is already flagging for this file.

Also applies to: 278-285

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gla/chunk.py` around lines 206 - 213, The variable named all in
fla.ops.gla.chunk (assigned in both the IS_VARLEN branch and the else branch)
shadows the builtin and triggers Ruff A001; rename it (for example to
total_elems or all_elems) and update every usage of that identifier within this
module (including the other occurrences around the i_n/i_t handling and the
block referenced near lines 278–285) so the new name replaces all reads/writes
of all without changing logic or types.
AGENTS.md-5-12 (1)

5-12: ⚠️ Potential issue | 🟡 Minor

Tighten the cu_seqlens load rule to account for int64 inputs.

Line 5 incorrectly states tl.load() from cu_seqlens returns int32, but for torch.LongTensor inputs (which match the repo's type), tl.load(cu_seqlens + ...) produces int64. The real overflow risk is when int64 loads are narrowed back to int32 or mixed with int32 tl.program_id() arithmetic before widening. Reword to reflect this distinction and prevent false positives in kernel reviews.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@AGENTS.md` around lines 5 - 12, Update the AGENTS.md rule to clarify that
tl.load(cu_seqlens + ...) can return int64 for torch.LongTensor inputs and that
the real bug is when int64 values are narrowed to int32 or mixed with int32
results from tl.program_id() before widening; instruct reviewers to ensure any
arithmetic involving tl.program_id() or tl.load(cu_seqlens + ...) is promoted to
tl.int64 via tl.cast(...) before multiplication with strides/dimensions, and
only cast back to tl.int32 immediately before calling tl.make_block_ptr (which
requires 32-bit shape/offsets), while flagging explicit narrowing like
.to(tl.int32) or mixed int32*int64 expressions.
🧹 Nitpick comments (3)
fla/ops/utils/pack.py (1)

39-41: Consider adding explicit int64 cast for cu_seqlens loads for consistency.

While cu_seqlens is passed as torch.LongTensor (int64) from Python, other files in this PR explicitly cast the loaded values to tl.int64 for clarity and consistency. The current code should work correctly since Triton's tl.load preserves the source dtype, but an explicit cast would align with the pattern used elsewhere.

-    bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
+    bos, eos = tl.load(cu_seqlens + i_b).to(tl.int64), tl.load(cu_seqlens + i_b + 1).to(tl.int64)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/utils/pack.py` around lines 39 - 41, Cast the values loaded from
cu_seqlens to tl.int64 for consistency: when loading cu_seqlens with tl.load
(the variables bos and eos), explicitly cast them to tl.int64 before computing T
so the types match the project pattern (e.g., use tl.cast on the results of
tl.load for cu_seqlens in pack.py where bos, eos and T are computed).
fla/ops/rwkv7/fused_addcmul.py (1)

146-147: Int64 computation followed by uint32 cast may still have limitations for very large tensors.

Line 146 correctly computes offset_base in int64, but line 147 casts x_idx to tl.uint32. While uint32 supports indices up to ~4.29B (vs int32's ~2.15B), this still imposes an upper limit. For tensors exceeding 4B elements, this cast could cause incorrect addressing.

If the expected use case is within 4B elements, this is acceptable. Otherwise, consider keeping the index as int64.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/rwkv7/fused_addcmul.py` around lines 146 - 147, The computation casts
x_idx to tl.uint32 after computing offset_base in tl.int64, which will overflow
for tensors >4B elements; change x_idx to remain int64 (e.g., remove or replace
the .to(tl.uint32) with .to(tl.int64) or leave as int64) so indices use full
64-bit addressing, and update any downstream uses that assume uint32 to accept
int64; specifically modify the expressions involving offset_base, x_idx,
tl.cast, tl.uint32, and tl.int64 (and respect T, T_OFFSET, D) to preserve 64-bit
indexing.
fla/ops/mesa_net/chunk_h_fwd.py (1)

59-69: Overflow fix applied correctly for bos/eos, but boh handling is inconsistent.

The varlen path correctly loads boh as tl.int64 (line 64), but the non-varlen path computes boh = i_n * NS using int32 arithmetic (line 69). While boh itself is unlikely to overflow, the downstream computation o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V on line 88 relies on an explicit int64 cast. This works but is inconsistent with the varlen path where boh is already int64.

For consistency with the varlen path and defensive coding, consider:

     else:
         bos, eos = tl.cast(i_n, tl.int64) * T, tl.cast(i_n, tl.int64) * T + T
         NT = tl.cdiv(T, BT)
         NS = tl.cdiv(T, BS)
-        boh = i_n * NS
+        boh = tl.cast(i_n, tl.int64) * NS
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/mesa_net/chunk_h_fwd.py` around lines 59 - 69, The non-varlen branch
computes boh with int32 arithmetic (boh = i_n * NS) while the varlen branch
loads boh as tl.int64; make boh consistently int64 in the non-varlen path to
avoid mixed-width arithmetic and match downstream expectations (used in o_h
calculation). Locate boh in chunk_h_fwd.py inside the IS_VARLEN else branch and
change its computation to produce a tl.int64 (e.g., cast i_n and/or the product
to tl.int64) so boh has the same type as the varlen path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/common/chunk_delta_h.py`:
- Around line 70-78: The fixed-length branch computes boh = i_n * NT in 32-bit
which can overflow before later cast to int64; change the else-path to produce
boh as int64 (e.g., compute boh = tl.cast(i_n, tl.int64) * tl.cast(NT, tl.int64)
or cast the product with .to(tl.int64)) so downstream uses like (boh * HV + i_h)
are done in 64-bit arithmetic; update both occurrences where boh is set in the
else branch (the fixed-length path) so boh matches the varlen path's int64 type.

In `@fla/ops/common/fused_recurrent.py`:
- Around line 65-67: The product "all" used for pointer offsets is currently
computed in int32 (from B and T) and can overflow; change the computation of
"all" (and any places computing B * T like the branches that set bos/eos and i_n
* T) to use int64 promotion (e.g., cast B and/or T to int64 before multiplying)
so offsets used for o, dq, dk, dv, dg, dgk, dgv are 64-bit; update both forward
and backward kernels where "all" is computed (the initial branch setting
T/bos/eos and the backward kernel computation) to ensure all pointer arithmetic
uses int64.

In `@fla/ops/delta_rule/fused_recurrent.py`:
- Around line 47-53: The variable all is left as int32 while bos/eos are widened
to int64, which can cause overflow when later multiplied by program_id()-derived
int32 indices (i_v/i_k); locate the assignments to all in the fused_recurrent.py
logic (the branches where IS_VARLEN is checked) and cast/convert all to int64
(e.g., use tl.cast(..., tl.int64) or .to(tl.int64)) immediately after it's
computed so subsequent arithmetic with i_v/i_k and additions with bos (int64)
happen in 64-bit and avoid overflow.

In `@fla/ops/gated_delta_product/chunk_deltaproduct_h.py`:
- Around line 65-70: The fixed-length branch still computes boh in int32 which
overflows when later used in (boh * H + i_h) * K * V; update both fixed-path boh
calculations so boh is computed in int64 by casting i_n to int64 before
multiplying: in the first fixed-length case compute boh as int64(i_n) * cdiv(T
// num_householder, BT) and in the second fixed-length case compute boh as
int64(i_n) * NT, ensuring chunk_offsets/boh math uses int64 throughout.

In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 58-67: The non-varlen branch computes i_tg = i_b * NT + i_t in
int32 which can overflow when later used to index h; change the expression to
promote to int64 (e.g., i_tg = tl.cast(i_b, tl.int64) * NT + i_t or cast the
whole result) so i_tg is int64 before any pointer arithmetic (references: i_tg,
i_b, NT, i_t, and uses where (i_tg * H + i_h) * K * V indexes h and where
bos/eos are int64).

In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 54-57: The offset calculations using i_v and i_nh can overflow
int32 (e.g., i_nh * K * V); cast i_v and i_nh to int64 before performing pointer
arithmetic so multiplications use 64-bit integers (e.g., use tl.cast(i_nh,
tl.int64) or i64 variant) when computing state-buffer offsets (the expressions
like i_nh * K * V and i_v * ...); apply the same change consistently in the
related kernels (rwkv6, rwkv7, common, dplr, logsumexp) where similar
multiplications occur to ensure correct indexing.

In `@fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py`:
- Around line 68-74: The fixed-length (else) branch still performs i_b * T in
int32 before casting, which can overflow; update the computations in the else
branch (and the other occurrences noted) to perform the multiplication in int64
by casting one operand to tl.int64 first (e.g., cast i_b or T to tl.int64 before
multiplying), then compute bos and eos as int64 and downcast only if/when
needed; ensure the same change is applied at the other occurrences referenced
(the other i_b * T sites in this file).

In `@fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py`:
- Around line 141-147: In chunk_dplr_bwd_o_kernel and chunk_dplr_bwd_kernel_dv
the loop index i_tg is computed as int32 and later used in pointer offset math
(e.g., i_tg * H + i_h) which can overflow for large chunk counts; change the
VARLEN path (i_tg = i_t) and the non-VARLEN path (i_tg = i_b * NT + i_t) to cast
i_tg to int64 immediately after computation so all subsequent multiplications
and offsets (h, dh, dgk_last) use int64 arithmetic; apply the same fix for the
other occurrence around lines 260-266.

In `@fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py`:
- Around line 61-67: The computed thread-group index i_tg must be cast to int64
in both branches to avoid 32-bit overflow when used in pointer-offset
arithmetic; update the branch where i_tg = i_b * NT + i_t and any branch using
i_tg to use tl.cast(i_tg, tl.int64) (or cast the factors so the multiplication
promotes to int64) before downstream computations like (i_tg * H + i_h) * K * V,
ensuring bos/eos/T/NT remain consistent and the offset arithmetic is performed
in int64.

In `@fla/ops/gla/chunk.py`:
- Around line 608-617: The varlen branch leaves i_tg as an int32 (i_tg = i_t)
which later gets multiplied by H*K*V when indexing h/dh and can overflow int32
for large B/T shapes; ensure i_tg is promoted to 64-bit in both branches (cast
i_t or computed i_tg via tl.cast(..., tl.int64)) or cast immediately where used
in offset math, and also ensure any intermediate multiplications for buffer
offsets (uses of H, K, V and i_tg when indexing h/dh) are done with int64 to
avoid overflow; update the IS_VARLEN branch (and the non-varlen branch where
i_tg is computed) to produce an int64 i_tg or wrap offset calculations with
tl.cast(..., tl.int64) so h/dh indexing uses 64-bit offsets.

In `@fla/ops/kda/chunk_bwd.py`:
- Around line 72-75: In chunk_kda_bwd_kernel_wy_dqkg_fused, the expressions
computing i_tg and bos/eos perform multiplication in int32 then cast, causing
overflow for large B*NT or B*T; change those to cast i_b to tl.int64 before
multiplying (e.g. compute i_tg as tl.cast(i_b, tl.int64) * NT + tl.cast(i_t,
tl.int64) and compute bos/eos as tl.cast(i_b, tl.int64) * T and tl.cast(i_b,
tl.int64) * T + T) so the multiplications happen in int64.

In `@fla/ops/log_linear_attn/chunk.py`:
- Around line 77-85: The kernel widens sequence offsets but leaves the
recurrent-state index i_n as int32, causing overflow in pointer arithmetic when
computing state-buffer addresses (used to build h0, ht, dh, h_l); fix by
promoting i_n to int64 immediately at the kernel entry (cast i_n to tl.int64
once) so all subsequent calculations that multiply i_n by L_IN, L_OUT, H, K, V,
NT, etc. use int64; update references to i_n in the state-pointer/address
computations (locations building h0, ht, dh, h_l and the regions near the
previously noted sites) to use the new int64 variable.

In `@fla/ops/mesa_net/chunk_h_kk_intra_bwd.py`:
- Around line 51-57: Widen i_tg to int64 before any arithmetic to avoid int32
overflow: in the varlen branch set i_tg = tl.cast(i_t, tl.int64) (instead of
keeping it int32), and in the fixed-len branch compute i_tg with all operands
cast to int64, e.g. i_tg = tl.cast(i_b, tl.int64) * tl.cast(NT, tl.int64) +
tl.cast(i_t, tl.int64); then remove the later .to(tl.int64) casts on the (i_tg *
H + i_h) expressions (lines that compute the chunk/state offsets) since i_tg and
operands will already be int64.

In `@fla/ops/nsa/compression.py`:
- Around line 55-62: The fixed-length branch leaves boc as int32 which can
overflow later when used for buffer pointer arithmetic; in the else branch where
boc is set via boc = i_b * tl.cdiv(T, BS) (and in the other occurrences at the
same pattern), cast or compute boc as tl.int64 (e.g., wrap with tl.cast(...,
tl.int64) or ensure operands are int64) so boc has the same widened integer type
as the varlen path (which uses chunk_offsets loaded as int64) before any
base-pointer math; update all instances of boc creation in the fixed-length path
(the else blocks around the IS_VARLEN checks) to produce an int64 boc.

In `@fla/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py`:
- Around line 60-66: The fixed-length branch computes boh = i_n * NT as int32
which can overflow later when used in the h pointer offset; change the
computation in the fixed-length path of intra_chunk_preprocess_bwd_prepare so
boh is promoted to int64 (mirror the varlen branch which does
tl.load(...).to(tl.int64)), e.g. cast the product (or operands) to int64 so boh
is int64 before any use (especially the ((boh + i_t) * H + i_h) * K * K
expression under RETURN_H=True).

In `@fla/ops/path_attn/parallel_path_bwd_inter_dkv.py`:
- Around line 60-63: The fixed-length branch computes boh_large as a 32-bit
product which can overflow; make boh_large 64-bit like bos/eos by casting
operands to tl.int64 — e.g. replace boh_large = i_n * tl.cdiv(T, S) with
boh_large = tl.cast(i_n, tl.int64) * tl.cast(tl.cdiv(T, S), tl.int64) (or
equivalently cast the whole product to tl.int64) so subsequent pointer
arithmetic ((boh_large * H + i_h) * K * K) uses 64-bit integers.

In `@fla/ops/rwkv6/chunk.py`:
- Around line 439-447: The non-varlen branch can produce int32 overflow for
backward-state offsets; update the non-varlen computation of boh and the later
index i_tg to use int64 casts like the varlen path: when computing boh in the
else branch replace boh = i_n * NT with boh = tl.cast(i_n, tl.int64) *
tl.cast(NT, tl.int64) (or cast the product to tl.int64), and when computing i_tg
replace arithmetic i_b * NT + i_t with tl.cast(i_b, tl.int64) * tl.cast(NT,
tl.int64) + tl.cast(i_t, tl.int64) so subsequent pointer multiplies by H*K*V use
int64 indices.

In `@fla/ops/ttt/fused_chunk.py`:
- Around line 186-190: The hidden-state base index boh is computed in int32 and
later multiplied into ((boh + i_t) * H + i_h) * K * V which overflows; fix by
computing boh and all intermediate index arithmetic in 64-bit: compute boh =
tl.cast(i_n, tl.int64) * tl.cast(NT, tl.int64) (and where NT is used ensure it's
int64), and make the downstream address expression use tl.cast(...) to int64 for
(boh + i_t), H, i_h, K, V before the multiplications so the final index remains
in int64 (update occurrences around boh creation and uses near variables boh,
i_t, i_h, H, K, V at both the forward and backward sites).

In `@fla/ops/utils/pooling.py`:
- Around line 43-55: The output pointer arithmetic uses a 32-bit i_tg causing
potential overflow; cast i_tg to int64 before computing p_o in both forward and
backward kernels: in the VARLEN path set i_tg to tl.cast(i_t, tl.int64), and in
the fixed-length path compute i_tg as tl.cast(i_b, tl.int64) * NT + tl.cast(i_t,
tl.int64); update uses of i_tg where p_o is created (the make_block_ptr call) so
the (i_tg * H + i_h) term is evaluated in int64 to avoid wrapping when building
the output block pointer.

---

Outside diff comments:
In `@fla/ops/common/chunk_h_split.py`:
- Around line 64-75: The split-state index variables (i_sh, i_ss, i_n, i_s and
any derived offsets like i_nh used to index split buffers) are still 32-bit and
can overflow when later multiplied by K*V; change these to 64-bit (use tl.int64)
or explicitly cast them to tl.int64 before any arithmetic/loads/stores that
compute buffer offsets (e.g., reading split_indices, computing bos/eos, NS-based
division, and any subsequent offset multiplications used in the split/reduction
paths). Update uses of i_sh, i_ss, i_n, i_s, i_nh, and any expressions that
multiply by K or V to use tl.int64 so buffer addressing and multiplications
cannot overflow. Ensure the same widening/casts are applied consistently in the
other code blocks that mirror this logic (the later branches/functions that
compute offsets for split buffers and reductions).

In `@fla/ops/gla/chunk.py`:
- Around line 336-345: The else branch widens i_b * NT and i_b * T only after
the multiplication which can overflow; cast operands to tl.int64 before
multiplying so i_tg, bos and eos are computed using 64-bit math: convert i_b
(and NT/T as needed) to tl.int64 prior to the multiplications that assign i_tg
and compute bos/eos (the calculations in the else branch that create i_tg, bos,
eos using i_b, NT, and T).

In `@fla/ops/kda/chunk_intra_token_parallel.py`:
- Around line 70-72: The non-varlen branch computes bos = (i_tg // T) * T using
int32 math which can overflow for large B*T; update the calculation in
chunk_intra_token_parallel.py so that the multiplication operates in int64 by
casting the divisor or operands to int64 (e.g., cast (i_tg // T) or T to int64)
before doing the * T, and keep i_t = i_tg % T unchanged; ensure the cast mirrors
the int64 usage pattern used elsewhere in this PR for consistency.

In `@fla/ops/rwkv7/fused_addcmul.py`:
- Line 300: The function currently contains a duplicated return of the same
tuple (oxr, oxw, oxk, oxv, oxa, None) which is unreachable and should be
removed; locate the redundant return that returns oxr, oxw, oxk, oxv, oxa, None
(duplicate of the previous return) and delete the extra statement so the
function only returns that tuple once.

---

Minor comments:
In `@AGENTS.md`:
- Around line 5-12: Update the AGENTS.md rule to clarify that tl.load(cu_seqlens
+ ...) can return int64 for torch.LongTensor inputs and that the real bug is
when int64 values are narrowed to int32 or mixed with int32 results from
tl.program_id() before widening; instruct reviewers to ensure any arithmetic
involving tl.program_id() or tl.load(cu_seqlens + ...) is promoted to tl.int64
via tl.cast(...) before multiplication with strides/dimensions, and only cast
back to tl.int32 immediately before calling tl.make_block_ptr (which requires
32-bit shape/offsets), while flagging explicit narrowing like .to(tl.int32) or
mixed int32*int64 expressions.

In `@fla/ops/gla/chunk.py`:
- Around line 206-213: The variable named all in fla.ops.gla.chunk (assigned in
both the IS_VARLEN branch and the else branch) shadows the builtin and triggers
Ruff A001; rename it (for example to total_elems or all_elems) and update every
usage of that identifier within this module (including the other occurrences
around the i_n/i_t handling and the block referenced near lines 278–285) so the
new name replaces all reads/writes of all without changing logic or types.

---

Nitpick comments:
In `@fla/ops/mesa_net/chunk_h_fwd.py`:
- Around line 59-69: The non-varlen branch computes boh with int32 arithmetic
(boh = i_n * NS) while the varlen branch loads boh as tl.int64; make boh
consistently int64 in the non-varlen path to avoid mixed-width arithmetic and
match downstream expectations (used in o_h calculation). Locate boh in
chunk_h_fwd.py inside the IS_VARLEN else branch and change its computation to
produce a tl.int64 (e.g., cast i_n and/or the product to tl.int64) so boh has
the same type as the varlen path.

In `@fla/ops/rwkv7/fused_addcmul.py`:
- Around line 146-147: The computation casts x_idx to tl.uint32 after computing
offset_base in tl.int64, which will overflow for tensors >4B elements; change
x_idx to remain int64 (e.g., remove or replace the .to(tl.uint32) with
.to(tl.int64) or leave as int64) so indices use full 64-bit addressing, and
update any downstream uses that assume uint32 to accept int64; specifically
modify the expressions involving offset_base, x_idx, tl.cast, tl.uint32, and
tl.int64 (and respect T, T_OFFSET, D) to preserve 64-bit indexing.

In `@fla/ops/utils/pack.py`:
- Around line 39-41: Cast the values loaded from cu_seqlens to tl.int64 for
consistency: when loading cu_seqlens with tl.load (the variables bos and eos),
explicitly cast them to tl.int64 before computing T so the types match the
project pattern (e.g., use tl.cast on the results of tl.load for cu_seqlens in
pack.py where bos, eos and T are computed).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9e5d2318-a4d0-4897-ad54-43c7ace65a9c

📥 Commits

Reviewing files that changed from the base of the PR and between a8cb120 and f0295f5.

📒 Files selected for processing (84)
  • AGENTS.md
  • fla/ops/attn/decoding.py
  • fla/ops/attn/parallel.py
  • fla/ops/comba/fused_recurrent.py
  • fla/ops/comba/utils.py
  • fla/ops/comba/wy_fast.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_h.py
  • fla/ops/common/chunk_h_parallel.py
  • fla/ops/common/chunk_h_split.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/chunk_scaled_dot_kkt.py
  • fla/ops/common/fused_chunk.py
  • fla/ops/common/fused_recurrent.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/delta_rule/fused_recurrent.py
  • fla/ops/delta_rule/wy_fast.py
  • fla/ops/gated_delta_product/chunk_deltaproduct_h.py
  • fla/ops/gated_delta_product/chunk_deltaproduct_o.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gated_delta_rule/gate.py
  • fla/ops/gated_delta_rule/wy_fast.py
  • fla/ops/gated_oja_rule/chunk_h.py
  • fla/ops/gated_oja_rule/chunk_kkt.py
  • fla/ops/gated_oja_rule/chunk_o.py
  • fla/ops/gated_oja_rule/fused_recurrent.py
  • fla/ops/gated_oja_rule/wy_fast.py
  • fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py
  • fla/ops/generalized_delta_rule/dplr/fused_recurrent.py
  • fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py
  • fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py
  • fla/ops/generalized_delta_rule/iplr/chunk.py
  • fla/ops/generalized_delta_rule/iplr/fused_recurrent.py
  • fla/ops/generalized_delta_rule/iplr/wy_fast.py
  • fla/ops/gla/chunk.py
  • fla/ops/gsa/chunk.py
  • fla/ops/hgrn/chunk.py
  • fla/ops/hgrn/fused_recurrent.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/chunk_intra_token_parallel.py
  • fla/ops/kda/fused_recurrent.py
  • fla/ops/kda/gate.py
  • fla/ops/kda/wy_fast.py
  • fla/ops/log_linear_attn/chunk.py
  • fla/ops/mesa_net/chunk_cg_solver_bwd.py
  • fla/ops/mesa_net/chunk_cg_solver_fwd.py
  • fla/ops/mesa_net/chunk_h_fwd.py
  • fla/ops/mesa_net/chunk_h_kk_intra_bwd.py
  • fla/ops/mesa_net/chunk_h_kv_intra_bwd.py
  • fla/ops/mesa_net/chunk_h_kv_intra_bwd_separate.py
  • fla/ops/nsa/compression.py
  • fla/ops/nsa/parallel.py
  • fla/ops/path_attn/cumprod_householder_bwd.py
  • fla/ops/path_attn/cumprod_householder_fwd.py
  • fla/ops/path_attn/intra_chunk_preprocess_bwd.py
  • fla/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py
  • fla/ops/path_attn/intra_chunk_preprocess_fwd.py
  • fla/ops/path_attn/parallel_path_bwd_inter_dkv.py
  • fla/ops/path_attn/parallel_path_bwd_inter_dqh.py
  • fla/ops/path_attn/parallel_path_bwd_intra.py
  • fla/ops/path_attn/parallel_path_fwd.py
  • fla/ops/path_attn/prepare_k_cache.py
  • fla/ops/path_attn/transform_q.py
  • fla/ops/rwkv6/chunk.py
  • fla/ops/rwkv6/fused_recurrent.py
  • fla/ops/rwkv7/fused_addcmul.py
  • fla/ops/rwkv7/fused_k_update.py
  • fla/ops/rwkv7/fused_recurrent.py
  • fla/ops/simple_gla/parallel.py
  • fla/ops/ttt/chunk.py
  • fla/ops/ttt/fused_chunk.py
  • fla/ops/utils/cumsum.py
  • fla/ops/utils/index.py
  • fla/ops/utils/matmul.py
  • fla/ops/utils/pack.py
  • fla/ops/utils/pooling.py
  • fla/ops/utils/solve_tril.py

Comment thread fla/ops/common/chunk_delta_h.py
Comment thread fla/ops/common/fused_recurrent.py
Comment thread fla/ops/delta_rule/fused_recurrent.py
Comment thread fla/ops/gated_delta_product/chunk_deltaproduct_h.py
Comment thread fla/ops/gated_oja_rule/chunk_o.py
Comment thread fla/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py
Comment thread fla/ops/path_attn/parallel_path_bwd_inter_dkv.py
Comment thread fla/ops/rwkv6/chunk.py
Comment thread fla/ops/ttt/fused_chunk.py
Comment thread fla/ops/utils/pooling.py
@tmct
Copy link
Copy Markdown
Contributor Author

tmct commented Apr 8, 2026

I don't think the coderabbit suggestions are as applicable as the Gemini ones - I will ignore them.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py`:
- Around line 68-73: The tail-row pointer arithmetic can overflow because i_t
(from tl.program_id()) and cu_seqlens loads are used as int32 when multiplied by
large strides (e.g., T, H*K, D); cast tl.program_id() results and any
tl.load(cu_seqlens + ...) to tl.int64 before performing multiplications for
pointer/row computations (specifically update the pointer math that derives the
last row from i_t in the intra kernels and the similar computations referenced
around the blocks using i_t/i_b and cu_seqlens — e.g., the sites currently at
the locations corresponding to the previous lines 206, 274-279, and 366) so all
multiplications with T, H*K, or D are done in int64 and then cast down only if
needed.
- Around line 424-433: The compound flattened index i_tg (computed from i_b and
i_t) is kept as int32 and later multiplied into the dgk_last base pointer,
violating the rule to widen program_id()-derived indices to int64; change the
assignment so i_tg is created/widened to tl.int64 (or cast to tl.int64
immediately before any arithmetic with dgk_last), e.g., compute i_tg from (i_b *
NT + i_t) as int64 when IS_VARLEN is false (or cast the existing i_tg to
tl.int64 before it is used with dgk_last) to ensure all arithmetic involving
tl.program_id()-sourced indices is done in int64.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62773b5f-c991-415b-8259-54c86a8b0d91

📥 Commits

Reviewing files that changed from the base of the PR and between f0295f5 and 9d97e38.

📒 Files selected for processing (1)
  • fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py

Comment thread fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py
Comment thread fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py
@zhiyuan1i

This comment was marked as duplicate.

Comment thread AGENTS.md Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete this file

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you suggest an alternative way to prevent regressions of this form? For example, does any of the pre-merge automated testing run on bf16-compatible GPUs? If so I can add some tests, but it would be a shame if these were large and/or slow.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I am thinking of my other PR that involves bf16 - clearly that is not so relevant here :)

@tmct
Copy link
Copy Markdown
Contributor Author

tmct commented Apr 14, 2026

Bear with me while I attempt to replace AGENTS.md with large tensor tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/test_int32_overflow.py`:
- Around line 45-48: Replace the placeholder test_placeholder (decorated with
requires_large_gpu) with a real regression test named e.g. test_int32_overflow
that constructs an overflow-scale input shape exercising the Triton path (the
int64 pointer-arithmetic fix), runs one representative kernel family used in CI,
computes a reference result on CPU, and asserts elementwise equality (or
tolerance) between kernel output and reference; ensure the test fails before the
fix and passes after by choosing dimensions that trigger 32-bit pointer
overflow, and include clear usage of requires_large_gpu, the kernel invocation,
and the final assert to validate correctness.
- Around line 25-42: The skip condition currently calls
_has_enough_gpu_memory(20) at import/collection time via pytest.mark.skipif
(requires_large_gpu), causing a 20GB allocation during collection; instead
remove the pytest.mark.skipif usage and create a runtime check (e.g., a fixture
named large_gpu or call inside the test body) that calls
_has_enough_gpu_memory(20) and invokes pytest.skip("Requires CUDA with >= 20GB
allocatable memory") when it returns False; update tests that referenced
requires_large_gpu to accept the new fixture or call the runtime skip so the
allocation only happens when the test runs, and keep or enhance the placeholder
test to perform the intended int32-overflow related assertions once implemented.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 26827e7d-143b-4205-8e1a-5cff83ebc0d0

📥 Commits

Reviewing files that changed from the base of the PR and between 03ab3ad and 37a4394.

📒 Files selected for processing (1)
  • tests/ops/test_int32_overflow.py

Comment on lines +25 to +42
def _has_enough_gpu_memory(min_gb=20):
"""Check CUDA availability and allocate min_gb to confirm it's usable."""
if not torch.cuda.is_available():
return False
try:
# Actually allocate to confirm memory is available, not just reported
x = torch.empty(int(min_gb * 1024**3 // 4), dtype=torch.float32, device='cuda')
del x
torch.cuda.empty_cache()
return True
except torch.cuda.OutOfMemoryError:
return False


requires_large_gpu = pytest.mark.skipif(
not _has_enough_gpu_memory(20),
reason='Requires CUDA with >= 20GB allocatable memory'
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Pytest parallelization config =="
fd -i 'pytest.ini|pyproject.toml|setup.cfg|tox.ini' -x sh -c '
  echo "--- $1 ---"
  rg -n "addopts|xdist|\\-n\\s|--numprocesses" "$1" || true
' sh {}

echo
echo "== Collection-time guard usage in overflow test =="
rg -nC3 "_has_enough_gpu_memory|pytest\\.mark\\.skipif|requires_large_gpu" tests/ops/test_int32_overflow.py

Repository: fla-org/flash-linear-attention

Length of output: 730


🏁 Script executed:

# Check pytest configuration in pyproject.toml
cat pyproject.toml | grep -A 20 "\[tool.pytest"

Repository: fla-org/flash-linear-attention

Length of output: 359


🏁 Script executed:

# Check the full test file to see if there are other tests
wc -l tests/ops/test_int32_overflow.py
head -50 tests/ops/test_int32_overflow.py

Repository: fla-org/flash-linear-attention

Length of output: 1612


🏁 Script executed:

# Check if pytest-xdist is in dependencies
rg -i "xdist|parallel" pyproject.toml setup.py setup.cfg || echo "No xdist found in config files"

Repository: fla-org/flash-linear-attention

Length of output: 171


Move GPU memory probe from collection time to test/fixture body.

pytest.mark.skipif(not _has_enough_gpu_memory(20), ...) evaluates at import time, triggering a 20GB allocation during test collection. This can cause CI timeouts or resource exhaustion, especially in container/sandbox environments where collection often precedes parallelization. Instead, use a fixture that calls pytest.skip(...) at runtime, or probe memory inside the test body.

Also, test_placeholder is a stub test (assert True). While it's marked to be replaced, consider adding minimal assertions related to the PR objective (int32 overflow validation with large shapes BTH*K > INT32_MAX) once real tests are implemented.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_int32_overflow.py` around lines 25 - 42, The skip condition
currently calls _has_enough_gpu_memory(20) at import/collection time via
pytest.mark.skipif (requires_large_gpu), causing a 20GB allocation during
collection; instead remove the pytest.mark.skipif usage and create a runtime
check (e.g., a fixture named large_gpu or call inside the test body) that calls
_has_enough_gpu_memory(20) and invokes pytest.skip("Requires CUDA with >= 20GB
allocatable memory") when it returns False; update tests that referenced
requires_large_gpu to accept the new fixture or call the runtime skip so the
allocation only happens when the test runs, and keep or enhance the placeholder
test to perform the intended int32-overflow related assertions once implemented.

Comment thread tests/ops/test_int32_overflow.py Outdated
tmct added 2 commits April 14, 2026 21:59
… catch the bug

Adds regression tests for 7 kernel families at B=4096,T=576,H=8,K=128
(B*T*H*K=2.4B > INT32_MAX). With int64 fixes reverted, these tests
should fail with illegal memory access, proving they catch the overflow.
@tmct tmct force-pushed the fix/int32-overflow-triton-kernels branch from cea6978 to 0781576 Compare April 14, 2026 21:46
@tmct
Copy link
Copy Markdown
Contributor Author

tmct commented Apr 18, 2026

Still working on this but I have got access to proper GPUs on it now. Working on creating suitable tests before pushing here again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants