Skip to content

[Example] Add KDA algorithm implementation in tilelang#1660

Merged
LeiWang1999 merged 23 commits intotile-ai:mainfrom
wfloveiu:example-kda-algorithm
Jan 26, 2026
Merged

[Example] Add KDA algorithm implementation in tilelang#1660
LeiWang1999 merged 23 commits intotile-ai:mainfrom
wfloveiu:example-kda-algorithm

Conversation

@wfloveiu
Copy link
Contributor

@wfloveiu wfloveiu commented Jan 12, 2026

The KDA algorithm, originally implemented with Triton in the flash-linear-attention repo, is fully reimplemented based on TileLang (0.1.6.post2+cuda.git729e66ca). Precision alignment and performance testing have been finished for every individual operator.

Summary by CodeRabbit

  • New Features

    • Added comprehensive KDA (Key-Decay-Attention) kernel implementations with GPU acceleration using Triton and TileLang, including forward and backward passes for chunked attention, cumulative sums, and delta-rule operations with support for variable-length sequences.
    • Added GPU utility functions for caching, hardware introspection, and device context management.
  • Bug Fixes

    • Removed public method from atomic operations class.
  • Documentation

    • Added KDA kernel implementation setup guide.
  • Tests

    • Added test utilities and example benchmarking scripts for KDA implementations.
  • Chores

    • Updated pre-commit tool configuration.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

Introduces comprehensive KDA (Key-Delta-Attention) GPU kernel implementations using Triton and TileLang, spanning cumulative sum, gated delta rules, inter-intra attention, and output computation operations. Includes utilities, examples, and reference implementations. Updates pre-commit configuration and removes a public method from AtomicAddNode.

Changes

Cohort / File(s) Summary
Configuration Update
.pre-commit-config.yaml
Updated ruff version from v0.14.10 to v0.14.11 and added --diff flag to ruff-format arguments.
API Removal
src/op/atomic_add.h
Removed public method ReturnIndicesAndSize(int src_dst) const from AtomicAddNode class, narrowing public API surface.
KDA Core Triton Kernels
examples/kda/FLA_KDA/cumsum.py, fla_chunk_delta.py, fla_chunk_inter.py, fla_chunk_intra.py, fla_chunk_intra_token_parallel.py, fla_chunk_o.py, fla_wy_fast.py
New Triton-based kernel implementations for chunked prefix sum, gated delta rule, inter/intra-chunk attention, token-parallel computation, output fusion, and W/U recomputation; support variable-length sequences, reverse operations, scaling, and mixed precision.
KDA Utilities
examples/kda/FLA_KDA/fla_utils.py, examples/kda/test_utils_kda.py
Introduces device/hardware introspection, caching decorators, input validation, tensor comparison helpers, benchmark utilities, and PyTorch version compatibility wrappers.
KDA Example/Test Files
examples/kda/chunk_bwd_dqkwg.py, chunk_bwd_dv.py, chunk_bwd_gla_dA.py, chunk_bwd_intra.py, chunk_delta_bwd.py, chunk_delta_h_fwd.py, chunk_inter_solve_fused.py, chunk_intra_token_parallel.py, chunk_o.py, wy_fast.py, wy_fast_bwd.py
TileLang-based example scripts implementing and benchmarking backward/forward kernel variants with autotuning, input/output preparation, reference comparison, and timing utilities.
Documentation
examples/kda/README.md
Added setup notes documenting required versions for TileLang, Triton, and FLA.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

enhancement

Poem

🐰 A warren of kernels, all chunked and so fast,
Triton and TileLang compute unsurpassed,
Delta rules gate, and cumulative sums flow,
Variable lengths handled—forward and back we go!
GPU magic crystallized, batched and refined—
KDA attention redefined!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.01% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Example] Add KDA algorithm implementation in tilelang' clearly and concisely describes the main change: adding KDA algorithm examples implemented in TileLang.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 15

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 89-132: In kernel, G_shared and Gn_shared are allocated with
input_dtype but should use gate_dtype; change the dtype argument for the
alloc_shared calls for G_shared and Gn_shared to gate_dtype so gate values keep
their intended precision (update the lines allocating G_shared and Gn_shared
inside the kernel).

In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 168-200: The loop uses T.ceildiv(S, block_S) but h was allocated
using BS = S // block_S causing OOB when S % block_S != 0; update the tensor
allocation symbol h_shape to use T.ceildiv(S, block_S) (i.e. h_shape = (B,
T.ceildiv(S, block_S), H, DK, DV)) so its second dimension matches the loop
bound, and apply the same fix for dh_shape in chunk_delta_bwd.py (replace BS
with T.ceildiv(S, block_S)); locate and change the allocations that reference BS
to use T.ceildiv(S, block_S) instead.

In @examples/KDA/chunk_o.py:
- Around line 125-133: The inner Parallel loop mistakenly iterates over block_DV
while indexing Q_shared/GK_shared/GQ_shared which have shape [block_S,
block_DK]; change the loop in the block that assigns Q_shared/GQ_shared from
"for i_s, i_v in T.Parallel(block_S, block_DV):" to iterate over block_DK (e.g.
"for i_s, i_k in T.Parallel(block_S, block_DK):") and update the index names
used when reading/writing Q_shared, GK_shared and GQ_shared so the second index
uses the block_DK iterator; leave HIDDEN_shared and the subsequent T.gemm call
unchanged.

In @examples/KDA/FLA_KDA/fla_chunk_delta.py:
- Around line 538-541: In the backward function the variable chunk_offsets can
remain undefined when cu_seqlens is not None; update the logic so that when
chunk_indices is None and cu_seqlens is not None you also set chunk_offsets
(mirror forward): call prepare_chunk_indices(cu_seqlens, chunk_size) to produce
chunk_indices and derive chunk_offsets (or compute chunk_offsets from the same
cu_seqlens/chunk_size helper) before proceeding; modify the branch around
chunk_indices, cu_seqlens and the subsequent use of chunk_offsets to ensure
chunk_offsets is always initialized (refer to chunk_indices, cu_seqlens,
prepare_chunk_indices, and chunk_offsets).
- Around line 483-491: The variable chunk_offsets is left undefined when
cu_seqlens is provided, causing an UnboundLocalError later when passed to the
kernel; fix by ensuring chunk_offsets is initialized in the branch that handles
cu_seqlens (e.g., after computing chunk_indices via
prepare_chunk_indices(cu_seqlens, chunk_size) compute or derive chunk_offsets
from chunk_indices, or explicitly set chunk_offsets = None if the kernel accepts
that), so that the subsequent kernel invocation that uses chunk_offsets, h, and
final_state always has a defined value.

In @examples/KDA/FLA_KDA/fla_chunk_intra.py:
- Around line 607-618: Add the two missing imports for the functions used in
this file: import chunk_kda_fwd_intra_token_parallel from
.fla_chunk_intra_token_parallel and import recompute_w_u_fwd from .fla_wy_fast;
place these import statements near the other module imports at the top of the
file so chunk_kda_fwd_intra_token_parallel (used around the Aqk, Akk_diag call)
and recompute_w_u_fwd (used later) are defined and available at runtime.
- Around line 638-647: The call to recompute_w_u_fwd in the return path will
raise NameError because the function is not imported or defined; add a proper
import for recompute_w_u_fwd at the top of the module (for example: from
.fla_wy_fast import recompute_w_u_fwd or the actual module where
recompute_w_u_fwd is defined) so the symbol is available to the function using
it, or if the implementation belongs in this file, paste the function definition
above its usage; ensure tests/imports run to verify the NameError is resolved.

In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 547-552: The conditional uses the function object check_shared_mem
instead of calling it, causing the middle branch to always be truthy; change the
second conditional to call check_shared_mem with the same arguments used in the
first branch (e.g., check_shared_mem('hopper', k.device.index)) and ensure
CONST_TILING is set to 64 in that branch so the three branches correctly use
128, 64, and 32 respectively.
- Around line 513-521: The branch guard for USE_A leaves b_A uninitialized when
USE_A is False; ensure b_A is always defined before it is used by either (a)
initializing b_A in the else case to a zero tensor with the same shape/dtype
expected by the later tl.where (matching the shape produced by p_A/tl.load) or
(b) move the tl.where logic inside the USE_A block and otherwise set b_A =
zeros(...). Reference symbols: USE_A, p_A, b_A, tl.load, tl.where, o_t, m_A, and
do.dtype.element_ty to create the zero tensor with correct shape and element
type so tl.where never sees an undefined b_A.

In @examples/KDA/test_utils.py:
- Around line 5-40: calc_sim and assert_similar use .data and perform Python
control flow on torch tensors which can raise on CUDA; change calc_sim to avoid
.data (use .detach().double() or .to(torch.double)) and return a Python float
(use sim.item()) instead of a tensor, ensure the zero-denominator check uses a
scalar (.item() or .eq(0).all().item()); in assert_similar stop doing chained
tensor comparisons by converting sim to a float before computing diff (e.g., sim
= calc_sim(...); diff = 1.0 - float(sim)), keep the isfinite/masking logic but
ensure masked_fill uses the correct masks (x.masked_fill(~x_mask, 0) /
y.masked_fill(~y_mask, 0) as already used) and replace any remaining .data
usages with .detach()/.to(...)

In @examples/KDA/wy_fast.py:
- Around line 204-205: The assignment "use_qg = False," creates a one-element
tuple instead of a boolean; remove the trailing comma so use_qg is assigned the
boolean False (i.e., change use_qg = False, to use_qg = False) where the
kernel/config options are set (the nearby use_kg = True line in wy_fast.py) so
any conditional checks on use_qg behave correctly; verify related conditional
logic that references use_qg expects a bool.
🟠 Major comments (15)
examples/KDA/chunk_bwd_dqkwg.py-89-94 (1)

89-94: Guard or handle S % chunk_size != 0 to avoid shape/OOB hazards.

BS = S // block_S is used to shape h/Dh and the kernel copies [bs*block_S:(bs+1)*block_S] without boundary handling. If S isn’t an exact multiple of chunk_size, this will mis-shape h and risks invalid memory accesses.

Also applies to: 155-164

examples/KDA/chunk_bwd_gla_dA.py-12-16 (1)

12-16: Avoid import-time CUDA_VISIBLE_DEVICES mutation (same concern as other scripts).

examples/KDA/chunk_o.py-183-228 (1)

183-228: run_test() kernel config parameters are ignored (misleading API).

run_test() accepts block_DK/block_DV/threads/num_stages, but the tilelang_chunk_fwd_o(...) call only forwards block_S. Either forward the rest (to allow fixed-config runs) or remove them from run_test() to match actual behavior.

examples/KDA/chunk_bwd_dqkwg.py-12-16 (1)

12-16: Avoid import-time CUDA_VISIBLE_DEVICES mutation.

Setting os.environ['CUDA_VISIBLE_DEVICES'] = '7' inside the module is a global side effect and will surprise users (and can break CI / multi-GPU boxes). Prefer documenting it in README or making it a CLI flag / env var read (no write).

examples/KDA/chunk_bwd_dv.py-12-16 (1)

12-16: Avoid import-time CUDA_VISIBLE_DEVICES mutation (same issue as other scripts).

examples/KDA/chunk_bwd_gla_dA.py-78-102 (1)

78-102: Fix V_shared dtype: should not be do_dtype.

V is declared as dtype=input_dtype but is copied into V_shared allocated as do_dtype. If these differ, you’ll get an implicit cast that can change numerics (and it’s surprising).

Proposed fix
-            V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
+            V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
examples/KDA/FLA_KDA/cumsum.py-41-56 (1)

41-56: Fix HEAD_FIRST + IS_VARLEN pointer arithmetic in cumsum kernels.

When IS_VARLEN=True, you set T = eos - bos (per-sequence length), but the HEAD_FIRST=True base pointer arithmetic uses bos*H + i_h*T, which is incorrect for the [B, H, T_global] memory layout. The head stride should be T_global (the full sequence length passed to the kernel), not the per-sequence span.

For example, with shape [1, 8, 1024], sequence starting at bos=100, eos=150, and i_h=2:

  • Correct offset: i_h * T_global + bos = 2 * 1024 + 100
  • Current code: bos*H + i_h*T = 100*8 + 2*50 = 900

This affects all four cumsum kernels (scalar/vector × local/global). Since chunk_local_cumsum enforces B=1 when cu_seqlens is provided, this combination is supported but untested. Update the offset calculation to use T (the kernel parameter) instead of the per-sequence T, or recalculate the head stride explicitly within the varlen branch.

Also applies to: lines 156–172, 216–234, and chunk_global_cumsum_vector_kernel.

examples/KDA/chunk_delta_bwd.py-268-306 (1)

268-306: Enable result verification to validate correctness.

The reference implementation results (dh_ref, dh0_ref, dv2_ref) are computed but the comparison code is commented out (Lines 304-306). This means the TileLang kernel's correctness isn't being validated.

Suggested fix
-    # compare_tensors("dh", dh_ref, dh_tilelang)
-    # compare_tensors("dh0", dh0_ref, dh0_tilelang)
-    # compare_tensors("dv2", dv2_ref, dv2_tilelang)
+    compare_tensors("dh", dh_ref, dh_tilelang)
+    compare_tensors("dh0", dh0_ref, dh0_tilelang)
+    compare_tensors("dv2", dv2_ref, dv2_tilelang)
examples/KDA/chunk_intra_token_parallel.py-276-286 (1)

276-286: Enable result verification to validate correctness.

The comparison code is commented out, so the TileLang results aren't being validated against the FLA reference. Given the PR description mentions "precision alignment testing completed," this verification should be enabled.

Suggested fix
     print(f"fla time: {fla_time} ms")
     print(f"tilelang time: {tilelang_time} ms")
     
-
-    # compare_tensors("Aqk", Aqk_ref, Aqk_tilelang)
-    # compare_tensors("Akk", Akk_ref, Akk_tilelang)
+    compare_tensors("Aqk", Aqk_ref, Aqk_tilelang)
+    compare_tensors("Akk", Akk_ref, Akk_tilelang)
examples/KDA/chunk_inter_solve_fused.py-12-13 (1)

12-13: Remove hardcoded CUDA device selection.

Setting CUDA_VISIBLE_DEVICES in code is problematic as it affects the entire process and prevents users from choosing their preferred GPU. This should be configurable or removed.

 import os
-os.environ['CUDA_VISIBLE_DEVICES'] = '7'
examples/KDA/FLA_KDA/fla_wy_fast.py-270-277 (1)

270-277: Hardcoded BT=64 may not match input tensor dimensions.

The forward function recompute_w_u_fwd derives BT from A.shape[-1], but the backward function hardcodes BT=64. If A was created with a different chunk size, this will cause incorrect behavior.

🔧 Suggested fix
 def prepare_wy_repr_bwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     gk: torch.Tensor,
     A: torch.Tensor,
     dk: torch.Tensor,
     dw: torch.Tensor,
     du: torch.Tensor,
     dg: torch.Tensor,
     cu_seqlens: torch.LongTensor  = None,
     chunk_indices: torch.LongTensor  = None,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     B, T, H, K, V = *k.shape, v.shape[-1]
-    BT = 64
+    BT = A.shape[-1]
     if chunk_indices is None and cu_seqlens is not None:
         chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
examples/KDA/FLA_KDA/fla_utils.py-154-156 (1)

154-156: Generator is consumed before function execution.

contiguous_args is a generator expression that will be exhausted after the first iteration. When passed to fn, it may result in empty arguments if any code path iterates over it before the function call.

🔧 Suggested fix - use tuple instead of generator
     @functools.wraps(fn)
     def wrapper(*args, **kwargs):
-        contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
+        contiguous_args = tuple(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
         contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
examples/KDA/wy_fast_bwd.py-344-348 (1)

344-348: Correctness verification is disabled.

The compare_tensors calls are commented out, meaning the test only benchmarks without verifying correctness. This should be enabled to ensure the TileLang kernel produces correct results.

-    # compare_tensors("dA", dA_tilelang, dA_ref)
-    # compare_tensors("dk", dk_tilelang, dk_ref)
-    # compare_tensors("dv", dv_tilelang, dv_ref)
-    # compare_tensors("dbeta", dbeta_tilelang, dbeta_ref)
-    # compare_tensors("dg", dg_tilelang, dg_ref)
+    compare_tensors("dA", dA_tilelang, dA_ref)
+    compare_tensors("dk", dk_tilelang, dk_ref)
+    compare_tensors("dv", dv_tilelang, dv_ref)
+    compare_tensors("dbeta", dbeta_tilelang, dbeta_ref)
+    compare_tensors("dg", dg_tilelang, dg_ref)
examples/KDA/FLA_KDA/fla_utils.py-25-26 (1)

25-26: Device check at import time may fail without GPU.

IS_NVIDIA_HOPPER accesses torch.cuda.get_device_name(0) at module import time, which will raise an error if no CUDA device is available. This prevents the module from being imported for testing or on CPU-only systems.

🔧 Suggested fix
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
+def _is_nvidia_hopper():
+    try:
+        return 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+    except Exception:
+        return False
+
+IS_NVIDIA_HOPPER = _is_nvidia_hopper()
examples/KDA/wy_fast_bwd.py-56-61 (1)

56-61: Fix dA shape to use chunk_size instead of DK.

The dA tensor is allocated with shape (B, S, H, DK) but should match the shape of A, which is (B, S, H, chunk_size) (see line 36 where A = torch.randn(B, S, H, BS, ...) and BS = chunk_size). Since DK and chunk_size are independent parameters, this shape mismatch will cause incorrect results. Use chunk_size for the final dimension instead:

Suggested fix
dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda()
🟡 Minor comments (9)
examples/KDA/chunk_bwd_dv.py-127-167 (1)

127-167: Drop unused scale from run_test() (or apply it) to avoid misleading results.

Ruff flags scale as unused here; either remove it from this dv benchmark harness or use it in the reference/kernel calls if it’s mathematically required.

examples/KDA/test_utils.py-42-82 (1)

42-82: Either use atol/rtol or remove them from compare_tensors().

Ruff flags atol/rtol as unused; consider adding an explicit pass/fail check (e.g., torch.testing.assert_close / torch.allclose) and printing the result.

examples/KDA/wy_fast.py-9-9 (1)

9-9: Remove hardcoded CUDA device selection.

Same issue as in other files - hardcoding CUDA_VISIBLE_DEVICES = '7' breaks portability.

examples/KDA/chunk_delta_bwd.py-8-8 (1)

8-8: Remove debug print statement.

The print(tilelang.__file__, flush=True) appears to be debug code that should be removed before merging.

Suggested fix
-print(tilelang.__file__, flush=True)
examples/KDA/chunk_delta_h_fwd.py-308-318 (1)

308-318: Inconsistent parameter name in benchmark call.

The benchmark call at Line 313 uses g=G, but the reference call at Line 278 uses gk=G. This inconsistency may cause the benchmark to fail or use wrong parameters.

Suggested fix
     fla_time = do_bench(
         chunk_gated_delta_rule_fwd_h,
         k=K,
         w=W,
         u=U,
-        g=G,
+        gk=G,
         initial_state=initial_state,
         output_final_state=store_final_state,
         chunk_size=chunk_size,
         save_new_value=save_new_value,
     )
examples/KDA/chunk_bwd_intra_op.py-3-4 (1)

3-4: Remove unused imports.

from re import I is never used and appears to be an erroneous import. The sys import is also flagged but has a noqa directive.

Suggested fix
-from re import I
-import sys  # noqa: F401
examples/KDA/chunk_intra_token_parallel.py-7-7 (1)

7-7: Remove hardcoded CUDA device selection.

Hardcoding CUDA_VISIBLE_DEVICES = '7' breaks portability across different systems. This should be configurable or removed entirely.

Suggested fix
-os.environ['CUDA_VISIBLE_DEVICES'] = '7'
+# Remove or make configurable via environment/CLI argument
examples/KDA/chunk_bwd_intra.py-514-529 (1)

514-529: Consider using string dtype identifiers for consistency.

The main() function uses T.float32 directly, but run_test uses getattr(torch, input_dtype) expecting string dtype names. This inconsistency could cause confusion.

🔧 Suggested fix
 def main():
     DK = 128
     run_test(
         B=1,
         S=8192,
         H=8,
         DK=DK,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
         threads=128,
         num_stages=0,
     )
examples/KDA/FLA_KDA/fla_chunk_intra.py-9-14 (1)

9-14: Dead code: conditional precision assignment is overwritten.

Lines 9-13 conditionally set SOLVE_TRIL_DOT_PRECISION based on IS_TF32_SUPPORTED, but line 14 unconditionally overwrites the result to 'tf32'. This makes the conditional logic dead code.

Additionally, the comment on line 278 mentions tf32x3 for precision, but the actual value used is 'tf32'. Please clarify the intended precision setting.

Suggested fix (remove dead code)
-IS_TF32_SUPPORTED=False 
-if IS_TF32_SUPPORTED:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
-else:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
-SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
+# Use tf32 precision for matrix inverse operations
+SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')
🧹 Nitpick comments (25)
examples/KDA/README.md (1)

1-7: Add “how to run” + environment details for reproducibility.

This README pins TileLang/Triton/FLA commit, but doesn’t say how to execute the benchmarks/tests (command lines, expected GPU/driver/CUDA/PyTorch versions, and any required env vars). That will make reproduction/debugging harder for reviewers/users.

examples/KDA/chunk_bwd_dqkwg.py (1)

44-58: Clean up unused output prep + unused args in run_test() for clarity.

Right now prepare_output() is unused and run_test() accepts many flags/configs that are ignored (per Ruff). Either wire them into the kernel invocation (to bypass autotune when desired) or remove them from this example to reduce confusion.

Also applies to: 238-310

examples/KDA/chunk_o.py (1)

140-145: Remove/replace debug-commented code in final example.

The commented “why is it wrong” block (and the alternative T.If snippet) reads like an unresolved investigation. If this is important, convert it into a short note explaining the TileLang constraint (or link an issue); otherwise, delete it.

examples/KDA/chunk_bwd_gla_dA.py (1)

106-129: Don’t print raw timing arrays in do_bench() by default.

print(times) makes the benchmark noisy and harder to parse; consider gating it behind a verbose flag.

examples/KDA/FLA_KDA/cumsum.py (1)

246-323: Public wrappers: consider aligning input guards / constraints consistently.

chunk_local_cumsum_scalar/vector aren’t @input_guard’d and don’t enforce the “cu_seqlens implies batch==1” constraint (only chunk_local_cumsum() does). If these are meant to be public entry points, it’s safer to apply the same guard/constraint consistently.

Also applies to: 390-468

examples/KDA/chunk_bwd_dv.py (1)

103-126: Don’t print raw timing arrays in do_bench() by default.

examples/KDA/FLA_KDA/fla_chunk_inter.py (1)

150-153: Add explicit Optional type hints for nullable parameters.

Parameters with None defaults should use Optional[T] for PEP 484 compliance and better type checking. Based on static analysis hint.

Suggested fix
+from typing import Optional
+
 def chunk_kda_bwd_dqkwg(
     q: torch.Tensor,
     k: torch.Tensor,
     w: torch.Tensor,
     v: torch.Tensor,
     h: torch.Tensor,
     g: torch.Tensor,
     do: torch.Tensor,
     dh: torch.Tensor,
     dv: torch.Tensor,
-    scale: float  = None,
-    cu_seqlens: torch.LongTensor  = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     chunk_size: int = 64,
-    chunk_indices: torch.LongTensor  = None,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ):
examples/KDA/chunk_bwd_intra_op.py (1)

514-529: Inconsistent dtype specification in main().

The main() function uses T.float32 directly (TileLang types) instead of string literals like "float32" that are converted via getattr(torch, ...) in run_test. This inconsistency could cause issues if the dtype handling differs.

Suggested fix for consistency
 def main():
     DK = 128
     run_test(
         B=1,
         S=8192,
         H=8,
         DK=DK,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
         threads=128,
         num_stages=0,
     )
examples/KDA/wy_fast.py (1)

172-178: Inconsistent synchronization pattern in benchmark loop.

This do_bench function synchronizes inside the timing loop (Line 177), while other files in this PR synchronize only before and after the entire timing loop. This inconsistency may affect benchmark comparisons.

Suggested fix for consistency
     torch.cuda.synchronize()
     for i in range(rep):
         start_event[i].record()
         fn(*args, **kwargs)
         end_event[i].record()
-        torch.cuda.synchronize()
-    
+    torch.cuda.synchronize()
examples/KDA/FLA_KDA/fla_wy_fast.py (1)

77-77: Consider using English comments for consistency.

The comment # 乘beta (meaning "multiply beta") should be in English for better maintainability and team collaboration.

-        b_kb = b_k * b_b[:, None] # 乘beta
+        b_kb = b_k * b_b[:, None]  # multiply by beta
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)

527-527: Use explicit Optional type annotation.

Per PEP 484, implicit Optional (using = None without Optional[T]) is prohibited.

-    scale: float  = None,
+    scale: Optional[float] = None,

Don't forget to add from typing import Optional at the top of the file.

examples/KDA/FLA_KDA/fla_utils.py (3)

50-54: Add stacklevel to warnings.warn for proper source attribution.

Without stacklevel, the warning will point to this utility function instead of the caller's code.

-            warnings.warn(msg)
+            warnings.warn(msg, stacklevel=2)

34-34: Replace fullwidth comma with ASCII comma.

The comment contains a fullwidth comma (,) which should be a regular comma for consistency.

-# error check,copy from
+# error check, copy from

128-145: Consider logging exceptions instead of silently passing.

The broad exception handling with pass makes debugging difficult. Consider logging the exception or using more specific exception types.

🔧 Suggested improvement
     # ---- Try the newer Triton 2.2+ API ----
     try:
         drv = triton.runtime.driver.active
         props = drv.utils.get_device_properties(tensor_idx)
         return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1
-    except Exception:
-        pass
+    except Exception as e:
+        import logging
+        logging.debug(f"Triton 2.2+ API failed: {e}")

     # ---- Fallback: Triton 2.0 / 2.1 API ----
     try:
         cuda = triton.runtime.driver.CudaDriver
         dev = cuda.get_current_device()
         props = cuda.get_device_properties(dev)
         return props.get("multiprocessor_count", 1)
-    except Exception:
-        pass
+    except Exception as e:
+        import logging
+        logging.debug(f"Triton 2.0/2.1 API failed: {e}")
examples/KDA/chunk_bwd_intra.py (2)

3-4: Remove unused import.

from re import I is unused and appears to be a typo or leftover from debugging.

-from re import I
-import sys  # noqa: F401
+import sys  # noqa: F401

482-486: Debug print statement should be removed.

The print(db_tilelang.shape) statement appears to be debug output that should be removed before merging.

-    print(db_tilelang.shape)
     dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(
         q, k, g, beta, dAqk, dAkk, dq, dk, db, dg
     )
examples/KDA/wy_fast_bwd.py (1)

3-3: Remove unused noqa directive.

The # noqa: F401 comment is unnecessary since sys is not imported with a specific symbol that would trigger F401.

-import sys  # noqa: F401
examples/KDA/chunk_inter_solve_fused.py (2)

507-507: Remove debug print statement.

The print(times) statement should be removed or converted to a debug log.

-    print(times)
     return times.mean().item()

111-170: Consider documenting the shared memory allocation strategy.

The kernel allocates many shared memory buffers (Aqk/Akk fragments, K/Q/GK buffers for 4 sub-chunks, Ai matrices, etc.). A brief comment explaining the memory layout and purpose would improve maintainability.

examples/KDA/FLA_KDA/fla_chunk_o.py (1)

534-537: Use explicit Optional type annotation.

+from typing import Optional
+
 def chunk_bwd_dv_local(
     q: torch.Tensor,
     k: torch.Tensor,
     do: torch.Tensor,
-    g: torch.Tensor  = None,
-    g_gamma: torch.Tensor  = None,
-    A: torch.Tensor  = None,
-    scale: float = None,
+    g: Optional[torch.Tensor] = None,
+    g_gamma: Optional[torch.Tensor] = None,
+    A: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
examples/KDA/FLA_KDA/fla_chunk_intra.py (5)

33-63: Document the chunk size constraint (BT = 4 × BC).

The kernel hardcodes four sub-chunks (i_tc0 through i_tc3), implying BT = 4 * BC. This constraint isn't documented or asserted. Consider adding a comment or runtime assertion in the wrapper to prevent misuse with incompatible chunk sizes.


387-387: Avoid shadowing Python builtin all.

The variable name all shadows Python's builtin function. Consider renaming to total_tokens or similar.

-    all = B * T
+    total_tokens = B * T

Also update line 415:

-    db += (i_k * all + bos) * H + i_h
+    db += (i_k * total_tokens + bos) * H + i_h

706-717: Clarify in-place vs. return semantics for gradient tensors.

The function accepts dq, dk, db, dg as parameters, but the returned dq and dk are actually newly allocated tensors (dq2, dk2), not the modified inputs. This could confuse callers expecting in-place updates.

Consider either:

  1. Documenting this behavior clearly in the docstring
  2. Renaming parameters to clarify they're inputs to be accumulated (e.g., dq_in, dk_in)
  3. Copying results back to input tensors if in-place semantics are intended

720-732: Add docstring for public API function.

This function lacks documentation. Since it's a public wrapper for the fused kernel, consider adding a docstring explaining its purpose, parameters, and relationship to chunk_kda_fwd_intra.


562-567: Use Optional[T] for parameters with None defaults.

PEP 484 prohibits implicit Optional. Parameters like gk: torch.Tensor = None should be gk: Optional[torch.Tensor] = None.

Suggested fix
+from typing import Optional
+
 def chunk_kda_fwd_intra(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
-    gk: torch.Tensor  = None,
-    beta: torch.Tensor  = None,
-    scale: float  = None,
-    cu_seqlens: torch.LongTensor  = None,
+    gk: Optional[torch.Tensor] = None,
+    beta: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     chunk_size: int = 64,
-    chunk_indices: torch.LongTensor  = None,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ):
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9936636 and e38fba5.

📒 Files selected for processing (22)
  • examples/KDA/FLA_KDA/cumsum.py
  • examples/KDA/FLA_KDA/fla_chunk_delta.py
  • examples/KDA/FLA_KDA/fla_chunk_inter.py
  • examples/KDA/FLA_KDA/fla_chunk_intra.py
  • examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py
  • examples/KDA/FLA_KDA/fla_chunk_o.py
  • examples/KDA/FLA_KDA/fla_utils.py
  • examples/KDA/FLA_KDA/fla_wy_fast.py
  • examples/KDA/README.md
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/chunk_bwd_gla_dA.py
  • examples/KDA/chunk_bwd_intra.py
  • examples/KDA/chunk_bwd_intra_op.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_inter_solve_fused.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/chunk_o.py
  • examples/KDA/test_utils.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/KDA/README.md
🧬 Code graph analysis (12)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (4)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
examples/KDA/FLA_KDA/cumsum.py (1)
  • grid (304-304)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
  • grid (167-167)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
  • grid (419-419)
examples/KDA/chunk_intra_token_parallel.py (1)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
  • chunk_kda_fwd_intra_token_parallel (116-168)
examples/KDA/FLA_KDA/cumsum.py (2)
examples/KDA/FLA_KDA/fla_utils.py (2)
  • prepare_chunk_indices (100-105)
  • input_guard (147-178)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
examples/KDA/chunk_bwd_dqkwg.py (2)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg (140-190)
examples/KDA/test_utils.py (1)
  • compare_tensors (42-82)
examples/KDA/FLA_KDA/fla_wy_fast.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
  • prepare_chunk_indices (100-105)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
examples/KDA/chunk_inter_solve_fused.py (4)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
  • chunk_kda_fwd_inter_solve_fused (720-759)
examples/KDA/test_utils.py (1)
  • compare_tensors (42-82)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (428-468)
tilelang/language/dtypes.py (1)
  • float32 (310-310)
examples/KDA/test_utils.py (1)
tilelang/carver/roller/policy/default.py (1)
  • sim (290-291)
examples/KDA/FLA_KDA/fla_chunk_o.py (2)
examples/KDA/FLA_KDA/fla_utils.py (3)
  • prepare_chunk_indices (100-105)
  • check_shared_mem (225-231)
  • input_guard (147-178)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
examples/KDA/wy_fast_bwd.py (2)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • prepare_wy_repr_bwd (257-312)
examples/KDA/wy_fast.py (2)
  • prepare_input (17-24)
  • kernel (85-159)
examples/KDA/chunk_o.py (3)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
  • chunk_gla_fwd_o_gk (399-437)
tilelang/language/copy_op.py (1)
  • copy (14-116)
examples/KDA/FLA_KDA/fla_utils.py (1)
tilelang/utils/device.py (1)
  • get_current_device (14-21)
examples/KDA/FLA_KDA/fla_chunk_delta.py (3)
examples/KDA/FLA_KDA/fla_utils.py (1)
  • prepare_chunk_indices (100-105)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
  • grid (151-151)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
  • grid (419-419)
🪛 Ruff (0.14.10)
examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


41-41: Unused function argument: output_dtype

(ARG001)


42-42: Unused function argument: accum_dtype

(ARG001)


108-108: Unused function argument: block_DK

(ARG001)


249-249: Unused function argument: block_DK

(ARG001)


250-250: Unused function argument: block_DV

(ARG001)


251-251: Unused function argument: threads

(ARG001)


252-252: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_intra_token_parallel.py

21-21: Unused function argument: output_dtype

(ARG001)


22-22: Unused function argument: accum_dtype

(ARG001)


258-258: Unpacked variable Aqk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


258-258: Unpacked variable Akk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/FLA_KDA/cumsum.py

33-33: Unused function argument: B

(ARG001)


88-88: Unused function argument: B

(ARG001)


146-146: Unused function argument: B

(ARG001)


206-206: Unused function argument: B

(ARG001)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-287: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


330-330: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


361-361: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


395-395: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


420-424: Avoid specifying long messages outside the exception class

(TRY003)


432-432: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


437-437: Unused function argument: kwargs

(ARG001)


464-468: Avoid specifying long messages outside the exception class

(TRY003)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused function argument: output_dtype

(ARG001)


32-32: Unused function argument: accum_dtype

(ARG001)


34-34: Unused function argument: state_dtype

(ARG001)


63-63: Unused function argument: gate_dtype

(ARG001)


132-132: Unused function argument: h0

(ARG001)


244-244: Unused function argument: block_DV

(ARG001)


245-245: Unused function argument: threads

(ARG001)


246-246: Unused function argument: num_stages

(ARG001)


247-247: Unused function argument: use_torch

(ARG001)


271-271: Unpacked variable dh_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


271-271: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


271-271: Unpacked variable dv2_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dh_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dh0_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dv2_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


49-49: Unused function argument: DV

(ARG001)


50-50: Unused function argument: chunk_size

(ARG001)


51-51: Unused function argument: qk_dtype

(ARG001)


249-249: Unused function argument: use_gk

(ARG001)


250-250: Unused function argument: use_initial_state

(ARG001)


251-251: Unused function argument: store_final_state

(ARG001)


252-252: Unused function argument: save_new_value

(ARG001)


253-253: Unused function argument: block_DK

(ARG001)


254-254: Unused function argument: block_DV

(ARG001)


255-255: Unused function argument: threads

(ARG001)


256-256: Unused function argument: num_stages

(ARG001)


260-260: Unpacked variable dq_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


260-260: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


260-260: Unpacked variable dw_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


260-260: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


285-285: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


285-285: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


285-285: Unpacked variable dw is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


285-285: Unpacked variable dg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/FLA_KDA/fla_chunk_inter.py

150-150: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

examples/KDA/chunk_inter_solve_fused.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


47-47: Unused function argument: sub_chunk_size

(ARG001)

examples/KDA/chunk_bwd_gla_dA.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused function argument: chunk_size

(ARG001)


34-34: Unused function argument: DV

(ARG001)


134-134: Unused function argument: DK

(ARG001)

examples/KDA/chunk_bwd_intra_op.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Unused function argument: output_dtype

(ARG001)


30-30: Unused function argument: accum_dtype

(ARG001)


32-32: Unused function argument: state_dtype

(ARG001)


59-59: Unused function argument: chunk_size

(ARG001)


63-63: Unused function argument: state_dtype

(ARG001)


98-98: Unused function argument: state_dtype

(ARG001)


135-135: Unused function argument: db

(ARG001)


427-427: Unused function argument: threads

(ARG001)


428-428: Unused function argument: num_stages

(ARG001)


429-429: Unused function argument: cu_seqlens

(ARG001)


430-430: Unused function argument: chunk_indices

(ARG001)

examples/KDA/chunk_bwd_dv.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


38-38: Unused function argument: chunk_size

(ARG001)


133-133: Unused function argument: scale

(ARG001)

examples/KDA/test_utils.py

42-42: Unused function argument: atol

(ARG001)


42-42: Unused function argument: rtol

(ARG001)


53-53: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


53-53: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

examples/KDA/chunk_bwd_intra.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Unused function argument: output_dtype

(ARG001)


30-30: Unused function argument: accum_dtype

(ARG001)


32-32: Unused function argument: state_dtype

(ARG001)


59-59: Unused function argument: chunk_size

(ARG001)


63-63: Unused function argument: state_dtype

(ARG001)


98-98: Unused function argument: state_dtype

(ARG001)


135-135: Unused function argument: db

(ARG001)


427-427: Unused function argument: threads

(ARG001)


428-428: Unused function argument: num_stages

(ARG001)


429-429: Unused function argument: cu_seqlens

(ARG001)


430-430: Unused function argument: chunk_indices

(ARG001)

examples/KDA/FLA_KDA/fla_chunk_o.py

323-323: Undefined name chunk_gla_fwd_A_kernel_intra_sub_inter

(F821)


343-343: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra

(F821)


365-365: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra_split

(F821)


384-384: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra_merge

(F821)


478-478: Unused function argument: g

(ARG001)


479-479: Unused function argument: g_gamma

(ARG001)


485-485: Unused function argument: scale

(ARG001)


491-491: Unused function argument: BK

(ARG001)


493-493: Unused function argument: USE_G

(ARG001)


494-494: Unused function argument: USE_G_GAMMA

(ARG001)


537-537: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused function argument: output_dtype

(ARG001)


67-67: Unused function argument: use_qg

(ARG001)


93-93: Unused function argument: QG

(ARG001)


199-199: Unused function argument: block_DK

(ARG001)


200-200: Unused function argument: block_DV

(ARG001)


201-201: Unused function argument: threads

(ARG001)


202-202: Unused function argument: num_stages

(ARG001)


209-209: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


210-210: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


29-29: Unused function argument: state_dtype

(ARG001)


51-51: Unused function argument: chunk_size

(ARG001)


54-54: Unused function argument: state_dtype

(ARG001)


92-92: Unused function argument: state_dtype

(ARG001)


291-291: Unused function argument: block_DK

(ARG001)


292-292: Unused function argument: block_DV

(ARG001)


293-293: Unused function argument: threads

(ARG001)


294-294: Unused function argument: num_stages

(ARG001)


316-316: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


316-316: Unpacked variable dv_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


316-316: Unpacked variable dbeta_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


316-316: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


316-316: Unpacked variable dA_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


342-342: Unpacked variable dA_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


342-342: Unpacked variable dk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


342-342: Unpacked variable dv_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


342-342: Unpacked variable dbeta_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


342-342: Unpacked variable dg_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_o.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


21-21: Unused function argument: output_dtype

(ARG001)


22-22: Unused function argument: accum_dtype

(ARG001)


37-37: Unused function argument: DK

(ARG001)


39-39: Unused function argument: chunk_size

(ARG001)


42-42: Ambiguous variable name: O

(E741)


99-99: Ambiguous variable name: O

(E741)


194-194: Unused function argument: block_DK

(ARG001)


195-195: Unused function argument: block_DV

(ARG001)


196-196: Unused function argument: threads

(ARG001)


197-197: Unused function argument: num_stages

(ARG001)

examples/KDA/FLA_KDA/fla_utils.py

34-34: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


52-52: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


133-134: try-except-pass detected, consider logging the exception

(S110)


133-133: Do not catch blind exception: Exception

(BLE001)


142-143: try-except-pass detected, consider logging the exception

(S110)


142-142: Do not catch blind exception: Exception

(BLE001)


220-220: Do not catch blind exception: BaseException

(BLE001)


230-230: Do not catch blind exception: Exception

(BLE001)

examples/KDA/FLA_KDA/fla_chunk_intra.py

564-564: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


607-607: Undefined name chunk_kda_fwd_intra_token_parallel

(F821)


638-638: Undefined name recompute_w_u_fwd

(F821)

examples/KDA/FLA_KDA/fla_chunk_delta.py

527-527: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🔇 Additional comments (10)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)

14-138: LGTM - Kernel implementation looks correct.

The inter-chunk backward kernel properly handles both variable-length and fixed-length sequences, with appropriate boundary checks and gradient accumulation logic.

examples/KDA/chunk_intra_token_parallel.py (1)

59-174: LGTM - TileLang kernel structure is well-organized.

The kernel factory correctly sets up shared memory allocations, fragment handling, and pipelined loops. The conditional handling of Sum_Akk_shared vs Sum_Aqk_shared appears intentional for the algorithm's requirements.

examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (2)

44-58: Binary search implementation for varlen looks correct.

The unrolled binary search (20 iterations for up to ~1M batch size) efficiently finds the sequence index for variable-length sequences. The implementation correctly handles the cu_seqlens lookup pattern.


98-114: LGTM - Computation loop correctly implements gated attention.

The loop properly computes the Q-K interaction with exponential gating, applies appropriate masking for the diagonal block structure, and stores results with boundary checks.

examples/KDA/chunk_bwd_intra_op.py (1)

143-391: Complex kernel structure - verify correctness with thorough testing.

The backward intra-chunk kernel has multiple processing phases (previous sub-chunks, current diagonal, subsequent sub-chunks, lower triangular). While the structure appears logical, the complexity warrants careful validation through the comparison with the reference implementation.

examples/KDA/chunk_delta_bwd.py (1)

125-224: LGTM - Backward kernel structure is correct.

The kernel properly implements the backward pass with reverse iteration, handles optional initial state and final state gradients, and uses appropriate memory layouts with swizzling.

examples/KDA/FLA_KDA/fla_wy_fast.py (1)

210-254: LGTM!

The forward kernel wrapper correctly handles variable-length sequences, derives block dimensions from input tensors, and properly initializes output tensors. The autotuning configuration covers a reasonable search space.

examples/KDA/FLA_KDA/fla_chunk_delta.py (1)

70-76: LGTM - Manual unrolling for performance.

The manual unrolling of state buffers (b_h1 through b_h4) for different K dimension ranges is a valid performance optimization pattern in Triton kernels, trading code verbosity for better register allocation and reduced branching overhead.

examples/KDA/chunk_bwd_intra.py (1)

83-105: LGTM - Well-structured autotuned kernel.

The kernel configuration with autotune and appropriate pass configs is well-structured. The tiling parameters and thread configurations provide a good search space for optimization.

examples/KDA/chunk_inter_solve_fused.py (1)

310-374: Forward substitution loops have correct boundary handling.

The pipelined loops for forward substitution on diagonal blocks correctly handle boundaries with T.min(BC, S-i_tc0) and similar guards. The identity matrix addition at lines 370-374 properly completes the inverse computation.

Comment on lines +483 to +491
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
assert K <= 256, "current kernel does not support head dimension larger than 256."

h = k.new_empty(B, NT, H, K, V)
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

chunk_offsets is undefined when cu_seqlens is provided.

When cu_seqlens is not None, the code skips the else branch where chunk_offsets is assigned, but chunk_offsets is still passed to the kernel on line 506, causing an UnboundLocalError.

🐛 Proposed fix
     if chunk_indices is None and cu_seqlens is not None:
         chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
     # N: the actual number of sequences in the batch with either equal or variable lengths
     if cu_seqlens is None:
         N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+    else:
+        N, NT = len(cu_seqlens) - 1, len(chunk_indices)
+        chunk_offsets = torch.zeros(N, dtype=torch.int32, device=k.device)
+        # Compute chunk offsets for variable-length sequences
+        for i in range(N):
+            chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0
     assert K <= 256, "current kernel does not support head dimension larger than 256."
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_delta.py around lines 483 - 491, The variable
chunk_offsets is left undefined when cu_seqlens is provided, causing an
UnboundLocalError later when passed to the kernel; fix by ensuring chunk_offsets
is initialized in the branch that handles cu_seqlens (e.g., after computing
chunk_indices via prepare_chunk_indices(cu_seqlens, chunk_size) compute or
derive chunk_offsets from chunk_indices, or explicitly set chunk_offsets = None
if the kernel accepts that), so that the subsequent kernel invocation that uses
chunk_offsets, h, and final_state always has a defined value.

Comment on lines +538 to +541
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same chunk_offsets issue in backward function.

Similar to the forward function, chunk_offsets is undefined when cu_seqlens is not None.

🐛 Proposed fix
     if chunk_indices is None and cu_seqlens is not None:
         chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
     if cu_seqlens is None:
         N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+    else:
+        N, NT = len(cu_seqlens) - 1, len(chunk_indices)
+        chunk_offsets = torch.zeros(N, dtype=torch.int32, device=q.device)
+        for i in range(N):
+            chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT = len(cu_seqlens) - 1, len(chunk_indices)
chunk_offsets = torch.zeros(N, dtype=torch.int32, device=q.device)
for i in range(N):
chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_delta.py around lines 538 - 541, In the
backward function the variable chunk_offsets can remain undefined when
cu_seqlens is not None; update the logic so that when chunk_indices is None and
cu_seqlens is not None you also set chunk_offsets (mirror forward): call
prepare_chunk_indices(cu_seqlens, chunk_size) to produce chunk_indices and
derive chunk_offsets (or compute chunk_offsets from the same
cu_seqlens/chunk_size helper) before proceeding; modify the branch around
chunk_indices, cu_seqlens and the subsequent use of chunk_offsets to ensure
chunk_offsets is always initialized (refer to chunk_indices, cu_seqlens,
prepare_chunk_indices, and chunk_offsets).

Comment on lines +5 to +40
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")


def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim


def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert:
raise AssertionError
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print_red_warning(f"{name} Error: {diff}")
if raise_assert:
raise AssertionError
else:
print(f"{name} {data} passed")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix calc_sim() / assert_similar() tensor-vs-Python control-flow bugs.

calc_sim() returns a torch scalar tensor, so diff = 1.0 - sim becomes a tensor; if not (0 <= diff <= eps): is not safe (chained comparisons over tensors) and is likely to throw or behave unexpectedly (esp. on CUDA). Also .data should be avoided.

Proposed fix
 def calc_sim(x, y, name="tensor"):
-    x, y = x.data.double(), y.data.double()
-    denominator = (x * x + y * y).sum()
-    if denominator == 0:
+    x, y = x.detach().double(), y.detach().double()
+    denominator = (x * x + y * y).sum()
+    denom = float(denominator.item())
+    if denom == 0.0:
         print_red_warning(f"{name} all zero")
-        return 1
-    sim = 2 * (x * y).sum() / denominator
-    return sim
+        return 1.0
+    sim = 2.0 * (x * y).sum() / denominator
+    return float(sim.item())

 def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
@@
-    sim = calc_sim(x, y, name)
-    diff = 1.0 - sim
-    if not (0 <= diff <= eps):
+    sim = calc_sim(x, y, name)
+    diff = 1.0 - sim
+    if not (0.0 <= diff <= eps):
         print_red_warning(f"{name} Error: {diff}")
         if raise_assert:
             raise AssertionError
🤖 Prompt for AI Agents
In @examples/KDA/test_utils.py around lines 5 - 40, calc_sim and assert_similar
use .data and perform Python control flow on torch tensors which can raise on
CUDA; change calc_sim to avoid .data (use .detach().double() or
.to(torch.double)) and return a Python float (use sim.item()) instead of a
tensor, ensure the zero-denominator check uses a scalar (.item() or
.eq(0).all().item()); in assert_similar stop doing chained tensor comparisons by
converting sim to a float before computing diff (e.g., sim = calc_sim(...); diff
= 1.0 - float(sim)), keep the isfinite/masking logic but ensure masked_fill uses
the correct masks (x.masked_fill(~x_mask, 0) / y.masked_fill(~y_mask, 0) as
already used) and replace any remaining .data usages with .detach()/.to(...)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In @examples/KDA/FLA_KDA/fla_chunk_intra.py:
- Around line 9-14: The conditional assignment to SOLVE_TRIL_DOT_PRECISION using
IS_TF32_SUPPORTED is currently dead because it is immediately overwritten by
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32'); either remove the unconditional
overwrite and keep the conditional block (use SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32') or 'ieee' based on IS_TF32_SUPPORTED) or delete the
conditional block and retain the single explicit SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32'); ensure you only assign SOLVE_TRIL_DOT_PRECISION once and
use the tl.constexpr(...) call consistently.

In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 421-429: The branch guarding creation of b_A (when USE_A is True)
leaves b_A undefined if USE_A is False; ensure b_A is always initialized before
its use by adding an else that sets b_A to a zero tensor with the same shape and
dtype as the loaded value (matching do.dtype.element_ty and shape of p_A load /
m_A result), or assert that USE_A must be True; update the block around the
USE_A check (where p_A, b_A are created) so b_A is defined for both branches
before the tl.where(...) on line using m_A.
- Around line 455-460: The conditional mistakenly uses the function object
check_shared_mem instead of calling it, causing the middle branch to always run;
update the second branch to call check_shared_mem with the same signature as the
first (e.g., check_shared_mem('hopper', k.device.index)) or the appropriate
args, so CONST_TILING is set to 64 only when the function returns truthy; ensure
the first branch remains check_shared_mem('hopper', k.device.index) and the else
remains CONST_TILING = 32.

In @examples/KDA/FLA_KDA/fla_utils.py:
- Line 25: The module-level expression that sets IS_NVIDIA_HOPPER calls
torch.cuda.get_device_name(0) and torch.cuda.get_device_capability() which will
raise if no CUDA device exists; change the initialization of IS_NVIDIA_HOPPER in
fla_utils.py to first check torch.cuda.is_available() (and/or
torch.cuda.device_count() > 0) and then safely call
get_device_name/get_device_capability inside a try/except, defaulting
IS_NVIDIA_HOPPER to False on any exception so import-time errors are avoided.
🧹 Nitpick comments (7)
examples/KDA/FLA_KDA/fla_utils.py (4)

1-31: Module setup and imports look reasonable.

The module correctly sets up environment detection for Hopper GPUs and CUDA graph usage. A few minor observations:

  1. Line 25: The (True and ...) pattern is redundant - the True and prefix does nothing.
  2. Line 34: Contains a fullwidth comma instead of ASCII comma , (as flagged by static analysis).
🔧 Suggested cleanup
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
-USE_CUDA_GRAPH = (True and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
+IS_NVIDIA_HOPPER = 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+USE_CUDA_GRAPH = os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1'
-# error check,copy from
+# error check, copy from

118-145: Silent exception swallowing in get_multiprocessor_count.

The function catches broad Exception and silently passes, which can hide real bugs. Consider logging at debug level for troubleshooting.

🔧 Optional: Add debug logging
+import logging
+logger = logging.getLogger(__name__)
+
 @functools.cache
 def get_multiprocessor_count(tensor_idx: int = 0) -> int:
     # ---- Try the newer Triton 2.2+ API ----
     try:
         drv = triton.runtime.driver.active
         props = drv.utils.get_device_properties(tensor_idx)
         return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1
     except Exception:
-        pass
+        logger.debug("Triton 2.2+ API not available, falling back")

     # ---- Fallback: Triton 2.0 / 2.1 API ----
     try:
         cuda = triton.runtime.driver.CudaDriver
         dev = cuda.get_current_device()
         props = cuda.get_device_properties(dev)
         return props.get("multiprocessor_count", 1)
     except Exception:
-        pass
+        logger.debug("Triton 2.0/2.1 API not available, returning default")

     return 1

147-178: input_guard decorator consumes generator prematurely.

Line 156 creates a generator expression for contiguous_args, but it's passed to fn() which will consume it. If fn needs to iterate over args multiple times, this will fail silently. Additionally, the iteration to find a tensor (lines 160-168) happens after the generator is created, so if args contains non-tensors before the first tensor, those won't be made contiguous when the generator is consumed.

Actually, looking more carefully - the generator is created but only consumed once when fn(*contiguous_args, ...) is called, so this should work. However, it would be safer to use a tuple.

🔧 Use tuple for clarity and safety
 @functools.wraps(fn)
 def wrapper(*args, **kwargs):
-    contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
+    contiguous_args = tuple(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
     contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}

214-231: get_all_max_shared_mem catches BaseException which is too broad.

Catching BaseException can suppress KeyboardInterrupt and SystemExit. Use Exception instead.

🔧 Narrow exception type
 def get_all_max_shared_mem():
     try:
         return [
             triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
             for i in range(device_torch_lib.device_count())
         ]
-    except BaseException:
+    except Exception:
         return [-1]
examples/KDA/FLA_KDA/fla_chunk_o.py (3)

12-13: check_shared_mem() called with no arguments vs. with 'ampere'.

Line 12 calls check_shared_mem() with no arguments (defaults to arch="none"), while line 13 calls it with 'ampere'. The default "none" won't match any Backend enum key, so it falls back to DEFAULT. Consider being explicit about the intended architecture.


89-92: Redundant condition if i_k >= 0 is always true.

i_k is a loop counter from range(tl.cdiv(K, BK)) which always starts at 0, so i_k >= 0 is always true. The comment "works but dkw, owing to divine benevolence" suggests this might be a workaround for a compiler issue.

If this is a known workaround, consider documenting it more clearly:

-        # works but dkw, owing to divine benevolence
-        # [BT, BV]
-        if i_k >= 0:
+        # NOTE: Redundant check works around Triton compiler issue
+        if i_k >= 0:  # Always true, but needed for correctness

438-449: Type annotations use implicit Optional.

Lines 442-445 use = None defaults without Optional[T] type hints. While this works, PEP 484 recommends explicit Optional.

🔧 Use explicit Optional types
+from typing import Optional
+
 def chunk_bwd_dv_local(
     q: torch.Tensor,
     k: torch.Tensor,
     do: torch.Tensor,
-    g: torch.Tensor  = None,
-    g_gamma: torch.Tensor  = None,
-    A: torch.Tensor  = None,
-    scale: float = None,
-    cu_seqlens: torch.LongTensor  = None,
+    g: Optional[torch.Tensor] = None,
+    g_gamma: Optional[torch.Tensor] = None,
+    A: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     chunk_size: int = 64,
-    chunk_indices: torch.LongTensor  = None,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ) -> torch.Tensor:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e38fba5 and f7adecc.

📒 Files selected for processing (3)
  • examples/KDA/FLA_KDA/fla_chunk_intra.py
  • examples/KDA/FLA_KDA/fla_chunk_o.py
  • examples/KDA/FLA_KDA/fla_utils.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/KDA/FLA_KDA/fla_chunk_intra.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
  • prepare_chunk_indices (100-105)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (428-468)
examples/KDA/FLA_KDA/fla_chunk_o.py (3)
examples/KDA/FLA_KDA/fla_utils.py (3)
  • prepare_chunk_indices (100-105)
  • check_shared_mem (225-231)
  • input_guard (147-178)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
  • grid (151-151)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/fla_utils.py

34-34: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


133-134: try-except-pass detected, consider logging the exception

(S110)


133-133: Do not catch blind exception: Exception

(BLE001)


142-143: try-except-pass detected, consider logging the exception

(S110)


142-142: Do not catch blind exception: Exception

(BLE001)


220-220: Do not catch blind exception: BaseException

(BLE001)


230-230: Do not catch blind exception: Exception

(BLE001)

examples/KDA/FLA_KDA/fla_chunk_o.py

386-386: Unused function argument: g

(ARG001)


387-387: Unused function argument: g_gamma

(ARG001)


393-393: Unused function argument: scale

(ARG001)


399-399: Unused function argument: BK

(ARG001)


401-401: Unused function argument: USE_G

(ARG001)


402-402: Unused function argument: USE_G_GAMMA

(ARG001)


445-445: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🔇 Additional comments (9)
examples/KDA/FLA_KDA/fla_utils.py (2)

57-93: Single-entry tensor cache implementation is correct.

The tensor_cache decorator correctly implements identity-based caching for tensor inputs. The use of is comparison for tensors is appropriate since tensor identity matters for caching.


200-211: Backend enum provides architecture-specific shared memory sizes.

The implementation is clean. Minor note: the comment references specific GPU models but the enum names are architecture generations, which is appropriate.

examples/KDA/FLA_KDA/fla_chunk_intra.py (5)

20-51: Forward kernel setup looks correct.

The heuristics for IS_VARLEN and autotune configurations are appropriate. The kernel signature properly declares compile-time constants with tl.constexpr.


64-91: Varlen handling and boundary setup look correct.

The chunk index loading and sequence boundary calculations handle both variable-length and fixed-length cases appropriately. The early return for out-of-bounds tiles (line 74-75) is correct.


629-668: Forward inter-solve wrapper implementation looks correct.

The function properly handles chunk index preparation and kernel launch. The assertion K <= 256 at line 643 provides a clear constraint.


559-626: The function's return value usage and chunk_local_cumsum parameter passing are correct. Callers at lines 447 and 499 in chunk_bwd_intra_op.py and chunk_bwd_intra.py properly capture the return values (dq_ref, dk_ref, db_ref, dg_ref = chunk_kda_bwd_intra(...)). The chunk_local_cumsum call correctly passes both chunk_size (required size parameter) and chunk_indices (optional parameter for sequence-level chunking support)—these parameters serve different purposes and are intentionally used together. No changes needed.

Likely an incorrect or invalid review comment.


248-267: Pointer arithmetic is correct; no issue identified.

After examining the code, the pointer offset and stride calculations are self-consistent. The Akk_diag offset at line 87 (bos * H + i_h) * BC correctly positions the pointer to the start of a diagonal block for a specific batch and head. The subsequent loads in the forward substitution loops (lines 249, 254, 259, 264) correctly use stride H*BC across all four loops, which matches the stride specified in the make_block_ptr calls (lines 225-228). The loop index variables handle the offset ranges correctly: loop 2 covers rows starting at i_tc0 + BC, loop 3 covers i_tc0 + 2*BC, and loop 4 covers i_tc0 + 3*BC, all properly expressed through the combined formula (i_tc0 + i) * H*BC. There is no double-counting of offsets and no stride inconsistency.

examples/KDA/FLA_KDA/fla_chunk_o.py (2)

307-345: chunk_gla_fwd_o_gk wrapper implementation looks correct.

The function properly handles chunk indices preparation, output tensor allocation, and kernel launch with appropriate parameters.


549-580: chunk_gla_bwd_dA wrapper implementation looks correct.

The function correctly allocates the output tensor with float32 dtype and launches the kernel with proper grid dimensions.

Comment on lines 9 to 14
IS_TF32_SUPPORTED=False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

SOLVE_TRIL_DOT_PRECISION is overwritten immediately, dead code.

Lines 9-13 set SOLVE_TRIL_DOT_PRECISION based on IS_TF32_SUPPORTED, but line 14 unconditionally overwrites it to 'tf32'. This makes the conditional logic dead code.

🔧 Remove dead code or fix the logic

If 'tf32' is the intended precision, remove the dead code:

-IS_TF32_SUPPORTED=False 
-if IS_TF32_SUPPORTED:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
-else:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
-SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
+SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')

Or if you intended to use the conditional:

-IS_TF32_SUPPORTED=False 
+IS_TF32_SUPPORTED = False
 if IS_TF32_SUPPORTED:
     SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
 else:
     SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
-SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
IS_TF32_SUPPORTED=False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')
Suggested change
IS_TF32_SUPPORTED=False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
IS_TF32_SUPPORTED = False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_intra.py around lines 9 - 14, The conditional
assignment to SOLVE_TRIL_DOT_PRECISION using IS_TF32_SUPPORTED is currently dead
because it is immediately overwritten by SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32'); either remove the unconditional overwrite and keep the
conditional block (use SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') or 'ieee'
based on IS_TF32_SUPPORTED) or delete the conditional block and retain the
single explicit SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32'); ensure you only
assign SOLVE_TRIL_DOT_PRECISION once and use the tl.constexpr(...) call
consistently.

Comment on lines 383 to 436
def chunk_bwd_kernel_dv_local(
q,
k,
g,
g_gamma,
A,
do,
dv,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_G_GAMMA: tl.constexpr,
USE_A: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T

# offset calculation
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
do += (bos * H + i_h) * V
dv += (bos * H + i_h) * V

if USE_A:
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))


o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty)

for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Multiple unused kernel arguments in chunk_bwd_kernel_dv_local.

Static analysis correctly identifies that g, g_gamma, scale, BK, USE_G, and USE_G_GAMMA are declared but never used in the kernel body. This suggests either incomplete implementation or arguments reserved for future use.

If these are placeholders for future features, consider adding a comment. If they're dead code, remove them to reduce confusion.

 @triton.jit(do_not_specialize=['T'])
 def chunk_bwd_kernel_dv_local(
     q,
     k,
-    g,
-    g_gamma,
+    g,  # Reserved for gating support (currently unused)
+    g_gamma,  # Reserved for gamma gating (currently unused)
     A,
     do,
     dv,
🧰 Tools
🪛 Ruff (0.14.10)

386-386: Unused function argument: g

(ARG001)


387-387: Unused function argument: g_gamma

(ARG001)


393-393: Unused function argument: scale

(ARG001)


399-399: Unused function argument: BK

(ARG001)


401-401: Unused function argument: USE_G

(ARG001)


402-402: Unused function argument: USE_G_GAMMA

(ARG001)

Comment on lines 421 to 429
if USE_A:
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))


o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

b_A used before assignment when USE_A is False.

If USE_A is False (line 421), b_A is never assigned, but it's used on line 429. This will cause a runtime error.

🐛 Initialize b_A before conditional
     if USE_A:
         p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
         b_A = tl.load(p_A, boundary_check=(0, 1))
-    
+    else:
+        # Initialize identity-like matrix when A is not provided
+        b_A = tl.zeros([BT, BT], dtype=tl.float32)

     o_t = i_t * BT + tl.arange(0, BT)

Alternatively, if USE_A=False is not a valid code path, add an assertion in the wrapper.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if USE_A:
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty)
if USE_A:
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))
else:
# Initialize identity-like matrix when A is not provided
b_A = tl.zeros([BT, BT], dtype=tl.float32)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty)
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 421 - 429, The branch
guarding creation of b_A (when USE_A is True) leaves b_A undefined if USE_A is
False; ensure b_A is always initialized before its use by adding an else that
sets b_A to a zero tensor with the same shape and dtype as the loaded value
(matching do.dtype.element_ty and shape of p_A load / m_A result), or assert
that USE_A must be True; update the block around the USE_A check (where p_A, b_A
are created) so b_A is defined for both branches before the tl.where(...) on
line using m_A.

Comment on lines 455 to 460
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: check_shared_mem used as boolean instead of being called.

Line 457 uses check_shared_mem (the function object) in a boolean context instead of calling it with check_shared_mem(). A function object is always truthy, so this branch will always execute.

🐛 Fix the missing function call
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem('ampere', k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 455 - 460, The conditional
mistakenly uses the function object check_shared_mem instead of calling it,
causing the middle branch to always run; update the second branch to call
check_shared_mem with the same signature as the first (e.g.,
check_shared_mem('hopper', k.device.index)) or the appropriate args, so
CONST_TILING is set to 64 only when the function returns truthy; ensure the
first branch remains check_shared_mem('hopper', k.device.index) and the else
remains CONST_TILING = 32.

log = tldevice.fast_logf
log2 = tldevice.fast_log2f

IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential crash if no CUDA device is available.

torch.cuda.get_device_name(0) and torch.cuda.get_device_capability() will raise an error if no CUDA device is present. This executes at module import time.

🛡️ Suggested defensive initialization
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
+def _detect_hopper():
+    try:
+        return 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+    except Exception:
+        return False
+
+IS_NVIDIA_HOPPER = _detect_hopper()
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_utils.py at line 25, The module-level expression
that sets IS_NVIDIA_HOPPER calls torch.cuda.get_device_name(0) and
torch.cuda.get_device_capability() which will raise if no CUDA device exists;
change the initialization of IS_NVIDIA_HOPPER in fla_utils.py to first check
torch.cuda.is_available() (and/or torch.cuda.device_count() > 0) and then safely
call get_device_name/get_device_capability inside a try/except, defaulting
IS_NVIDIA_HOPPER to False on any exception so import-time errors are avoided.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
examples/KDA/FLA_KDA/cumsum.py (4)

56-64: Inconsistent reverse cumsum implementation between scalar and vector kernels.

The scalar kernel manually computes reverse cumsum while the vector kernel (line 116) uses tl.cumsum(b_s, axis=0, reverse=True). Consider using the native reverse=True parameter here for consistency if supported for 1D tensors, or add a comment explaining why the manual approach is necessary.


178-179: Redundant condition: if i_c >= 0: is always true.

Since i_c iterates from range(NT) starting at 0, this condition is always satisfied. The check can be removed.

Suggested fix
         b_o += b_z
-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss
         if HAS_SCALE:
             b_o *= scale

246-255: Consider using explicit Optional type hints for nullable parameters.

PEP 484 prohibits implicit Optional. Parameters like scale, cu_seqlens, and chunk_indices that default to None should be typed as Optional[T].

Suggested fix
+from typing import Optional
+
 def chunk_local_cumsum_scalar(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: torch.Tensor  = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
-    output_dtype: torch.dtype  = torch.float,
-    chunk_indices: torch.LongTensor  = None,
+    output_dtype: torch.dtype = torch.float,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ) -> torch.Tensor:

This applies to similar signatures in other wrapper functions as well.


427-438: Unused **kwargs parameter.

The kwargs parameter is accepted but never used. If this is for forward compatibility, consider documenting it. Otherwise, remove to avoid confusion.

Option 1: Remove unused kwargs
 def chunk_local_cumsum(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
     scale: float = None,
     cu_seqlens: torch.Tensor  = None,
     head_first: bool = False,
     output_dtype: torch.dtype  = torch.float,
     chunk_indices: torch.LongTensor  = None,
-    **kwargs,
 ) -> torch.Tensor:
Option 2: Document if intentional for API stability
 def chunk_local_cumsum(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
     scale: float = None,
     cu_seqlens: torch.Tensor  = None,
     head_first: bool = False,
     output_dtype: torch.dtype  = torch.float,
     chunk_indices: torch.LongTensor  = None,
-    **kwargs,
+    **kwargs,  # Reserved for future parameters / API compatibility
 ) -> torch.Tensor:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f7adecc and 1ef405a.

📒 Files selected for processing (1)
  • examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py

33-33: Unused function argument: B

(ARG001)


88-88: Unused function argument: B

(ARG001)


146-146: Unused function argument: B

(ARG001)


206-206: Unused function argument: B

(ARG001)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-287: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


330-330: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


361-361: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


395-395: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


420-424: Avoid specifying long messages outside the exception class

(TRY003)


432-432: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


437-437: Unused function argument: kwargs

(ARG001)


464-468: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (7)
examples/KDA/FLA_KDA/cumsum.py (7)

1-11: LGTM!

Standard imports and constants for Triton kernel development.


67-122: LGTM!

Clean implementation using native Triton reverse cumsum support.


185-244: LGTM!

Global vector cumsum implementation is correct with proper accumulation logic.


283-323: LGTM!

Vector cumsum wrapper correctly sets up the dynamic grid and delegates to the kernel.


325-354: LGTM!

Global scalar cumsum wrapper correctly handles both fixed-length and variable-length sequences.


356-388: LGTM!

Global vector cumsum wrapper correctly computes block size and grid dimensions.


390-425: LGTM!

Dispatcher correctly routes based on tensor dimensionality with appropriate validation.

@wfloveiu wfloveiu force-pushed the example-kda-algorithm branch from 1ef405a to ffd4fa5 Compare January 12, 2026 15:14
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In @examples/KDA/FLA_KDA/cumsum.py:
- Around line 437-438: The function in examples/KDA/FLA_KDA/cumsum.py whose
signature currently includes an unused **kwargs parameter should either drop
**kwargs or explicitly handle/validate unexpected keyword args; update the
function signature to remove **kwargs if no extra kwargs are intended, or add
code at the start of that function to raise a TypeError listing unexpected keys
(so callers aren’t silently ignored). Locate the function with the signature
that returns torch.Tensor (the one capturing **kwargs) and apply the chosen
change consistently across any internal calls or tests.
- Around line 178-179: The condition "if i_c >= 0" is redundant because the loop
initializes i_c at 0 so it is always true; remove the conditional and directly
perform the accumulation (replace the guarded statement with a direct "b_z +=
b_ss" inside the loop) at the location using variables i_c, b_z, and b_ss to
eliminate dead code.
- Around line 464-468: The error message raised for unsupported input shape when
checking g.shape is inaccurate for the 3D scalar case; update the message in the
cumsum module (the raise that references g.shape) to mirror chunk_global_cumsum
style and list both allowed shapes like "[B, T, H] / [B, T, H, D]" (or
vice-versa depending on head_first) so it correctly documents the 3D scalar and
4D tensor possibilities.
🧹 Nitpick comments (2)
examples/KDA/FLA_KDA/cumsum.py (2)

297-301: Inconsistent assertion ordering compared to scalar version.

The power-of-2 assertion is placed after prepare_chunk_indices is called with chunk_size, whereas in chunk_local_cumsum_scalar (line 260), the assertion is placed before any usage. Consider moving the assertion before line 298 for consistency and to catch invalid input earlier.

Suggested fix
     if head_first:
         B, H, T, S = g.shape
     else:
         B, T, H, S = g.shape
+    assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
     BT = chunk_size
     if chunk_indices is None and cu_seqlens is not None:
         chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
     NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
-    assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"

     g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)

246-255: Consider using explicit Optional[T] type hints.

Several parameters across the file use implicit optional types (e.g., scale: float = None). PEP 484 recommends using explicit Optional[float] for clarity. This applies to scale, cu_seqlens, and chunk_indices parameters throughout the file.

Example fix for chunk_local_cumsum_scalar
+from typing import Optional
+
 def chunk_local_cumsum_scalar(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: torch.Tensor  = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
     output_dtype: torch.dtype  = torch.float,
-    chunk_indices: torch.LongTensor  = None,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ) -> torch.Tensor:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ef405a and ffd4fa5.

📒 Files selected for processing (1)
  • examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/KDA/FLA_KDA/cumsum.py (2)
examples/KDA/FLA_KDA/fla_utils.py (2)
  • prepare_chunk_indices (100-105)
  • input_guard (147-178)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py

33-33: Unused function argument: B

(ARG001)


88-88: Unused function argument: B

(ARG001)


146-146: Unused function argument: B

(ARG001)


206-206: Unused function argument: B

(ARG001)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-287: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


330-330: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


361-361: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


395-395: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


420-424: Avoid specifying long messages outside the exception class

(TRY003)


432-432: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


437-437: Unused function argument: kwargs

(ARG001)


464-468: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
examples/KDA/FLA_KDA/cumsum.py (5)

13-65: LGTM!

The scalar local cumsum kernel logic is correct. The reverse cumsum implementation properly computes the suffix sums, and the pointer arithmetic correctly handles both HEAD_FIRST and non-HEAD_FIRST memory layouts. The B parameter is intentionally included in the autotune key even though it's not used in the kernel body.


67-122: LGTM!

The vector local cumsum kernel correctly uses Triton's native reverse=True parameter for the reverse cumsum operation. The 2D block pointer setup properly handles the additional state dimension S.


185-244: LGTM!

The vector global cumsum kernel correctly accumulates state across chunks with proper ordering of the store and accumulator update operations.


325-387: LGTM!

Both global cumsum wrappers correctly configure the grid dimensions and kernel parameters. The BS constraint in the vector version appropriately limits block size for better performance.


390-424: LGTM!

The public API correctly dispatches to scalar or vector implementations based on input dimensionality, with clear error messaging for unsupported shapes.

Comment on lines +178 to +179
if i_c >= 0:
b_z += b_ss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unnecessary condition: if i_c >= 0 is always true.

Since the loop starts at i_c = 0, this condition will always evaluate to True. This appears to be dead code or a leftover from previous logic.

Suggested fix
         b_o += b_z
-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss
         if HAS_SCALE:
             b_o *= scale
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if i_c >= 0:
b_z += b_ss
b_o += b_z
b_z += b_ss
if HAS_SCALE:
b_o *= scale
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 178 - 179, The condition "if i_c
>= 0" is redundant because the loop initializes i_c at 0 so it is always true;
remove the conditional and directly perform the accumulation (replace the
guarded statement with a direct "b_z += b_ss" inside the loop) at the location
using variables i_c, b_z, and b_ss to eliminate dead code.

Comment on lines +437 to +438
**kwargs,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused **kwargs may silently ignore invalid arguments.

The **kwargs parameter is captured but never used, which could lead to confusing behavior where callers pass invalid keyword arguments that are silently ignored rather than raising an error. Consider either removing it or documenting its intended purpose.

Suggested fix (remove if not needed)
 def chunk_local_cumsum(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
     scale: float = None,
     cu_seqlens: torch.Tensor  = None,
     head_first: bool = False,
     output_dtype: torch.dtype  = torch.float,
     chunk_indices: torch.LongTensor  = None,
-    **kwargs,
 ) -> torch.Tensor:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
**kwargs,
) -> torch.Tensor:
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: torch.Tensor = None,
head_first: bool = False,
output_dtype: torch.dtype = torch.float,
chunk_indices: torch.LongTensor = None,
) -> torch.Tensor:
🧰 Tools
🪛 Ruff (0.14.10)

437-437: Unused function argument: kwargs

(ARG001)

🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 437 - 438, The function in
examples/KDA/FLA_KDA/cumsum.py whose signature currently includes an unused
**kwargs parameter should either drop **kwargs or explicitly handle/validate
unexpected keyword args; update the function signature to remove **kwargs if no
extra kwargs are intended, or add code at the start of that function to raise a
TypeError listing unexpected keys (so callers aren’t silently ignored). Locate
the function with the signature that returns torch.Tensor (the one capturing
**kwargs) and apply the chosen change consistently across any internal calls or
tests.

Comment on lines 464 to 468
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Error message is inaccurate for the scalar case.

The error message states the expected shape is (B, T, H, D) but the 3D scalar case expects (B, T, H) without the D dimension. Consider updating to match the style used in chunk_global_cumsum which mentions both [B, T, H]/[B, T, H, D].

Suggested fix
     else:
         raise ValueError(
             f"Unsupported input shape {g.shape}, "
-            f"which should be (B, T, H, D) if `head_first=False` "
-            f"or (B, H, T, D) otherwise",
+            f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
+            f"or [B, H, T]/[B, H, T, D] otherwise",
         )
🧰 Tools
🪛 Ruff (0.14.10)

464-468: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 464 - 468, The error message
raised for unsupported input shape when checking g.shape is inaccurate for the
3D scalar case; update the message in the cumsum module (the raise that
references g.shape) to mirror chunk_global_cumsum style and list both allowed
shapes like "[B, T, H] / [B, T, H, D]" (or vice-versa depending on head_first)
so it correctly documents the 3D scalar and 4D tensor possibilities.

@wfloveiu wfloveiu force-pushed the example-kda-algorithm branch from ffd4fa5 to 801ae9b Compare January 12, 2026 15:50
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
examples/KDA/FLA_KDA/cumsum.py (4)

59-61: Inconsistent reverse cumsum implementation.

The vector kernel uses tl.cumsum(b_s, axis=0, reverse=True) directly (line 116), but this scalar kernel manually computes the reverse. Consider using the built-in parameter for consistency:

♻️ Suggested simplification
-    b_o = tl.cumsum(b_s, axis=0)
-    if REVERSE:
-        b_z = tl.sum(b_s, axis=0)
-        b_o = -b_o + b_z[None] + b_s
+    if REVERSE:
+        b_o = tl.cumsum(b_s, axis=0, reverse=True)
+    else:
+        b_o = tl.cumsum(b_s, axis=0)

178-179: Dead code: condition is always true.

i_c ranges from 0 to NT-1, so if i_c >= 0 is always True. This branch is redundant.

♻️ Remove dead condition
         b_o += b_z
-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss
         if HAS_SCALE:

250-254: Use explicit Optional type hints per PEP 484.

Parameters with = None default should use Optional[T] type annotation for clarity and static analysis compatibility.

♻️ Suggested type hint fix
+from typing import Optional
+
 def chunk_local_cumsum_scalar(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: torch.Tensor  = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
-    output_dtype: torch.dtype  = torch.float,
-    chunk_indices: torch.LongTensor  = None,
+    output_dtype: Optional[torch.dtype] = torch.float,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ) -> torch.Tensor:

Note: This same pattern applies to all wrapper functions (chunk_local_cumsum_vector, chunk_global_cumsum_scalar, chunk_global_cumsum_vector, chunk_global_cumsum, and chunk_local_cumsum).


437-438: Consider removing or documenting unused **kwargs.

The **kwargs parameter is not used within the function. If it's intended for forward compatibility or interface consistency with other modules, consider adding a docstring to document this intent. Otherwise, removing it would improve clarity.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ffd4fa5 and 801ae9b.

📒 Files selected for processing (1)
  • examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py

33-33: Unused function argument: B

(ARG001)


88-88: Unused function argument: B

(ARG001)


146-146: Unused function argument: B

(ARG001)


206-206: Unused function argument: B

(ARG001)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-287: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


330-330: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


361-361: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


395-395: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


420-424: Avoid specifying long messages outside the exception class

(TRY003)


432-432: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


437-437: Unused function argument: kwargs

(ARG001)


464-468: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (7)
examples/KDA/FLA_KDA/cumsum.py (7)

1-11: LGTM!

Imports and constants are appropriate for the Triton kernel implementation.


67-122: LGTM!

The vector kernel correctly handles 2D blocks with proper boundary checks and uses the built-in reverse=True parameter for cumulative sum.


185-244: LGTM!

The global vector kernel correctly implements the chunked global cumulative sum with proper handling of both forward and reverse directions.


283-323: LGTM with minor note.

Implementation is correct. The power-of-2 assertion placement differs from the scalar version (line 301 vs 260), but this doesn't affect functionality. Same Optional type hint improvements apply here as noted above.


325-354: LGTM!

The global scalar wrapper correctly handles both regular and variable-length inputs.


356-388: LGTM!

The global vector wrapper correctly computes block size and handles the vectorized cumsum.


390-425: LGTM!

The dispatcher correctly routes to scalar/vector implementations based on input dimensionality with helpful error messages.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 302-304: main() is passing TileLang types (T.float32) into
run_test via parameters input_dtype/gate_dtype/qk_dtype, but run_test expects
string dtypes used with getattr(torch, ...); change the call sites in main() to
pass string names like "float32" (or whatever exact torch dtype attribute name
you need) for input_dtype, gate_dtype and qk_dtype, or alternatively update
run_test to accept TileLang types by converting them to the corresponding torch
dtype before calling getattr; update the parameters consistently so
getattr(torch, ...) receives a string key.

In @examples/KDA/chunk_bwd_gla_dA.py:
- Around line 83-84: The V_shared buffer is allocated with do_dtype but later
populated from V which uses input_dtype, risking implicit casts; change the
allocation for V_shared to use input_dtype (i.e., T.alloc_shared((block_S,
block_DV), dtype=input_dtype)) or explicitly cast V to do_dtype at the copy site
so the types match; update the allocation of V_shared (and any corresponding
uses) to prevent precision loss between do_dtype and input_dtype.
- Line 12: Remove the hardcoded os.environ["CUDA_VISIBLE_DEVICES"] = "7"
assignment and either delete the line entirely or replace it with reading a
configurable value (e.g., via argparse or
os.environ.get("CUDA_VISIBLE_DEVICES")) so device selection is not forced in the
script; ensure any replacement uses a fallback (None or empty) and document that
users should set CUDA_VISIBLE_DEVICES externally if needed.

In @examples/KDA/chunk_bwd_intra_op.py:
- Around line 508-512: The dtype parameters (input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype) are using TileLang type objects
(T.float32) instead of string names; update each to use the matching string
representation (e.g., "float32") wherever these kwargs are set in this module
(replace T.float32 with "float32" for
input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype) to match the
pattern used in the other files.

In @examples/KDA/chunk_delta_bwd.py:
- Around line 267-307: The test computes FLA references (dh_ref, dh0_ref,
dv2_ref) but never verifies TileLang outputs because the compare_tensors calls
are commented out; uncomment and invoke compare_tensors to validate correctness
(compare_tensors("dh", dh_ref, dh_tilelang), compare_tensors("dh0", dh0_ref,
dh0_tilelang), compare_tensors("dv2", dv2_ref, dv2_tilelang")). Ensure these
comparisons only run when the references are actually computed (i.e., wrap them
in the same use_gk conditional or compute refs unconditionally), and import or
define compare_tensors if missing so the comparisons execute after running
chunk_gated_delta_rule_bwd_dhu and tilelang_chunk_gated_delta_rule_bwd_dhu.

In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 307-317: The benchmark invocation passes the wrong keyword to the
reference function: change the do_bench call that references
chunk_gated_delta_rule_fwd_h to use the gk= parameter instead of g=G so the
reference gets the intended gate constant; locate the do_bench call where
chunk_gated_delta_rule_fwd_h is passed and replace the g=G argument with gk=G.
- Around line 336-340: The test is passing TileLang type objects (T.float16,
T.float32) to run_test but run_test expects string dtype names; update the call
site in main() (the parameters input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype) to pass corresponding string names ("float16",
"float32", etc.) instead of T.float16/T.float32 so the dtype formats match
run_test's expectations.

In @examples/KDA/chunk_inter_solve_fused.py:
- Line 13: Remove the hardcoded CUDA device assignment that sets
os.environ["CUDA_VISIBLE_DEVICES"] = "7"; either delete this line or replace it
with a configurable approach (read the device index from an environment variable
or a CLI/config parameter and only set CUDA_VISIBLE_DEVICES if that value is
present/valid). Update the initialization in the script (the line using
os.environ["CUDA_VISIBLE_DEVICES"]) to validate the provided value before
applying it, and document or expose the configuration parameter so users on
systems with different GPU counts can override it.

In @examples/KDA/chunk_intra_token_parallel.py:
- Around line 260-266: Aqk_tilelang and Akk_tilelang are computed but never
validated because the comparison/assertion block is commented out; restore or
add explicit comparisons against the reference tensors (Aqk_ref, Akk_ref) using
existing helpers like assert_similar or compare_tensors from test_utils.py and
ensure test_utils is imported, e.g., call compare_tensors("Aqk", Aqk_ref,
Aqk_tilelang) and compare_tensors("Akk", Akk_ref, Akk_tilelang) (or equivalent
assert_similar checks) where the commented lines (around the existing 291-296
comparison block) currently reside so the kernel outputs are actually validated.

In @examples/KDA/chunk_o.py:
- Around line 132-134: The parallel loop uses indices (i_s, i_v) over (block_S,
block_DV) but then indexes Q_shared and GK_shared which are shaped (block_S,
block_DK); change the loop to iterate over (block_S, block_DK) (e.g., (i_s,
i_k)) or otherwise use the DK-range for the second index, and update uses of
Q_shared and GK_shared to Q_shared[i_s, i_k] and GK_shared[i_s, i_k] (and
GQ_shared[i_s, i_k]) so the second index matches the buffers' block_DK
dimension.

In @examples/KDA/FLA_KDA/fla_chunk_delta.py:
- Around line 487-494: When cu_seqlens is provided you never assign N, NT, or
chunk_offsets, causing UnboundLocalError; update the branch that handles
cu_seqlens (around the prepare_chunk_indices call) to also set N (number of
sequences), NT (number of chunks, e.g., from chunk_indices or using triton.cdiv
on totals), and chunk_offsets (derived from cu_seqlens or prepare_chunk_indices
output) before they are used; make the same fix inside
chunk_gated_delta_rule_bwd_dhu so both code paths initialize N, NT, and
chunk_offsets when cu_seqlens is not None.

In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 432-437: The elif currently tests the function object instead of
calling it; change "elif check_shared_mem:" to call the function (e.g., "elif
check_shared_mem(\"hopper\", k.device.index):" or the correct argument set for
your use case) so the condition actually invokes check_shared_mem rather than
always evaluating truthiness of the function object, and keep CONST_TILING
assignments as-is.

In @examples/KDA/FLA_KDA/fla_wy_fast.py:
- Around line 269-277: The backward helper prepare_wy_repr_bwd hardcodes BT = 64
which can mismatch the forward recompute_w_u_fwd where BT is derived as
A.shape[-1]; change prepare_wy_repr_bwd to compute BT = A.shape[-1] (or
otherwise derive BT from the same input used in forward, e.g., A or the relevant
tensor shape) instead of the hardcoded 64 so forward and backward use the same
tiling.

In @examples/KDA/test_utils.py:
- Around line 27-30: The non-finite comparison in the block using
x.masked_fill(x_mask, 0) and y.masked_fill(y_mask, 0) has the mask inverted;
currently finite positions are being filled and non-finite values compared.
Change the masked_fill calls to use the inverted masks (~x_mask and ~y_mask) so
you mask out (fill) finite values and compare only non-finite positions (keep
using equal_nan=True and the existing error/raise behavior), referencing the
variables x, y, x_mask, y_mask and the masked_fill calls to locate the change.

In @examples/KDA/wy_fast_bwd.py:
- Around line 364-374: The call to main() in wy_fast_bwd.py is passing TileLang
types (T.float32) for dtype params while run_test expects string names to use
getattr(torch, ...); update the main(...) invocation to pass string dtype names
(e.g., "float32") for input_dtype, output_dtype, accum_dtype, gate_dtype, and
state_dtype so run_test can resolve torch dtypes via getattr; keep all other
numeric/size params (chunk_size, block_DK, block_DV, threads, num_stages)
unchanged.

In @examples/KDA/wy_fast.py:
- Around line 207-208: The variable use_qg is mistakenly assigned as a
one-element tuple (False,) which is truthy; change its assignment to the boolean
False (remove the trailing comma) so conditional checks using use_qg behave
correctly; confirm similar assignments (e.g., use_kg) are booleans and update
any other occurrences of use_qg in the file to expect a boolean value.
- Around line 266-269: The bug is that main() passes TileLang type objects
(T.bfloat16, T.float32) into run_test which expects string dtype names resolved
with getattr(torch, ...); update the call site in main() to pass dtype names as
strings (e.g., "bfloat16" for input_dtype/output_dtype and "float32" for
gate_dtype/accum_dtype) so run_test can successfully call getattr(torch, dtype).
Reference symbols: main(), run_test(), and the dtype parameters
input_dtype/output_dtype/gate_dtype/accum_dtype.
🟡 Minor comments (7)
examples/KDA/FLA_KDA/fla_utils.py-24-24 (1)

24-24: Potential crash on systems without CUDA devices.

torch.cuda.get_device_name(0) will raise an exception if no CUDA device is available. This executes at module import time, making the entire module unimportable on CPU-only systems.

Consider guarding with a device availability check:

IS_NVIDIA_HOPPER = (
    torch.cuda.is_available() and 
    ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)
)
examples/KDA/chunk_intra_token_parallel.py-8-8 (1)

8-8: Hardcoded GPU device index.

CUDA_VISIBLE_DEVICES="7" is hardcoded, which will fail on systems with fewer GPUs or different configurations. Consider removing this or making it configurable via environment variable fallback.

🔧 Suggested fix
-os.environ["CUDA_VISIBLE_DEVICES"] = "7"
+# os.environ["CUDA_VISIBLE_DEVICES"] = "7"  # Uncomment to pin specific GPU
examples/KDA/FLA_KDA/cumsum.py-176-177 (1)

176-177: Condition i_c >= 0 is always true.

Since i_c iterates from range(NT), it starts at 0 and is always non-negative. This condition provides no filtering.

🐛 If this was meant to skip the first iteration, use:
-        if i_c >= 0:
-            b_z += b_ss
+        if i_c > 0:
+            b_z += b_ss

Otherwise, remove the conditional entirely:

-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss
examples/KDA/chunk_bwd_dqkwg.py-271-274 (1)

271-274: Correctness validation is disabled.

The compare_tensors calls are commented out. Enable these to validate the TileLang kernel against the reference implementation.

Suggested fix
-    # compare_tensors("dq", dq_ref, dq)
-    # compare_tensors("dk", dk_ref, dk)
-    # compare_tensors("dw", dw_ref, dw)
-    # compare_tensors("dg", dg_ref, dg)
+    compare_tensors("dq", dq_ref, dq)
+    compare_tensors("dk", dk_ref, dk)
+    compare_tensors("dw", dw_ref, dw)
+    compare_tensors("dg", dg_ref, dg)
examples/KDA/wy_fast_bwd.py-334-338 (1)

334-338: Correctness validation is disabled.

The compare_tensors calls are commented out, meaning the test only benchmarks without validating correctness. Consider enabling these checks or adding a flag to toggle validation.

Suggested fix
-    # compare_tensors("dA", dA_tilelang, dA_ref)
-    # compare_tensors("dk", dk_tilelang, dk_ref)
-    # compare_tensors("dv", dv_tilelang, dv_ref)
-    # compare_tensors("dbeta", dbeta_tilelang, dbeta_ref)
-    # compare_tensors("dg", dg_tilelang, dg_ref)
+    compare_tensors("dA", dA_tilelang, dA_ref)
+    compare_tensors("dk", dk_tilelang, dk_ref)
+    compare_tensors("dv", dv_tilelang, dv_ref)
+    compare_tensors("dbeta", dbeta_tilelang, dbeta_ref)
+    compare_tensors("dg", dg_tilelang, dg_ref)
examples/KDA/FLA_KDA/fla_chunk_intra.py-10-15 (1)

10-15: Dead code: SOLVE_TRIL_DOT_PRECISION is overwritten.

Lines 10-14 define SOLVE_TRIL_DOT_PRECISION conditionally, but line 15 unconditionally overwrites it with "tf32". The conditional logic is dead code.

If you intend to always use tf32, simplify to:
-IS_TF32_SUPPORTED = False
-if IS_TF32_SUPPORTED:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
-else:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
 SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")

Otherwise, remove line 15 to enable the conditional logic.

examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py-114-168 (1)

114-168: Incorrect return type annotation.

The function signature declares -> None but the function returns Aqk, Akk on line 168. Update the return type annotation to match the actual behavior.

Suggested fix
 def chunk_kda_fwd_intra_token_parallel(
     ...
-) -> None:
+) -> tuple[torch.Tensor, torch.Tensor]:
🧹 Nitpick comments (38)
examples/KDA/chunk_bwd_gla_dA.py (8)

4-4: Remove unused sys import.

The sys module is imported but never used. The noqa: F401 directive is also unnecessary since the lint rule isn't triggered for valid reasons.

Suggested fix
-import sys  # noqa: F401

17-28: Unused chunk_size parameter.

The chunk_size parameter is declared but never used in this function. Either remove it or document why it's kept for interface consistency.

Suggested fix (if not needed for interface consistency)
 def prepare_input(
     B,
     S,
     H,
     DV,
-    chunk_size,
     input_dtype,
     do_dtype,
 ):

31-40: Unused DV parameter.

The DV parameter is unused in this function.

Suggested fix
 def prepare_output(
     B,
     S,
     H,
-    DV,
     chunk_size,
     d_type,
 ):

43-51: Consider moving itertools import to module level.

The itertools import inside the function works but is unconventional. For consistency and slight performance benefit on repeated calls, consider moving it to the top-level imports.


86-86: Remove commented-out code and translate comments.

The commented-out dA_shared allocation and copy statements should be removed if no longer needed. Also, the Chinese comment on line 102 (下三角矩阵) should be translated to English for broader accessibility.

Suggested fix
-            # dA_shared = T.alloc_shared((block_S, block_S), dtype=da_dtype)
...
             for i_s1, i_s2 in T.Parallel(block_S, block_S):
-                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # 下三角矩阵
-            # T.copy(dA_fragment, dA_shared)
+                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # lower triangular matrix

Also applies to: 102-103


130-130: Consider removing debug print statement.

Printing raw timing data on every benchmark call may be noisy. Consider removing it or gating it behind a verbose flag.

Suggested fix
-    print(times)

150-162: Wasted tensor allocation on line 150.

dA_tilelang is allocated via prepare_output on line 150, but then immediately overwritten by the kernel output on line 162. The allocation serves no purpose. If you intended to pass dA_tilelang as an output buffer to the kernel, the kernel invocation pattern would need to change.

Suggested fix
-    dA_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, da_dtype))
     kernel = tilelang_chunk_bwd_kernel_dv_local(
         B=B,
         S=S,
         H=H,
         DV=DV,
         scale=scale,
         input_dtype=input_dtype,
         da_dtype=da_dtype,
         do_dtype=do_dtype,
         chunk_size=chunk_size,
     )
     dA_tilelang = kernel(DO, V_new)

138-138: Unused DK parameter.

The DK parameter is declared but never used in run_test.

examples/KDA/chunk_delta_bwd.py (3)

3-9: Remove debug artifacts and unused import.

The sys import is unused, and the print(tilelang.__file__) statement will execute on every module import, which is not suitable for production/example code.

Proposed fix
-import sys  # noqa: F401
-
 import tilelang
 import tilelang.language as T
 from tilelang.autotuner import autotune
-
-print(tilelang.__file__, flush=True)

22-49: Unused function parameters.

Parameters output_dtype, accum_dtype, and state_dtype are declared but never used. Consider removing them or prefixing with _ if they're placeholders for future use.


263-265: Unnecessary pre-allocation of output tensors.

prepare_output allocates dh_tilelang, dh0_tilelang, and dv2_tilelang, but these are immediately overwritten by the kernel call at line 295. The @tilelang.jit(out_idx=[-3, -2, -1]) decorator already handles output allocation.

Proposed fix
-    dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
-        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
-    )
-
     # fla ref
examples/KDA/test_utils.py (2)

43-45: Unused atol and rtol parameters.

The function signature accepts atol and rtol but they are never used in the comparison logic. Either implement tolerance-based checks or remove these parameters to avoid confusion.


49-54: Minor: Chinese comments and fullwidth characters.

Comments on lines 49, 54, 64, and 68 use Chinese text and fullwidth parentheses. Consider using English comments for broader accessibility, or at minimum replace fullwidth characters with ASCII equivalents for consistency.

examples/KDA/FLA_KDA/fla_utils.py (3)

24-25: Redundant True and prefix.

True and (...) is equivalent to just the condition itself. This pattern suggests a toggle that was left enabled.

♻️ Simplification
-IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)
-USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
+IS_NVIDIA_HOPPER = "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+USE_CUDA_GRAPH = os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"

136-150: Silent exception swallowing may hide configuration issues.

Both try/except blocks catch broad Exception and silently pass. If the Triton API changes or returns unexpected data, failures will be masked, and the function falls back to returning 1 without any indication of why.

Consider logging at debug level:

except Exception as e:
    import logging
    logging.debug(f"Triton 2.2+ API failed: {e}")

224-230: Catching BaseException is overly broad.

BaseException includes KeyboardInterrupt and SystemExit, which should typically propagate. Use Exception instead.

♻️ Proposed fix
     try:
         return [
             triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] for i in range(device_torch_lib.device_count())
         ]
-    except BaseException:
+    except Exception:
         return [-1]
examples/KDA/FLA_KDA/cumsum.py (2)

260-261: Power-of-2 check uses non-standard pattern.

The check chunk_size == 2 ** (chunk_size.bit_length() - 1) works but is fragile for edge cases (fails for 0). Consider the canonical form.

♻️ More robust alternative
-    assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
+    assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, "chunk_size must be a power of 2"

440-441: Unused **kwargs parameter.

The kwargs parameter is accepted but never used. If it's for future extensibility, document it; otherwise, remove it to avoid confusion.

examples/KDA/wy_fast.py (1)

10-10: Avoid hardcoding CUDA_VISIBLE_DEVICES.

Hardcoding a specific GPU device index limits portability and may cause issues in different environments. Consider removing this or using an environment variable fallback.

Suggested fix
-os.environ["CUDA_VISIBLE_DEVICES"] = "7"
+# Allow override via environment variable, default to all visible devices
+# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
examples/KDA/chunk_bwd_dv.py (2)

12-12: Avoid hardcoding CUDA_VISIBLE_DEVICES.

Same issue as in wy_fast.py. Consider removing or making configurable.


133-144: Unused parameters: DK and scale.

These parameters are declared but never used in run_test. Consider removing them or documenting their intended future use.

Suggested fix
 def run_test(
     B,
     S,
     H,
-    DK,
+    DK,  # Used for prepare_input
     DV,
-    scale,
     input_dtype,
     do_dtype,
     output_dtype,
     chunk_size,
 ):
examples/KDA/chunk_bwd_dqkwg.py (2)

11-11: Avoid hardcoding CUDA_VISIBLE_DEVICES.

Same issue across multiple files.


241-248: Consider removing unused parameters.

Multiple parameters (use_gk, use_initial_state, store_final_state, save_new_value, block_DK, block_DV, threads, num_stages) are declared but never used. This adds confusion.

examples/KDA/FLA_KDA/fla_chunk_inter.py (1)

141-155: Unused parameter w and implicit Optional type.

  1. Parameter w is passed but never used in the function body (only used to create dw with same shape).
  2. Per PEP 484, scale: float = None should be scale: Optional[float] = None.
Suggested fix for type hint
+from typing import Optional
+
 def chunk_kda_bwd_dqkwg(
     q: torch.Tensor,
     k: torch.Tensor,
     w: torch.Tensor,
     v: torch.Tensor,
     h: torch.Tensor,
     g: torch.Tensor,
     do: torch.Tensor,
     dh: torch.Tensor,
     dv: torch.Tensor,
-    scale: float = None,
+    scale: Optional[float] = None,
     cu_seqlens: torch.LongTensor = None,
     chunk_size: int = 64,
     chunk_indices: torch.LongTensor = None,
 ):
examples/KDA/chunk_bwd_intra_op.py (2)

470-470: Remove debug print statement.

This debug print should be removed or guarded behind a debug flag.

Suggested fix
-    print(db_tilelang.shape)

200-200: Minor typo in comment.

"i_j is index ofprevious sub_chunks" should be "i_j is index of previous sub-chunks".

Suggested fix
-                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index ofprevious sub_chunks
+                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index of previous sub-chunks
examples/KDA/chunk_o.py (3)

4-4: Remove unused noqa directive.

The sys import is not used and the noqa: F401 directive is unnecessary since F401 is not enabled.

Suggested fix
-import sys  # noqa: F401

213-214: Wasteful allocation immediately overwritten.

prepare_output allocates O_ref on line 213, but it's immediately overwritten by the result of chunk_gla_fwd_o_gk on line 214. Remove the unnecessary allocation.

Suggested fix
-    O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
     O_ref = chunk_gla_fwd_o_gk(Q, V, G, A, HIDDEN, scale, chunk_size=chunk_size, use_exp2=True)

198-201: Unused function arguments.

Parameters block_DK, block_DV, threads, and num_stages are accepted but never used in run_test. Since autotuning is enabled on the kernel, these could be removed from the function signature, or they should be passed to the kernel if manual configuration is intended.

examples/KDA/FLA_KDA/fla_chunk_delta.py (1)

534-534: Use explicit Optional type annotation.

PEP 484 prohibits implicit Optional. The scale parameter should use Optional[float] for clarity.

Suggested fix
+from typing import Optional
+
 def chunk_gated_delta_rule_bwd_dhu(
     ...
-    scale: float = None,
+    scale: Optional[float] = None,
     ...
examples/KDA/chunk_bwd_intra.py (3)

3-3: Remove unused noqa directive.

Suggested fix
-import sys  # noqa: F401

200-200: Typo in comment.

Minor typo: "ofprevious" should be "of previous".

Suggested fix
-                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index ofprevious sub_chunks
+                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index of previous sub_chunks

470-470: Remove debug print statement.

This appears to be leftover debug code that should be removed before merging.

Suggested fix
-    print(db_tilelang.shape)
     dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)
examples/KDA/chunk_inter_solve_fused.py (2)

3-3: Remove unused noqa directive.

Suggested fix
-import sys  # noqa: F401

528-528: Remove debug print statement.

The print(times) inside do_bench appears to be debug output that should be removed.

Suggested fix
-    print(times)
     return times.mean().item()
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)

370-370: Variable all shadows Python builtin.

Using all as a variable name shadows the built-in all() function. Consider using a more descriptive name like total_tokens or batch_tokens.

Suggested fix
-    all = B * T
+    total_tokens = B * T
     if IS_VARLEN:
         ...
     ...
-    db += (i_k * all + bos) * H + i_h
+    db += (i_k * total_tokens + bos) * H + i_h
examples/KDA/FLA_KDA/fla_chunk_o.py (2)

87-88: Condition if i_k >= 0: is always true.

Since i_k comes from range(tl.cdiv(K, BK)), it is always non-negative. This condition appears to be dead code or a placeholder. If it's intentional for clarity/future use, consider adding a comment.


360-381: Multiple unused kernel arguments.

The kernel chunk_bwd_kernel_dv_local declares parameters g, g_gamma, scale, BK, USE_G, and USE_G_GAMMA but never uses them. If these are placeholders for future functionality, consider adding a TODO comment. Otherwise, they should be removed.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 801ae9b and 5d39797.

📒 Files selected for processing (22)
  • .pre-commit-config.yaml
  • examples/KDA/FLA_KDA/cumsum.py
  • examples/KDA/FLA_KDA/fla_chunk_delta.py
  • examples/KDA/FLA_KDA/fla_chunk_inter.py
  • examples/KDA/FLA_KDA/fla_chunk_intra.py
  • examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py
  • examples/KDA/FLA_KDA/fla_chunk_o.py
  • examples/KDA/FLA_KDA/fla_utils.py
  • examples/KDA/FLA_KDA/fla_wy_fast.py
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/chunk_bwd_gla_dA.py
  • examples/KDA/chunk_bwd_intra.py
  • examples/KDA/chunk_bwd_intra_op.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_inter_solve_fused.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/chunk_o.py
  • examples/KDA/test_utils.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
✅ Files skipped from review due to trivial changes (1)
  • .pre-commit-config.yaml
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
🧬 Code graph analysis (12)
examples/KDA/test_utils.py (1)
tilelang/carver/roller/policy/default.py (1)
  • sim (290-291)
examples/KDA/FLA_KDA/fla_chunk_delta.py (4)
examples/KDA/FLA_KDA/fla_utils.py (1)
  • prepare_chunk_indices (107-112)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
  • grid (168-169)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
  • grid (149-150)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
  • grid (320-321)
examples/KDA/chunk_delta_bwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
  • chunk_gated_delta_rule_bwd_dhu (524-579)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/chunk_o.py (6)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
  • chunk_gla_fwd_o_gk (299-340)
tilelang/language/kernel.py (1)
  • threads (215-219)
tilelang/language/annotations.py (1)
  • annotate_layout (27-40)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (62-71)
tilelang/language/dtypes.py (2)
  • bfloat16 (407-407)
  • float32 (310-310)
examples/KDA/chunk_bwd_dqkwg.py (1)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg (141-193)
examples/KDA/chunk_inter_solve_fused.py (3)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
  • chunk_kda_fwd_inter_solve_fused (611-650)
examples/KDA/test_utils.py (1)
  • compare_tensors (43-83)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/FLA_KDA/fla_utils.py (1)
tilelang/utils/device.py (1)
  • get_current_device (14-21)
examples/KDA/wy_fast.py (4)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • recompute_w_u_fwd (210-254)
examples/KDA/test_utils.py (1)
  • compare_tensors (43-83)
tilelang/language/proxy.py (1)
  • Tensor (233-233)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
examples/KDA/FLA_KDA/fla_utils.py (2)
  • prepare_chunk_indices (107-112)
  • check_shared_mem (234-240)
examples/KDA/wy_fast_bwd.py (2)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • prepare_wy_repr_bwd (257-312)
examples/KDA/FLA_KDA/fla_chunk_intra.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
  • prepare_chunk_indices (107-112)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
examples/KDA/FLA_KDA/fla_utils.py (2)
  • prepare_chunk_indices (107-112)
  • check_shared_mem (234-240)
🪛 Ruff (0.14.10)
examples/KDA/test_utils.py

43-43: Unused function argument: atol

(ARG001)


43-43: Unused function argument: rtol

(ARG001)


54-54: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


54-54: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

examples/KDA/chunk_bwd_intra_op.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


27-27: Unused function argument: output_dtype

(ARG001)


28-28: Unused function argument: accum_dtype

(ARG001)


30-30: Unused function argument: state_dtype

(ARG001)


57-57: Unused function argument: chunk_size

(ARG001)


61-61: Unused function argument: state_dtype

(ARG001)


96-96: Unused function argument: state_dtype

(ARG001)


133-133: Unused function argument: db

(ARG001)


419-419: Unused function argument: threads

(ARG001)


420-420: Unused function argument: num_stages

(ARG001)


421-421: Unused function argument: cu_seqlens

(ARG001)


422-422: Unused function argument: chunk_indices

(ARG001)

examples/KDA/FLA_KDA/fla_chunk_delta.py

534-534: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

examples/KDA/FLA_KDA/cumsum.py

32-32: Unused function argument: B

(ARG001)


85-85: Unused function argument: B

(ARG001)


144-144: Unused function argument: B

(ARG001)


206-206: Unused function argument: B

(ARG001)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-287: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


333-333: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


364-364: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


398-398: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


423-427: Avoid specifying long messages outside the exception class

(TRY003)


435-435: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


440-440: Unused function argument: kwargs

(ARG001)


467-469: Avoid specifying long messages outside the exception class

(TRY003)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


30-30: Unused function argument: output_dtype

(ARG001)


31-31: Unused function argument: accum_dtype

(ARG001)


33-33: Unused function argument: state_dtype

(ARG001)


60-60: Unused function argument: gate_dtype

(ARG001)


130-130: Unused function argument: h0

(ARG001)


244-244: Unused function argument: block_DV

(ARG001)


245-245: Unused function argument: threads

(ARG001)


246-246: Unused function argument: num_stages

(ARG001)


247-247: Unused function argument: use_torch

(ARG001)


270-270: Unpacked variable dh_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


270-270: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


270-270: Unpacked variable dv2_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dh_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dh0_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


295-295: Unpacked variable dv2_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_o.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused function argument: output_dtype

(ARG001)


23-23: Unused function argument: accum_dtype

(ARG001)


39-39: Unused function argument: DK

(ARG001)


41-41: Unused function argument: chunk_size

(ARG001)


44-44: Ambiguous variable name: O

(E741)


99-99: Ambiguous variable name: O

(E741)


198-198: Unused function argument: block_DK

(ARG001)


199-199: Unused function argument: block_DV

(ARG001)


200-200: Unused function argument: threads

(ARG001)


201-201: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


47-47: Unused function argument: DV

(ARG001)


48-48: Unused function argument: chunk_size

(ARG001)


49-49: Unused function argument: qk_dtype

(ARG001)


241-241: Unused function argument: use_gk

(ARG001)


242-242: Unused function argument: use_initial_state

(ARG001)


243-243: Unused function argument: store_final_state

(ARG001)


244-244: Unused function argument: save_new_value

(ARG001)


245-245: Unused function argument: block_DK

(ARG001)


246-246: Unused function argument: block_DV

(ARG001)


247-247: Unused function argument: threads

(ARG001)


248-248: Unused function argument: num_stages

(ARG001)


252-252: Unpacked variable dq_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


252-252: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


252-252: Unpacked variable dw_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


252-252: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


269-269: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


269-269: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


269-269: Unpacked variable dw is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


269-269: Unpacked variable dg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_inter_solve_fused.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


27-27: Unused function argument: output_dtype

(ARG001)


28-28: Unused function argument: accum_dtype

(ARG001)


49-49: Unused function argument: sub_chunk_size

(ARG001)

examples/KDA/FLA_KDA/fla_utils.py

33-33: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


140-141: try-except-pass detected, consider logging the exception

(S110)


140-140: Do not catch blind exception: Exception

(BLE001)


149-150: try-except-pass detected, consider logging the exception

(S110)


149-149: Do not catch blind exception: Exception

(BLE001)


229-229: Do not catch blind exception: BaseException

(BLE001)


239-239: Do not catch blind exception: Exception

(BLE001)

examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


41-41: Unused function argument: output_dtype

(ARG001)


42-42: Unused function argument: accum_dtype

(ARG001)


108-108: Unused function argument: block_DK

(ARG001)


248-248: Unused function argument: block_DK

(ARG001)


249-249: Unused function argument: block_DV

(ARG001)


250-250: Unused function argument: threads

(ARG001)


251-251: Unused function argument: num_stages

(ARG001)

examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


18-18: Unused function argument: output_dtype

(ARG001)


71-71: Unused function argument: use_qg

(ARG001)


97-97: Unused function argument: QG

(ARG001)


202-202: Unused function argument: block_DK

(ARG001)


203-203: Unused function argument: block_DV

(ARG001)


204-204: Unused function argument: threads

(ARG001)


205-205: Unused function argument: num_stages

(ARG001)


212-212: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


213-213: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/FLA_KDA/fla_chunk_inter.py

151-151: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


24-24: Unused function argument: output_dtype

(ARG001)


25-25: Unused function argument: accum_dtype

(ARG001)


27-27: Unused function argument: state_dtype

(ARG001)


49-49: Unused function argument: chunk_size

(ARG001)


52-52: Unused function argument: state_dtype

(ARG001)


91-91: Unused function argument: state_dtype

(ARG001)


282-282: Unused function argument: block_DK

(ARG001)


283-283: Unused function argument: block_DV

(ARG001)


284-284: Unused function argument: threads

(ARG001)


285-285: Unused function argument: num_stages

(ARG001)


306-306: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


306-306: Unpacked variable dv_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


306-306: Unpacked variable dbeta_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


306-306: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


306-306: Unpacked variable dA_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


332-332: Unpacked variable dA_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


332-332: Unpacked variable dk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


332-332: Unpacked variable dv_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


332-332: Unpacked variable dbeta_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


332-332: Unpacked variable dg_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_intra_token_parallel.py

22-22: Unused function argument: output_dtype

(ARG001)


23-23: Unused function argument: accum_dtype

(ARG001)


260-260: Unpacked variable Aqk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


260-260: Unpacked variable Akk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/FLA_KDA/fla_chunk_o.py

363-363: Unused function argument: g

(ARG001)


364-364: Unused function argument: g_gamma

(ARG001)


370-370: Unused function argument: scale

(ARG001)


376-376: Unused function argument: BK

(ARG001)


378-378: Unused function argument: USE_G

(ARG001)


379-379: Unused function argument: USE_G_GAMMA

(ARG001)


422-422: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

examples/KDA/chunk_bwd_dv.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


39-39: Unused function argument: chunk_size

(ARG001)


139-139: Unused function argument: scale

(ARG001)

examples/KDA/chunk_bwd_gla_dA.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused function argument: chunk_size

(ARG001)


35-35: Unused function argument: DV

(ARG001)


138-138: Unused function argument: DK

(ARG001)

examples/KDA/chunk_bwd_intra.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


27-27: Unused function argument: output_dtype

(ARG001)


28-28: Unused function argument: accum_dtype

(ARG001)


30-30: Unused function argument: state_dtype

(ARG001)


57-57: Unused function argument: chunk_size

(ARG001)


61-61: Unused function argument: state_dtype

(ARG001)


96-96: Unused function argument: state_dtype

(ARG001)


133-133: Unused function argument: db

(ARG001)


419-419: Unused function argument: threads

(ARG001)


420-420: Unused function argument: num_stages

(ARG001)


421-421: Unused function argument: cu_seqlens

(ARG001)


422-422: Unused function argument: chunk_indices

(ARG001)

🔇 Additional comments (19)
examples/KDA/chunk_bwd_gla_dA.py (1)

170-182: LGTM!

Test configuration with reasonable parameters for benchmarking. The hardcoded values are appropriate for an example/benchmark script.

examples/KDA/chunk_delta_bwd.py (6)

52-67: LGTM!

Output tensor preparation looks correct. The unused gate_dtype parameter mirrors the pattern in prepare_input for API consistency.


70-79: LGTM!

Autotune configuration generation is well-structured, producing 60 combinations for exploration.


130-130: Unused h0 input parameter in kernel.

The h0 (initial state) tensor is passed to the kernel but never used in the computation. Comparing with the reference implementation chunk_gated_delta_rule_bwd_dhu in FLA_KDA/fla_chunk_delta.py, h0 is also passed to the Triton kernel. Please verify whether h0 should contribute to the backward gradient computation, or if it's only needed for the dh0 output derivation that occurs at line 223.


182-224: LGTM!

The backward pass kernel logic correctly implements:

  • Reverse iteration for gradient accumulation
  • Pipelined memory operations with configurable stages
  • GEMM-based gradient updates for dv and dh
  • Optional GK scaling with exp2

310-332: LGTM!

Standard CUDA event-based benchmarking implementation with proper synchronization.


360-361: LGTM!

Standard entry point pattern.

examples/KDA/chunk_intra_token_parallel.py (1)

112-115: Shared memory allocation size mismatch.

Aqk_shared and Akk_shared are allocated with shape (block_H, DK) but are used for intermediate reductions that eventually write to Sum_Aqk_shared of shape (block_H, CS). This seems correct for element-wise products, but verify that DK matches the expected dimension for the dot product accumulation.

examples/KDA/FLA_KDA/fla_wy_fast.py (1)

309-311: Input gradients dk and dg are overwritten without accumulation.

The function receives dk and dg as inputs (potentially containing existing gradients), but lines 309-310 unconditionally replace them with dk2 and dg2. If these inputs are meant to accumulate gradients, this discards prior values.

Verify if this is intentional. If accumulation is expected:

-    dk = dk2
-    dg = dg2
+    dk.copy_(dk2)
+    dg.copy_(dg2)

Or if the inputs serve only as workspace, consider renaming them to clarify intent.

examples/KDA/FLA_KDA/cumsum.py (1)

57-60: Inconsistent reverse cumsum implementation between scalar and vector kernels.

In chunk_local_cumsum_scalar_kernel (lines 57-60), reverse cumsum is manually computed as -cumsum + sum + original. In chunk_local_cumsum_vector_kernel (line 113), the built-in tl.cumsum(..., reverse=True) is used. This could lead to subtle numerical differences.

Verify that both approaches produce identical results, or consider using the same method for consistency.

examples/KDA/wy_fast.py (2)

57-58: LGTM on autotune and JIT configuration.

The autotune decorator with config generation and JIT compilation with fast math enabled follows established TileLang patterns.


115-126: LGTM on swizzled layout annotations.

Proper use of swizzled layouts for shared memory buffers to optimize memory access patterns.

examples/KDA/wy_fast_bwd.py (1)

135-241: LGTM on backward kernel implementation.

The kernel correctly implements the backward pass with proper shared memory allocation, fragment management, and gradient computations for dA, dk, dv, dbeta, and dg.

examples/KDA/chunk_bwd_dv.py (2)

85-86: Potential type mismatch for A_shared.

A_shared is allocated with do_dtype but A tensor has input_dtype. This may cause implicit type conversions. Verify this is intentional.


76-104: LGTM on DV backward kernel.

The kernel correctly implements the backward DV computation with proper triangular masking and pipelined execution.

examples/KDA/FLA_KDA/fla_chunk_inter.py (2)

14-29: LGTM on autotune configuration.

Good use of Triton's heuristics for IS_VARLEN detection and autotune with multiple block size and warp configurations.


56-90: LGTM on kernel pointer setup and varlen handling.

Proper handling of variable-length sequences with chunk indices and boundary conditions.

examples/KDA/chunk_delta_h_fwd.py (1)

168-198: LGTM on forward kernel recurrence implementation.

The pipelined loop correctly stores the previous hidden state, computes the recurrence V_new = U - W * h, applies gating with exp2, and updates the hidden state.

examples/KDA/chunk_bwd_intra_op.py (1)

141-382: LGTM on intra-chunk backward kernel.

The kernel correctly implements the complex intra-chunk backward pass with:

  • Sub-chunk processing with proper indexing
  • Upper and lower triangular handling for diagonal blocks
  • Proper gradient accumulation for dq, dk, db, dg

Comment on lines +83 to +84
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential dtype mismatch for V_shared.

V_shared is allocated with do_dtype, but it's used to copy from V which has input_dtype. If do_dtype != input_dtype, this could cause implicit type conversion or precision loss.

Suggested fix
             do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
-            V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
+            V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_gla_dA.py around lines 83 - 84, The V_shared buffer
is allocated with do_dtype but later populated from V which uses input_dtype,
risking implicit casts; change the allocation for V_shared to use input_dtype
(i.e., T.alloc_shared((block_S, block_DV), dtype=input_dtype)) or explicitly
cast V to do_dtype at the copy site so the types match; update the allocation of
V_shared (and any corresponding uses) to prevent precision loss between do_dtype
and input_dtype.

Comment on lines 508 to 512
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: passing TileLang types instead of strings.

Same pattern as other files.

Proposed fix
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
input_dtype="float32",
output_dtype="float32",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_intra_op.py around lines 508 - 512, The dtype
parameters (input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype) are
using TileLang type objects (T.float32) instead of string names; update each to
use the matching string representation (e.g., "float32") wherever these kwargs
are set in this module (replace T.float32 with "float32" for
input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype) to match the
pattern used in the other files.

Comment on lines +432 to +437
if check_shared_mem("hopper", k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Missing function call parentheses.

check_shared_mem is a function, but elif check_shared_mem: checks if the function object is truthy (always True) instead of calling it. This means the elif branch will always be taken when the first condition is false.

Suggested fix
     if check_shared_mem("hopper", k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem():
         CONST_TILING = 64
     else:
         CONST_TILING = 32
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 432 - 437, The elif
currently tests the function object instead of calling it; change "elif
check_shared_mem:" to call the function (e.g., "elif
check_shared_mem(\"hopper\", k.device.index):" or the correct argument set for
your use case) so the condition actually invokes check_shared_mem rather than
always evaluating truthiness of the function object, and keep CONST_TILING
assignments as-is.

Comment on lines +269 to +277
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = 64
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Inconsistent BT derivation between forward and backward.

In recompute_w_u_fwd (line 221), BT = A.shape[-1] is derived from the input. In prepare_wy_repr_bwd (line 271), BT = 64 is hardcoded. If the forward pass uses a different BT, the backward pass will produce incorrect results.

🐛 Proposed fix
 def prepare_wy_repr_bwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     gk: torch.Tensor,
     A: torch.Tensor,
     dk: torch.Tensor,
     dw: torch.Tensor,
     du: torch.Tensor,
     dg: torch.Tensor,
     cu_seqlens: torch.LongTensor = None,
     chunk_indices: torch.LongTensor = None,
 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
     B, T, H, K, V = *k.shape, v.shape[-1]
-    BT = 64
+    BT = A.shape[-1]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = 64
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_wy_fast.py around lines 269 - 277, The backward
helper prepare_wy_repr_bwd hardcodes BT = 64 which can mismatch the forward
recompute_w_u_fwd where BT is derived as A.shape[-1]; change prepare_wy_repr_bwd
to compute BT = A.shape[-1] (or otherwise derive BT from the same input used in
forward, e.g., A or the relevant tensor shape) instead of the hardcoded 64 so
forward and backward use the same tiling.

Comment on lines +27 to +30
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Inverted mask logic in non-finite value comparison.

The comparison on line 27 uses masked_fill(x_mask, 0) which fills finite positions with 0, leaving non-finite values for comparison. This appears inverted—you likely want to compare the non-finite values by masking out the finite ones with ~x_mask.

🐛 Proposed fix
-    if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
+    if not torch.isclose(x.masked_fill(~x_mask, 0), y.masked_fill(~y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
🤖 Prompt for AI Agents
In @examples/KDA/test_utils.py around lines 27 - 30, The non-finite comparison
in the block using x.masked_fill(x_mask, 0) and y.masked_fill(y_mask, 0) has the
mask inverted; currently finite positions are being filled and non-finite values
compared. Change the masked_fill calls to use the inverted masks (~x_mask and
~y_mask) so you mask out (fill) finite values and compare only non-finite
positions (keep using equal_nan=True and the existing error/raise behavior),
referencing the variables x, y, x_mask, y_mask and the masked_fill calls to
locate the change.

Comment on lines +364 to +374
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: passing TileLang types instead of strings.

Similar to wy_fast.py, main() passes T.float32 but run_test expects string dtype arguments for getattr(torch, ...).

Proposed fix
     run_test(
         B=1,
         S=32768,
         H=8,
         DK=DK,
         DV=DV,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
)
input_dtype="float32",
output_dtype="float32",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast_bwd.py around lines 364 - 374, The call to main() in
wy_fast_bwd.py is passing TileLang types (T.float32) for dtype params while
run_test expects string names to use getattr(torch, ...); update the main(...)
invocation to pass string dtype names (e.g., "float32") for input_dtype,
output_dtype, accum_dtype, gate_dtype, and state_dtype so run_test can resolve
torch dtypes via getattr; keep all other numeric/size params (chunk_size,
block_DK, block_DV, threads, num_stages) unchanged.

Comment on lines +266 to +269
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: passing TileLang types instead of strings.

run_test expects string dtype arguments (e.g., "bfloat16") that are resolved via getattr(torch, input_dtype), but main() passes T.bfloat16, T.float32 directly. This will cause getattr(torch, T.bfloat16) to fail.

Proposed fix
     run_test(
         B=1,
         S=8192,
         H=64,
         DK=128,
         DV=128,
         chunk_size=64,
-        input_dtype=T.bfloat16,
-        output_dtype=T.bfloat16,
-        gate_dtype=T.float32,
-        accum_dtype=T.float32,
+        input_dtype="bfloat16",
+        output_dtype="bfloat16",
+        gate_dtype="float32",
+        accum_dtype="float32",
         block_DK=64,
         block_DV=32,
         threads=128,
         num_stages=3,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
run_test(
B=1,
S=8192,
H=64,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3,
)
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 266 - 269, The bug is that main()
passes TileLang type objects (T.bfloat16, T.float32) into run_test which expects
string dtype names resolved with getattr(torch, ...); update the call site in
main() to pass dtype names as strings (e.g., "bfloat16" for
input_dtype/output_dtype and "float32" for gate_dtype/accum_dtype) so run_test
can successfully call getattr(torch, dtype). Reference symbols: main(),
run_test(), and the dtype parameters
input_dtype/output_dtype/gate_dtype/accum_dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is valid

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wfloveiu, thank you for confirming! Once you've applied the fix to pass string dtype names ("bfloat16", "float32") instead of TileLang type objects, the code should work correctly. Let me know if you need any assistance with the change.


🧠 Learnings used
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

@wfloveiu
Copy link
Contributor Author

Fixed linting issues, ready for review.

@LeiWang1999 LeiWang1999 self-requested a review January 13, 2026 05:08
@tzj-fxz tzj-fxz self-requested a review January 13, 2026 08:27
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 242-245: Re-enable the correctness checks by uncommenting the
compare_tensors calls (compare_tensors("dq", dq_ref, dq), compare_tensors("dk",
dk_ref, dk), compare_tensors("dw", dw_ref, dw), compare_tensors("dg", dg_ref,
dg")) so the TileLang outputs are validated against the reference tensors; if
you expect numerical differences, call compare_tensors with an appropriate
tolerance or document known precision differences near these calls and adjust
the tolerance parameter (or add a comment) to avoid false failures while
preserving verification.

In @examples/KDA/chunk_bwd_intra_op.py:
- Line 16: Remove the debug print setting call
torch.set_printoptions(profile="full") from the top-level of
chunk_bwd_intra_op.py; simply delete that line (and any other ad-hoc
torch.set_printoptions calls) so tensors aren’t forced into verbose output
during normal execution, leaving normal print behavior unchanged.

In @examples/KDA/chunk_delta_bwd.py:
- Around line 301-303: Uncomment the validation calls so TileLang outputs are
compared to the reference: restore the three compare_tensors calls for "dh",
"dh0", and "dv2" (i.e., compare_tensors("dh", dh_ref, dh_tilelang),
compare_tensors("dh0", dh0_ref, dh0_tilelang), compare_tensors("dv2", dv2_ref,
dv2_tilelang")) in chunk_delta_bwd.py so the kernel outputs (dh_tilelang,
dh0_tilelang, dv2_tilelang) are actually validated against
dh_ref/dh0_ref/dv2_ref.
- Around line 121-136: The kernel function declares an unused parameter h0 in
its signature; either remove h0 from the kernel(...) parameter list or implement
the initial state handling referenced by use_initial_state so h0 is consumed. If
removing, update any callers and tests that pass h0 and delete related
h0_shape/h0 dtype defs; if implementing, add logic in kernel (e.g., within the
use_initial_state branch) to incorporate h0 into the backward computation and
ensure dh0 (or dh) is computed correctly from h0. Update any docstrings/comments
to reflect the chosen behavior.

In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 272-282: The benchmark call to do_bench for
chunk_gated_delta_rule_fwd_h is passing the gate parameter as g=G but the
reference implementation expects gk=G, so update the argument name from g to gk
in that do_bench invocation (and verify the chunk_gated_delta_rule_fwd_h
signature accepts gk) to ensure the correct value is forwarded; adjust any other
mismatched calls to use gk consistently.

In @examples/KDA/chunk_intra_token_parallel.py:
- Around line 185-197: The run_test function accepts a scale parameter but it is
only forwarded to the reference implementation
chunk_kda_fwd_intra_token_parallel and not to the TileLang kernel; update the
TileLang kernel's function signature to accept a scale argument (matching the
reference) and pass scale in the kernel invocation inside run_test (the call
currently missing scale); ensure any internal uses of scale in the kernel are
wired through and remove or update the comment "# scale 如何传值" once fixed so both
reference and kernel implementations receive identical parameters for correct
comparison.

In @examples/KDA/wy_fast_bwd.py:
- Around line 310-314: The correctness checks were disabled by commenting out
the compare_tensors calls; re-enable verification by uncommenting the
compare_tensors invocations for dA, dk, dv, dbeta, and dg so dk_tilelang/dk_ref,
dv_tilelang/dv_ref, dbeta_tilelang/dbeta_ref, dg_tilelang/dg_ref, and
dA_tilelang/dA_ref are compared; ensure these reference tensors are populated
before the calls and keep the comparisons after the TileLang outputs are
computed so any mismatches are reported.

In @examples/KDA/wy_fast.py:
- Around line 179-180: The variable use_qg is mistakenly assigned a
single-element tuple (False,) which is truthy; change the assignment of use_qg
in wy_fast.py from the tuple to a boolean (use_qg = False) so conditionals
behave correctly, and scan for any other occurrences of use_qg to ensure they
expect a boolean.
- Line 222: The print statement prints "tritron time" which is a typo; find the
print call that references triton_time (print("tritron time:", triton_time)) and
correct the string to "triton time:" so the output label matches the variable
name and intended wording.
🧹 Nitpick comments (19)
examples/KDA/test_utils.py (3)

43-83: Unused atol and rtol parameters.

The function signature accepts atol and rtol but they're never used. Either remove them or integrate them into the comparison logic (e.g., using torch.allclose with these tolerances).

♻️ Suggested fix to use or remove unused parameters
 def compare_tensors(name, x, y, atol=1e-5, rtol=1e-5):
     import numpy as np
     import torch

     diff = (x - y).abs()

     # ========= 最大绝对误差 =========
     max_abs_err = diff.max().item()
     abs_flat_idx = diff.argmax()
     abs_idx = list(np.unravel_index(abs_flat_idx.cpu().numpy(), diff.shape))
     ...
     print("=====================================\n")
+
+    # Return pass/fail based on tolerances
+    return torch.allclose(x, y, atol=atol, rtol=rtol)

86-107: Duplicate do_bench implementation.

This function is duplicated in examples/KDA/chunk_o.py (lines 161-183). Consider consolidating into a single shared utility to avoid maintenance burden.


3-3: Remove unused constant RCP_LN2 or document its intended purpose.

The constant is defined but never referenced anywhere in the codebase.

examples/KDA/chunk_bwd_dv.py (2)

31-40: Unused chunk_size parameter.

The chunk_size parameter is accepted but never used. Consider removing it for clarity.

♻️ Suggested fix
 def prepare_output(
     B,
     S,
     H,
     DV,
-    chunk_size,
     output_dtype,
 ):
     dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
     return dv

117-132: Unused variable dv_tilelang from prepare_output.

Line 120 allocates dv_tilelang via prepare_output, but it's immediately overwritten by the kernel result on line 131. The pre-allocation is wasted.

♻️ Suggested fix - remove unnecessary pre-allocation
-    dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype))
     kernel = tilelang_chunk_bwd_kernel_dv_local(
         B=B,
         S=S,
         H=H,
         DV=DV,
         input_dtype=input_dtype,
         output_dtype=output_dtype,
         do_dtype=do_dtype,
         chunk_size=chunk_size,
     )
     dv_tilelang = kernel(DO, A)
examples/KDA/chunk_bwd_gla_dA.py (2)

14-37: Unused parameters in preparation functions.

  • prepare_input: chunk_size (line 19) is unused.
  • prepare_output: DV (line 32) is unused; the output shape uses chunk_size for the last dimension.

Consider removing unused parameters or documenting the intent for API consistency.


118-135: Pre-allocated dA_tilelang is unused.

Similar to chunk_bwd_dv.py, line 122 allocates dA_tilelang but it's overwritten on line 134. The allocation is wasted.

examples/KDA/chunk_bwd_intra_op.py (2)

199-199: Minor typo in comment.

"ofprevious" should be "of previous".

📝 Fix typo
-                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index ofprevious sub_chunks
+                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index of previous sub_chunks

444-448: Pre-allocated outputs are immediately overwritten.

Lines 444-446 prepare output tensors that are immediately overwritten by the kernel call on line 447. The prepare_output call is unnecessary.

♻️ Suggested fix
-    dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output(
-        B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
-    )
     dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)
examples/KDA/chunk_delta_h_fwd.py (2)

3-3: Remove unused noqa directive.

The sys import isn't actually used and has an unnecessary noqa: F401 directive. Consider removing the import entirely if not needed, or remove the noqa comment if the import is required.

Suggested fix
-import sys  # noqa: F401

301-305: Dtype passed as TileLang type instead of string.

In main(), the dtypes are passed as T.float16, T.float32 etc., but in run_test() lines 225-228, they are converted using getattr(torch, input_dtype). This works because TileLang types have string representations, but it's inconsistent with how strings are used elsewhere and may be confusing.

examples/KDA/wy_fast.py (1)

184-185: Unused variables should use underscore prefix.

QG_ref and QG_tilelang are unpacked but never used. Use underscore prefix to indicate intentional discard.

Suggested fix
-    W_ref, U_ref, QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
-    W_tilelang, U_tilelang, QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
+    W_ref, U_ref, _QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
+    W_tilelang, U_tilelang, _QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
examples/KDA/chunk_bwd_dqkwg.py (2)

48-52: Use torch.empty instead of torch.randn for output tensors.

Output tensors are initialized with random values using torch.randn, but they will be overwritten by the kernel. Using torch.empty is more appropriate and slightly more efficient.

Suggested fix
-    dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
-    dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
-    dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
-    dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
+    dq = torch.empty(B, S, H, DK, dtype=torch.float32).cuda()
+    dk = torch.empty(B, S, H, DK, dtype=torch.float32).cuda()
+    dw = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()
+    dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()

223-240: Unused reference output variables.

The reference outputs dq_ref, dk_ref, dw_ref, dg_ref and TileLang outputs dq, dk, dw, dg are computed but never used (comparison is commented out). Use underscore prefix for intentional discard or enable the comparison.

examples/KDA/chunk_intra_token_parallel.py (1)

137-144: Commented-out swizzled layout annotations.

The T.annotate_layout block is commented out. If swizzling improves performance, consider enabling it. If it causes issues, document why it's disabled.

examples/KDA/chunk_bwd_intra.py (2)

199-199: Minor typo in comment.

"ofprevious" should be "of previous".

Suggested fix
-                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index ofprevious sub_chunks
+                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index of previous sub_chunks

80-84: TMA disabled - verify necessity.

The kernel disables TMA lowering via TL_DISABLE_TMA_LOWER: True. If this is a workaround for a known issue, consider adding a comment explaining why.

examples/KDA/wy_fast_bwd.py (2)

282-292: Unused reference output variables.

All reference outputs (dk_ref, dv_ref, dbeta_ref, dg_ref, dA_ref) are computed but never used since comparison is disabled. Use underscore prefixes or enable the comparison.

Suggested fix if keeping disabled
-    dk_ref, dv_ref, dbeta_ref, dg_ref, dA_ref = prepare_wy_repr_bwd(
+    _dk_ref, _dv_ref, _dbeta_ref, _dg_ref, _dA_ref = prepare_wy_repr_bwd(

76-79: Both TMA and warp specialization disabled.

The kernel disables both TL_DISABLE_TMA_LOWER and TL_DISABLE_WARP_SPECIALIZED. Consider adding a comment explaining why these optimizations are disabled for this specific kernel.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5d39797 and 71f0fc9.

📒 Files selected for processing (12)
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/chunk_bwd_gla_dA.py
  • examples/KDA/chunk_bwd_intra.py
  • examples/KDA/chunk_bwd_intra_op.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_inter_solve_fused.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/test_utils.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/KDA/chunk_bwd_intra_op.py
  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/wy_fast_bwd.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/chunk_bwd_dv.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_intra_token_parallel.py
🧬 Code graph analysis (8)
examples/KDA/wy_fast.py (3)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • recompute_w_u_fwd (210-254)
examples/KDA/test_utils.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_delta_h_fwd.py (1)
examples/KDA/test_utils.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_delta_bwd.py (1)
examples/KDA/test_utils.py (1)
  • do_bench (86-108)
examples/KDA/test_utils.py (3)
tilelang/carver/roller/policy/default.py (1)
  • sim (290-291)
tilelang/language/tir/op.py (1)
  • all (1913-1930)
examples/KDA/chunk_o.py (1)
  • do_bench (162-184)
examples/KDA/chunk_bwd_intra.py (2)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
  • chunk_kda_bwd_intra (541-608)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/wy_fast_bwd.py (3)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • prepare_wy_repr_bwd (257-312)
examples/KDA/test_utils.py (1)
  • do_bench (86-108)
examples/KDA/chunk_intra_token_parallel.py (2)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/test_utils.py (1)
  • do_bench (86-108)
examples/KDA/chunk_bwd_dv.py (1)
examples/KDA/test_utils.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
🪛 Ruff (0.14.10)
examples/KDA/chunk_bwd_intra_op.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


29-29: Unused function argument: state_dtype

(ARG001)


56-56: Unused function argument: chunk_size

(ARG001)


60-60: Unused function argument: state_dtype

(ARG001)


95-95: Unused function argument: state_dtype

(ARG001)


132-132: Unused function argument: db

(ARG001)


396-396: Unused function argument: threads

(ARG001)


397-397: Unused function argument: num_stages

(ARG001)


398-398: Unused function argument: cu_seqlens

(ARG001)


399-399: Unused function argument: chunk_indices

(ARG001)

examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused function argument: output_dtype

(ARG001)


68-68: Unused function argument: use_qg

(ARG001)


94-94: Unused function argument: QG

(ARG001)


174-174: Unused function argument: block_DK

(ARG001)


175-175: Unused function argument: block_DV

(ARG001)


176-176: Unused function argument: threads

(ARG001)


177-177: Unused function argument: num_stages

(ARG001)


184-184: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


185-185: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused function argument: output_dtype

(ARG001)


32-32: Unused function argument: accum_dtype

(ARG001)


98-98: Unused function argument: block_DK

(ARG001)


213-213: Unused function argument: block_DK

(ARG001)


214-214: Unused function argument: block_DV

(ARG001)


215-215: Unused function argument: threads

(ARG001)


216-216: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused function argument: output_dtype

(ARG001)


29-29: Unused function argument: accum_dtype

(ARG001)


31-31: Unused function argument: state_dtype

(ARG001)


58-58: Unused function argument: gate_dtype

(ARG001)


128-128: Unused function argument: h0

(ARG001)


242-242: Unused function argument: block_DV

(ARG001)


243-243: Unused function argument: threads

(ARG001)


244-244: Unused function argument: num_stages

(ARG001)


245-245: Unused function argument: use_torch

(ARG001)


268-268: Unpacked variable dh_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


268-268: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


268-268: Unpacked variable dv2_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


291-291: Unpacked variable dh_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


291-291: Unpacked variable dh0_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


291-291: Unpacked variable dv2_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/test_utils.py

43-43: Unused function argument: atol

(ARG001)


43-43: Unused function argument: rtol

(ARG001)


54-54: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


54-54: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

examples/KDA/chunk_bwd_intra.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


29-29: Unused function argument: state_dtype

(ARG001)


56-56: Unused function argument: chunk_size

(ARG001)


60-60: Unused function argument: state_dtype

(ARG001)


95-95: Unused function argument: state_dtype

(ARG001)


132-132: Unused function argument: db

(ARG001)


396-396: Unused function argument: threads

(ARG001)


397-397: Unused function argument: num_stages

(ARG001)


398-398: Unused function argument: cu_seqlens

(ARG001)


399-399: Unused function argument: chunk_indices

(ARG001)

examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


43-43: Unused function argument: DV

(ARG001)


44-44: Unused function argument: chunk_size

(ARG001)


45-45: Unused function argument: qk_dtype

(ARG001)


212-212: Unused function argument: use_gk

(ARG001)


213-213: Unused function argument: use_initial_state

(ARG001)


214-214: Unused function argument: store_final_state

(ARG001)


215-215: Unused function argument: save_new_value

(ARG001)


216-216: Unused function argument: block_DK

(ARG001)


217-217: Unused function argument: block_DV

(ARG001)


218-218: Unused function argument: threads

(ARG001)


219-219: Unused function argument: num_stages

(ARG001)


223-223: Unpacked variable dq_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dw_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dw is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_inter_solve_fused.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


24-24: Unused function argument: output_dtype

(ARG001)


25-25: Unused function argument: accum_dtype

(ARG001)


46-46: Unused function argument: sub_chunk_size

(ARG001)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


25-25: Unused function argument: output_dtype

(ARG001)


26-26: Unused function argument: accum_dtype

(ARG001)


28-28: Unused function argument: state_dtype

(ARG001)


50-50: Unused function argument: chunk_size

(ARG001)


53-53: Unused function argument: state_dtype

(ARG001)


92-92: Unused function argument: state_dtype

(ARG001)


258-258: Unused function argument: block_DK

(ARG001)


259-259: Unused function argument: block_DV

(ARG001)


260-260: Unused function argument: threads

(ARG001)


261-261: Unused function argument: num_stages

(ARG001)


282-282: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dv_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dbeta_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dA_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dA_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dv_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dbeta_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dg_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_intra_token_parallel.py

20-20: Unused function argument: output_dtype

(ARG001)


21-21: Unused function argument: accum_dtype

(ARG001)


233-233: Unpacked variable Aqk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


233-233: Unpacked variable Akk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_bwd_dv.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


36-36: Unused function argument: chunk_size

(ARG001)


111-111: Unused function argument: scale

(ARG001)

examples/KDA/chunk_bwd_gla_dA.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


19-19: Unused function argument: chunk_size

(ARG001)


32-32: Unused function argument: DV

(ARG001)


110-110: Unused function argument: DK

(ARG001)

🔇 Additional comments (12)
examples/KDA/test_utils.py (1)

20-40: Inverted mask logic in non-finite value comparison.

Line 27 uses masked_fill(x_mask, 0) which fills finite values with 0, keeping non-finite values. This is correct for comparing non-finite values. However, the variable naming and comments could be clearer, and the logic appears intentional for checking that non-finite values (NaN, Inf) match in both tensors.

Lines 31-32 correctly zero out non-finite values for the similarity calculation.

examples/KDA/chunk_bwd_dv.py (1)

54-102: TileLang kernel implementation looks correct.

The kernel properly:

  • Uses swizzled layouts for shared memory
  • Applies lower-triangular masking on A
  • Uses pipelined loops for DV dimension
  • Performs GEMM with A transposed
examples/KDA/chunk_bwd_gla_dA.py (1)

51-103: Kernel implementation is well-structured.

The kernel correctly:

  • Accumulates dA via pipelined GEMM across V blocks
  • Applies lower-triangular masking with scale
  • Uses appropriate swizzled layouts
examples/KDA/chunk_delta_bwd.py (1)

80-103: Kernel factory implementation follows expected patterns.

The kernel properly handles:

  • Multiple tensor outputs via out_idx=[-3, -2, -1]
  • Autotuning with configurable parameters
  • Proper gradient accumulation in reverse order (line 182)

Note: The use_initial_state flag conditionally stores dh0 (line 220-221), but h0 input is still unused.

examples/KDA/chunk_bwd_intra_op.py (2)

121-382: Complex intra-chunk backward kernel implementation.

The kernel handles:

  • Sub-chunk processing with proper indexing
  • Inter-sub-chunk gradient flow via kg_fragment and gating
  • Lower triangular diagonal processing
  • Proper accumulation of dq2, dk2, db, dg2

The implementation follows the expected pattern from the FLA reference and correctly handles boundary conditions.


130-133: Input tensor db declared but unused in kernel.

The kernel signature includes db: T.Tensor(db_shape, dtype=input_dtype) but it's never read inside the kernel. The actual db addition happens post-kernel in run_test (line 448: db_tilelang.sum(0).add_(db)).

This is correct behavior since the kernel computes partial db2 values that are summed and added to the original db afterward. Consider adding a brief comment to clarify this design.

examples/KDA/chunk_delta_h_fwd.py (1)

158-160: Verify h store ordering in pipelined loop.

The hidden state h is stored at the beginning of each iteration before the recurrence update is computed. This means h[i_s] receives the value computed in the previous iteration (or the initial state for i_s=0). Confirm this matches the intended semantics where h[i] should contain the state before processing chunk i.

examples/KDA/chunk_inter_solve_fused.py (3)

558-559: LGTM: Correctness verification is enabled.

Unlike some other files in this PR, this file properly enables compare_tensors calls to verify the TileLang kernel output matches the reference implementation.


326-336: Verify forward substitution loop bounds.

The pipelined loop starts at index 2 (T.Pipelined(2, T.min(BC, S - i_tc0), ...)). Verify this is intentional and correctly handles the first two rows of the block matrix during forward substitution. A comment explaining why the loop starts at 2 would improve clarity.


89-100: Helpful docstring for complex kernel.

Good documentation of the kernel's purpose and the six-step algorithm. This makes the long kernel much more maintainable.

examples/KDA/chunk_bwd_intra.py (1)

455-458: LGTM: Correctness verification is enabled.

The compare_tensors calls are properly enabled to verify dq, dk, db, and dg outputs against the reference implementation.

examples/KDA/chunk_intra_token_parallel.py (1)

264-269: Correctness verification is disabled.

The compare_tensors calls and debug prints are commented out. Enable these to verify the TileLang implementation matches the reference.

⛔ Skipped due to learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Comment on lines +121 to +136
@T.prim_func
def kernel(
# Input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
GK: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output
dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n examples/KDA/chunk_delta_bwd.py | head -180 | tail -70

Repository: tile-ai/tilelang

Length of output: 3903


🏁 Script executed:

rg -n '\bh0\b' examples/KDA/chunk_delta_bwd.py

Repository: tile-ai/tilelang

Length of output: 653


🏁 Script executed:

wc -l examples/KDA/chunk_delta_bwd.py

Repository: tile-ai/tilelang

Length of output: 95


🏁 Script executed:

cat -n examples/KDA/chunk_delta_bwd.py | sed -n '180,240p'

Repository: tile-ai/tilelang

Length of output: 3286


Remove unused parameter h0 from kernel signature or implement initial state handling logic.

The kernel declares h0: T.Tensor(h0_shape, dtype=input_dtype) at line 128 but never uses it in the kernel body. Given the use_initial_state conditional at line 220, h0 may have been intended for initialization logic that is not yet implemented. Either remove the parameter if it's unused, or implement the logic to incorporate the initial hidden state in the backward pass.

🧰 Tools
🪛 Ruff (0.14.10)

128-128: Unused function argument: h0

(ARG001)

🤖 Prompt for AI Agents
In @examples/KDA/chunk_delta_bwd.py around lines 121 - 136, The kernel function
declares an unused parameter h0 in its signature; either remove h0 from the
kernel(...) parameter list or implement the initial state handling referenced by
use_initial_state so h0 is consumed. If removing, update any callers and tests
that pass h0 and delete related h0_shape/h0 dtype defs; if implementing, add
logic in kernel (e.g., within the use_initial_state branch) to incorporate h0
into the backward computation and ensure dh0 (or dh) is computed correctly from
h0. Update any docstrings/comments to reflect the chosen behavior.

Comment on lines +185 to +197
def run_test(
B,
S,
H,
DK,
scale,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
sub_chunk_size,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat examples/KDA/chunk_intra_token_parallel.py | head -300 | tail -250

Repository: tile-ai/tilelang

Length of output: 8024


🏁 Script executed:

wc -l examples/KDA/chunk_intra_token_parallel.py

Repository: tile-ai/tilelang

Length of output: 106


🏁 Script executed:

# First, let's look at the file structure - check lines around the kernel definition and run_test
sed -n '55,80p' examples/KDA/chunk_intra_token_parallel.py

Repository: tile-ai/tilelang

Length of output: 618


🏁 Script executed:

# Check the run_test function signature (lines 185-197)
sed -n '185,197p' examples/KDA/chunk_intra_token_parallel.py

Repository: tile-ai/tilelang

Length of output: 220


🏁 Script executed:

# Check how scale is used in the kernel call and reference implementation
sed -n '200,270p' examples/KDA/chunk_intra_token_parallel.py

Repository: tile-ai/tilelang

Length of output: 1954


Unused scale parameter in run_test creates inconsistency between reference and kernel implementations.

The scale parameter is passed to run_test (line 185) and used when calling the reference implementation chunk_kda_fwd_intra_token_parallel (line 249), but is not passed to the TileLang kernel (line 252). The kernel function signature (lines 60-73) lacks a scale parameter entirely. Additionally, a code comment (# scale 如何传值 on line 243) indicates the developer was aware this parameter needs to be handled. The TileLang kernel and reference implementation must receive the same parameters for correctness comparison.

🤖 Prompt for AI Agents
In @examples/KDA/chunk_intra_token_parallel.py around lines 185 - 197, The
run_test function accepts a scale parameter but it is only forwarded to the
reference implementation chunk_kda_fwd_intra_token_parallel and not to the
TileLang kernel; update the TileLang kernel's function signature to accept a
scale argument (matching the reference) and pass scale in the kernel invocation
inside run_test (the call currently missing scale); ensure any internal uses of
scale in the kernel are wired through and remove or update the comment "# scale
如何传值" once fixed so both reference and kernel implementations receive identical
parameters for correct comparison.

tilelang_time = do_bench(kernel, K, V, Beta, G, A)
triton_time = do_bench(recompute_w_u_fwd, k=K, v=V, beta=Beta, gk=G, A=A)
print("tilelang time:", tilelang_time)
print("tritron time:", triton_time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: "tritron" should be "triton".

Suggested fix
     print("tilelang time:", tilelang_time)
-    print("tritron time:", triton_time)
+    print("triton time:", triton_time)
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py at line 222, The print statement prints "tritron
time" which is a typo; find the print call that references triton_time
(print("tritron time:", triton_time)) and correct the string to "triton time:"
so the output label matches the variable name and intended wording.

@tzj-fxz
Copy link
Contributor

tzj-fxz commented Jan 13, 2026

Hi! Thank you for contribution!
A small question: In your implementation of KDA kernels, there are two operators chunk_bwd_intra.py and chunk_bwd_intra_op.py that share most part of code. It seems that they have totally same functionality. Could you check them and only keep one operator?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 242-245: The tensor validation calls have been commented out so
correctness isn't verified; re-enable the comparisons by restoring the
compare_tensors invocations for "dq", "dk", "dw", and "dg" (using dq_ref/dq,
dk_ref/dk, dw_ref/dw, dg_ref/dg) in chunk_bwd_dqkwg.py (around the current
commented block) or wrap them behind a debug/validation flag so they run during
test/CI; ensure that compare_tensors is imported/defined and invoked with the
same argument order and any necessary tolerance parameters so validation
succeeds.
- Around line 38-52: The prepare_output function ignores the qk_dtype parameter
by hardcoding dq and dk to torch.float32; update dq and dk allocations in
prepare_output to use dtype=qk_dtype (or convert them to qk_dtype after
creation) so the passed qk_dtype controls their dtype, leaving dw and dg using
gate_dtype as before; locate prepare_output and replace the dtype=torch.float32
occurrences for dq and dk with dtype=qk_dtype.
- Around line 273-275: main() is passing TileLang type objects (T.float32) into
run_test which expects string names and uses getattr(torch, input_dtype); change
the three arguments passed (input_dtype, gate_dtype, qk_dtype) from T.float32
etc. to string names like "float32" (or "float16" as appropriate) so
getattr(torch, ...) works, or alternatively change run_test to accept dtype
objects and skip getattr(torch, ...) — update either the call sites in main() or
the implementation of run_test (referenced by the run_test function name) so the
types and usage match.

In @examples/KDA/chunk_delta_bwd.py:
- Around line 298-300: The three commented-out validation calls disable
correctness checks; restore or justify them by uncommenting the compare_tensors
calls for "dh", "dh0", and "dv2" (i.e., compare_tensors("dh", dh_ref,
dh_tilelang); compare_tensors("dh0", dh0_ref, dh0_tilelang);
compare_tensors("dv2", dv2_ref, dv2_tilelang)) so the test actually validates
outputs, or if they must remain disabled, add a clear inline comment explaining
why (e.g., non-determinism, known numerical drift) and add an alternative
lightweight assertion (shape/dtype checks) to ensure the test still provides
basic validation.
- Around line 311-315: main() is passing TileLang types (e.g., T.bfloat16,
T.float32) into run_test where run_test calls getattr(torch, dtype) expecting
dtype to be a string; change the arguments passed to run_test for input_dtype,
output_dtype, accum_dtype, gate_dtype, and state_dtype to be dtype name strings
(e.g., "bfloat16", "float32") instead of TileLang types so getattr(torch, dtype)
works correctly, or alternatively convert TileLang types to their name strings
before calling getattr inside run_test (refer to the run_test call site and the
parameter names input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype).

In @examples/KDA/chunk_delta_h_fwd.py:
- Line 58: BS is computed with floor division (BS = S // chunk_size) which can
mismatch the kernel's ceil-div iteration (T.ceildiv(S, block_S)) and cause
out-of-bounds writes; change the computation to use ceiling division (e.g., BS =
T.ceildiv(S, chunk_size)) or add an explicit validation that S % chunk_size == 0
and raise/handle the error; update references where BS is used (and any code
expecting exact divisibility) to match the chosen approach.
- Around line 294-315: The bug is a dtype type mismatch: main() passes TileLang
dtype objects (e.g., T.float16/T.float32) to run_test, but run_test expects
string names and uses getattr(torch, input_dtype); fix by changing the dtype
arguments in main() to string names (e.g., "float16", "float32") for
input_dtype, output_dtype, accum_dtype, gate_dtype, and state_dtype so
getattr(torch, ...) works as intended; alternatively update run_test to accept
TileLang dtype objects by converting them to strings before calling getattr
(detect non-str and map to their name) but the quickest fix is to replace
T.float16/T.float32 in main() with the corresponding string names.
- Around line 272-282: The benchmark call to do_bench uses the scalar gate
argument g=G but the correctness test uses the vector gate gk=G; update the
do_bench invocation for chunk_gated_delta_rule_fwd_h to pass gk=G (instead of
g=G) so the benchmark exercises the same gate configuration as the correctness
test (also scan for other do_bench calls referencing g vs gk and make them
consistent with the reference function signature).

In @examples/KDA/chunk_o.py:
- Around line 132-134: The loop uses the wrong inner index variable and shape:
in the DK loop the parallel iterator is "for i_s, i_v in T.Parallel(block_S,
block_DV)" but Q_shared and GQ_shared are shaped (block_S, block_DK) and the DK
index is named i_k; change the parallel iterator to iterate over block_DK (e.g.,
T.Parallel(block_S, block_DK)) and rename the inner loop variable to i_k, then
replace occurrences of i_v inside the body (Q_shared[i_s, i_v], GQ_shared[i_s,
i_v], GK_shared[i_s, i_v]) with i_k so indexing uses the DK dimension
consistently.
- Around line 247-250: The call to run_test from main() passes TileLang type
objects (e.g., T.bfloat16, T.float32) into parameters
input_dtype/output_dtype/accum_dtype/gate_dtype while run_test expects strings
(it calls getattr(torch, dtype)); change the arguments in main() to pass dtype
names as strings (e.g., "bfloat16", "float32") for input_dtype, output_dtype,
accum_dtype and gate_dtype so getattr(torch, dtype) works as intended, or
alternatively update run_test to accept TileLang types and convert them to the
corresponding torch dtype names before calling getattr.

In @examples/KDA/wy_fast_bwd.py:
- Around line 310-314: The correctness checks are disabled because the
compare_tensors calls are commented out; either re-enable those calls (uncomment
compare_tensors for dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang,
dg_tilelang to validate against dA_ref, dk_ref, dv_ref, dbeta_ref, dg_ref) or
add a runtime flag (e.g., validate=True/--validate) that conditionally runs
compare_tensors so CI can enable validation while benchmarks can disable it;
alternatively move these compare_tensors calls into a dedicated test that
imports the same variables and runs validation.
- Around line 341-345: main() is passing TileLang type objects like T.float32
into run_test, but run_test uses getattr(torch, input_dtype) expecting a string;
change either main() or run_test so run_test receives a torch dtype or a string:
in main(), pass a string like "float32" (e.g., input_dtype="float32") or pass
torch.float32 directly, or update run_test to detect TileLang types (e.g.,
T.float32) and convert them to torch dtypes by using the TileLang type's name or
mapping (e.g., input_dtype = getattr(torch, input_dtype.name) or via a small
dict), making sure to apply the same fix for output_dtype, accum_dtype,
gate_dtype, and state_dtype and to locate the conversion logic near the run_test
signature and any getattr(torch, ...) calls.
- Around line 295-307: The run_test call passes raw string dtype names into
tilelang_wy_fast_bwd, but tilelang_wy_fast_bwd expects TileLang dtype objects
(e.g., T.float32); before calling tilelang_wy_fast_bwd in run_test, convert each
string dtype parameter to the corresponding TileLang dtype using getattr(T,
input_dtype) and likewise for output_dtype, accum_dtype, gate_dtype, and
state_dtype so the arguments passed to tilelang_wy_fast_bwd are TileLang dtype
objects instead of strings.

In @examples/KDA/wy_fast.py:
- Around line 54-76: The function tilelang_recompute_w_u_fwd currently accepts
the flag use_qg and declares an output QG but never populates it; either remove
the unused parameter and QG output from the function signature and from the
@tilelang.jit out_idx list, or implement the QG computation/write-path when
use_qg is true (e.g., produce and store QG inside the kernel) and keep the flag;
if you intentionally defer implementation, add a clear TODO comment in
tilelang_recompute_w_u_fwd explaining that QG is unimplemented and ensure the
@tilelang.jit out_idx and any downstream callers are kept consistent.
- Around line 179-180: The variable use_qg is mistakenly assigned as a
one-element tuple (False,) rather than a boolean, causing wrong truthy checks;
change the assignment of use_qg to the boolean False (remove the trailing comma)
so any conditionals or kernel parameters expecting a bool receive a proper
boolean value, and verify usages of use_qg (and related flags like use_kg)
continue to treat it as a boolean.
- Around line 237-240: The test is passing TileLang dtype constants (e.g.,
T.bfloat16) to run_test which expects string dtype names (it uses getattr(torch,
input_dtype)), causing an AttributeError; update the main() invocation to pass
string dtype names like "bfloat16", "float32" (e.g., replace
input_dtype=T.bfloat16 with input_dtype="bfloat16", output_dtype="bfloat16",
gate_dtype="float32", accum_dtype="float32") so getattr(torch, ...) resolves
correctly to torch dtypes for the run_test call.
🧹 Nitpick comments (25)
examples/KDA/chunk_bwd_gla_dA.py (7)

4-4: Remove unused sys import.

The sys module is imported but never used. The noqa: F401 directive is also unnecessary.

Suggested fix
-import sys  # noqa: F401

14-25: Unused chunk_size parameter.

The chunk_size parameter is declared but never used. If it's kept for API consistency with other functions, consider prefixing with underscore (_chunk_size) to indicate it's intentionally unused.

Suggested fix
 def prepare_input(
     B,
     S,
     H,
     DV,
-    chunk_size,
+    _chunk_size,
     input_dtype,
     do_dtype,
 ):

28-37: Unused DV parameter.

Similar to prepare_input, the DV parameter is unused. Consider removing or prefixing with underscore.

Suggested fix
 def prepare_output(
     B,
     S,
     H,
-    DV,
+    _DV,
     chunk_size,
     d_type,
 ):

83-91: Remove commented-out code.

Dead code (commented allocation and layout annotations) should be removed to improve maintainability. If these are needed for reference, consider documenting them elsewhere or in a commit message.

Suggested fix
             dA_fragment = T.alloc_fragment((block_S, block_S), dtype=T.float32)
-            # dA_shared = T.alloc_shared((block_S, block_S), dtype=da_dtype)
-
-            # T.annotate_layout(
-            #     {
-            #         do_shared: tilelang.layout.make_swizzled_layout(do_shared),
-            #         V_shared: tilelang.layout.make_swizzled_layout(V_shared),
-            #     }
-            # )
-            # T.use_swizzle(10)

             T.clear(dA_fragment)

99-101: Translate comment to English and remove commented-out code.

The comment "下三角矩阵" should be translated to "lower triangular mask" for consistency. Also remove the commented-out copy on line 100.

Suggested fix
             for i_s1, i_s2 in T.Parallel(block_S, block_S):
-                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # 下三角矩阵
-            # T.copy(dA_fragment, dA_shared)
+                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # lower triangular mask
             T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, 0:block_S])

106-117: Unused DK parameter.

The DK parameter is declared but never used in run_test. Remove it or prefix with underscore if kept for signature consistency.

Suggested fix
 def run_test(
     B,
     S,
     H,
-    DK,
     DV,
     scale,
     input_dtype,
     do_dtype,
     da_dtype,
     chunk_size,
 ):

And update the call in main():

     run_test(
         B=1,
         S=1024 * 8,  # 32768
         H=64,
-        DK=128,
         DV=128,

119-119: Remove debug print statement.

This debug print should be removed before merging, or converted to proper logging if needed for diagnostics.

Suggested fix
-    print(DO.dtype, V_new.dtype)
examples/KDA/chunk_delta_h_fwd.py (4)

3-3: Remove unused sys import.

The sys import and noqa directive are no longer needed since the sys.path.insert on line 10 is commented out. This is also flagged by static analysis (RUF100).

Suggested fix
-import sys  # noqa: F401
 import tilelang

23-45: Unused parameters output_dtype and accum_dtype.

These parameters are defined but never used in the function body. If they're placeholders for future use, consider documenting this intent; otherwise, remove them to avoid confusion.

Suggested fix (if not needed)
 def prepare_input(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     input_dtype,
-    output_dtype,
-    accum_dtype,
     gate_dtype,
 ):

98-98: Unused block_DK parameter in kernel configuration.

The block_DK parameter is included in the autotuning configs but never used in the kernel. The kernel allocates tensors with full DK dimension (e.g., b_h_shared at line 129, K_shared at line 137) rather than tiling over it.

If tiling over DK is intended for performance or memory optimization, the kernel logic needs updating. Otherwise, consider removing block_DK from get_configs() and the function signature to reduce the autotuning search space (currently 54 configs, would be 18 without block_DK).


213-216: Unused kernel config parameters in run_test.

The parameters block_DK, block_DV, threads, and num_stages are accepted but never used since the @autotune decorator automatically selects configurations.

If manual config selection is needed for debugging, consider adding support to bypass autotuning. Otherwise, remove these parameters to avoid confusion.

examples/KDA/wy_fast_bwd.py (5)

3-3: Remove unused sys import.

The sys module is imported but never used in this file. The noqa: F401 directive is also unnecessary since there's no valid reason to suppress the warning.

Suggested fix
-import sys  # noqa: F401
-

17-42: Unused function parameters and minor efficiency improvement.

Parameters output_dtype, accum_dtype, and state_dtype are declared but never used. Consider prefixing with underscore or removing if they serve no purpose.

Also, creating tensors directly on CUDA device is slightly more efficient than creating on CPU and then moving:

Suggested improvement
 def prepare_input(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     input_dtype,
-    output_dtype,
-    accum_dtype,
+    _output_dtype,
+    _accum_dtype,
     gate_dtype,
-    state_dtype,
+    _state_dtype,
 ):
     BS = chunk_size
-    K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
+    K = torch.randn(B, S, H, DK, dtype=input_dtype, device="cuda")

44-61: Unused parameters in prepare_output.

Similar to prepare_input, parameters chunk_size and state_dtype are declared but unused. Consider prefixing with underscore for consistency.


154-170: Consider removing commented-out code.

Lines 154, 165, and 170 contain commented-out allocations and a disabled swizzle call. If these are no longer needed, consider removing them to keep the code clean.


277-279: Unused output tensors from prepare_output.

The tensors allocated here (dk_tilelang, dv_tilelang, etc.) are immediately overwritten by the kernel call at line 308. This allocation serves no purpose since the autotuned kernel returns new tensors via out_idx.

Suggested fix
-    dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang, dA_tilelang = prepare_output(
-        B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
-    )
-
examples/KDA/wy_fast.py (2)

15-22: Unused output_dtype parameter.

The output_dtype parameter is accepted but never used in this function. If it's for API consistency with other functions, consider prefixing with underscore or adding a comment.


174-177: Unused function parameters in run_test.

The parameters block_DK, block_DV, threads, and num_stages are accepted but never passed to the kernel. The kernel relies on autotuning to select these values. Either remove these parameters or pass them to override autotuning.

examples/KDA/chunk_delta_bwd.py (1)

265-267: Reference results are computed but never used.

dh_ref, dh0_ref, and dv2_ref are unpacked but never compared against TileLang outputs since the compare_tensors calls are commented out. Consider prefixing with underscore if intentionally unused, or enable the validation.

examples/KDA/chunk_bwd_dv.py (2)

31-40: Unused chunk_size parameter in prepare_output.

The chunk_size parameter is accepted but not used. Consider removing it or prefixing with underscore if kept for API consistency.


105-116: Unused scale parameter in run_test.

The scale parameter is accepted but never used in the function. It's passed in main() but has no effect.

♻️ Proposed fix - remove unused parameter
 def run_test(
     B,
     S,
     H,
     DK,
     DV,
-    scale,
     input_dtype,
     do_dtype,
     output_dtype,
     chunk_size,
 ):

And in main():

     run_test(
         B=1,
         S=1024 * 8,  # 32768
         H=64,
         DK=128,
         DV=128,
-        scale=1.0,
         input_dtype="bfloat16",
examples/KDA/chunk_bwd_dqkwg.py (1)

212-219: Many unused parameters in run_test.

Parameters use_gk, use_initial_state, store_final_state, save_new_value, block_DK, block_DV, threads, and num_stages are accepted but never used. Consider removing them or using them to configure the kernel/test behavior.

examples/KDA/chunk_o.py (3)

142-146: Consider removing or translating Chinese comment.

The comment "改成下面的代码为什么就错了" (translated: "Why does it fail when changed to the code below?") should either be removed, translated to English, or replaced with a proper explanation for maintainability.


161-183: Duplicate do_bench implementation.

This file defines its own do_bench function while other files (wy_fast.py, chunk_delta_bwd.py, etc.) import it from test_utils. Consider using the shared implementation for consistency.

♻️ Proposed fix
 from FLA_KDA.fla_chunk_o import chunk_gla_fwd_o_gk
-from test_utils import compare_tensors
+from test_utils import compare_tensors, do_bench

...

-def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
-    """
-    Do benchmark for a function.
-    """
-    start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
-    end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
-    for _ in range(warmup):
-        fn(*args, **kwargs)
-
-    torch.cuda.synchronize()
-    for i in range(rep):
-        start_event[i].record()
-        fn(*args, **kwargs)
-        end_event[i].record()
-    torch.cuda.synchronize()
-
-    # Record clocks
-    times = torch.tensor(
-        [s.elapsed_time(e) for s, e in zip(start_event, end_event)],
-        dtype=torch.float,
-    )
-
-    return times.mean().item()

197-200: Unused parameters in run_test.

Parameters block_DK, block_DV, threads, and num_stages are accepted but not used. The kernel relies on autotuning. Consider removing these or passing them to override autotuning behavior.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71f0fc9 and 68bfdc1.

📒 Files selected for processing (8)
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/chunk_bwd_gla_dA.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_o.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/wy_fast_bwd.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_h_fwd.py
🧬 Code graph analysis (2)
examples/KDA/wy_fast.py (3)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
  • recompute_w_u_fwd (210-254)
examples/KDA/test_utils.py (1)
  • compare_tensors (43-83)
tilelang/language/allocate.py (2)
  • alloc_shared (39-54)
  • alloc_fragment (71-82)
examples/KDA/chunk_delta_h_fwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
  • chunk_gated_delta_rule_fwd_h (470-521)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
🪛 Ruff (0.14.10)
examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


43-43: Unused function argument: DV

(ARG001)


44-44: Unused function argument: chunk_size

(ARG001)


45-45: Unused function argument: qk_dtype

(ARG001)


212-212: Unused function argument: use_gk

(ARG001)


213-213: Unused function argument: use_initial_state

(ARG001)


214-214: Unused function argument: store_final_state

(ARG001)


215-215: Unused function argument: save_new_value

(ARG001)


216-216: Unused function argument: block_DK

(ARG001)


217-217: Unused function argument: block_DV

(ARG001)


218-218: Unused function argument: threads

(ARG001)


219-219: Unused function argument: num_stages

(ARG001)


223-223: Unpacked variable dq_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dw_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


223-223: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dw is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


240-240: Unpacked variable dg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused function argument: output_dtype

(ARG001)


29-29: Unused function argument: accum_dtype

(ARG001)


31-31: Unused function argument: state_dtype

(ARG001)


58-58: Unused function argument: gate_dtype

(ARG001)


128-128: Unused function argument: h0

(ARG001)


239-239: Unused function argument: block_DV

(ARG001)


240-240: Unused function argument: threads

(ARG001)


241-241: Unused function argument: num_stages

(ARG001)


242-242: Unused function argument: use_torch

(ARG001)


265-265: Unpacked variable dh_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


265-265: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


265-265: Unpacked variable dv2_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


288-288: Unpacked variable dh_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


288-288: Unpacked variable dh0_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


288-288: Unpacked variable dv2_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused function argument: output_dtype

(ARG001)


68-68: Unused function argument: use_qg

(ARG001)


94-94: Unused function argument: QG

(ARG001)


174-174: Unused function argument: block_DK

(ARG001)


175-175: Unused function argument: block_DV

(ARG001)


176-176: Unused function argument: threads

(ARG001)


177-177: Unused function argument: num_stages

(ARG001)


184-184: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


185-185: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused function argument: output_dtype

(ARG001)


32-32: Unused function argument: accum_dtype

(ARG001)


98-98: Unused function argument: block_DK

(ARG001)


213-213: Unused function argument: block_DK

(ARG001)


214-214: Unused function argument: block_DV

(ARG001)


215-215: Unused function argument: threads

(ARG001)


216-216: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_o.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused function argument: output_dtype

(ARG001)


23-23: Unused function argument: accum_dtype

(ARG001)


39-39: Unused function argument: DK

(ARG001)


41-41: Unused function argument: chunk_size

(ARG001)


44-44: Ambiguous variable name: O

(E741)


99-99: Ambiguous variable name: O

(E741)


197-197: Unused function argument: block_DK

(ARG001)


198-198: Unused function argument: block_DV

(ARG001)


199-199: Unused function argument: threads

(ARG001)


200-200: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_bwd_gla_dA.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


19-19: Unused function argument: chunk_size

(ARG001)


32-32: Unused function argument: DV

(ARG001)


110-110: Unused function argument: DK

(ARG001)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


25-25: Unused function argument: output_dtype

(ARG001)


26-26: Unused function argument: accum_dtype

(ARG001)


28-28: Unused function argument: state_dtype

(ARG001)


50-50: Unused function argument: chunk_size

(ARG001)


53-53: Unused function argument: state_dtype

(ARG001)


92-92: Unused function argument: state_dtype

(ARG001)


258-258: Unused function argument: block_DK

(ARG001)


259-259: Unused function argument: block_DV

(ARG001)


260-260: Unused function argument: threads

(ARG001)


261-261: Unused function argument: num_stages

(ARG001)


282-282: Unpacked variable dk_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dv_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dbeta_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dg_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


282-282: Unpacked variable dA_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dA_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dv_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dbeta_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


308-308: Unpacked variable dg_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_bwd_dv.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


36-36: Unused function argument: chunk_size

(ARG001)


111-111: Unused function argument: scale

(ARG001)

🔇 Additional comments (19)
examples/KDA/chunk_bwd_gla_dA.py (2)

80-81: Verify dtype for V_shared allocation.

do_shared uses do_dtype (correct), but V_shared also uses do_dtype while the source tensor V has input_dtype. If input_dtype and do_dtype differ, an implicit cast occurs during the copy. This may be intentional for computation precision alignment, but consider using input_dtype or adding a comment explaining the design choice.

Suggested fix if intentional cast
             do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
-            V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
+            V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)  # Cast V to do_dtype for GEMM precision

142-158: LGTM!

The main entry point is well-structured with reasonable default parameters for testing.

examples/KDA/chunk_delta_h_fwd.py (2)

65-75: LGTM!

The configuration generation for autotuning is well-structured with reasonable parameter ranges.


158-192: Kernel logic appears correct.

The pipelined loop correctly implements the chunked forward pass:

  • Stores the previous state to h before computing the current chunk
  • Applies the gated recurrence with exp2 (matching use_exp2=True in reference)
  • Accumulates via GEMM operations

The commented-out swizzle annotations (lines 140-150) are appropriately disabled per the commit message "Remove redundant swizzle as it can be done automatically."

examples/KDA/wy_fast_bwd.py (2)

63-74: LGTM!

The autotuning configuration space is well-defined with reasonable parameter ranges.


220-220: This concern is based on an incorrect assumption about the default value of clear_accum. According to the T.gemm definition in tilelang/language/gemm_op.py, the default value of clear_accum is False, not True. Therefore, line 220 does not inadvertently clear dA_fragment. Both line 195 (DK loop) and line 220 (DV loop) accumulate contributions into dA_fragment as intended, with no bug present.

Likely an incorrect or invalid review comment.

examples/KDA/wy_fast.py (2)

1-12: LGTM - Imports and setup are appropriate.

The imports and random seed initialization follow the pattern established across the KDA examples.


143-158: Kernel logic for W and KG computation is correct.

The pipelined loop correctly computes W with gating (T.exp2) and conditionally computes KG when use_kg is enabled. The use of shared memory and fragments follows TileLang patterns.

examples/KDA/chunk_delta_bwd.py (4)

1-17: LGTM - Imports and setup are appropriate.

Imports include necessary TileLang, PyTorch, and reference implementations. Random seed is set for reproducibility.


171-176: Good handling of final state gradient initialization.

The conditional initialization of b_dh_fragment from dht or clearing it based on use_final_state_gradient is correct.


177-218: Backward pass iteration logic is well-structured.

The reverse iteration (i_s_inv) correctly processes chunks in reverse order for the backward pass. The gradient updates for dh, dv, and final dh0 follow the expected pattern.


125-136: Review comment is incorrect.

The reviewer mistakenly referenced the forward kernel. In the FLA implementation, h0 is used in the forward pass (chunk_gated_delta_rule_fwd_kernel_h_blockdim64, lines 97-107) to initialize hidden state, not in the backward pass. The backward kernel (chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64) does not accept h0 as a parameter—it only computes and outputs dh0 (the gradient with respect to the initial state).

The tilelang implementation correctly matches this behavior: h0 is accepted as a parameter (for API consistency) but is not used during backward computation. This is the correct approach since the backward pass computes gradients, not reads the forward state.

examples/KDA/chunk_bwd_dv.py (4)

1-11: LGTM - Imports and setup follow established patterns.


94-95: Lower triangular mask implementation is correct.

The masking T.if_then_else(i_s1 >= i_s2, A_shared[i_s1, i_s2], 0.0) correctly implements a lower triangular mask where elements above the diagonal are zeroed.


96-100: Pipelined DV computation with proper staging.

The pipelined loop over DV dimension with GEMM (transpose_A) and copy operations follows correct TileLang patterns.


140-152: Correct dtype passing pattern.

Unlike other files, this one correctly passes string dtypes ("bfloat16", "float32") to run_test, which is compatible with the getattr(torch, dtype) usage.

examples/KDA/chunk_bwd_dqkwg.py (2)

1-11: LGTM - Imports and setup follow established patterns.


106-198: Complex kernel logic with proper gradient computation.

The kernel correctly handles:

  • Chunk-wise gradient computation with proper indexing
  • Gating with T.exp2 for GK
  • Reduction operations for dgkn_fragment
  • Conditional accumulation for the last token in chunk

The mathematical operations appear to follow the expected backward pass pattern.

examples/KDA/chunk_o.py (1)

1-11: LGTM - Imports and setup follow established patterns.

Comment on lines 38 to 52
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
qk_dtype,
gate_dtype,
):
dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

qk_dtype parameter is ignored - outputs use hardcoded torch.float32.

The qk_dtype parameter is accepted but dq and dk are always allocated with dtype=torch.float32 regardless of the passed value. Either use the parameter or remove it.

♻️ Proposed fix
 def prepare_output(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     qk_dtype,
     gate_dtype,
 ):
-    dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
-    dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
+    dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
+    dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
     dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     return dq, dk, dw, dg
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
qk_dtype,
gate_dtype,
):
dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
qk_dtype,
gate_dtype,
):
dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
🧰 Tools
🪛 Ruff (0.14.10)

43-43: Unused function argument: DV

(ARG001)


44-44: Unused function argument: chunk_size

(ARG001)


45-45: Unused function argument: qk_dtype

(ARG001)

🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_dqkwg.py around lines 38 - 52, The prepare_output
function ignores the qk_dtype parameter by hardcoding dq and dk to
torch.float32; update dq and dk allocations in prepare_output to use
dtype=qk_dtype (or convert them to qk_dtype after creation) so the passed
qk_dtype controls their dtype, leaving dw and dg using gate_dtype as before;
locate prepare_output and replace the dtype=torch.float32 occurrences for dq and
dk with dtype=qk_dtype.

Comment on lines +341 to +345
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: T.float32 is not a valid argument for getattr(torch, ...).

The main() function passes TileLang types (e.g., T.float32) to run_test, but run_test uses getattr(torch, input_dtype) which expects a string like "float32". This will cause an AttributeError at runtime.

Suggested fix
     run_test(
         B=1,
         S=32768,
         H=8,
         DK=DK,
         DV=DV,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_dtype=T.float32,
output_dtype=T.float32,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype="float32",
output_dtype="float32",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
)
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast_bwd.py around lines 341 - 345, main() is passing
TileLang type objects like T.float32 into run_test, but run_test uses
getattr(torch, input_dtype) expecting a string; change either main() or run_test
so run_test receives a torch dtype or a string: in main(), pass a string like
"float32" (e.g., input_dtype="float32") or pass torch.float32 directly, or
update run_test to detect TileLang types (e.g., T.float32) and convert them to
torch dtypes by using the TileLang type's name or mapping (e.g., input_dtype =
getattr(torch, input_dtype.name) or via a small dict), making sure to apply the
same fix for output_dtype, accum_dtype, gate_dtype, and state_dtype and to
locate the conversion logic near the run_test signature and any getattr(torch,
...) calls.

Comment on lines +54 to +76
@autotune(configs=get_configs(), warmup=3, rep=5)
@tilelang.jit(out_idx=[-4, -3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def tilelang_recompute_w_u_fwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
use_qg,
use_kg,
# kernel config
block_S=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused use_qg parameter - QG tensor is never populated.

The use_qg parameter and QG tensor are declared but the kernel never writes to QG. If QG computation is intentionally omitted, consider either:

  1. Removing the parameter and output tensor
  2. Adding a TODO comment explaining future implementation

The kernel signature declares QG as output but the body never computes or writes to it.

🧰 Tools
🪛 Ruff (0.14.10)

68-68: Unused function argument: use_qg

(ARG001)

🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 54 - 76, The function
tilelang_recompute_w_u_fwd currently accepts the flag use_qg and declares an
output QG but never populates it; either remove the unused parameter and QG
output from the function signature and from the @tilelang.jit out_idx list, or
implement the QG computation/write-path when use_qg is true (e.g., produce and
store QG inside the kernel) and keep the flag; if you intentionally defer
implementation, add a clear TODO comment in tilelang_recompute_w_u_fwd
explaining that QG is unimplemented and ensure the @tilelang.jit out_idx and any
downstream callers are kept consistent.

Comment on lines +237 to +240
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Type mismatch: passing TileLang types where string dtypes are expected.

run_test expects string dtype parameters (e.g., "bfloat16") based on getattr(torch, input_dtype) usage, but main() passes TileLang types like T.bfloat16. This will cause an AttributeError since getattr(torch, T.bfloat16) is invalid.

🐛 Proposed fix - use string dtypes
     run_test(
         B=1,
         S=8192,
         H=64,
         DK=128,
         DV=128,
         chunk_size=64,
-        input_dtype=T.bfloat16,
-        output_dtype=T.bfloat16,
-        gate_dtype=T.float32,
-        accum_dtype=T.float32,
+        input_dtype="bfloat16",
+        output_dtype="bfloat16",
+        gate_dtype="float32",
+        accum_dtype="float32",
         block_DK=64,
         block_DV=32,
         threads=128,
         num_stages=3,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
run_test(
B=1,
S=8192,
H=64,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3,
)
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 237 - 240, The test is passing TileLang
dtype constants (e.g., T.bfloat16) to run_test which expects string dtype names
(it uses getattr(torch, input_dtype)), causing an AttributeError; update the
main() invocation to pass string dtype names like "bfloat16", "float32" (e.g.,
replace input_dtype=T.bfloat16 with input_dtype="bfloat16",
output_dtype="bfloat16", gate_dtype="float32", accum_dtype="float32") so
getattr(torch, ...) resolves correctly to torch dtypes for the run_test call.

@wfloveiu
Copy link
Contributor Author

Hi! Thank you for contribution! A small question: In your implementation of KDA kernels, there are two operators chunk_bwd_intra.py and chunk_bwd_intra_op.py that share most part of code. It seems that they have totally same functionality. Could you check them and only keep one operator?

Thank you for your review, i will check and fix them

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@examples/KDA/chunk_bwd_intra.py`:
- Around line 478-493: main() is passing TileLang types (e.g. T.float32) into
run_test(), but run_test() (which uses getattr(torch, input_dtype) inside the
run_test function) expects string names like "float32", causing an
AttributeError; fix by changing the arguments in main() to string dtype names
(e.g. "float32" for input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype)
so run_test() can resolve them via getattr(torch, ...). Ensure the update
references the main() call site and the run_test(...) parameter names so all
dtype args are passed as strings.

In `@examples/KDA/chunk_delta_h_fwd.py`:
- Around line 178-182: The access GK[bb, (i_s + 1) * block_S - 1, bh, :] can go
out of bounds when S is not divisible by block_S; either clamp the computed
index or enforce divisibility. Fix option A: replace the raw index with a
clamped value (e.g., end_idx = min((i_s + 1) * block_S - 1, S - 1)) and use
GK[bb, end_idx, bh, :] when setting GK_last_shared in the use_gk branch; or
option B: add a runtime assertion in run_test or prepare_input that S % block_S
== 0 to guarantee no partial final block. Ensure references to use_gk,
GK_last_shared, i_s, block_S and S are adjusted accordingly.
♻️ Duplicate comments (7)
examples/KDA/wy_fast.py (3)

68-68: Unused use_qg parameter and QG output tensor.

The use_qg parameter is accepted but never used, and the QG tensor is declared as output but never written to. Either implement the QG computation path or remove these unused declarations.

Also applies to: 94-94


222-222: Typo: "tritron" should be "triton".

-    print("tritron time:", triton_time)
+    print("triton time:", triton_time)

237-240: Type mismatch: passing TileLang types where string dtypes are expected.

run_test uses getattr(torch, input_dtype) which expects string arguments like "bfloat16", but main() passes TileLang type objects (T.bfloat16, T.float32). This will cause an AttributeError at runtime.

     run_test(
         B=1,
         S=8192,
         H=64,
         DK=128,
         DV=128,
         chunk_size=64,
-        input_dtype=T.bfloat16,
-        output_dtype=T.bfloat16,
-        gate_dtype=T.float32,
-        accum_dtype=T.float32,
+        input_dtype="bfloat16",
+        output_dtype="bfloat16",
+        gate_dtype="float32",
+        accum_dtype="float32",
         block_DK=64,
         block_DV=32,
         threads=128,
         num_stages=3,
     )
examples/KDA/chunk_delta_bwd.py (1)

121-136: Unused h0 parameter in kernel signature.

The kernel declares h0 but never uses it. Per the use_initial_state logic at line 217-218, only dh0 (the gradient w.r.t. initial state) is written. If h0 is not needed for this backward computation, consider removing it from the kernel signature to avoid confusion.

examples/KDA/chunk_bwd_dqkwg.py (2)

47-48: dq and dk always use torch.float32 regardless of parameters.

The prepare_output function ignores its parameters and hardcodes torch.float32 for dq and dk. While this may be intentional (gradients often need higher precision), it contradicts the function signature which accepts gate_dtype.


122-123: Gate buffers G_shared and Gn_shared allocate with input_dtype instead of gate_dtype.

These buffers store gate values (sourced from the G tensor at line 95 which uses gate_dtype), but are allocated with input_dtype. When input_dtype differs from gate_dtype (e.g., bfloat16 vs float32), this causes precision loss and dtype mismatch.

Suggested fix
-            G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)  # chunk G
-            Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype)  # chunk last token G
+            G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype)  # chunk G
+            Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype)  # chunk last token G
examples/KDA/wy_fast_bwd.py (1)

332-351: Critical: main() passes TileLang types but run_test expects strings.

main() passes TileLang type objects (T.float32) to run_test, but run_test uses getattr(torch, input_dtype) at lines 270-274 which expects string arguments like "float32". This will cause an AttributeError at runtime.

🐛 Proposed fix
 def main():
     DK = 128
     DV = 128
     run_test(
         B=1,
         S=32768,
         H=8,
         DK=DK,
         DV=DV,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
         block_DK=32,
         block_DV=32,
         threads=128,
         num_stages=0,
     )
🧹 Nitpick comments (16)
examples/KDA/wy_fast.py (5)

6-6: Remove unused sys import.

The sys module is imported but never used, and the noqa directive is unnecessary.

-import sys  # noqa: F401

15-22: Unused output_dtype parameter.

The output_dtype parameter is declared but never used in the function body. Consider removing it or prefixing with underscore if reserved for future use.

-def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
+def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, gate_dtype=torch.float32):

112-123: Remove or document commented-out layout annotations.

This block of commented code for swizzled layouts appears to be either dead code or work-in-progress. Consider removing it to reduce clutter, or add a TODO/NOTE explaining why it's retained.


174-177: Unused kernel configuration parameters.

The block_DK, block_DV, threads, and num_stages parameters are accepted but never passed to the kernel (lines 202-216). If these are intended for manual configuration (bypassing autotune), pass them to the kernel call. Otherwise, remove them from the function signature.

 def run_test(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     input_dtype,
     output_dtype,
     gate_dtype,
     accum_dtype,
-    block_DK,
-    block_DV,
-    threads,
-    num_stages,
 ):

184-185: Prefix unused unpacked variables with underscore.

QG_ref and QG_tilelang are unpacked but never used. Prefix them with underscore to indicate intentional disuse.

-    W_ref, U_ref, QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
-    W_tilelang, U_tilelang, QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
+    W_ref, U_ref, _QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
+    W_tilelang, U_tilelang, _QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))
examples/KDA/chunk_delta_h_fwd.py (4)

3-3: Remove unused sys import.

The sys module is imported but never used (the sys.path.insert on line 10 is commented out). The # noqa: F401 directive is also unnecessary.

Suggested fix
-import sys  # noqa: F401
 import tilelang

23-45: Unused parameters output_dtype and accum_dtype.

These parameters are declared but never used in the function body. Consider removing them if not needed for API consistency, or prefixing with underscore to indicate intentional non-use.

Suggested fix
 def prepare_input(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     input_dtype,
-    output_dtype,
-    accum_dtype,
+    _output_dtype,
+    _accum_dtype,
     gate_dtype,
 ):

98-98: Unused block_DK parameter in autotuned kernel.

The block_DK parameter is included in autotune configs (line 68) but is never used in the kernel—DK is used directly for all allocations and operations. Either remove block_DK from configs, or use it to tile the DK dimension for potential performance gains.


213-216: Unused kernel configuration parameters.

The parameters block_DK, block_DV, threads, and num_stages are accepted but never used—the kernel relies on autotuning. Consider removing them from the function signature or passing them to the kernel call if manual config override is desired.

Suggested fix
 def run_test(
     B,
     S,
     H,
     DK,
     DV,
     input_dtype,
     output_dtype,
     accum_dtype,
     gate_dtype,
     state_dtype,
     chunk_size,
     use_gk=True,
     use_initial_state=True,
     store_final_state=True,
     save_new_value=True,
-    block_DK=64,
-    block_DV=32,
-    threads=128,
-    num_stages=0,
 ):
examples/KDA/chunk_o.py (2)

197-231: Unused kernel configuration parameters in run_test.

The parameters block_DK, block_DV, threads, and num_stages are accepted by run_test but never passed to tilelang_chunk_fwd_o. The kernel is invoked at line 217-230 without these configuration options, meaning the autotuner will select them instead. If this is intentional (relying on autotuning), consider removing these parameters from run_test's signature to avoid confusion. If manual configuration was intended, these should be passed to the kernel.

♻️ Option A: Remove unused parameters
 def run_test(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     input_dtype,
     output_dtype,
     accum_dtype,
     gate_dtype,
-    block_DK,
-    block_DV,
-    threads,
-    num_stages,
 ):
♻️ Option B: Pass parameters to kernel
     kernel = tilelang_chunk_fwd_o(
         B,
         S,
         H,
         DK,
         DV,
         input_dtype,
         output_dtype,
         accum_dtype,
         gate_dtype,
         chunk_size,
         scale,
         block_S,
+        block_DK=block_DK,
+        block_DV=block_DV,
+        threads=threads,
+        num_stages=num_stages,
     )

3-4: Remove unused sys import.

The sys module is imported but never used, and the noqa directive is unnecessary.

♻️ Proposed fix
-import sys  # noqa: F401
examples/KDA/chunk_delta_bwd.py (1)

3-3: Remove unused sys import.

♻️ Proposed fix
-import sys  # noqa: F401
examples/KDA/chunk_bwd_dqkwg.py (1)

4-4: Remove unused sys import.

♻️ Proposed fix
-import sys  # noqa: F401
examples/KDA/wy_fast_bwd.py (1)

3-3: Remove unused sys import.

♻️ Proposed fix
-import sys  # noqa: F401
examples/KDA/chunk_bwd_intra.py (2)

130-133: Input parameter db is declared but unused in kernel.

The kernel signature includes db: T.Tensor(db_shape, dtype=input_dtype) but it's never read in the kernel body. The kernel writes to db2 (output) but doesn't use the input db. If db is meant to be accumulated into the output (similar to how dq and dk are added at lines 248-250 and 362-376), ensure it's added to the final result.

Note: Looking at line 448 in run_test, db_tilelang.sum(0).add_(db) is called after the kernel, which does add db. However, this is inconsistent with how dq and dk are handled inside the kernel.


3-3: Remove unused sys import.

♻️ Proposed fix
-import sys  # noqa: F401
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 68bfdc1 and f73f1a5.

📒 Files selected for processing (7)
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_intra.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_o.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_o.py
  • examples/KDA/chunk_delta_bwd.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_delta_bwd.py
🧬 Code graph analysis (2)
examples/KDA/chunk_delta_h_fwd.py (3)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
  • chunk_gated_delta_rule_fwd_h (470-521)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/test_utils.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_bwd_dqkwg.py (4)
tilelang/language/dtypes.py (1)
  • float32 (310-310)
tilelang/language/allocate.py (2)
  • alloc_fragment (71-82)
  • alloc_shared (39-54)
tilelang/language/loop.py (2)
  • Pipelined (97-134)
  • Parallel (13-72)
tilelang/language/reduce_op.py (1)
  • reduce_sum (144-166)
🪛 Ruff (0.14.11)
examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused function argument: output_dtype

(ARG001)


68-68: Unused function argument: use_qg

(ARG001)


94-94: Unused function argument: QG

(ARG001)


174-174: Unused function argument: block_DK

(ARG001)


175-175: Unused function argument: block_DV

(ARG001)


176-176: Unused function argument: threads

(ARG001)


177-177: Unused function argument: num_stages

(ARG001)


184-184: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


185-185: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


25-25: Unused function argument: output_dtype

(ARG001)


26-26: Unused function argument: accum_dtype

(ARG001)


28-28: Unused function argument: state_dtype

(ARG001)


50-50: Unused function argument: chunk_size

(ARG001)


53-53: Unused function argument: state_dtype

(ARG001)


92-92: Unused function argument: state_dtype

(ARG001)


258-258: Unused function argument: block_DK

(ARG001)


259-259: Unused function argument: block_DV

(ARG001)


260-260: Unused function argument: threads

(ARG001)


261-261: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_bwd_intra.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


29-29: Unused function argument: state_dtype

(ARG001)


56-56: Unused function argument: chunk_size

(ARG001)


60-60: Unused function argument: state_dtype

(ARG001)


95-95: Unused function argument: state_dtype

(ARG001)


132-132: Unused function argument: db

(ARG001)


396-396: Unused function argument: threads

(ARG001)


397-397: Unused function argument: num_stages

(ARG001)


398-398: Unused function argument: cu_seqlens

(ARG001)


399-399: Unused function argument: chunk_indices

(ARG001)

examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused function argument: output_dtype

(ARG001)


32-32: Unused function argument: accum_dtype

(ARG001)


98-98: Unused function argument: block_DK

(ARG001)


213-213: Unused function argument: block_DK

(ARG001)


214-214: Unused function argument: block_DV

(ARG001)


215-215: Unused function argument: threads

(ARG001)


216-216: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


43-43: Unused function argument: DV

(ARG001)


44-44: Unused function argument: chunk_size

(ARG001)


209-209: Unused function argument: qk_dtype

(ARG001)


211-211: Unused function argument: use_gk

(ARG001)


212-212: Unused function argument: use_initial_state

(ARG001)


213-213: Unused function argument: store_final_state

(ARG001)


214-214: Unused function argument: save_new_value

(ARG001)


215-215: Unused function argument: block_DK

(ARG001)


216-216: Unused function argument: block_DV

(ARG001)


217-217: Unused function argument: threads

(ARG001)


218-218: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_o.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused function argument: output_dtype

(ARG001)


23-23: Unused function argument: accum_dtype

(ARG001)


39-39: Unused function argument: DK

(ARG001)


41-41: Unused function argument: chunk_size

(ARG001)


44-44: Ambiguous variable name: O

(E741)


99-99: Ambiguous variable name: O

(E741)


197-197: Unused function argument: block_DK

(ARG001)


198-198: Unused function argument: block_DV

(ARG001)


199-199: Unused function argument: threads

(ARG001)


200-200: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused function argument: output_dtype

(ARG001)


29-29: Unused function argument: accum_dtype

(ARG001)


31-31: Unused function argument: state_dtype

(ARG001)


58-58: Unused function argument: gate_dtype

(ARG001)


128-128: Unused function argument: h0

(ARG001)


239-239: Unused function argument: block_DV

(ARG001)


240-240: Unused function argument: threads

(ARG001)


241-241: Unused function argument: num_stages

(ARG001)


242-242: Unused function argument: use_torch

(ARG001)

🔇 Additional comments (14)
examples/KDA/wy_fast.py (2)

25-39: LGTM!

The conditional tensor allocation based on use_qg and use_kg flags is clean and appropriate.


42-51: LGTM!

The configuration generator for autotuning is well-structured.

examples/KDA/chunk_delta_h_fwd.py (3)

48-62: LGTM!

The ceiling division on line 58 correctly matches the kernel iteration bounds, and the tensor allocations are appropriate.


65-75: LGTM!

Configuration generation for autotuning is correctly implemented.


295-316: LGTM!

The dtype values are correctly passed as strings. Note that the current test configuration (S=8192, chunk_size=64) ensures S is evenly divisible, avoiding the potential out-of-bounds issue identified in the kernel. Consider adding a test case with non-divisible values to verify the fix once applied.

examples/KDA/chunk_o.py (1)

1-11: LGTM on the core kernel implementation.

The indexing bug (using block_DV instead of block_DK for Q_shared/GK_shared/GQ_shared) from past reviews has been correctly fixed - line 132 now properly iterates over T.Parallel(block_S, block_DK) with i_k2. The dtype mismatch in main() has also been addressed with string arguments.

examples/KDA/chunk_delta_bwd.py (2)

298-300: LGTM - Validation is now enabled.

The compare_tensors calls are now active, addressing the previous review concern about disabled validation. The test will now properly verify correctness against the reference implementation.


303-325: LGTM - Dtype issue has been fixed.

The main() function now correctly passes string dtype names ("bfloat16", "float32") to run_test, which uses getattr(torch, ...) to resolve them. This addresses the previous review concern about type mismatch.

examples/KDA/chunk_bwd_dqkwg.py (2)

241-244: LGTM - Validation is now enabled.

The compare_tensors calls are active, properly validating dq, dk, dw, and dg against the reference implementation. This addresses the previous review concern.


264-284: LGTM - Dtype handling in main() is correct.

The main() function now passes string dtype names ("float32") to run_test, correctly aligning with the getattr(torch, ...) usage.

examples/KDA/wy_fast_bwd.py (2)

310-314: LGTM - Validation is now enabled.

The compare_tensors calls are active for all outputs (dA, dk, dv, dbeta, dg), properly verifying correctness against the reference implementation.


295-307: TileLang dtype handling is correct in the kernel. The tilelang_wy_fast_bwd function properly accepts dtype parameters and uses them consistently in T.Tensor, T.alloc_shared, and T.alloc_fragment declarations. The API signatures in TileLang's proxy layer support both string dtype names and TVM dtype objects, making the current implementation valid.

examples/KDA/chunk_bwd_intra.py (2)

455-458: LGTM - Validation is enabled and comprehensive.

The compare_tensors calls validate all outputs (dq, dk, db, dg) against the reference implementation. The post-processing at lines 448-453 (summing db_tilelang and applying cumsum to dg_tilelang) is correctly applied before comparison.


140-146: Complex grid decomposition looks correct.

The grid decomposition (i_k * NC + i_i, i_t, i_bh) properly handles the Cartesian product of K-blocks and sub-chunks, with correct index extraction. The sub-chunk index calculation at line 145 (i_ti = i_t * BT + i_i * BC) is mathematically sound.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@wfloveiu
Copy link
Contributor Author

Hi! Thank you for contribution! A small question: In your implementation of KDA kernels, there are two operators chunk_bwd_intra.py and chunk_bwd_intra_op.py that share most part of code. It seems that they have totally same functionality. Could you check them and only keep one operator?

Thank you for your review, i will check and fix them

I have commit the code. However, the pre-commit.ci show an error, I check it local which show no error.

tzj-fxz and others added 4 commits January 14, 2026 14:11
…d corresponding test case (tile-ai#1654)

* Add unroll loop functionality and corresponding test case

- Introduced a new `UnrollLoop` function in the transform module to unroll loops based on various configuration options.
- Added a test case in `test_tilelang_language_unroll.py` to validate the behavior of `T.unroll` with only the extent parameter, ensuring correct kernel generation with unroll pragmas.

* Refactor unroll kernel implementation and update test case

- Changed the kernel function in `test_tilelang_language_unroll.py` to use a new `unroll_kernel` function that compiles and returns the output tensor, improving clarity and structure.
- Updated the `OptimizeForTarget` function in `phase.py` to ensure the `UnrollLoop` transformation is applied correctly, maintaining consistency in optimization phases.

* lint fix

* lint fix
LeiWang1999 and others added 7 commits January 14, 2026 14:26
…nsitive LetStmt dependencies (tile-ai#1657)

* [Enhancement] Update global load/store functions for CUDA compatibility (tile-ai#1652)

Refactor the `ld_global_256` and `st_global_256` functions to support both CUDA versions above 12.9 and earlier versions. This change ensures that 256-bit loads and stores are handled correctly across different CUDA versions, improving performance and compatibility. The implementation now uses two 128-bit loads/stores for older versions, enhancing the robustness of the codebase.

* Update comments in global load/store functions for CUDA compatibility

Clarified comments in `ld_global_256` and `st_global_256` functions to indicate that the fallback for CUDA versions below 12.9 may have performance regressions. This change enhances code readability and provides better context for developers working with different CUDA versions.

* Update submodule and enhance LetStmt handling in inject_pipeline.cc

- Updated the TVM submodule to the latest commit.
- Improved the handling of LetStmt in the inject_pipeline.cc file to account for transitive dependencies on loop variables, ensuring correct variable substitution in rewritten blocks.
- Adjusted test_tilelang_issue_1263.py to remove unnecessary jit decorator and updated the kernel compilation process with specific pass configurations.

* lint fix

* revert tvm

* remove unused test

* test fix
…s operations (tile-ai#1663)

* [Enhancement] Update CallNode handling to include annotations in various operations

- Modified CallNode invocations in multiple files to ensure that annotations are passed correctly, enhancing the consistency and functionality of the codebase.
- Removed the "use_tma" annotation from AtomicAddNode and adjusted related calls to maintain expected behavior.
- Updated CUDA intrinsic dispatch functions to include annotations, improving compatibility and correctness in CUDA operations.

* lint fix
…#1664)

* [Fix] Refactor type hint extraction logic in DSLMutator for better clarity and handling of annotations

* [Refactor] Remove redundant tensor creation in loop layout tests and update kernel compilation parameters
* [Feat] Add tilelang autodd for delta debugging

* fix typos

* fix lint error

* fix typos

* fix lint error

* fix bugs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix codeview comments

* [Refactor] Move AutoDD detection to env module and update import logic

* Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation.
* Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage.
* Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability.

* [Documentation] Add AutoDD section to debug_tools_for_tilelang.md

* Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs.
* Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity.
* Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects.
* Included tips for effective usage and a reference to a complete example in the documentation.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: kurisu6912 <227995639+kurisu6912@users.noreply.github.com>
Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
@tzj-fxz
Copy link
Contributor

tzj-fxz commented Jan 14, 2026

Hi! Thank you for contribution! A small question: In your implementation of KDA kernels, there are two operators chunk_bwd_intra.py and chunk_bwd_intra_op.py that share most part of code. It seems that they have totally same functionality. Could you check them and only keep one operator?

Thank you for your review, i will check and fix them

I have commit the code. However, the pre-commit.ci show an error, I check it local which show no error.

Latest PR of TileLang has fixed this lint error. I also pull from the main branch and fix it.

tzj-fxz
tzj-fxz previously approved these changes Jan 14, 2026
* support cp.reduce.async.bulk.tensor and add test

* Refactor flash attention example by removing unnecessary layout annotations

* support swizzle layout for tma reduce

* auto swizzle for non-1d tma atomic add

* upd example and test

* lint

* typo

* add constraint for test

* Refactor CUDA data type mapping by moving the to_CUtensorMapDataType function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc.

* lint
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)

340-343: Inconsistent copy pattern: dk should be copied via dk_shared.

Line 343 copies directly from dk fragment to global memory, but dk_shared was populated at line 342 and not used. This is inconsistent with the dv pattern at lines 340-341 and differs from the varlen implementation at example_gqa_bwd_tma_reduce_varlen.py line 481 which correctly uses dk_shared.

Proposed fix
             T.copy(dv, dv_shared)
             T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
             T.copy(dk, dk_shared)
-            T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
+            T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
🤖 Fix all issues with AI agents
In `@examples/KDA/chunk_delta_bwd.py`:
- Around line 262-268: The reference tensors dh_ref, dh0_ref, dv2_ref are only
produced when use_gk is True but compare_tensors is called unconditionally;
either compute the reference via chunk_gated_delta_rule_bwd_dhu when use_gk is
False (call the same function path or its non-gk equivalent) or wrap the
compare_tensors calls in an if use_gk block so they only run when
dh_ref/dh0_ref/dv2_ref exist; update the logic around the
chunk_gated_delta_rule_bwd_dhu call and the subsequent compare_tensors calls to
ensure the reference variables are always defined before use or comparisons are
conditional on use_gk.

In `@examples/KDA/chunk_inter_solve_fused.py`:
- Around line 81-82: Add a runtime check at the start of the kernel function to
enforce the relationship between chunk_size and sub_chunk_size used elsewhere:
assert chunk_size == 4 * sub_chunk_size (or raise a ValueError with a clear
message). Locate where block_S = BS = chunk_size and BC = sub_chunk_size are set
(symbols block_S, BS, BC, chunk_size, sub_chunk_size) and insert the assertion
immediately after those assignments to prevent silent incorrect computation when
the kernel assumes exactly 4 sub-chunks.

In `@src/op/atomic_add.cc`:
- Around line 373-379: The code dereferences as_const_int(...) results for
mat_stride and mat_continuous without checking for nullptr, which can crash if
shared_tensor->shape elements are not compile-time constants; update the block
using as_const_int(shared_tensor->shape[dim - 2]) and
as_const_int(shared_tensor->shape[dim - 1]) to first capture each result into a
local pointer, verify they are non-null (and handle the non-constant case by
returning an error, throwing, or using a safe fallback), and only then assign
mat_stride/mat_continuous and call makeGemmABLayoutHopper; ensure error handling
is consistent with surrounding code paths that construct Layout when shape
values are unknown.
♻️ Duplicate comments (15)
examples/KDA/test_utils_kda.py (2)

10-17: Fix tensor-to-Python control flow in calc_sim().

This function has several issues that were flagged in past reviews but remain unaddressed:

  1. .data is deprecated; use .detach() instead
  2. denominator == 0 compares a tensor to a scalar, which works but is fragile
  3. Returns a tensor instead of a Python float, causing issues in assert_similar
Proposed fix
 def calc_sim(x, y, name="tensor"):
-    x, y = x.data.double(), y.data.double()
+    x, y = x.detach().double(), y.detach().double()
     denominator = (x * x + y * y).sum()
-    if denominator == 0:
+    if denominator.item() == 0:
         print_red_warning(f"{name} all zero")
-        return 1
+        return 1.0
     sim = 2 * (x * y).sum() / denominator
-    return sim
+    return float(sim.item())

27-30: Inverted mask logic in non-finite value comparison.

The comparison fills finite positions with 0 (since x_mask = torch.isfinite(x)), leaving non-finite values. To compare non-finite values while masking out finite ones, the masks should be inverted.

Proposed fix
-    if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
+    if not torch.isclose(x.masked_fill(~x_mask, 0), y.masked_fill(~y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
examples/KDA/chunk_bwd_dqkwg.py (2)

47-51: qk_dtype parameter is still ignored.

The prepare_output function accepts qk_dtype but dq and dk are hardcoded to torch.float32. This was flagged in a previous review. Either use the parameter or remove it.

Proposed fix
 def prepare_output(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
     gate_dtype,
 ):
-    dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
-    dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
+    # Note: Reference implementation uses float32 for dq/dk accumulation
+    dq = torch.empty(B, S, H, DK, dtype=torch.float32).cuda()
+    dk = torch.empty(B, S, H, DK, dtype=torch.float32).cuda()
     dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     return dq, dk, dw, dg

Note: Also consider using torch.empty instead of torch.randn for output tensors that will be completely overwritten.


122-123: Gate dtype mismatch for G_shared / Gn_shared.

These buffers store values copied from G (which has gate_dtype) but are allocated with input_dtype. This can silently change precision. A previous review marked this as addressed, but the code still shows input_dtype.

Proposed fix
-            G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)  # chunk G
-            Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype)  # chunk last token G
+            G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype)  # chunk G
+            Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype)  # chunk last token G
examples/KDA/wy_fast.py (3)

237-240: Type mismatch: passing TileLang types instead of strings.

run_test expects string dtype arguments (resolved via getattr(torch, input_dtype)), but main() passes T.bfloat16, T.float32. This causes an AttributeError at runtime. This was flagged in a previous review but remains unfixed.

Proposed fix
     run_test(
         B=1,
         S=8192,
         H=64,
         DK=128,
         DV=128,
         chunk_size=64,
-        input_dtype=T.bfloat16,
-        output_dtype=T.bfloat16,
-        gate_dtype=T.float32,
-        accum_dtype=T.float32,
+        input_dtype="bfloat16",
+        output_dtype="bfloat16",
+        gate_dtype="float32",
+        accum_dtype="float32",
         block_DK=64,
         block_DV=32,
         threads=128,
         num_stages=3,
     )

222-222: Typo: "tritron" should be "triton".

Proposed fix
-    print("tritron time:", triton_time)
+    print("triton time:", triton_time)

68-69: use_qg parameter declared but QG tensor never populated.

The use_qg parameter and QG output tensor are declared, but the kernel never computes or writes to QG. If QG computation is intentionally deferred, add a TODO comment. Otherwise, remove the unused parameter and output.

Also applies to: 94-94

examples/KDA/chunk_delta_h_fwd.py (1)

178-182: Potential out-of-bounds access when S is not divisible by block_S.

At line 180, GK[bb, (i_s + 1) * block_S - 1, bh, :] can exceed bounds when S is not evenly divisible by block_S. For example, if S=100 and block_S=64, when i_s=1, the index becomes 127, exceeding the valid range [0, 99].

This was flagged in a previous review but appears unaddressed.

Proposed fix using index clamping
                 if use_gk:
-                    T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared)  # block last token
+                    last_idx = T.min((i_s + 1) * block_S - 1, S - 1)
+                    T.copy(GK[bb, last_idx, bh, :], GK_last_shared)  # block last token

Alternatively, add a runtime assertion that S % chunk_size == 0.

examples/KDA/chunk_intra_token_parallel.py (2)

60-73: Kernel lacks scale parameter while reference implementation uses it.

The tilelang_chunk_kda_fwd_intra_token_parallel kernel signature doesn't include a scale parameter, but the reference implementation chunk_kda_fwd_intra_token_parallel (called at line 212-214) receives scale. This creates an inconsistency where the reference and TileLang kernel operate with different scaling behaviors, potentially causing validation failures when comparisons are enabled.


233-239: Kernel outputs are computed but never validated against reference.

Aqk_tilelang and Akk_tilelang are computed at line 233 but never compared against Aqk_ref and Akk_ref. The comparison code at lines 264-269 is commented out. For a test/validation file, this defeats the purpose of correctness verification mentioned in the PR objectives.

Also applies to: 264-269

examples/KDA/chunk_bwd_gla_dA.py (1)

80-81: Potential dtype mismatch for V_shared.

V_shared is allocated with do_dtype at line 81, but it's used to copy from V which has input_dtype (line 96). If do_dtype != input_dtype, this could cause implicit type conversion or precision loss.

🐛 Suggested fix
             do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
-            V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
+            V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
examples/KDA/chunk_delta_bwd.py (1)

121-136: Unused h0 parameter in kernel signature.

The kernel declares h0: T.Tensor(h0_shape, dtype=input_dtype) at line 128 but never uses it in the kernel body. Given the use_initial_state flag at line 217-218 only writes to dh0, the h0 input may have been intended for initialization logic that's not yet implemented.

examples/KDA/wy_fast_bwd.py (2)

341-351: Critical type mismatch: T.float32 passed where string expected.

main() passes TileLang type objects (e.g., T.float32) to run_test, but run_test uses getattr(torch, input_dtype) at lines 270-274, which expects string arguments like "float32". This will cause an AttributeError at runtime.

🐛 Proposed fix
 def main():
     DK = 128
     DV = 128
     run_test(
         B=1,
         S=32768,
         H=8,
         DK=DK,
         DV=DV,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,

295-307: String dtypes passed to TileLang kernel expecting TileLang types.

run_test passes string dtype parameters directly to tilelang_wy_fast_bwd, but the kernel's T.Tensor(..., dtype=input_dtype) declarations expect TileLang dtype objects. This creates an inconsistency - either convert strings to TileLang types before the kernel call, or have the kernel handle string conversion internally.

🔧 Proposed fix - convert dtypes before kernel call
     kernel = tilelang_wy_fast_bwd(
         B,
         S,
         H,
         DK,
         DV,
-        input_dtype,
-        output_dtype,
-        accum_dtype,
-        gate_dtype,
-        state_dtype,
+        getattr(T, input_dtype),
+        getattr(T, output_dtype),
+        getattr(T, accum_dtype),
+        getattr(T, gate_dtype),
+        getattr(T, state_dtype),
         chunk_size,
     )
examples/KDA/chunk_bwd_intra.py (1)

478-493: Critical: main() passes TileLang types but run_test expects strings.

This issue was flagged in a previous review and remains unfixed. main() passes T.float32 (a TileLang type object) to run_test, but run_test uses getattr(torch, input_dtype) at lines 407-411, which expects string arguments like "float32". This will cause an AttributeError at runtime.

🐛 Proposed fix
 def main():
     DK = 128
     run_test(
         B=1,
         S=8192,
         H=8,
         DK=DK,
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
         chunk_size=64,
         threads=128,
         num_stages=0,
     )
🧹 Nitpick comments (17)
src/tl_templates/cuda/copy_sm90.h (1)

265-331: Implementation looks correct and follows established patterns.

The new tma_store_add overloads correctly implement descriptor-based TMA atomic add operations for 1D through 5D tensors. The inline assembly uses the proper PTX instruction format for cp.reduce.async.bulk.tensor.Xd.global.shared::cta.add.bulk_group.

One minor inconsistency: the tma_store overloads (lines 184-253) include a CacheHintSm90 template parameter with .L2::cache_hint in the assembly, while these tma_store_add overloads omit cache hint support. The PTX ISA does support L2 cache hints for reduce operations, so adding this for API consistency could be beneficial.

♻️ Optional: Add cache hint support for consistency with tma_store
+template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
 TL_DEVICE void tma_store_add(const CUtensorMap &descriptor,
                              void const *const smem_ptr, int32_t const &crd0) {
   uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
   uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
   asm volatile(
-      "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group "
-      "[%0, {%2}], [%1];"
+      "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group"
+      ".L2::cache_hint [%0, {%2}], [%1], %3;"
       :
-      : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
+      : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "l"(cache_hint)
       : "memory");
 }

Apply similar changes to the 2D-5D overloads if cache hint control is desired.

src/op/atomic_add.cc (1)

317-332: Consider extracting to a shared utility function.

This implementation is identical to CopyNode::ComputeLinearLayout in src/op/copy.cc (lines 269-284). Consider extracting this to a shared utility function to reduce duplication and ensure consistent behavior.

testing/python/language/test_tilelang_language_atomic_add.py (1)

389-401: Consider also running the explicit_swizzle variant to verify correctness.

The test compiles with explicit_swizzle=True and verifies kernel source equivalence, but doesn't actually execute the explicit swizzle variant to confirm numerical correctness.

♻️ Suggested enhancement
     kernel_with_explicit_swizzle = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32], explicit_swizzle=True)
     # Ensure auto swizzled layout is applied
     assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source()
+
+    # Also verify explicit_swizzle variant produces correct results
+    out_explicit = torch.zeros((16, 16), dtype=torch.float32, device="cuda")
+    tma_atomic_add_program(out_explicit, explicit_swizzle=True)
+    torch.testing.assert_close(out_explicit, torch.ones((16, 16), dtype=torch.float32, device="cuda") * 16)
examples/KDA/test_utils_kda.py (1)

43-43: Unused atol and rtol parameters.

These parameters are accepted but never used in the function body. Either use them for tolerance-based pass/fail logic or remove them to avoid confusion.

examples/KDA/chunk_bwd_dv.py (1)

31-40: Unused chunk_size parameter in prepare_output.

The chunk_size parameter is accepted but never used. Consider removing it or documenting why it's reserved for future use.

examples/KDA/chunk_bwd_dqkwg.py (1)

200-219: Multiple unused parameters in run_test.

Parameters qk_dtype, use_gk, use_initial_state, store_final_state, save_new_value, block_DK, block_DV, threads, and num_stages are accepted but never used. Consider removing unused parameters or documenting they're reserved for future use.

examples/KDA/chunk_intra_token_parallel.py (1)

13-30: Unused function parameters in prepare_input.

The output_dtype and accum_dtype parameters are declared but never used within this function. Consider removing them if they serve no purpose, or prefix with underscore to indicate intentional non-use.

♻️ Suggested fix
 def prepare_input(
     B,
     S,
     H,
     DK,
     chunk_size,
     input_dtype,
-    output_dtype,
-    accum_dtype,
+    _output_dtype,
+    _accum_dtype,
     gate_dtype,
 ):
examples/KDA/chunk_bwd_gla_dA.py (1)

106-117: Unused DK parameter in run_test.

The DK parameter is declared at line 110 but never used in the function body. This appears to be a copy-paste artifact from other similar test functions.

♻️ Suggested fix
 def run_test(
     B,
     S,
     H,
-    DK,
     DV,
     scale,
     input_dtype,
     do_dtype,
     da_dtype,
     chunk_size,
 ):

And update the call in main():

     run_test(
         B=1,
         S=1024 * 8,  # 32768
         H=64,
-        DK=128,
         DV=128,
examples/KDA/chunk_delta_bwd.py (1)

236-242: Unused tuning parameters in run_test.

The parameters block_DV, threads, num_stages, and use_torch are declared but never used. These appear to be intended for manual kernel configuration override but aren't wired through to tilelang_chunk_gated_delta_rule_bwd_dhu.

♻️ Either remove unused parameters or wire them through

Option 1 - Remove:

 def run_test(
     ...
     use_final_state_gradient=True,
-    block_DV=64,
-    threads=256,
-    num_stages=0,
-    use_torch=False,
 ):

Option 2 - Wire through to kernel:

     kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
         ...
         use_final_state_gradient,
+        block_DV=block_DV,
+        threads=threads,
+        num_stages=num_stages,
     )
examples/KDA/chunk_inter_solve_fused.py (1)

326-398: Forward substitution loops have off-by-one potential in start index.

The T.Pipelined loops for forward substitution use unusual start indices:

  • Line 326: T.Pipelined(2, T.min(BC, S - i_tc0), ...)
  • Line 340: T.Pipelined(BC + 2, T.min(2 * BC, S - i_tc0), ...)
  • Line 362: T.Pipelined(2 * BC + 2, T.min(3 * BC, S - i_tc0), ...)
  • Line 380: T.Pipelined(3 * BC + 2, T.min(4 * BC, S - i_tc0), ...)

The + 2 offset in start indices appears intentional to skip the first two rows of each diagonal block, but this pattern is non-obvious and should be documented.

Consider adding a comment explaining why the loop starts at i*BC + 2 rather than i*BC:

# Start at offset 2 to skip the first two rows already handled by identity initialization
for i_i in T.Pipelined(2, T.min(BC, S - i_tc0), num_stages=num_stages):
examples/KDA/wy_fast_bwd.py (1)

14-14: Debug print settings left enabled.

torch.set_printoptions(profile="full") is set at line 14, which will produce verbose output for all tensor prints. This is typically a debugging artifact that should be removed before merging.

♻️ Remove debug setting
 torch.random.manual_seed(0)
-torch.set_printoptions(profile="full")
examples/KDA/chunk_bwd_intra.py (6)

3-3: Remove unused sys import.

The sys module is imported but never used. The # noqa: F401 directive is also flagged as unnecessary by static analysis.

🧹 Proposed fix
-import sys  # noqa: F401
-

19-49: Consider removing or documenting unused parameters.

The parameters output_dtype, accum_dtype, and state_dtype are never used in this function. If they're kept for API consistency with other functions, consider adding a brief comment explaining this. Otherwise, remove them to reduce confusion.


199-199: Minor typo in comment.

"ofprevious" should be "of previous".

📝 Proposed fix
-                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index ofprevious sub_chunks
+                for i_j in T.Pipelined(i_i, num_stages=num_stages):  # i_j is index of previous sub_chunks

274-274: Remove commented-out code.

Lines 274 and 331 contain commented-out code (T.use_swizzle(10) and dkt_lower_temp allocation). If these are no longer needed, consider removing them to keep the code clean.

Also applies to: 331-331


444-447: Remove redundant tensor allocation.

The tensors allocated by prepare_output on line 444-446 are immediately overwritten by the kernel call on line 447. The prepare_output call is unnecessary here.

♻️ Proposed fix
-    dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output(
-        B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
-    )
     dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)

385-400: Consider removing or documenting unused parameters.

The parameters threads, num_stages, cu_seqlens, and chunk_indices are declared but never used in this function. If they're placeholders for future functionality, consider adding a TODO comment or removing them.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 99db818 and 619fd5f.

📒 Files selected for processing (22)
  • examples/KDA/chunk_bwd_dqkwg.py
  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/chunk_bwd_gla_dA.py
  • examples/KDA/chunk_bwd_intra.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
  • examples/KDA/chunk_inter_solve_fused.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/test_utils_kda.py
  • examples/KDA/wy_fast.py
  • examples/KDA/wy_fast_bwd.py
  • examples/autodd/tilelang_buggy.py
  • examples/autodd/tilelang_minimized_expected.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
  • src/op/atomic_add.cc
  • src/op/atomic_add.h
  • src/op/copy.cc
  • src/op/utils.cc
  • src/op/utils.h
  • src/tl_templates/cuda/copy_sm90.h
  • testing/python/language/test_tilelang_language_atomic_add.py
💤 Files with no reviewable changes (1)
  • src/op/copy.cc
✅ Files skipped from review due to trivial changes (2)
  • examples/autodd/tilelang_minimized_expected.py
  • examples/autodd/tilelang_buggy.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/KDA/chunk_bwd_dv.py
  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_bwd.py
  • examples/KDA/chunk_bwd_dqkwg.py
  • testing/python/language/test_tilelang_language_atomic_add.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/wy_fast_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
📚 Learning: 2025-12-15T07:23:50.065Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:50.065Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.

Applied to files:

  • src/tl_templates/cuda/copy_sm90.h
📚 Learning: 2026-01-12T07:25:35.591Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:35.591Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.

Applied to files:

  • src/op/atomic_add.cc
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/KDA/wy_fast.py
  • examples/KDA/chunk_delta_bwd.py
  • testing/python/language/test_tilelang_language_atomic_add.py
  • examples/KDA/chunk_intra_token_parallel.py
  • examples/KDA/wy_fast_bwd.py
  • examples/KDA/chunk_delta_h_fwd.py
🧬 Code graph analysis (12)
src/op/atomic_add.h (2)
src/op/atomic_add.cc (2)
  • ComputeLinearLayout (317-332)
  • ComputeLinearLayout (317-317)
src/op/copy.cc (2)
  • ComputeLinearLayout (270-285)
  • ComputeLinearLayout (270-270)
src/op/utils.h (1)
src/op/utils.cc (2)
  • to_CUtensorMapDataType (96-157)
  • to_CUtensorMapDataType (96-96)
examples/KDA/chunk_bwd_dv.py (2)
examples/KDA/test_utils_kda.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_o.py (2)
  • do_bench (161-183)
  • kernel (93-156)
examples/KDA/test_utils_kda.py (2)
tilelang/carver/roller/policy/default.py (1)
  • sim (290-291)
tilelang/language/tir/op.py (1)
  • all (1913-1930)
examples/KDA/wy_fast.py (2)
examples/KDA/test_utils_kda.py (1)
  • compare_tensors (43-83)
tilelang/language/allocate.py (1)
  • alloc_shared (39-54)
examples/KDA/chunk_delta_bwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
  • chunk_gated_delta_rule_bwd_dhu (524-579)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/chunk_bwd_dqkwg.py (2)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg (141-193)
examples/KDA/test_utils_kda.py (2)
  • do_bench (86-108)
  • compare_tensors (43-83)
examples/KDA/chunk_inter_solve_fused.py (2)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/test_utils_kda.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_bwd_gla_dA.py (1)
examples/KDA/test_utils_kda.py (2)
  • compare_tensors (43-83)
  • do_bench (86-108)
examples/KDA/chunk_intra_token_parallel.py (3)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
  • chunk_kda_fwd_intra_token_parallel (114-168)
examples/KDA/FLA_KDA/cumsum.py (1)
  • chunk_local_cumsum (431-469)
examples/KDA/test_utils_kda.py (1)
  • do_bench (86-108)
examples/KDA/wy_fast_bwd.py (1)
examples/KDA/test_utils_kda.py (2)
  • do_bench (86-108)
  • compare_tensors (43-83)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
  • flashattn_bwd_postprocess (198-228)
🪛 Ruff (0.14.11)
examples/KDA/chunk_bwd_dv.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


36-36: Unused function argument: chunk_size

(ARG001)


111-111: Unused function argument: scale

(ARG001)

examples/KDA/test_utils_kda.py

43-43: Unused function argument: atol

(ARG001)


43-43: Unused function argument: rtol

(ARG001)


54-54: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


54-54: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

examples/KDA/wy_fast.py

6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused function argument: output_dtype

(ARG001)


68-68: Unused function argument: use_qg

(ARG001)


94-94: Unused function argument: QG

(ARG001)


174-174: Unused function argument: block_DK

(ARG001)


175-175: Unused function argument: block_DV

(ARG001)


176-176: Unused function argument: threads

(ARG001)


177-177: Unused function argument: num_stages

(ARG001)


184-184: Unpacked variable QG_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


185-185: Unpacked variable QG_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/chunk_delta_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused function argument: output_dtype

(ARG001)


29-29: Unused function argument: accum_dtype

(ARG001)


31-31: Unused function argument: state_dtype

(ARG001)


58-58: Unused function argument: gate_dtype

(ARG001)


128-128: Unused function argument: h0

(ARG001)


239-239: Unused function argument: block_DV

(ARG001)


240-240: Unused function argument: threads

(ARG001)


241-241: Unused function argument: num_stages

(ARG001)


242-242: Unused function argument: use_torch

(ARG001)

examples/KDA/chunk_bwd_dqkwg.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


43-43: Unused function argument: DV

(ARG001)


44-44: Unused function argument: chunk_size

(ARG001)


209-209: Unused function argument: qk_dtype

(ARG001)


211-211: Unused function argument: use_gk

(ARG001)


212-212: Unused function argument: use_initial_state

(ARG001)


213-213: Unused function argument: store_final_state

(ARG001)


214-214: Unused function argument: save_new_value

(ARG001)


215-215: Unused function argument: block_DK

(ARG001)


216-216: Unused function argument: block_DV

(ARG001)


217-217: Unused function argument: threads

(ARG001)


218-218: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_inter_solve_fused.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


24-24: Unused function argument: output_dtype

(ARG001)


25-25: Unused function argument: accum_dtype

(ARG001)


46-46: Unused function argument: sub_chunk_size

(ARG001)

examples/KDA/chunk_bwd_intra.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused function argument: output_dtype

(ARG001)


27-27: Unused function argument: accum_dtype

(ARG001)


29-29: Unused function argument: state_dtype

(ARG001)


56-56: Unused function argument: chunk_size

(ARG001)


60-60: Unused function argument: state_dtype

(ARG001)


95-95: Unused function argument: state_dtype

(ARG001)


132-132: Unused function argument: db

(ARG001)


396-396: Unused function argument: threads

(ARG001)


397-397: Unused function argument: num_stages

(ARG001)


398-398: Unused function argument: cu_seqlens

(ARG001)


399-399: Unused function argument: chunk_indices

(ARG001)

examples/KDA/chunk_bwd_gla_dA.py

4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


19-19: Unused function argument: chunk_size

(ARG001)


32-32: Unused function argument: DV

(ARG001)


110-110: Unused function argument: DK

(ARG001)

examples/KDA/chunk_intra_token_parallel.py

20-20: Unused function argument: output_dtype

(ARG001)


21-21: Unused function argument: accum_dtype

(ARG001)


233-233: Unpacked variable Aqk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


233-233: Unpacked variable Akk_tilelang is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/KDA/wy_fast_bwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


25-25: Unused function argument: output_dtype

(ARG001)


26-26: Unused function argument: accum_dtype

(ARG001)


28-28: Unused function argument: state_dtype

(ARG001)


50-50: Unused function argument: chunk_size

(ARG001)


53-53: Unused function argument: state_dtype

(ARG001)


92-92: Unused function argument: state_dtype

(ARG001)


258-258: Unused function argument: block_DK

(ARG001)


259-259: Unused function argument: block_DV

(ARG001)


260-260: Unused function argument: threads

(ARG001)


261-261: Unused function argument: num_stages

(ARG001)

examples/KDA/chunk_delta_h_fwd.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused function argument: output_dtype

(ARG001)


32-32: Unused function argument: accum_dtype

(ARG001)


98-98: Unused function argument: block_DK

(ARG001)


213-213: Unused function argument: block_DK

(ARG001)


214-214: Unused function argument: block_DV

(ARG001)


215-215: Unused function argument: threads

(ARG001)


216-216: Unused function argument: num_stages

(ARG001)

🔇 Additional comments (32)
src/op/utils.h (2)

9-9: LGTM: Include added for CUDA stub types.

The include is necessary to access CUtensorMapDataType enum used by the new to_CUtensorMapDataType function.


20-26: LGTM: Utility declarations are well-documented and correctly placed.

The to_CUtensorMapDataType declaration matches its implementation in utils.cc, and the ReverseArray template is a clean, correct implementation using reverse iterators.

src/op/utils.cc (1)

95-157: LGTM: Comprehensive data type mapping with proper error handling.

The function correctly handles all common TVM data types. The mapping of int16 to CU_TENSOR_MAP_DATA_TYPE_UINT16 is correct—CUDA's TensorMap API defines only UINT16 for 16-bit integers and lacks a signed INT16 type.

examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)

576-582: LGTM!

The mod_post invocation correctly matches the function signature, and the pattern of applying postprocessing only to dq while handling dk/dv via reduction is appropriate since they're stored in a different shape (groups, *k.shape).

examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)

399-408: LGTM!

The mod_post invocation correctly matches the function signature (batch, heads, head_kv, seq_len, dim_qk, dim_v), and the usage pattern aligns with the varlen implementation.

src/op/atomic_add.h (1)

80-81: LGTM!

The method declaration is consistent with the implementation pattern used by other operators (e.g., CopyNode::ComputeLinearLayout).

src/op/atomic_add.cc (3)

8-8: LGTM!

The include is required for accessing layout utility functions (makeGemmABLayoutHopper, makeLinearLayout, swizzle layout functions) used in the TMA atomic add implementation.


538-545: Warnings for unsupported layouts should consider failing compilation.

When an unsupported swizzle layout is detected (padded or unknown), the code logs a warning but proceeds with SWIZZLE_NONE. This could lead to incorrect results if the shared memory was actually written using a swizzled pattern.

Consider whether this should be an error that fails compilation, or at minimum document this behavior so users understand the potential for silent correctness issues.


616-657: LGTM!

The TMA reduce implementation correctly handles both single-step and multi-step (split) invocation paths. The loop unrolling approach for split operations is appropriate for TMA, and the thread predicate guard ensures only one thread per block executes the TMA operation.

testing/python/language/test_tilelang_language_atomic_add.py (3)

2-4: LGTM!

The new imports are required for the TMA atomic add test functionality.


355-368: LGTM!

The test program correctly exercises both automatic and explicit swizzle layout paths for TMA atomic add operations.


370-374: Good addition of half-precision type coverage.

Testing float16 and bfloat16 alongside float32 ensures the atomic operations work correctly across common data types.

examples/KDA/test_utils_kda.py (1)

86-107: LGTM!

The benchmarking function correctly uses CUDA events for timing with proper warmup and synchronization. The debug print at line 107 may be intentional but could be removed or made optional if noisy output is undesirable.

examples/KDA/chunk_bwd_dv.py (2)

54-102: LGTM!

The TileLang kernel correctly implements the chunked backward computation for dv. The lower triangular mask application and pipelined GEMM operations are properly structured.


105-137: LGTM with minor note.

The test flow correctly compares the TileLang kernel against the reference implementation and benchmarks both. The scale parameter at line 111 is unused but may be kept for API consistency with related test functions.

examples/KDA/chunk_bwd_dqkwg.py (1)

264-288: LGTM!

The main() function correctly passes string dtype names, and the validation calls are properly enabled.

examples/KDA/wy_fast.py (1)

97-159: LGTM for kernel implementation.

The kernel correctly implements the forward computation for W, U, and optional KG with proper shared memory management and pipelining.

examples/KDA/chunk_delta_h_fwd.py (3)

272-283: LGTM!

The benchmark now correctly uses gk=G and use_exp2=True, matching the correctness test parameters.


115-194: LGTM for kernel implementation.

The kernel correctly implements the forward chunked gated delta rule with proper handling of initial/final states, pipelined iteration, and optional V_new output. The ceildiv usage at line 104 ensures proper tensor allocation.


295-316: LGTM!

The main() function correctly passes string dtype names, addressing the previous type mismatch issue.

examples/KDA/chunk_intra_token_parallel.py (2)

93-99: Index computation logic appears correct.

The chunk and sub-chunk index calculations (i_c, i_s, i_tc, i_ts, loops) correctly derive sequence positions for the token-parallel iteration pattern.


161-180: Pipelined loop and GEMM-like reduction logic looks correct.

The inner loop properly handles the Q-K gated interactions with exponential gating factors and accumulates results into Sum_Aqk and Sum_Akk. The conditional write at line 176-177 correctly avoids writing Akk for the diagonal element (j < bs).

examples/KDA/chunk_bwd_gla_dA.py (2)

93-101: Lower-triangular masking and GEMM logic is correct.

The kernel correctly computes dA by accumulating the GEMM of DO and V^T, then applies a lower-triangular mask (i_s1 >= i_s2) with scaling. The logic aligns with the backward pass semantics.


135-139: Correctness validation is enabled - good practice.

Unlike some other files in this PR, this file properly calls compare_tensors to validate the TileLang kernel output against the reference.

examples/KDA/chunk_delta_bwd.py (1)

177-218: Backward pass loop structure and gradient updates look correct.

The kernel correctly processes chunks in reverse order (i_s_inv), updates dv by accumulating with the existing value, and computes dh updates through GEMMs with Q, W, and gating factors. The conditional use_gk and use_initial_state branches are properly handled.

examples/KDA/chunk_inter_solve_fused.py (2)

400-404: Identity matrix addition for diagonal blocks.

The code correctly adds identity to diagonal blocks after forward substitution, which is standard for computing (I - L)^{-1} where L is strictly lower triangular.


558-559: Validation is enabled - good.

The file properly validates both Aqk and Akk outputs against the reference implementation.

examples/KDA/wy_fast_bwd.py (3)

183-212: dk update loop logic appears correct.

The kernel properly computes dk updates by:

  1. Loading K, dk_old, dg_old, and GK
  2. Computing gated K (K_shared_beta_g = K * Beta * exp2(GK))
  3. Accumulating GEMM results for dA and dk
  4. Computing dbeta contribution via reduction
  5. Updating dg

The accumulation pattern and gating logic align with the WY representation backward pass.


233-241: dA correction logic is correct.

The dA correction computes -(A^T @ tril(dA) @ A) which matches the mathematical formulation for the WY backward pass. The lower-triangular masking (i_s1 > i_s2) correctly excludes the diagonal.


310-314: Validation is now enabled - good.

The file properly calls compare_tensors for all five outputs (dA, dk, dv, dbeta, dg).

examples/KDA/chunk_bwd_intra.py (2)

51-67: LGTM with minor note.

The function correctly prepares output tensors. Same note as prepare_input regarding unused parameters (chunk_size, state_dtype) - consider documenting if kept for API consistency.


69-78: LGTM!

Clean configuration generation for autotuning.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +262 to +268
# fla ref
print("fla running...", flush=True)
if use_gk:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(
q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, use_exp2=True
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reference computation only runs when use_gk=True.

The reference variables dh_ref, dh0_ref, dv2_ref are only assigned when use_gk=True (lines 264-267), but compare_tensors at lines 298-300 unconditionally uses them. If use_gk=False, this will raise a NameError.

🐛 Proposed fix
     compare_tensors("dh", dh_ref, dh_tilelang)
     compare_tensors("dh0", dh0_ref, dh0_tilelang)
     compare_tensors("dv2", dv2_ref, dv2_tilelang)
+    # Note: These comparisons assume use_gk=True; wrap in conditional if use_gk can be False

Or wrap in conditional:

+    if use_gk:
         compare_tensors("dh", dh_ref, dh_tilelang)
         compare_tensors("dh0", dh0_ref, dh0_tilelang)
         compare_tensors("dv2", dv2_ref, dv2_tilelang)
🤖 Prompt for AI Agents
In `@examples/KDA/chunk_delta_bwd.py` around lines 262 - 268, The reference
tensors dh_ref, dh0_ref, dv2_ref are only produced when use_gk is True but
compare_tensors is called unconditionally; either compute the reference via
chunk_gated_delta_rule_bwd_dhu when use_gk is False (call the same function path
or its non-gk equivalent) or wrap the compare_tensors calls in an if use_gk
block so they only run when dh_ref/dh0_ref/dv2_ref exist; update the logic
around the chunk_gated_delta_rule_bwd_dhu call and the subsequent
compare_tensors calls to ensure the reference variables are always defined
before use or comparisons are conditional on use_gk.

Comment on lines +81 to +82
block_S = BS = chunk_size
BC = sub_chunk_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if chunk_size / sub_chunk_size ratio is documented or enforced elsewhere
rg -n "chunk_size.*sub_chunk_size|BC.*4|4.*BC" examples/KDA/

Repository: tile-ai/tilelang

Length of output: 1998


🏁 Script executed:

#!/bin/bash
# Read the file to understand the kernel logic and find any constraints
wc -l examples/KDA/chunk_inter_solve_fused.py

Repository: tile-ai/tilelang

Length of output: 103


🏁 Script executed:

#!/bin/bash
# Examine the file around lines 81-82 and look for constraints
sed -n '70,150p' examples/KDA/chunk_inter_solve_fused.py | cat -n

Repository: tile-ai/tilelang

Length of output: 4457


🏁 Script executed:

#!/bin/bash
# Search for assertions or validation related to chunk_size and sub_chunk_size
rg -n "assert|chunk_size|sub_chunk_size" examples/KDA/chunk_inter_solve_fused.py | head -40

Repository: tile-ai/tilelang

Length of output: 989


🏁 Script executed:

#!/bin/bash
# Look for the test or main section to see what chunk sizes are used
rg -n "chunk_size.*=|sub_chunk_size.*=" examples/KDA/chunk_inter_solve_fused.py

Repository: tile-ai/tilelang

Length of output: 344


🏁 Script executed:

#!/bin/bash
# Look at the kernel's main loop structure around line 380 and 503 for context
sed -n '375,385p' examples/KDA/chunk_inter_solve_fused.py | cat -n

Repository: tile-ai/tilelang

Length of output: 916


🏁 Script executed:

#!/bin/bash
# Check line 503 context
sed -n '500,510p' examples/KDA/chunk_inter_solve_fused.py | cat -n

Repository: tile-ai/tilelang

Length of output: 502


🏁 Script executed:

#!/bin/bash
# Look for any documentation comments in the file about constraints
rg -n "chunk.*4|4.*chunk|sub.chunk.*ratio" examples/KDA/chunk_inter_solve_fused.py

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

#!/bin/bash
# Check if there are similar patterns in other KDA files to understand if this is a known constraint
rg -n "chunk_size.*4|4.*sub_chunk" examples/KDA/ -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 7682


Add assertion to enforce chunk_size = 4 * sub_chunk_size constraint.

The kernel hardcodes logic for exactly 4 sub-chunks (see line 380's T.min(4 * BC, ...) and line 503's 3 * BC : 4 * BC indexing). This requires chunk_size = 4 * sub_chunk_size, but there is no validation to enforce it. Passing incompatible sizes (e.g., chunk_size=64, sub_chunk_size=32) would silently produce incorrect results. Add an assertion at the start of the kernel function to validate this constraint.

🤖 Prompt for AI Agents
In `@examples/KDA/chunk_inter_solve_fused.py` around lines 81 - 82, Add a runtime
check at the start of the kernel function to enforce the relationship between
chunk_size and sub_chunk_size used elsewhere: assert chunk_size == 4 *
sub_chunk_size (or raise a ValueError with a clear message). Locate where
block_S = BS = chunk_size and BC = sub_chunk_size are set (symbols block_S, BS,
BC, chunk_size, sub_chunk_size) and insert the assertion immediately after those
assignments to prevent silent incorrect computation when the kernel assumes
exactly 4 sub-chunks.

Comment on lines 373 to 379
int dim = shared_tensor->shape.size();
const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]);
const int64_t mat_continuous =
*as_const_int(shared_tensor->shape[dim - 1]);
Layout swizzle_layout =
makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous,
shared_tensor->dtype.bits(), /*k_inner=*/true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential null pointer dereference if tensor dimensions are not constants.

as_const_int returns nullptr when the expression is not a compile-time constant. Dereferencing without a null check could cause a segmentation fault.

🐛 Proposed fix with null checks
       // TMA atomic add is similar to TMA Store - we should perform swizzle if
       // possible Use the last two dimensions to analyze swizzling
       int dim = shared_tensor->shape.size();
-      const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]);
-      const int64_t mat_continuous =
-          *as_const_int(shared_tensor->shape[dim - 1]);
+      auto mat_stride_ptr = as_const_int(shared_tensor->shape[dim - 2]);
+      auto mat_continuous_ptr = as_const_int(shared_tensor->shape[dim - 1]);
+      if (mat_stride_ptr == nullptr || mat_continuous_ptr == nullptr) {
+        // Cannot determine swizzle layout for non-constant dimensions
+        result_map.Set(shared_tensor, ComputeLinearLayout(shared_tensor));
+        return result_map;
+      }
+      const int64_t mat_stride = *mat_stride_ptr;
+      const int64_t mat_continuous = *mat_continuous_ptr;
       Layout swizzle_layout =
           makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous,
                                  shared_tensor->dtype.bits(), /*k_inner=*/true);
🤖 Prompt for AI Agents
In `@src/op/atomic_add.cc` around lines 373 - 379, The code dereferences
as_const_int(...) results for mat_stride and mat_continuous without checking for
nullptr, which can crash if shared_tensor->shape elements are not compile-time
constants; update the block using as_const_int(shared_tensor->shape[dim - 2])
and as_const_int(shared_tensor->shape[dim - 1]) to first capture each result
into a local pointer, verify they are non-null (and handle the non-constant case
by returning an error, throwing, or using a safe fallback), and only then assign
mat_stride/mat_continuous and call makeGemmABLayoutHopper; ensure error handling
is consistent with surrounding code paths that construct Layout when shape
values are unknown.

@tzj-fxz tzj-fxz self-requested a review January 14, 2026 10:17
tzj-fxz
tzj-fxz previously approved these changes Jan 14, 2026
@SiriusNEO SiriusNEO changed the title add KDA algorithm implementation in tilelang [Example] Add KDA algorithm implementation in tilelang Jan 19, 2026
- Updated the TVM submodule to commit 354eef9a.
- Removed several outdated KDA example files and utility scripts that are no longer in use, including chunk_bwd_dqkwg.py, chunk_bwd_dv.py, chunk_bwd_gla_dA.py, chunk_bwd_intra.py, chunk_delta_bwd.py, chunk_delta_h_fwd.py, chunk_inter_solve_fused.py, chunk_intra_token_parallel.py, chunk_o.py, README.md, test_utils_kda.py, wy_fast_bwd.py, wy_fast.py, and various FLA_KDA implementations.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Fix all issues with AI agents
In `@examples/kda/chunk_bwd_dv.py`:
- Line 137: The inline comment for the constant S is incorrect: the expression S
= 1024 * 8 evaluates to 8192, not 32768; update either the numeric expression or
the comment. Locate the S definition (symbol "S") and either change the
multiplication to match 32768 (e.g., 1024 * 32) if 32768 was intended, or update
the trailing comment to "# 8192" if the current value is correct; make sure the
code and comment are consistent.

In `@examples/kda/chunk_bwd_gla_dA.py`:
- Line 134: The inline comment for the constant S is incorrect: the expression S
= 1024 * 8 evaluates to 8192, not 32768; update the comment next to the S
assignment in examples/kda/chunk_bwd_gla_dA.py (the S variable declaration) to
reflect the correct value (8192) or remove the misleading comment, and ensure
consistency with the same fix applied to the S comment in chunk_bwd_dv.py if
present.

In `@examples/kda/FLA_KDA/fla_chunk_delta.py`:
- Around line 536-543: The function currently hard-codes BT = 64 while accepting
a chunk_size parameter, causing mismatches when callers pass a different
chunk_size; either bind BT to the parameter or enforce chunk_size==64. Fix by
replacing the hard-coded BT assignment with BT = chunk_size (or add an assertion
like assert chunk_size == 64 with a clear message) and ensure any use of
chunk_indices and downstream offsets relies on BT (the same value) so the kernel
and Python-side chunking remain consistent; update references to BT/chunk_size
in this function (and any callers using chunk_indices) accordingly.
- Around line 524-535: The parameter annotation for scale in
chunk_gated_delta_rule_bwd_dhu is incorrect: it has a default of None but is
typed as float; change the annotation to Optional[float] and add the
corresponding import from typing (e.g., from typing import Optional) at the top
of the module so the function signature reads scale: Optional[float] = None;
keep other parameters unchanged and ensure any linters/type checkers accept the
Optional import.

In `@examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py`:
- Around line 114-125: The function chunk_kda_fwd_intra_token_parallel currently
declares a None return but actually returns Aqk and Akk; update the signature to
reflect the real return type (e.g., -> Tuple[torch.Tensor, torch.Tensor]) and
add the necessary typing import (from typing import Tuple) or use a
forward-compatible annotation (e.g., -> tuple[torch.Tensor, torch.Tensor] for
Py3.9+), ensuring the function name chunk_kda_fwd_intra_token_parallel and its
callers are consistent with the corrected return type.

In `@examples/kda/FLA_KDA/fla_chunk_intra.py`:
- Around line 10-16: The code currently sets SOLVE_TRIL_DOT_PRECISION based on
IS_TF32_SUPPORTED (the conditional block setting tl.constexpr("tf32x3") or
tl.constexpr("ieee")) and then immediately overwrites it with
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32"), which defeats the conditional;
remove the unconditional override line so SOLVE_TRIL_DOT_PRECISION respects
IS_TF32_SUPPORTED (or conversely delete the conditional and keep the hardcoded
tl.constexpr("tf32") line if tf32 is intentionally fixed)—update references to
SOLVE_TRIL_DOT_PRECISION used by tl.dot operations accordingly.

In `@examples/kda/FLA_KDA/fla_utils.py`:
- Around line 16-25: The module currently calls CUDA APIs at import time which
can raise on CPU-only hosts; update the initialization of device,
device_torch_lib, IS_NVIDIA_HOPPER, and USE_CUDA_GRAPH to first guard with
torch.cuda.is_available() (or wrap torch.cuda calls in try/except) and provide
safe CPU fallbacks (e.g., device="cpu", device_torch_lib=torch) and default
IS_NVIDIA_HOPPER=False when CUDA is unavailable or the queries fail; ensure
USE_CUDA_GRAPH still reads the env var but only enables when CUDA is available.

In `@examples/kda/wy_fast_bwd.py`:
- Around line 55-60: The preallocated dA uses DK but the kernel writes (B, S, H,
chunk_size); change the dA allocation in prepare_output (the line creating dA)
to use chunk_size instead of DK so dA is torch.empty(B, S, H, chunk_size,
dtype=output_dtype).cuda(); keep the other tensors and dtypes unchanged and
ensure any downstream code expecting dA uses that chunk_size shape.
♻️ Duplicate comments (22)
examples/kda/test_utils_kda.py (2)

10-17: Fix tensor-to-Python control flow issues in calc_sim().

This function has multiple issues that were flagged in a previous review:

  1. .data is deprecated; use .detach() instead
  2. denominator == 0 comparison on a tensor is ambiguous
  3. Returns a tensor instead of a Python float, causing issues in downstream comparisons

27-30: Inverted mask logic when comparing non-finite values.

The mask logic is inverted: x_mask is True for finite values, so masked_fill(x_mask, 0) zeros out the finite positions and leaves non-finite values for comparison. This should use ~x_mask to mask out the non-finite values and compare only finite ones (or vice versa depending on intent).

examples/kda/chunk_delta_bwd.py (2)

117-132: Unused h0 parameter in kernel signature.

The kernel declares h0: T.Tensor(h0_shape, dtype=input_dtype) but never uses it. This was flagged in a previous review. Either implement the initial state handling logic or remove the parameter.


244-280: NameError when use_gk=False: reference variables undefined.

The reference tensors dh_ref, dh0_ref, dv2_ref are only assigned when use_gk=True (lines 244-247), but compare_tensors at lines 278-280 uses them unconditionally. If use_gk=False, this will raise a NameError.

examples/kda/chunk_bwd_gla_dA.py (1)

79-80: Potential dtype mismatch for V_shared.

V_shared is allocated with do_dtype but is used to copy from V which has input_dtype. This was flagged in a previous review.

examples/kda/FLA_KDA/cumsum.py (2)

176-177: Remove redundant if i_c >= 0.
i_c starts at 0 and is always non‑negative, so the branch is dead code.

♻️ Suggested refactor
-        if i_c >= 0:
-            b_z += b_ss
+        b_z += b_ss

430-440: Tighten chunk_local_cumsum API and error message.
**kwargs is unused (silently ignores unexpected args) and the error message omits the 3D scalar shape.

🧹 Suggested fix
 def chunk_local_cumsum(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
     scale: float = None,
     cu_seqlens: torch.Tensor = None,
     head_first: bool = False,
     output_dtype: torch.dtype = torch.float,
     chunk_indices: torch.LongTensor = None,
-    **kwargs,
 ) -> torch.Tensor:
@@
     else:
         raise ValueError(
             f"Unsupported input shape {g.shape}, "
-            f"which should be (B, T, H, D) if `head_first=False` or (B, H, T, D) otherwise",
+            f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
+            f"or [B, H, T]/[B, H, T, D] otherwise",
         )

Also applies to: 467-469

examples/kda/wy_fast.py (3)

53-94: use_qg/QG are declared but never written.
QG remains uninitialized even when returned. Either remove the flag/output or implement the QG write path.


203-204: Fix benchmark label typo.

✅ Suggested fix
-    print("tritron time:", triton_time)
+    print("triton time:", triton_time)

219-222: Pass dtype names as strings in main().
run_test uses getattr(torch, input_dtype) so T.bfloat16/T.float32 will raise AttributeError.

🐛 Suggested fix
-        input_dtype=T.bfloat16,
-        output_dtype=T.bfloat16,
-        gate_dtype=T.float32,
-        accum_dtype=T.float32,
+        input_dtype="bfloat16",
+        output_dtype="bfloat16",
+        gate_dtype="float32",
+        accum_dtype="float32",
#!/bin/bash
# Verify run_test expects string dtype names
rg -n "getattr\\(torch, input_dtype\\)|input_dtype=T" examples/kda/wy_fast.py
examples/kda/FLA_KDA/fla_chunk_o.py (2)

398-405: Initialize b_A for the USE_A=False path.
b_A is used unconditionally but only assigned when USE_A is true, which will crash at runtime on the false path.

🐛 Suggested fix
     if USE_A:
         p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
         b_A = tl.load(p_A, boundary_check=(0, 1))
+    else:
+        b_A = tl.zeros([BT, BT], dtype=tl.float32)

432-435: Call check_shared_mem in the elif branch.
The current branch tests the function object instead of invoking it, so the elif always evaluates truthy when the first condition fails.

🔧 Suggested fix
-    elif check_shared_mem:
+    elif check_shared_mem("ampere", k.device.index):
         CONST_TILING = 64
examples/kda/chunk_intra_token_parallel.py (1)

176-254: TileLang path ignores scale and outputs aren’t validated.
scale is passed to the reference but never to the TileLang kernel, and the TileLang outputs are not compared against references.

✅ Suggested fix (validation)
-from test_utils_kda import do_bench
+from test_utils_kda import do_bench, compare_tensors
@@
     Aqk_tilelang, Akk_tilelang = kernel(
         q,
         k,
         gk,
         beta,
     )
@@
+    compare_tensors("Aqk", Aqk_ref, Aqk_tilelang)
+    compare_tensors("Akk", Akk_ref, Akk_tilelang)

Please also thread scale into tilelang_chunk_kda_fwd_intra_token_parallel and apply it consistently with the reference kernel.

examples/kda/chunk_delta_h_fwd.py (1)

167-169: Clamp GK last-token index for partial chunks.
When S isn’t divisible by block_S, (i_s + 1) * block_S - 1 can exceed S - 1.

🐛 Suggested fix
-                if use_gk:
-                    T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared)  # block last token
+                if use_gk:
+                    last_idx = T.min((i_s + 1) * block_S - 1, S - 1)
+                    T.copy(GK[bb, last_idx, bh, :], GK_last_shared)  # block last token
examples/kda/chunk_inter_solve_fused.py (1)

69-70: Enforce kernel size invariants.
The kernel hardcodes 4 sub‑chunks per chunk and later loads diagonals unconditionally; without explicit checks, incompatible sizes can cause incorrect results or OOB access.

✅ Suggested fix
     block_S = BS = chunk_size
     BC = sub_chunk_size
+    assert chunk_size == 4 * sub_chunk_size, "chunk_size must be 4 * sub_chunk_size"
+    assert S % chunk_size == 0, "S must be divisible by chunk_size"
examples/kda/wy_fast_bwd.py (2)

289-302: Convert dtype strings to TileLang dtypes before kernel construction.

run_test uses string dtypes for torch conversion, but those strings are passed directly to tilelang_wy_fast_bwd, which expects TileLang dtypes in T.Tensor(..., dtype=...). This mismatch will fail unless TileLang explicitly accepts strings here.

🐛 Suggested fix
+    tl_input_dtype = getattr(T, input_dtype) if isinstance(input_dtype, str) else input_dtype
+    tl_output_dtype = getattr(T, output_dtype) if isinstance(output_dtype, str) else output_dtype
+    tl_accum_dtype = getattr(T, accum_dtype) if isinstance(accum_dtype, str) else accum_dtype
+    tl_gate_dtype = getattr(T, gate_dtype) if isinstance(gate_dtype, str) else gate_dtype
+    tl_state_dtype = getattr(T, state_dtype) if isinstance(state_dtype, str) else state_dtype
     kernel = tilelang_wy_fast_bwd(
         B,
         S,
         H,
         DK,
         DV,
-        input_dtype,
-        output_dtype,
-        accum_dtype,
-        gate_dtype,
-        state_dtype,
+        tl_input_dtype,
+        tl_output_dtype,
+        tl_accum_dtype,
+        tl_gate_dtype,
+        tl_state_dtype,
         chunk_size,
     )

327-340: Fix dtype arguments in main() to match run_test.

run_test uses getattr(torch, input_dtype) and expects strings like "float32". Passing T.float32 will raise AttributeError.

🐛 Suggested fix
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
examples/kda/FLA_KDA/fla_wy_fast.py (1)

269-272: Keep backward BT consistent with forward tiling.

prepare_wy_repr_bwd hardcodes BT = 64 while the forward path uses BT = A.shape[-1]. If chunk size changes, backward will be incorrect.

🐛 Suggested fix
-    BT = 64
+    BT = A.shape[-1]
examples/kda/chunk_bwd_dqkwg.py (2)

121-122: Use gate_dtype for gate buffers.

G_shared and Gn_shared store gate values but are allocated with input_dtype, which can silently reduce precision.

🐛 Suggested fix
-            G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)  # chunk G
-            Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype)  # chunk last token G
+            G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype)  # chunk G
+            Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype)  # chunk last token G

37-50: Honor qk_dtype for dq/dk outputs.

qk_dtype is accepted by run_test but dq/dk are always float32. Either remove the parameter or use it consistently.

♻️ Suggested fix
-def prepare_output(
+def prepare_output(
     B,
     S,
     H,
     DK,
     DV,
     chunk_size,
-    gate_dtype,
+    qk_dtype,
+    gate_dtype,
 ):
-    dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
-    dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
+    dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
+    dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
     dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
     return dq, dk, dw, dg
-    dq, dk, dw, dg = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype))
+    dq, dk, dw, dg = prepare_output(
+        B, S, H, DK, DV, chunk_size, getattr(torch, qk_dtype), getattr(torch, gate_dtype)
+    )
examples/kda/chunk_bwd_intra.py (1)

474-486: Fix dtype arguments in main() to match run_test.

run_test expects strings for getattr(torch, ...); passing T.float32 raises AttributeError.

🐛 Suggested fix
-        input_dtype=T.float32,
-        output_dtype=T.float32,
-        accum_dtype=T.float32,
-        gate_dtype=T.float32,
-        state_dtype=T.float32,
+        input_dtype="float32",
+        output_dtype="float32",
+        accum_dtype="float32",
+        gate_dtype="float32",
+        state_dtype="float32",
examples/kda/FLA_KDA/fla_chunk_delta.py (1)

487-492: Initialize N/NT/chunk_offsets for varlen inputs (still uninitialized).

When cu_seqlens is provided, the cu_seqlens is None branch is skipped, so N, NT, and chunk_offsets are never set before allocation or kernel launch. This still triggers UnboundLocalError and/or invalid kernel arguments.

🐛 Suggested fix pattern (apply in both forward and backward)
     if cu_seqlens is None:
         N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+    else:
+        N = len(cu_seqlens) - 1
+        NT = len(chunk_indices)
+        # derive per-sequence chunk offsets from chunk_indices
+        counts = torch.bincount(chunk_indices[:, 0], minlength=N)
+        chunk_offsets = torch.zeros(N, device=k.device, dtype=torch.int32)
+        chunk_offsets[1:] = torch.cumsum(counts, 0)[:-1]

Also applies to: 545-548

🧹 Nitpick comments (18)
examples/kda/FLA_KDA/fla_utils.py (4)

33-55: Replace the fullwidth comma in the comment.

Ruff flags the character; use a standard comma to avoid lint churn.

✏️ Lint-friendly comment
-# error check,copy from
+# error check, copy from

125-152: Don’t swallow driver errors silently in get_multiprocessor_count.

Catching Exception and pass can hide real driver/API failures and make debugging tough. Prefer narrower exceptions and/or a warning.

🛠️ Example with warning
-    except Exception:
-        pass
+    except Exception as e:
+        warnings.warn(f"Failed to read multiprocessor count (new API): {e}", stacklevel=2)
-    except Exception:
-        pass
+    except Exception as e:
+        warnings.warn(f"Failed to read multiprocessor count (legacy API): {e}", stacklevel=2)

155-186: Guard custom_device_ctx usage to CUDA tensors only.

input_guard currently uses a CUDA device context for any tensor type; if a CPU tensor slips in, this can create confusing failures. Consider checking tensor.is_cuda before entering a CUDA context.

🧭 Safer device-context selection
-        if tensor is not None:
-            ctx = custom_device_ctx(tensor.device.index)
-        else:
-            ctx = contextlib.nullcontext()
+        if tensor is not None and tensor.is_cuda:
+            ctx = custom_device_ctx(tensor.device.index)
+        else:
+            ctx = contextlib.nullcontext()

224-240: Handle shared-memory query failures more explicitly.

BaseException/Exception with silent fallbacks can mask real issues (driver misconfig, bad indices). Prefer narrower exceptions and a warning.

🛠️ Narrow exceptions + warning
-    except BaseException:
-        return [-1]
+    except (RuntimeError, AttributeError, KeyError, IndexError) as e:
+        warnings.warn(f"Failed to query max shared mem: {e}", stacklevel=2)
+        return [-1]
-    except Exception:
-        return False
+    except (RuntimeError, AttributeError, KeyError, IndexError) as e:
+        warnings.warn(f"Failed to check shared mem: {e}", stacklevel=2)
+        return False
examples/kda/test_utils_kda.py (2)

43-84: Remove or use the atol and rtol parameters.

The compare_tensors function accepts atol and rtol parameters but never uses them. Either implement tolerance-based comparison or remove these parameters to avoid misleading callers.

♻️ Proposed fix to use the tolerance parameters
 def compare_tensors(name, x, y, atol=1e-5, rtol=1e-5):
     import numpy as np
     import torch

     diff = (x - y).abs()
+    
+    # Check if tensors are close within tolerance
+    is_close = torch.allclose(x, y, atol=atol, rtol=rtol)
+    print(f"Tensors close (atol={atol}, rtol={rtol}): {is_close}")

     # ========= Max Absolute Error =========

107-108: Remove debug print(times) statement.

This debug statement prints the raw timing tensor before returning the mean. Consider removing it or gating it behind a verbose parameter to reduce noise in production benchmarks.

♻️ Proposed fix
-    print(times)
     return times.mean().item()
examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py (1)

70-71: Translate Chinese comments for consistency.

Consider translating these comments to English for broader accessibility:

  • "chunk 首坐标" → "chunk start coordinate"
  • "subchunk 首坐标" → "subchunk start coordinate"
♻️ Proposed fix
-    i_tc = i_c * BT  # chunk 首坐标
-    i_ts = i_tc + i_s * BC  # subchunk 首坐标
+    i_tc = i_c * BT  # chunk start coordinate
+    i_ts = i_tc + i_s * BC  # subchunk start coordinate
examples/kda/chunk_bwd_dv.py (3)

4-4: Remove unused sys import.

The sys module is imported but never used. The noqa: F401 directive suppresses the warning but doesn't address the underlying issue.

♻️ Proposed fix
-import sys  # noqa: F401

114-125: Dead code: prepare_output result is immediately overwritten.

The output tensor dv_tilelang allocated at line 114 is never used because the kernel at line 125 returns a new tensor that overwrites it. Either remove the prepare_output call or use it for in-place kernel execution.

♻️ Proposed fix
-    dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype))
     kernel = tilelang_chunk_bwd_kernel_dv_local(
         B=B,
         S=S,
         H=H,
         DV=DV,
         input_dtype=input_dtype,
         output_dtype=output_dtype,
         do_dtype=do_dtype,
         chunk_size=chunk_size,
     )
     dv_tilelang = kernel(DO, A)

31-40: Remove unused chunk_size parameter from prepare_output.

The chunk_size parameter is accepted but never used in the function body. Remove it to clean up the API.

♻️ Proposed fix
 def prepare_output(
     B,
     S,
     H,
     DV,
-    chunk_size,
     output_dtype,
 ):
     dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
     return dv
examples/kda/chunk_delta_bwd.py (1)

219-222: Unused parameters in run_test function.

The parameters block_DV, threads, num_stages, and use_torch are accepted but never used. These appear to be remnants of manual tuning that's now handled by autotune. Consider removing them or passing them to override autotuning.

examples/kda/chunk_bwd_gla_dA.py (2)

89-89: Translate Chinese comment for consistency.

Consider translating "下三角矩阵" to "lower triangular matrix" for broader accessibility.

♻️ Proposed fix
-                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # 下三角矩阵
+                dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0)  # lower triangular matrix

107-108: Remove debug print statement.

This debug print for dtypes should be removed or gated behind a verbose flag.

♻️ Proposed fix
     DO, V_new = prepare_input(B, S, H, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype))
-    print(DO.dtype, V_new.dtype)
     dA_ref = chunk_gla_bwd_dA(v=V_new, do=DO, scale=scale)
examples/kda/FLA_KDA/fla_chunk_inter.py (2)

151-154: Use explicit Optional type annotation for nullable parameters.

Per PEP 484, scale: float = None should be scale: Optional[float] = None (or scale: float | None = None in Python 3.10+) to explicitly indicate the parameter can be None.

♻️ Proposed fix
+from typing import Optional
+
 def chunk_kda_bwd_dqkwg(
     q: torch.Tensor,
     k: torch.Tensor,
     w: torch.Tensor,
     v: torch.Tensor,
     h: torch.Tensor,
     g: torch.Tensor,
     do: torch.Tensor,
     dh: torch.Tensor,
     dv: torch.Tensor,
-    scale: float = None,
-    cu_seqlens: torch.LongTensor = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     chunk_size: int = 64,
-    chunk_indices: torch.LongTensor = None,
+    chunk_indices: Optional[torch.LongTensor] = None,
 ):

163-166: Intentional dtype difference for gradient outputs.

The code allocates dq and dk with dtype=torch.float while dw and dg use empty_like (inheriting the input dtype). This appears intentional for numerical stability in gradient accumulation, but consider adding a comment to clarify the rationale.

♻️ Proposed documentation
+    # dq, dk need float32 for gradient accumulation stability
     dq = torch.empty_like(q, dtype=torch.float)
     dk = torch.empty_like(k, dtype=torch.float)
     dw = torch.empty_like(w)
     dg = torch.empty_like(g)
examples/kda/wy_fast_bwd.py (1)

241-257: Kernel config args are accepted but never used.

block_DK, block_DV, threads, and num_stages are passed into run_test but not forwarded to the kernel, so caller intent is ignored.

♻️ Suggested fix
     kernel = tilelang_wy_fast_bwd(
         B,
         S,
         H,
         DK,
         DV,
         input_dtype,
         output_dtype,
         accum_dtype,
         gate_dtype,
         state_dtype,
         chunk_size,
+        block_DK=block_DK,
+        block_DV=block_DV,
+        threads=threads,
+        num_stages=num_stages,
     )
examples/kda/chunk_o.py (1)

169-213: Forward kernel config params to tilelang_chunk_fwd_o.

block_DK, block_DV, threads, and num_stages are accepted but never used, so the benchmark ignores caller-provided configs.

♻️ Suggested fix
     kernel = tilelang_chunk_fwd_o(
         B,
         S,
         H,
         DK,
         DV,
         input_dtype,
         output_dtype,
         accum_dtype,
         gate_dtype,
         chunk_size,
         scale,
         block_S,
+        block_DK=block_DK,
+        block_DV=block_DV,
+        threads=threads,
+        num_stages=num_stages,
     )
examples/kda/chunk_bwd_intra.py (1)

381-438: Forward kernel config args to tilelang_chunk_bwd_intra.

threads and num_stages are accepted but unused, so the test ignores the caller’s tuning choices.

♻️ Suggested fix
     kernel = tilelang_chunk_bwd_intra(
         B=B,
         S=S,
         H=H,
         DK=DK,
         input_dtype=input_dtype,
         output_dtype=output_dtype,
         accum_dtype=accum_dtype,
         gate_dtype=gate_dtype,
         state_dtype=state_dtype,
         chunk_size=chunk_size,
         block_DK=block_DK,
+        threads=threads,
+        num_stages=num_stages,
     )

def main():
run_test(
B=1,
S=1024 * 8, # 32768
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect comment: 1024 × 8 = 8192, not 32768.

The inline comment states # 32768 but 1024 * 8 = 8192. Either fix the comment or update the value if 32768 was intended.

🐛 Proposed fix (if 8192 is intended)
-        S=1024 * 8,  # 32768
+        S=1024 * 8,  # 8192
🐛 Proposed fix (if 32768 is intended)
-        S=1024 * 8,  # 32768
+        S=1024 * 32,  # 32768
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
S=1024 * 8, # 32768
S=1024 * 8, # 8192
Suggested change
S=1024 * 8, # 32768
S=1024 * 32, # 32768
🤖 Prompt for AI Agents
In `@examples/kda/chunk_bwd_dv.py` at line 137, The inline comment for the
constant S is incorrect: the expression S = 1024 * 8 evaluates to 8192, not
32768; update either the numeric expression or the comment. Locate the S
definition (symbol "S") and either change the multiplication to match 32768
(e.g., 1024 * 32) if 32768 was intended, or update the trailing comment to "#
8192" if the current value is correct; make sure the code and comment are
consistent.

def main():
run_test(
B=1,
S=1024 * 8, # 32768
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect comment: 1024 × 8 = 8192, not 32768.

Same issue as in chunk_bwd_dv.py: the comment states # 32768 but the expression evaluates to 8192.

🐛 Proposed fix
-        S=1024 * 8,  # 32768
+        S=1024 * 8,  # 8192
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
S=1024 * 8, # 32768
S=1024 * 8, # 8192
🤖 Prompt for AI Agents
In `@examples/kda/chunk_bwd_gla_dA.py` at line 134, The inline comment for the
constant S is incorrect: the expression S = 1024 * 8 evaluates to 8192, not
32768; update the comment next to the S assignment in
examples/kda/chunk_bwd_gla_dA.py (the S variable declaration) to reflect the
correct value (8192) or remove the misleading comment, and ensure consistency
with the same fix applied to the S comment in chunk_bwd_dv.py if present.

Comment on lines +524 to +535
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor = None,
gk: torch.Tensor = None,
h0: torch.Tensor = None,
dht: torch.Tensor = None,
scale: float = None,
cu_seqlens: torch.LongTensor = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -name "fla_chunk_delta.py" -type f

Repository: tile-ai/tilelang

Length of output: 101


🏁 Script executed:

# Once found, let's read the relevant lines
head -n 550 examples/kda/FLA_KDA/fla_chunk_delta.py | tail -n 50

Repository: tile-ai/tilelang

Length of output: 1493


🏁 Script executed:

# Check the imports at the top of the file
head -n 50 examples/kda/FLA_KDA/fla_chunk_delta.py

Repository: tile-ai/tilelang

Length of output: 1318


Use Optional[float] for scale's default None (RUF013).

The scale parameter has a default value of None but is typed as float. Update the type annotation to Optional[float] and add the import from typing.

Proposed fix
-import torch
+import torch
+from typing import Optional
@@
-    scale: float = None,
+    scale: Optional[float] = None,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor = None,
gk: torch.Tensor = None,
h0: torch.Tensor = None,
dht: torch.Tensor = None,
scale: float = None,
cu_seqlens: torch.LongTensor = None,
from typing import Optional
import torch
def chunk_gated_delta_rule_bwd_dhu(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
do: torch.Tensor,
dv: torch.Tensor,
g: torch.Tensor = None,
gk: torch.Tensor = None,
h0: torch.Tensor = None,
dht: torch.Tensor = None,
scale: Optional[float] = None,
cu_seqlens: torch.LongTensor = None,
🧰 Tools
🪛 Ruff (0.14.13)

534-534: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_delta.py` around lines 524 - 535, The
parameter annotation for scale in chunk_gated_delta_rule_bwd_dhu is incorrect:
it has a default of None but is typed as float; change the annotation to
Optional[float] and add the corresponding import from typing (e.g., from typing
import Optional) at the top of the module so the function signature reads scale:
Optional[float] = None; keep other parameters unchanged and ensure any
linters/type checkers accept the Optional import.

Comment on lines +536 to +543
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_indices: torch.LongTensor = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Backward chunk size mismatch (BT hard-coded to 64).

chunk_size is still accepted (and used for chunk_indices), but BT is forced to 64. If callers pass a non-64 chunk_size, chunking and offsets will diverge from the kernel’s BT, producing incorrect results. Either enforce 64 explicitly or wire BT to chunk_size.

🔧 Option: enforce 64 to avoid silent mismatch
-    BT = 64
+    assert chunk_size == 64, "chunk_size must be 64 for blockdim64 kernels"
+    BT = chunk_size
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_delta.py` around lines 536 - 543, The function
currently hard-codes BT = 64 while accepting a chunk_size parameter, causing
mismatches when callers pass a different chunk_size; either bind BT to the
parameter or enforce chunk_size==64. Fix by replacing the hard-coded BT
assignment with BT = chunk_size (or add an assertion like assert chunk_size ==
64 with a clear message) and ensure any use of chunk_indices and downstream
offsets relies on BT (the same value) so the kernel and Python-side chunking
remain consistent; update references to BT/chunk_size in this function (and any
callers using chunk_indices) accordingly.

Comment on lines +114 to +125
def chunk_kda_fwd_intra_token_parallel(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor,
beta: torch.Tensor,
Aqk: torch.Tensor,
Akk: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor = None,
chunk_size: int = 64,
sub_chunk_size: int = 16,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation contradicts actual behavior.

The function signature declares -> None but the function returns Aqk, Akk at line 168. This inconsistency can mislead callers and type checkers.

🐛 Proposed fix
-) -> None:
+) -> tuple[torch.Tensor, torch.Tensor]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def chunk_kda_fwd_intra_token_parallel(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor,
beta: torch.Tensor,
Aqk: torch.Tensor,
Akk: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor = None,
chunk_size: int = 64,
sub_chunk_size: int = 16,
) -> None:
def chunk_kda_fwd_intra_token_parallel(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor,
beta: torch.Tensor,
Aqk: torch.Tensor,
Akk: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor = None,
chunk_size: int = 64,
sub_chunk_size: int = 16,
) -> tuple[torch.Tensor, torch.Tensor]:
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py` around lines 114 -
125, The function chunk_kda_fwd_intra_token_parallel currently declares a None
return but actually returns Aqk and Akk; update the signature to reflect the
real return type (e.g., -> Tuple[torch.Tensor, torch.Tensor]) and add the
necessary typing import (from typing import Tuple) or use a forward-compatible
annotation (e.g., -> tuple[torch.Tensor, torch.Tensor] for Py3.9+), ensuring the
function name chunk_kda_fwd_intra_token_parallel and its callers are consistent
with the corrected return type.

Comment on lines +10 to +16
IS_TF32_SUPPORTED = False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")
# ============================================================================
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's verify the file exists and read the relevant lines
if [ -f "examples/kda/FLA_KDA/fla_chunk_intra.py" ]; then
  echo "=== File content (lines 1-30) ===" 
  head -30 examples/kda/FLA_KDA/fla_chunk_intra.py | cat -n
else
  echo "File not found"
fi

Repository: tile-ai/tilelang

Length of output: 1258


🏁 Script executed:

# Let's also check if this file exists with a different path pattern
find . -name "fla_chunk_intra.py" -type f 2>/dev/null

Repository: tile-ai/tilelang

Length of output: 101


🏁 Script executed:

# Search for usages of SOLVE_TRIL_DOT_PRECISION in the file
rg "SOLVE_TRIL_DOT_PRECISION" examples/kda/FLA_KDA/fla_chunk_intra.py -n

Repository: tile-ai/tilelang

Length of output: 1367


🏁 Script executed:

# Check more context around lines 10-15 to see if there are comments explaining this
sed -n '5,20p' examples/kda/FLA_KDA/fla_chunk_intra.py | cat -n

Repository: tile-ai/tilelang

Length of output: 765


Remove the unconditional precision override.

SOLVE_TRIL_DOT_PRECISION is set based on IS_TF32_SUPPORTED on lines 11-14, then immediately overwritten to "tf32" on line 15, which defeats the conditional and renders the flag meaningless. The precision is used extensively in downstream tl.dot operations (lines 278+), so the effective behavior is always "tf32" regardless of the IS_TF32_SUPPORTED check.

Either remove the conditional logic if "tf32" is the intended hardcoded value, or remove line 15 to respect the conditional logic:

Suggested fix
 IS_TF32_SUPPORTED = False
-if IS_TF32_SUPPORTED:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
-else:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
-SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")
+SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") if IS_TF32_SUPPORTED else tl.constexpr("ieee")

or simply:

 IS_TF32_SUPPORTED = False
-if IS_TF32_SUPPORTED:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
-else:
-    SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
 SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
IS_TF32_SUPPORTED = False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")
# ============================================================================
IS_TF32_SUPPORTED = False
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") if IS_TF32_SUPPORTED else tl.constexpr("ieee")
# ============================================================================
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_intra.py` around lines 10 - 16, The code
currently sets SOLVE_TRIL_DOT_PRECISION based on IS_TF32_SUPPORTED (the
conditional block setting tl.constexpr("tf32x3") or tl.constexpr("ieee")) and
then immediately overwrites it with SOLVE_TRIL_DOT_PRECISION =
tl.constexpr("tf32"), which defeats the conditional; remove the unconditional
override line so SOLVE_TRIL_DOT_PRECISION respects IS_TF32_SUPPORTED (or
conversely delete the conditional and keep the hardcoded tl.constexpr("tf32")
line if tf32 is intentionally fixed)—update references to
SOLVE_TRIL_DOT_PRECISION used by tl.dot operations accordingly.

Comment on lines +16 to +25
device = "cuda"
device_torch_lib = getattr(torch, device)

exp = tldevice.fast_expf
exp2 = tldevice.exp2
log = tldevice.fast_logf
log2 = tldevice.fast_log2f

IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)
USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and read it
find . -name "fla_utils.py" -path "*/FLA_KDA/*"

Repository: tile-ai/tilelang

Length of output: 95


🏁 Script executed:

# Read the specific file to examine lines 16-25
cat -n examples/kda/FLA_KDA/fla_utils.py | head -30

Repository: tile-ai/tilelang

Length of output: 1143


🏁 Script executed:

# Check the full context of the file, especially around imports and the device initialization
wc -l examples/kda/FLA_KDA/fla_utils.py
cat -n examples/kda/FLA_KDA/fla_utils.py

Repository: tile-ai/tilelang

Length of output: 9304


Guard CUDA device queries at import time.

Line 24 calls CUDA APIs unconditionally at module import, which will raise RuntimeError on CPU-only hosts or when no GPU is visible, breaking imports entirely. Defer evaluation or guard with torch.cuda.is_available() and a safe fallback.

🔧 Suggested guard for safe import
-IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)
+def _is_nvidia_hopper() -> bool:
+    if not torch.cuda.is_available():
+        return False
+    try:
+        name = torch.cuda.get_device_name(0)
+        major, _ = torch.cuda.get_device_capability(0)
+        return ("NVIDIA H" in name) or (major >= 9)
+    except Exception:
+        return False
+
+IS_NVIDIA_HOPPER = _is_nvidia_hopper()
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_utils.py` around lines 16 - 25, The module currently
calls CUDA APIs at import time which can raise on CPU-only hosts; update the
initialization of device, device_torch_lib, IS_NVIDIA_HOPPER, and USE_CUDA_GRAPH
to first guard with torch.cuda.is_available() (or wrap torch.cuda calls in
try/except) and provide safe CPU fallbacks (e.g., device="cpu",
device_torch_lib=torch) and default IS_NVIDIA_HOPPER=False when CUDA is
unavailable or the queries fail; ensure USE_CUDA_GRAPH still reads the env var
but only enables when CUDA is available.

Comment on lines +55 to +60
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()
dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
return dk, dv, dbeta, dg, dA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix dA output shape to match kernel output.

dA is produced with shape (B, S, H, chunk_size) in the kernel, but prepare_output allocates (B, S, H, DK). This will misalign if the preallocated output is ever used.

🐛 Suggested fix
-    dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
+    dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()
dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
return dk, dv, dbeta, dg, dA
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()
dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda()
return dk, dv, dbeta, dg, dA
🤖 Prompt for AI Agents
In `@examples/kda/wy_fast_bwd.py` around lines 55 - 60, The preallocated dA uses
DK but the kernel writes (B, S, H, chunk_size); change the dA allocation in
prepare_output (the line creating dA) to use chunk_size instead of DK so dA is
torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda(); keep the other
tensors and dtypes unchanged and ensure any downstream code expecting dA uses
that chunk_size shape.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, and I refactored some of the code, I think now we can let this commit in.

@LeiWang1999 LeiWang1999 merged commit 5748841 into tile-ai:main Jan 26, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants