[Example] Add KDA algorithm implementation in tilelang#1660
[Example] Add KDA algorithm implementation in tilelang#1660LeiWang1999 merged 23 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughIntroduces comprehensive KDA (Key-Delta-Attention) GPU kernel implementations using Triton and TileLang, spanning cumulative sum, gated delta rules, inter-intra attention, and output computation operations. Includes utilities, examples, and reference implementations. Updates pre-commit configuration and removes a public method from AtomicAddNode. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 15
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 89-132: In kernel, G_shared and Gn_shared are allocated with
input_dtype but should use gate_dtype; change the dtype argument for the
alloc_shared calls for G_shared and Gn_shared to gate_dtype so gate values keep
their intended precision (update the lines allocating G_shared and Gn_shared
inside the kernel).
In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 168-200: The loop uses T.ceildiv(S, block_S) but h was allocated
using BS = S // block_S causing OOB when S % block_S != 0; update the tensor
allocation symbol h_shape to use T.ceildiv(S, block_S) (i.e. h_shape = (B,
T.ceildiv(S, block_S), H, DK, DV)) so its second dimension matches the loop
bound, and apply the same fix for dh_shape in chunk_delta_bwd.py (replace BS
with T.ceildiv(S, block_S)); locate and change the allocations that reference BS
to use T.ceildiv(S, block_S) instead.
In @examples/KDA/chunk_o.py:
- Around line 125-133: The inner Parallel loop mistakenly iterates over block_DV
while indexing Q_shared/GK_shared/GQ_shared which have shape [block_S,
block_DK]; change the loop in the block that assigns Q_shared/GQ_shared from
"for i_s, i_v in T.Parallel(block_S, block_DV):" to iterate over block_DK (e.g.
"for i_s, i_k in T.Parallel(block_S, block_DK):") and update the index names
used when reading/writing Q_shared, GK_shared and GQ_shared so the second index
uses the block_DK iterator; leave HIDDEN_shared and the subsequent T.gemm call
unchanged.
In @examples/KDA/FLA_KDA/fla_chunk_delta.py:
- Around line 538-541: In the backward function the variable chunk_offsets can
remain undefined when cu_seqlens is not None; update the logic so that when
chunk_indices is None and cu_seqlens is not None you also set chunk_offsets
(mirror forward): call prepare_chunk_indices(cu_seqlens, chunk_size) to produce
chunk_indices and derive chunk_offsets (or compute chunk_offsets from the same
cu_seqlens/chunk_size helper) before proceeding; modify the branch around
chunk_indices, cu_seqlens and the subsequent use of chunk_offsets to ensure
chunk_offsets is always initialized (refer to chunk_indices, cu_seqlens,
prepare_chunk_indices, and chunk_offsets).
- Around line 483-491: The variable chunk_offsets is left undefined when
cu_seqlens is provided, causing an UnboundLocalError later when passed to the
kernel; fix by ensuring chunk_offsets is initialized in the branch that handles
cu_seqlens (e.g., after computing chunk_indices via
prepare_chunk_indices(cu_seqlens, chunk_size) compute or derive chunk_offsets
from chunk_indices, or explicitly set chunk_offsets = None if the kernel accepts
that), so that the subsequent kernel invocation that uses chunk_offsets, h, and
final_state always has a defined value.
In @examples/KDA/FLA_KDA/fla_chunk_intra.py:
- Around line 607-618: Add the two missing imports for the functions used in
this file: import chunk_kda_fwd_intra_token_parallel from
.fla_chunk_intra_token_parallel and import recompute_w_u_fwd from .fla_wy_fast;
place these import statements near the other module imports at the top of the
file so chunk_kda_fwd_intra_token_parallel (used around the Aqk, Akk_diag call)
and recompute_w_u_fwd (used later) are defined and available at runtime.
- Around line 638-647: The call to recompute_w_u_fwd in the return path will
raise NameError because the function is not imported or defined; add a proper
import for recompute_w_u_fwd at the top of the module (for example: from
.fla_wy_fast import recompute_w_u_fwd or the actual module where
recompute_w_u_fwd is defined) so the symbol is available to the function using
it, or if the implementation belongs in this file, paste the function definition
above its usage; ensure tests/imports run to verify the NameError is resolved.
In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 547-552: The conditional uses the function object check_shared_mem
instead of calling it, causing the middle branch to always be truthy; change the
second conditional to call check_shared_mem with the same arguments used in the
first branch (e.g., check_shared_mem('hopper', k.device.index)) and ensure
CONST_TILING is set to 64 in that branch so the three branches correctly use
128, 64, and 32 respectively.
- Around line 513-521: The branch guard for USE_A leaves b_A uninitialized when
USE_A is False; ensure b_A is always defined before it is used by either (a)
initializing b_A in the else case to a zero tensor with the same shape/dtype
expected by the later tl.where (matching the shape produced by p_A/tl.load) or
(b) move the tl.where logic inside the USE_A block and otherwise set b_A =
zeros(...). Reference symbols: USE_A, p_A, b_A, tl.load, tl.where, o_t, m_A, and
do.dtype.element_ty to create the zero tensor with correct shape and element
type so tl.where never sees an undefined b_A.
In @examples/KDA/test_utils.py:
- Around line 5-40: calc_sim and assert_similar use .data and perform Python
control flow on torch tensors which can raise on CUDA; change calc_sim to avoid
.data (use .detach().double() or .to(torch.double)) and return a Python float
(use sim.item()) instead of a tensor, ensure the zero-denominator check uses a
scalar (.item() or .eq(0).all().item()); in assert_similar stop doing chained
tensor comparisons by converting sim to a float before computing diff (e.g., sim
= calc_sim(...); diff = 1.0 - float(sim)), keep the isfinite/masking logic but
ensure masked_fill uses the correct masks (x.masked_fill(~x_mask, 0) /
y.masked_fill(~y_mask, 0) as already used) and replace any remaining .data
usages with .detach()/.to(...)
In @examples/KDA/wy_fast.py:
- Around line 204-205: The assignment "use_qg = False," creates a one-element
tuple instead of a boolean; remove the trailing comma so use_qg is assigned the
boolean False (i.e., change use_qg = False, to use_qg = False) where the
kernel/config options are set (the nearby use_kg = True line in wy_fast.py) so
any conditional checks on use_qg behave correctly; verify related conditional
logic that references use_qg expects a bool.
🟠 Major comments (15)
examples/KDA/chunk_bwd_dqkwg.py-89-94 (1)
89-94: Guard or handleS % chunk_size != 0to avoid shape/OOB hazards.
BS = S // block_Sis used to shapeh/Dhand the kernel copies[bs*block_S:(bs+1)*block_S]without boundary handling. IfSisn’t an exact multiple ofchunk_size, this will mis-shapehand risks invalid memory accesses.Also applies to: 155-164
examples/KDA/chunk_bwd_gla_dA.py-12-16 (1)
12-16: Avoid import-timeCUDA_VISIBLE_DEVICESmutation (same concern as other scripts).examples/KDA/chunk_o.py-183-228 (1)
183-228:run_test()kernel config parameters are ignored (misleading API).
run_test()acceptsblock_DK/block_DV/threads/num_stages, but thetilelang_chunk_fwd_o(...)call only forwardsblock_S. Either forward the rest (to allow fixed-config runs) or remove them fromrun_test()to match actual behavior.examples/KDA/chunk_bwd_dqkwg.py-12-16 (1)
12-16: Avoid import-timeCUDA_VISIBLE_DEVICESmutation.Setting
os.environ['CUDA_VISIBLE_DEVICES'] = '7'inside the module is a global side effect and will surprise users (and can break CI / multi-GPU boxes). Prefer documenting it in README or making it a CLI flag / env var read (no write).examples/KDA/chunk_bwd_dv.py-12-16 (1)
12-16: Avoid import-timeCUDA_VISIBLE_DEVICESmutation (same issue as other scripts).examples/KDA/chunk_bwd_gla_dA.py-78-102 (1)
78-102: FixV_shareddtype: should not bedo_dtype.
Vis declared asdtype=input_dtypebut is copied intoV_sharedallocated asdo_dtype. If these differ, you’ll get an implicit cast that can change numerics (and it’s surprising).Proposed fix
- V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)examples/KDA/FLA_KDA/cumsum.py-41-56 (1)
41-56: Fix HEAD_FIRST + IS_VARLEN pointer arithmetic in cumsum kernels.When
IS_VARLEN=True, you setT = eos - bos(per-sequence length), but theHEAD_FIRST=Truebase pointer arithmetic usesbos*H + i_h*T, which is incorrect for the[B, H, T_global]memory layout. The head stride should beT_global(the full sequence length passed to the kernel), not the per-sequence span.For example, with shape
[1, 8, 1024], sequence starting atbos=100, eos=150, andi_h=2:
- Correct offset:
i_h * T_global + bos = 2 * 1024 + 100- Current code:
bos*H + i_h*T = 100*8 + 2*50 = 900❌This affects all four cumsum kernels (scalar/vector × local/global). Since
chunk_local_cumsumenforcesB=1whencu_seqlensis provided, this combination is supported but untested. Update the offset calculation to useT(the kernel parameter) instead of the per-sequenceT, or recalculate the head stride explicitly within the varlen branch.Also applies to: lines 156–172, 216–234, and chunk_global_cumsum_vector_kernel.
examples/KDA/chunk_delta_bwd.py-268-306 (1)
268-306: Enable result verification to validate correctness.The reference implementation results (
dh_ref,dh0_ref,dv2_ref) are computed but the comparison code is commented out (Lines 304-306). This means the TileLang kernel's correctness isn't being validated.Suggested fix
- # compare_tensors("dh", dh_ref, dh_tilelang) - # compare_tensors("dh0", dh0_ref, dh0_tilelang) - # compare_tensors("dv2", dv2_ref, dv2_tilelang) + compare_tensors("dh", dh_ref, dh_tilelang) + compare_tensors("dh0", dh0_ref, dh0_tilelang) + compare_tensors("dv2", dv2_ref, dv2_tilelang)examples/KDA/chunk_intra_token_parallel.py-276-286 (1)
276-286: Enable result verification to validate correctness.The comparison code is commented out, so the TileLang results aren't being validated against the FLA reference. Given the PR description mentions "precision alignment testing completed," this verification should be enabled.
Suggested fix
print(f"fla time: {fla_time} ms") print(f"tilelang time: {tilelang_time} ms") - - # compare_tensors("Aqk", Aqk_ref, Aqk_tilelang) - # compare_tensors("Akk", Akk_ref, Akk_tilelang) + compare_tensors("Aqk", Aqk_ref, Aqk_tilelang) + compare_tensors("Akk", Akk_ref, Akk_tilelang)examples/KDA/chunk_inter_solve_fused.py-12-13 (1)
12-13: Remove hardcoded CUDA device selection.Setting
CUDA_VISIBLE_DEVICESin code is problematic as it affects the entire process and prevents users from choosing their preferred GPU. This should be configurable or removed.import os -os.environ['CUDA_VISIBLE_DEVICES'] = '7'examples/KDA/FLA_KDA/fla_wy_fast.py-270-277 (1)
270-277: HardcodedBT=64may not match input tensor dimensions.The forward function
recompute_w_u_fwdderivesBTfromA.shape[-1], but the backward function hardcodesBT=64. IfAwas created with a different chunk size, this will cause incorrect behavior.🔧 Suggested fix
def prepare_wy_repr_bwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, gk: torch.Tensor, A: torch.Tensor, dk: torch.Tensor, dw: torch.Tensor, du: torch.Tensor, dg: torch.Tensor, cu_seqlens: torch.LongTensor = None, chunk_indices: torch.LongTensor = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] - BT = 64 + BT = A.shape[-1] if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT)examples/KDA/FLA_KDA/fla_utils.py-154-156 (1)
154-156: Generator is consumed before function execution.
contiguous_argsis a generator expression that will be exhausted after the first iteration. When passed tofn, it may result in empty arguments if any code path iterates over it before the function call.🔧 Suggested fix - use tuple instead of generator
@functools.wraps(fn) def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_args = tuple(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}examples/KDA/wy_fast_bwd.py-344-348 (1)
344-348: Correctness verification is disabled.The
compare_tensorscalls are commented out, meaning the test only benchmarks without verifying correctness. This should be enabled to ensure the TileLang kernel produces correct results.- # compare_tensors("dA", dA_tilelang, dA_ref) - # compare_tensors("dk", dk_tilelang, dk_ref) - # compare_tensors("dv", dv_tilelang, dv_ref) - # compare_tensors("dbeta", dbeta_tilelang, dbeta_ref) - # compare_tensors("dg", dg_tilelang, dg_ref) + compare_tensors("dA", dA_tilelang, dA_ref) + compare_tensors("dk", dk_tilelang, dk_ref) + compare_tensors("dv", dv_tilelang, dv_ref) + compare_tensors("dbeta", dbeta_tilelang, dbeta_ref) + compare_tensors("dg", dg_tilelang, dg_ref)examples/KDA/FLA_KDA/fla_utils.py-25-26 (1)
25-26: Device check at import time may fail without GPU.
IS_NVIDIA_HOPPERaccessestorch.cuda.get_device_name(0)at module import time, which will raise an error if no CUDA device is available. This prevents the module from being imported for testing or on CPU-only systems.🔧 Suggested fix
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) +def _is_nvidia_hopper(): + try: + return 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 + except Exception: + return False + +IS_NVIDIA_HOPPER = _is_nvidia_hopper()examples/KDA/wy_fast_bwd.py-56-61 (1)
56-61: FixdAshape to usechunk_sizeinstead ofDK.The
dAtensor is allocated with shape(B, S, H, DK)but should match the shape ofA, which is(B, S, H, chunk_size)(see line 36 whereA = torch.randn(B, S, H, BS, ...)andBS = chunk_size). SinceDKandchunk_sizeare independent parameters, this shape mismatch will cause incorrect results. Usechunk_sizefor the final dimension instead:Suggested fix
dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda()
🟡 Minor comments (9)
examples/KDA/chunk_bwd_dv.py-127-167 (1)
127-167: Drop unusedscalefromrun_test()(or apply it) to avoid misleading results.Ruff flags
scaleas unused here; either remove it from this dv benchmark harness or use it in the reference/kernel calls if it’s mathematically required.examples/KDA/test_utils.py-42-82 (1)
42-82: Either useatol/rtolor remove them fromcompare_tensors().Ruff flags
atol/rtolas unused; consider adding an explicit pass/fail check (e.g.,torch.testing.assert_close/torch.allclose) and printing the result.examples/KDA/wy_fast.py-9-9 (1)
9-9: Remove hardcoded CUDA device selection.Same issue as in other files - hardcoding
CUDA_VISIBLE_DEVICES = '7'breaks portability.examples/KDA/chunk_delta_bwd.py-8-8 (1)
8-8: Remove debug print statement.The
print(tilelang.__file__, flush=True)appears to be debug code that should be removed before merging.Suggested fix
-print(tilelang.__file__, flush=True)examples/KDA/chunk_delta_h_fwd.py-308-318 (1)
308-318: Inconsistent parameter name in benchmark call.The benchmark call at Line 313 uses
g=G, but the reference call at Line 278 usesgk=G. This inconsistency may cause the benchmark to fail or use wrong parameters.Suggested fix
fla_time = do_bench( chunk_gated_delta_rule_fwd_h, k=K, w=W, u=U, - g=G, + gk=G, initial_state=initial_state, output_final_state=store_final_state, chunk_size=chunk_size, save_new_value=save_new_value, )examples/KDA/chunk_bwd_intra_op.py-3-4 (1)
3-4: Remove unused imports.
from re import Iis never used and appears to be an erroneous import. Thesysimport is also flagged but has a noqa directive.Suggested fix
-from re import I -import sys # noqa: F401examples/KDA/chunk_intra_token_parallel.py-7-7 (1)
7-7: Remove hardcoded CUDA device selection.Hardcoding
CUDA_VISIBLE_DEVICES = '7'breaks portability across different systems. This should be configurable or removed entirely.Suggested fix
-os.environ['CUDA_VISIBLE_DEVICES'] = '7' +# Remove or make configurable via environment/CLI argumentexamples/KDA/chunk_bwd_intra.py-514-529 (1)
514-529: Consider using string dtype identifiers for consistency.The
main()function usesT.float32directly, butrun_testusesgetattr(torch, input_dtype)expecting string dtype names. This inconsistency could cause confusion.🔧 Suggested fix
def main(): DK = 128 run_test( B=1, S=8192, H=8, DK=DK, - input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", chunk_size=64, threads=128, num_stages=0, )examples/KDA/FLA_KDA/fla_chunk_intra.py-9-14 (1)
9-14: Dead code: conditional precision assignment is overwritten.Lines 9-13 conditionally set
SOLVE_TRIL_DOT_PRECISIONbased onIS_TF32_SUPPORTED, but line 14 unconditionally overwrites the result to'tf32'. This makes the conditional logic dead code.Additionally, the comment on line 278 mentions
tf32x3for precision, but the actual value used is'tf32'. Please clarify the intended precision setting.Suggested fix (remove dead code)
-IS_TF32_SUPPORTED=False -if IS_TF32_SUPPORTED: - SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') -else: - SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') -SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32') +# Use tf32 precision for matrix inverse operations +SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')
🧹 Nitpick comments (25)
examples/KDA/README.md (1)
1-7: Add “how to run” + environment details for reproducibility.This README pins TileLang/Triton/FLA commit, but doesn’t say how to execute the benchmarks/tests (command lines, expected GPU/driver/CUDA/PyTorch versions, and any required env vars). That will make reproduction/debugging harder for reviewers/users.
examples/KDA/chunk_bwd_dqkwg.py (1)
44-58: Clean up unused output prep + unused args inrun_test()for clarity.Right now
prepare_output()is unused andrun_test()accepts many flags/configs that are ignored (per Ruff). Either wire them into the kernel invocation (to bypass autotune when desired) or remove them from this example to reduce confusion.Also applies to: 238-310
examples/KDA/chunk_o.py (1)
140-145: Remove/replace debug-commented code in final example.The commented “why is it wrong” block (and the alternative
T.Ifsnippet) reads like an unresolved investigation. If this is important, convert it into a short note explaining the TileLang constraint (or link an issue); otherwise, delete it.examples/KDA/chunk_bwd_gla_dA.py (1)
106-129: Don’t print raw timing arrays indo_bench()by default.
print(times)makes the benchmark noisy and harder to parse; consider gating it behind averboseflag.examples/KDA/FLA_KDA/cumsum.py (1)
246-323: Public wrappers: consider aligning input guards / constraints consistently.
chunk_local_cumsum_scalar/vectoraren’t@input_guard’d and don’t enforce the “cu_seqlensimplies batch==1” constraint (onlychunk_local_cumsum()does). If these are meant to be public entry points, it’s safer to apply the same guard/constraint consistently.Also applies to: 390-468
examples/KDA/chunk_bwd_dv.py (1)
103-126: Don’t print raw timing arrays indo_bench()by default.examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
150-153: Add explicitOptionaltype hints for nullable parameters.Parameters with
Nonedefaults should useOptional[T]for PEP 484 compliance and better type checking. Based on static analysis hint.Suggested fix
+from typing import Optional + def chunk_kda_bwd_dqkwg( q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, v: torch.Tensor, h: torch.Tensor, g: torch.Tensor, do: torch.Tensor, dh: torch.Tensor, dv: torch.Tensor, - scale: float = None, - cu_seqlens: torch.LongTensor = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, chunk_size: int = 64, - chunk_indices: torch.LongTensor = None, + chunk_indices: Optional[torch.LongTensor] = None, ):examples/KDA/chunk_bwd_intra_op.py (1)
514-529: Inconsistent dtype specification in main().The
main()function usesT.float32directly (TileLang types) instead of string literals like"float32"that are converted viagetattr(torch, ...)inrun_test. This inconsistency could cause issues if the dtype handling differs.Suggested fix for consistency
def main(): DK = 128 run_test( B=1, S=8192, H=8, DK=DK, - input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", chunk_size=64, threads=128, num_stages=0, )examples/KDA/wy_fast.py (1)
172-178: Inconsistent synchronization pattern in benchmark loop.This
do_benchfunction synchronizes inside the timing loop (Line 177), while other files in this PR synchronize only before and after the entire timing loop. This inconsistency may affect benchmark comparisons.Suggested fix for consistency
torch.cuda.synchronize() for i in range(rep): start_event[i].record() fn(*args, **kwargs) end_event[i].record() - torch.cuda.synchronize() - + torch.cuda.synchronize()examples/KDA/FLA_KDA/fla_wy_fast.py (1)
77-77: Consider using English comments for consistency.The comment
# 乘beta(meaning "multiply beta") should be in English for better maintainability and team collaboration.- b_kb = b_k * b_b[:, None] # 乘beta + b_kb = b_k * b_b[:, None] # multiply by betaexamples/KDA/FLA_KDA/fla_chunk_delta.py (1)
527-527: Use explicitOptionaltype annotation.Per PEP 484, implicit
Optional(using= NonewithoutOptional[T]) is prohibited.- scale: float = None, + scale: Optional[float] = None,Don't forget to add
from typing import Optionalat the top of the file.examples/KDA/FLA_KDA/fla_utils.py (3)
50-54: Addstackleveltowarnings.warnfor proper source attribution.Without
stacklevel, the warning will point to this utility function instead of the caller's code.- warnings.warn(msg) + warnings.warn(msg, stacklevel=2)
34-34: Replace fullwidth comma with ASCII comma.The comment contains a fullwidth comma (,) which should be a regular comma for consistency.
-# error check,copy from +# error check, copy from
128-145: Consider logging exceptions instead of silently passing.The broad exception handling with
passmakes debugging difficult. Consider logging the exception or using more specific exception types.🔧 Suggested improvement
# ---- Try the newer Triton 2.2+ API ---- try: drv = triton.runtime.driver.active props = drv.utils.get_device_properties(tensor_idx) return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1 - except Exception: - pass + except Exception as e: + import logging + logging.debug(f"Triton 2.2+ API failed: {e}") # ---- Fallback: Triton 2.0 / 2.1 API ---- try: cuda = triton.runtime.driver.CudaDriver dev = cuda.get_current_device() props = cuda.get_device_properties(dev) return props.get("multiprocessor_count", 1) - except Exception: - pass + except Exception as e: + import logging + logging.debug(f"Triton 2.0/2.1 API failed: {e}")examples/KDA/chunk_bwd_intra.py (2)
3-4: Remove unused import.
from re import Iis unused and appears to be a typo or leftover from debugging.-from re import I -import sys # noqa: F401 +import sys # noqa: F401
482-486: Debug print statement should be removed.The
print(db_tilelang.shape)statement appears to be debug output that should be removed before merging.- print(db_tilelang.shape) dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel( q, k, g, beta, dAqk, dAkk, dq, dk, db, dg )examples/KDA/wy_fast_bwd.py (1)
3-3: Remove unused noqa directive.The
# noqa: F401comment is unnecessary sincesysis not imported with a specific symbol that would trigger F401.-import sys # noqa: F401examples/KDA/chunk_inter_solve_fused.py (2)
507-507: Remove debug print statement.The
print(times)statement should be removed or converted to a debug log.- print(times) return times.mean().item()
111-170: Consider documenting the shared memory allocation strategy.The kernel allocates many shared memory buffers (Aqk/Akk fragments, K/Q/GK buffers for 4 sub-chunks, Ai matrices, etc.). A brief comment explaining the memory layout and purpose would improve maintainability.
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
534-537: Use explicitOptionaltype annotation.+from typing import Optional + def chunk_bwd_dv_local( q: torch.Tensor, k: torch.Tensor, do: torch.Tensor, - g: torch.Tensor = None, - g_gamma: torch.Tensor = None, - A: torch.Tensor = None, - scale: float = None, + g: Optional[torch.Tensor] = None, + g_gamma: Optional[torch.Tensor] = None, + A: Optional[torch.Tensor] = None, + scale: Optional[float] = None,examples/KDA/FLA_KDA/fla_chunk_intra.py (5)
33-63: Document the chunk size constraint (BT = 4 × BC).The kernel hardcodes four sub-chunks (
i_tc0throughi_tc3), implyingBT = 4 * BC. This constraint isn't documented or asserted. Consider adding a comment or runtime assertion in the wrapper to prevent misuse with incompatible chunk sizes.
387-387: Avoid shadowing Python builtinall.The variable name
allshadows Python's builtin function. Consider renaming tototal_tokensor similar.- all = B * T + total_tokens = B * TAlso update line 415:
- db += (i_k * all + bos) * H + i_h + db += (i_k * total_tokens + bos) * H + i_h
706-717: Clarify in-place vs. return semantics for gradient tensors.The function accepts
dq,dk,db,dgas parameters, but the returneddqanddkare actually newly allocated tensors (dq2,dk2), not the modified inputs. This could confuse callers expecting in-place updates.Consider either:
- Documenting this behavior clearly in the docstring
- Renaming parameters to clarify they're inputs to be accumulated (e.g.,
dq_in,dk_in)- Copying results back to input tensors if in-place semantics are intended
720-732: Add docstring for public API function.This function lacks documentation. Since it's a public wrapper for the fused kernel, consider adding a docstring explaining its purpose, parameters, and relationship to
chunk_kda_fwd_intra.
562-567: UseOptional[T]for parameters withNonedefaults.PEP 484 prohibits implicit
Optional. Parameters likegk: torch.Tensor = Noneshould begk: Optional[torch.Tensor] = None.Suggested fix
+from typing import Optional + def chunk_kda_fwd_intra( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - gk: torch.Tensor = None, - beta: torch.Tensor = None, - scale: float = None, - cu_seqlens: torch.LongTensor = None, + gk: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, chunk_size: int = 64, - chunk_indices: torch.LongTensor = None, + chunk_indices: Optional[torch.LongTensor] = None, ):
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
examples/KDA/FLA_KDA/cumsum.pyexamples/KDA/FLA_KDA/fla_chunk_delta.pyexamples/KDA/FLA_KDA/fla_chunk_inter.pyexamples/KDA/FLA_KDA/fla_chunk_intra.pyexamples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.pyexamples/KDA/FLA_KDA/fla_chunk_o.pyexamples/KDA/FLA_KDA/fla_utils.pyexamples/KDA/FLA_KDA/fla_wy_fast.pyexamples/KDA/README.mdexamples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_dv.pyexamples/KDA/chunk_bwd_gla_dA.pyexamples/KDA/chunk_bwd_intra.pyexamples/KDA/chunk_bwd_intra_op.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_inter_solve_fused.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/chunk_o.pyexamples/KDA/test_utils.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/KDA/README.md
🧬 Code graph analysis (12)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (4)
tilelang/math/__init__.py (1)
next_power_of_2(1-2)examples/KDA/FLA_KDA/cumsum.py (1)
grid(304-304)examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
grid(167-167)examples/KDA/FLA_KDA/fla_chunk_o.py (1)
grid(419-419)
examples/KDA/chunk_intra_token_parallel.py (1)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
chunk_kda_fwd_intra_token_parallel(116-168)
examples/KDA/FLA_KDA/cumsum.py (2)
examples/KDA/FLA_KDA/fla_utils.py (2)
prepare_chunk_indices(100-105)input_guard(147-178)tilelang/math/__init__.py (1)
next_power_of_2(1-2)
examples/KDA/chunk_bwd_dqkwg.py (2)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
chunk_kda_bwd_dqkwg(140-190)examples/KDA/test_utils.py (1)
compare_tensors(42-82)
examples/KDA/FLA_KDA/fla_wy_fast.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
prepare_chunk_indices(100-105)tilelang/math/__init__.py (1)
next_power_of_2(1-2)
examples/KDA/chunk_inter_solve_fused.py (4)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
chunk_kda_fwd_inter_solve_fused(720-759)examples/KDA/test_utils.py (1)
compare_tensors(42-82)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(428-468)tilelang/language/dtypes.py (1)
float32(310-310)
examples/KDA/test_utils.py (1)
tilelang/carver/roller/policy/default.py (1)
sim(290-291)
examples/KDA/FLA_KDA/fla_chunk_o.py (2)
examples/KDA/FLA_KDA/fla_utils.py (3)
prepare_chunk_indices(100-105)check_shared_mem(225-231)input_guard(147-178)tilelang/math/__init__.py (1)
next_power_of_2(1-2)
examples/KDA/wy_fast_bwd.py (2)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
prepare_wy_repr_bwd(257-312)examples/KDA/wy_fast.py (2)
prepare_input(17-24)kernel(85-159)
examples/KDA/chunk_o.py (3)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_chunk_o.py (1)
chunk_gla_fwd_o_gk(399-437)tilelang/language/copy_op.py (1)
copy(14-116)
examples/KDA/FLA_KDA/fla_utils.py (1)
tilelang/utils/device.py (1)
get_current_device(14-21)
examples/KDA/FLA_KDA/fla_chunk_delta.py (3)
examples/KDA/FLA_KDA/fla_utils.py (1)
prepare_chunk_indices(100-105)examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
grid(151-151)examples/KDA/FLA_KDA/fla_chunk_o.py (1)
grid(419-419)
🪛 Ruff (0.14.10)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
41-41: Unused function argument: output_dtype
(ARG001)
42-42: Unused function argument: accum_dtype
(ARG001)
108-108: Unused function argument: block_DK
(ARG001)
249-249: Unused function argument: block_DK
(ARG001)
250-250: Unused function argument: block_DV
(ARG001)
251-251: Unused function argument: threads
(ARG001)
252-252: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_intra_token_parallel.py
21-21: Unused function argument: output_dtype
(ARG001)
22-22: Unused function argument: accum_dtype
(ARG001)
258-258: Unpacked variable Aqk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
258-258: Unpacked variable Akk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/FLA_KDA/cumsum.py
33-33: Unused function argument: B
(ARG001)
88-88: Unused function argument: B
(ARG001)
146-146: Unused function argument: B
(ARG001)
206-206: Unused function argument: B
(ARG001)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-287: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
330-330: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
361-361: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
395-395: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
420-424: Avoid specifying long messages outside the exception class
(TRY003)
432-432: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
437-437: Unused function argument: kwargs
(ARG001)
464-468: Avoid specifying long messages outside the exception class
(TRY003)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
31-31: Unused function argument: output_dtype
(ARG001)
32-32: Unused function argument: accum_dtype
(ARG001)
34-34: Unused function argument: state_dtype
(ARG001)
63-63: Unused function argument: gate_dtype
(ARG001)
132-132: Unused function argument: h0
(ARG001)
244-244: Unused function argument: block_DV
(ARG001)
245-245: Unused function argument: threads
(ARG001)
246-246: Unused function argument: num_stages
(ARG001)
247-247: Unused function argument: use_torch
(ARG001)
271-271: Unpacked variable dh_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
271-271: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
271-271: Unpacked variable dv2_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dh_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dh0_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dv2_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
49-49: Unused function argument: DV
(ARG001)
50-50: Unused function argument: chunk_size
(ARG001)
51-51: Unused function argument: qk_dtype
(ARG001)
249-249: Unused function argument: use_gk
(ARG001)
250-250: Unused function argument: use_initial_state
(ARG001)
251-251: Unused function argument: store_final_state
(ARG001)
252-252: Unused function argument: save_new_value
(ARG001)
253-253: Unused function argument: block_DK
(ARG001)
254-254: Unused function argument: block_DV
(ARG001)
255-255: Unused function argument: threads
(ARG001)
256-256: Unused function argument: num_stages
(ARG001)
260-260: Unpacked variable dq_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
260-260: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
260-260: Unpacked variable dw_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
260-260: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
285-285: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
285-285: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
285-285: Unpacked variable dw is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
285-285: Unpacked variable dg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/FLA_KDA/fla_chunk_inter.py
150-150: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
examples/KDA/chunk_inter_solve_fused.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
47-47: Unused function argument: sub_chunk_size
(ARG001)
examples/KDA/chunk_bwd_gla_dA.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
22-22: Unused function argument: chunk_size
(ARG001)
34-34: Unused function argument: DV
(ARG001)
134-134: Unused function argument: DK
(ARG001)
examples/KDA/chunk_bwd_intra_op.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
29-29: Unused function argument: output_dtype
(ARG001)
30-30: Unused function argument: accum_dtype
(ARG001)
32-32: Unused function argument: state_dtype
(ARG001)
59-59: Unused function argument: chunk_size
(ARG001)
63-63: Unused function argument: state_dtype
(ARG001)
98-98: Unused function argument: state_dtype
(ARG001)
135-135: Unused function argument: db
(ARG001)
427-427: Unused function argument: threads
(ARG001)
428-428: Unused function argument: num_stages
(ARG001)
429-429: Unused function argument: cu_seqlens
(ARG001)
430-430: Unused function argument: chunk_indices
(ARG001)
examples/KDA/chunk_bwd_dv.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
38-38: Unused function argument: chunk_size
(ARG001)
133-133: Unused function argument: scale
(ARG001)
examples/KDA/test_utils.py
42-42: Unused function argument: atol
(ARG001)
42-42: Unused function argument: rtol
(ARG001)
53-53: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
53-53: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
examples/KDA/chunk_bwd_intra.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
29-29: Unused function argument: output_dtype
(ARG001)
30-30: Unused function argument: accum_dtype
(ARG001)
32-32: Unused function argument: state_dtype
(ARG001)
59-59: Unused function argument: chunk_size
(ARG001)
63-63: Unused function argument: state_dtype
(ARG001)
98-98: Unused function argument: state_dtype
(ARG001)
135-135: Unused function argument: db
(ARG001)
427-427: Unused function argument: threads
(ARG001)
428-428: Unused function argument: num_stages
(ARG001)
429-429: Unused function argument: cu_seqlens
(ARG001)
430-430: Unused function argument: chunk_indices
(ARG001)
examples/KDA/FLA_KDA/fla_chunk_o.py
323-323: Undefined name chunk_gla_fwd_A_kernel_intra_sub_inter
(F821)
343-343: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra
(F821)
365-365: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra_split
(F821)
384-384: Undefined name chunk_gla_fwd_A_kernel_intra_sub_intra_merge
(F821)
478-478: Unused function argument: g
(ARG001)
479-479: Unused function argument: g_gamma
(ARG001)
485-485: Unused function argument: scale
(ARG001)
491-491: Unused function argument: BK
(ARG001)
493-493: Unused function argument: USE_G
(ARG001)
494-494: Unused function argument: USE_G_GAMMA
(ARG001)
537-537: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused function argument: output_dtype
(ARG001)
67-67: Unused function argument: use_qg
(ARG001)
93-93: Unused function argument: QG
(ARG001)
199-199: Unused function argument: block_DK
(ARG001)
200-200: Unused function argument: block_DV
(ARG001)
201-201: Unused function argument: threads
(ARG001)
202-202: Unused function argument: num_stages
(ARG001)
209-209: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
210-210: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
29-29: Unused function argument: state_dtype
(ARG001)
51-51: Unused function argument: chunk_size
(ARG001)
54-54: Unused function argument: state_dtype
(ARG001)
92-92: Unused function argument: state_dtype
(ARG001)
291-291: Unused function argument: block_DK
(ARG001)
292-292: Unused function argument: block_DV
(ARG001)
293-293: Unused function argument: threads
(ARG001)
294-294: Unused function argument: num_stages
(ARG001)
316-316: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
316-316: Unpacked variable dv_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
316-316: Unpacked variable dbeta_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
316-316: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
316-316: Unpacked variable dA_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
342-342: Unpacked variable dA_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
342-342: Unpacked variable dk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
342-342: Unpacked variable dv_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
342-342: Unpacked variable dbeta_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
342-342: Unpacked variable dg_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_o.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
21-21: Unused function argument: output_dtype
(ARG001)
22-22: Unused function argument: accum_dtype
(ARG001)
37-37: Unused function argument: DK
(ARG001)
39-39: Unused function argument: chunk_size
(ARG001)
42-42: Ambiguous variable name: O
(E741)
99-99: Ambiguous variable name: O
(E741)
194-194: Unused function argument: block_DK
(ARG001)
195-195: Unused function argument: block_DV
(ARG001)
196-196: Unused function argument: threads
(ARG001)
197-197: Unused function argument: num_stages
(ARG001)
examples/KDA/FLA_KDA/fla_utils.py
34-34: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
52-52: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
133-134: try-except-pass detected, consider logging the exception
(S110)
133-133: Do not catch blind exception: Exception
(BLE001)
142-143: try-except-pass detected, consider logging the exception
(S110)
142-142: Do not catch blind exception: Exception
(BLE001)
220-220: Do not catch blind exception: BaseException
(BLE001)
230-230: Do not catch blind exception: Exception
(BLE001)
examples/KDA/FLA_KDA/fla_chunk_intra.py
564-564: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
607-607: Undefined name chunk_kda_fwd_intra_token_parallel
(F821)
638-638: Undefined name recompute_w_u_fwd
(F821)
examples/KDA/FLA_KDA/fla_chunk_delta.py
527-527: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🔇 Additional comments (10)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
14-138: LGTM - Kernel implementation looks correct.The inter-chunk backward kernel properly handles both variable-length and fixed-length sequences, with appropriate boundary checks and gradient accumulation logic.
examples/KDA/chunk_intra_token_parallel.py (1)
59-174: LGTM - TileLang kernel structure is well-organized.The kernel factory correctly sets up shared memory allocations, fragment handling, and pipelined loops. The conditional handling of
Sum_Akk_sharedvsSum_Aqk_sharedappears intentional for the algorithm's requirements.examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (2)
44-58: Binary search implementation for varlen looks correct.The unrolled binary search (20 iterations for up to ~1M batch size) efficiently finds the sequence index for variable-length sequences. The implementation correctly handles the cu_seqlens lookup pattern.
98-114: LGTM - Computation loop correctly implements gated attention.The loop properly computes the Q-K interaction with exponential gating, applies appropriate masking for the diagonal block structure, and stores results with boundary checks.
examples/KDA/chunk_bwd_intra_op.py (1)
143-391: Complex kernel structure - verify correctness with thorough testing.The backward intra-chunk kernel has multiple processing phases (previous sub-chunks, current diagonal, subsequent sub-chunks, lower triangular). While the structure appears logical, the complexity warrants careful validation through the comparison with the reference implementation.
examples/KDA/chunk_delta_bwd.py (1)
125-224: LGTM - Backward kernel structure is correct.The kernel properly implements the backward pass with reverse iteration, handles optional initial state and final state gradients, and uses appropriate memory layouts with swizzling.
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
210-254: LGTM!The forward kernel wrapper correctly handles variable-length sequences, derives block dimensions from input tensors, and properly initializes output tensors. The autotuning configuration covers a reasonable search space.
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
70-76: LGTM - Manual unrolling for performance.The manual unrolling of state buffers (
b_h1throughb_h4) for different K dimension ranges is a valid performance optimization pattern in Triton kernels, trading code verbosity for better register allocation and reduced branching overhead.examples/KDA/chunk_bwd_intra.py (1)
83-105: LGTM - Well-structured autotuned kernel.The kernel configuration with autotune and appropriate pass configs is well-structured. The tiling parameters and thread configurations provide a good search space for optimization.
examples/KDA/chunk_inter_solve_fused.py (1)
310-374: Forward substitution loops have correct boundary handling.The pipelined loops for forward substitution on diagonal blocks correctly handle boundaries with
T.min(BC, S-i_tc0)and similar guards. The identity matrix addition at lines 370-374 properly completes the inverse computation.
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) | ||
| # N: the actual number of sequences in the batch with either equal or variable lengths | ||
| if cu_seqlens is None: | ||
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None | ||
| assert K <= 256, "current kernel does not support head dimension larger than 256." | ||
|
|
||
| h = k.new_empty(B, NT, H, K, V) | ||
| final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None |
There was a problem hiding this comment.
chunk_offsets is undefined when cu_seqlens is provided.
When cu_seqlens is not None, the code skips the else branch where chunk_offsets is assigned, but chunk_offsets is still passed to the kernel on line 506, causing an UnboundLocalError.
🐛 Proposed fix
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+ else:
+ N, NT = len(cu_seqlens) - 1, len(chunk_indices)
+ chunk_offsets = torch.zeros(N, dtype=torch.int32, device=k.device)
+ # Compute chunk offsets for variable-length sequences
+ for i in range(N):
+ chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0
assert K <= 256, "current kernel does not support head dimension larger than 256."🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_delta.py around lines 483 - 491, The variable
chunk_offsets is left undefined when cu_seqlens is provided, causing an
UnboundLocalError later when passed to the kernel; fix by ensuring chunk_offsets
is initialized in the branch that handles cu_seqlens (e.g., after computing
chunk_indices via prepare_chunk_indices(cu_seqlens, chunk_size) compute or
derive chunk_offsets from chunk_indices, or explicitly set chunk_offsets = None
if the kernel accepts that), so that the subsequent kernel invocation that uses
chunk_offsets, h, and final_state always has a defined value.
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) | ||
| if cu_seqlens is None: | ||
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None |
There was a problem hiding this comment.
Same chunk_offsets issue in backward function.
Similar to the forward function, chunk_offsets is undefined when cu_seqlens is not None.
🐛 Proposed fix
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+ else:
+ N, NT = len(cu_seqlens) - 1, len(chunk_indices)
+ chunk_offsets = torch.zeros(N, dtype=torch.int32, device=q.device)
+ for i in range(N):
+ chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if chunk_indices is None and cu_seqlens is not None: | |
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) | |
| if cu_seqlens is None: | |
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None | |
| if chunk_indices is None and cu_seqlens is not None: | |
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) | |
| if cu_seqlens is None: | |
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None | |
| else: | |
| N, NT = len(cu_seqlens) - 1, len(chunk_indices) | |
| chunk_offsets = torch.zeros(N, dtype=torch.int32, device=q.device) | |
| for i in range(N): | |
| chunk_offsets[i] = (chunk_indices[:, 0] == i).nonzero()[0].item() if i > 0 else 0 |
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_delta.py around lines 538 - 541, In the
backward function the variable chunk_offsets can remain undefined when
cu_seqlens is not None; update the logic so that when chunk_indices is None and
cu_seqlens is not None you also set chunk_offsets (mirror forward): call
prepare_chunk_indices(cu_seqlens, chunk_size) to produce chunk_indices and
derive chunk_offsets (or compute chunk_offsets from the same
cu_seqlens/chunk_size helper) before proceeding; modify the branch around
chunk_indices, cu_seqlens and the subsequent use of chunk_offsets to ensure
chunk_offsets is always initialized (refer to chunk_indices, cu_seqlens,
prepare_chunk_indices, and chunk_offsets).
| def print_red_warning(message): | ||
| print(f"\033[31mWARNING: {message}\033[0m") | ||
|
|
||
|
|
||
| def calc_sim(x, y, name="tensor"): | ||
| x, y = x.data.double(), y.data.double() | ||
| denominator = (x * x + y * y).sum() | ||
| if denominator == 0: | ||
| print_red_warning(f"{name} all zero") | ||
| return 1 | ||
| sim = 2 * (x * y).sum() / denominator | ||
| return sim | ||
|
|
||
|
|
||
| def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): | ||
| x_mask = torch.isfinite(x) | ||
| y_mask = torch.isfinite(y) | ||
| if not torch.all(x_mask == y_mask): | ||
| print_red_warning(f"{name} Error: isfinite mask mismatch") | ||
| if raise_assert: | ||
| raise AssertionError | ||
| if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): | ||
| print_red_warning(f"{name} Error: nonfinite value mismatch") | ||
| if raise_assert: | ||
| raise AssertionError | ||
| x = x.masked_fill(~x_mask, 0) | ||
| y = y.masked_fill(~y_mask, 0) | ||
| sim = calc_sim(x, y, name) | ||
| diff = 1.0 - sim | ||
| if not (0 <= diff <= eps): | ||
| print_red_warning(f"{name} Error: {diff}") | ||
| if raise_assert: | ||
| raise AssertionError | ||
| else: | ||
| print(f"{name} {data} passed") | ||
|
|
There was a problem hiding this comment.
Fix calc_sim() / assert_similar() tensor-vs-Python control-flow bugs.
calc_sim() returns a torch scalar tensor, so diff = 1.0 - sim becomes a tensor; if not (0 <= diff <= eps): is not safe (chained comparisons over tensors) and is likely to throw or behave unexpectedly (esp. on CUDA). Also .data should be avoided.
Proposed fix
def calc_sim(x, y, name="tensor"):
- x, y = x.data.double(), y.data.double()
- denominator = (x * x + y * y).sum()
- if denominator == 0:
+ x, y = x.detach().double(), y.detach().double()
+ denominator = (x * x + y * y).sum()
+ denom = float(denominator.item())
+ if denom == 0.0:
print_red_warning(f"{name} all zero")
- return 1
- sim = 2 * (x * y).sum() / denominator
- return sim
+ return 1.0
+ sim = 2.0 * (x * y).sum() / denominator
+ return float(sim.item())
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
@@
- sim = calc_sim(x, y, name)
- diff = 1.0 - sim
- if not (0 <= diff <= eps):
+ sim = calc_sim(x, y, name)
+ diff = 1.0 - sim
+ if not (0.0 <= diff <= eps):
print_red_warning(f"{name} Error: {diff}")
if raise_assert:
raise AssertionError🤖 Prompt for AI Agents
In @examples/KDA/test_utils.py around lines 5 - 40, calc_sim and assert_similar
use .data and perform Python control flow on torch tensors which can raise on
CUDA; change calc_sim to avoid .data (use .detach().double() or
.to(torch.double)) and return a Python float (use sim.item()) instead of a
tensor, ensure the zero-denominator check uses a scalar (.item() or
.eq(0).all().item()); in assert_similar stop doing chained tensor comparisons by
converting sim to a float before computing diff (e.g., sim = calc_sim(...); diff
= 1.0 - float(sim)), keep the isfinite/masking logic but ensure masked_fill uses
the correct masks (x.masked_fill(~x_mask, 0) / y.masked_fill(~y_mask, 0) as
already used) and replace any remaining .data usages with .detach()/.to(...)
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In @examples/KDA/FLA_KDA/fla_chunk_intra.py:
- Around line 9-14: The conditional assignment to SOLVE_TRIL_DOT_PRECISION using
IS_TF32_SUPPORTED is currently dead because it is immediately overwritten by
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32'); either remove the unconditional
overwrite and keep the conditional block (use SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32') or 'ieee' based on IS_TF32_SUPPORTED) or delete the
conditional block and retain the single explicit SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32'); ensure you only assign SOLVE_TRIL_DOT_PRECISION once and
use the tl.constexpr(...) call consistently.
In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 421-429: The branch guarding creation of b_A (when USE_A is True)
leaves b_A undefined if USE_A is False; ensure b_A is always initialized before
its use by adding an else that sets b_A to a zero tensor with the same shape and
dtype as the loaded value (matching do.dtype.element_ty and shape of p_A load /
m_A result), or assert that USE_A must be True; update the block around the
USE_A check (where p_A, b_A are created) so b_A is defined for both branches
before the tl.where(...) on line using m_A.
- Around line 455-460: The conditional mistakenly uses the function object
check_shared_mem instead of calling it, causing the middle branch to always run;
update the second branch to call check_shared_mem with the same signature as the
first (e.g., check_shared_mem('hopper', k.device.index)) or the appropriate
args, so CONST_TILING is set to 64 only when the function returns truthy; ensure
the first branch remains check_shared_mem('hopper', k.device.index) and the else
remains CONST_TILING = 32.
In @examples/KDA/FLA_KDA/fla_utils.py:
- Line 25: The module-level expression that sets IS_NVIDIA_HOPPER calls
torch.cuda.get_device_name(0) and torch.cuda.get_device_capability() which will
raise if no CUDA device exists; change the initialization of IS_NVIDIA_HOPPER in
fla_utils.py to first check torch.cuda.is_available() (and/or
torch.cuda.device_count() > 0) and then safely call
get_device_name/get_device_capability inside a try/except, defaulting
IS_NVIDIA_HOPPER to False on any exception so import-time errors are avoided.
🧹 Nitpick comments (7)
examples/KDA/FLA_KDA/fla_utils.py (4)
1-31: Module setup and imports look reasonable.The module correctly sets up environment detection for Hopper GPUs and CUDA graph usage. A few minor observations:
- Line 25: The
(True and ...)pattern is redundant - theTrue andprefix does nothing.- Line 34: Contains a fullwidth comma
,instead of ASCII comma,(as flagged by static analysis).🔧 Suggested cleanup
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) -USE_CUDA_GRAPH = (True and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') +IS_NVIDIA_HOPPER = 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +USE_CUDA_GRAPH = os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1'-# error check,copy from +# error check, copy from
118-145: Silent exception swallowing inget_multiprocessor_count.The function catches broad
Exceptionand silently passes, which can hide real bugs. Consider logging at debug level for troubleshooting.🔧 Optional: Add debug logging
+import logging +logger = logging.getLogger(__name__) + @functools.cache def get_multiprocessor_count(tensor_idx: int = 0) -> int: # ---- Try the newer Triton 2.2+ API ---- try: drv = triton.runtime.driver.active props = drv.utils.get_device_properties(tensor_idx) return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1 except Exception: - pass + logger.debug("Triton 2.2+ API not available, falling back") # ---- Fallback: Triton 2.0 / 2.1 API ---- try: cuda = triton.runtime.driver.CudaDriver dev = cuda.get_current_device() props = cuda.get_device_properties(dev) return props.get("multiprocessor_count", 1) except Exception: - pass + logger.debug("Triton 2.0/2.1 API not available, returning default") return 1
147-178:input_guarddecorator consumes generator prematurely.Line 156 creates a generator expression for
contiguous_args, but it's passed tofn()which will consume it. Iffnneeds to iterate overargsmultiple times, this will fail silently. Additionally, the iteration to find a tensor (lines 160-168) happens after the generator is created, so ifargscontains non-tensors before the first tensor, those won't be made contiguous when the generator is consumed.Actually, looking more carefully - the generator is created but only consumed once when
fn(*contiguous_args, ...)is called, so this should work. However, it would be safer to use a tuple.🔧 Use tuple for clarity and safety
@functools.wraps(fn) def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_args = tuple(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
214-231:get_all_max_shared_memcatchesBaseExceptionwhich is too broad.Catching
BaseExceptioncan suppressKeyboardInterruptandSystemExit. UseExceptioninstead.🔧 Narrow exception type
def get_all_max_shared_mem(): try: return [ triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] for i in range(device_torch_lib.device_count()) ] - except BaseException: + except Exception: return [-1]examples/KDA/FLA_KDA/fla_chunk_o.py (3)
12-13:check_shared_mem()called with no arguments vs. with'ampere'.Line 12 calls
check_shared_mem()with no arguments (defaults toarch="none"), while line 13 calls it with'ampere'. The default"none"won't match anyBackendenum key, so it falls back toDEFAULT. Consider being explicit about the intended architecture.
89-92: Redundant conditionif i_k >= 0is always true.
i_kis a loop counter fromrange(tl.cdiv(K, BK))which always starts at 0, soi_k >= 0is always true. The comment "works but dkw, owing to divine benevolence" suggests this might be a workaround for a compiler issue.If this is a known workaround, consider documenting it more clearly:
- # works but dkw, owing to divine benevolence - # [BT, BV] - if i_k >= 0: + # NOTE: Redundant check works around Triton compiler issue + if i_k >= 0: # Always true, but needed for correctness
438-449: Type annotations use implicitOptional.Lines 442-445 use
= Nonedefaults withoutOptional[T]type hints. While this works, PEP 484 recommends explicitOptional.🔧 Use explicit Optional types
+from typing import Optional + def chunk_bwd_dv_local( q: torch.Tensor, k: torch.Tensor, do: torch.Tensor, - g: torch.Tensor = None, - g_gamma: torch.Tensor = None, - A: torch.Tensor = None, - scale: float = None, - cu_seqlens: torch.LongTensor = None, + g: Optional[torch.Tensor] = None, + g_gamma: Optional[torch.Tensor] = None, + A: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, chunk_size: int = 64, - chunk_indices: torch.LongTensor = None, + chunk_indices: Optional[torch.LongTensor] = None, ) -> torch.Tensor:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/KDA/FLA_KDA/fla_chunk_intra.pyexamples/KDA/FLA_KDA/fla_chunk_o.pyexamples/KDA/FLA_KDA/fla_utils.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/KDA/FLA_KDA/fla_chunk_intra.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
prepare_chunk_indices(100-105)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(428-468)
examples/KDA/FLA_KDA/fla_chunk_o.py (3)
examples/KDA/FLA_KDA/fla_utils.py (3)
prepare_chunk_indices(100-105)check_shared_mem(225-231)input_guard(147-178)examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
grid(151-151)tilelang/math/__init__.py (1)
next_power_of_2(1-2)
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/fla_utils.py
34-34: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
133-134: try-except-pass detected, consider logging the exception
(S110)
133-133: Do not catch blind exception: Exception
(BLE001)
142-143: try-except-pass detected, consider logging the exception
(S110)
142-142: Do not catch blind exception: Exception
(BLE001)
220-220: Do not catch blind exception: BaseException
(BLE001)
230-230: Do not catch blind exception: Exception
(BLE001)
examples/KDA/FLA_KDA/fla_chunk_o.py
386-386: Unused function argument: g
(ARG001)
387-387: Unused function argument: g_gamma
(ARG001)
393-393: Unused function argument: scale
(ARG001)
399-399: Unused function argument: BK
(ARG001)
401-401: Unused function argument: USE_G
(ARG001)
402-402: Unused function argument: USE_G_GAMMA
(ARG001)
445-445: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🔇 Additional comments (9)
examples/KDA/FLA_KDA/fla_utils.py (2)
57-93: Single-entry tensor cache implementation is correct.The
tensor_cachedecorator correctly implements identity-based caching for tensor inputs. The use ofiscomparison for tensors is appropriate since tensor identity matters for caching.
200-211:Backendenum provides architecture-specific shared memory sizes.The implementation is clean. Minor note: the comment references specific GPU models but the enum names are architecture generations, which is appropriate.
examples/KDA/FLA_KDA/fla_chunk_intra.py (5)
20-51: Forward kernel setup looks correct.The heuristics for
IS_VARLENand autotune configurations are appropriate. The kernel signature properly declares compile-time constants withtl.constexpr.
64-91: Varlen handling and boundary setup look correct.The chunk index loading and sequence boundary calculations handle both variable-length and fixed-length cases appropriately. The early return for out-of-bounds tiles (line 74-75) is correct.
629-668: Forward inter-solve wrapper implementation looks correct.The function properly handles chunk index preparation and kernel launch. The assertion
K <= 256at line 643 provides a clear constraint.
559-626: The function's return value usage andchunk_local_cumsumparameter passing are correct. Callers at lines 447 and 499 inchunk_bwd_intra_op.pyandchunk_bwd_intra.pyproperly capture the return values (dq_ref, dk_ref, db_ref, dg_ref = chunk_kda_bwd_intra(...)). Thechunk_local_cumsumcall correctly passes bothchunk_size(required size parameter) andchunk_indices(optional parameter for sequence-level chunking support)—these parameters serve different purposes and are intentionally used together. No changes needed.Likely an incorrect or invalid review comment.
248-267: Pointer arithmetic is correct; no issue identified.After examining the code, the pointer offset and stride calculations are self-consistent. The
Akk_diagoffset at line 87(bos * H + i_h) * BCcorrectly positions the pointer to the start of a diagonal block for a specific batch and head. The subsequent loads in the forward substitution loops (lines 249, 254, 259, 264) correctly use strideH*BCacross all four loops, which matches the stride specified in themake_block_ptrcalls (lines 225-228). The loop index variables handle the offset ranges correctly: loop 2 covers rows starting ati_tc0 + BC, loop 3 coversi_tc0 + 2*BC, and loop 4 coversi_tc0 + 3*BC, all properly expressed through the combined formula(i_tc0 + i) * H*BC. There is no double-counting of offsets and no stride inconsistency.examples/KDA/FLA_KDA/fla_chunk_o.py (2)
307-345:chunk_gla_fwd_o_gkwrapper implementation looks correct.The function properly handles chunk indices preparation, output tensor allocation, and kernel launch with appropriate parameters.
549-580:chunk_gla_bwd_dAwrapper implementation looks correct.The function correctly allocates the output tensor with float32 dtype and launches the kernel with proper grid dimensions.
| IS_TF32_SUPPORTED=False | ||
| if IS_TF32_SUPPORTED: | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') | ||
| else: | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') | ||
| SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32') |
There was a problem hiding this comment.
SOLVE_TRIL_DOT_PRECISION is overwritten immediately, dead code.
Lines 9-13 set SOLVE_TRIL_DOT_PRECISION based on IS_TF32_SUPPORTED, but line 14 unconditionally overwrites it to 'tf32'. This makes the conditional logic dead code.
🔧 Remove dead code or fix the logic
If 'tf32' is the intended precision, remove the dead code:
-IS_TF32_SUPPORTED=False
-if IS_TF32_SUPPORTED:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
-else:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
-SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')
+SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')Or if you intended to use the conditional:
-IS_TF32_SUPPORTED=False
+IS_TF32_SUPPORTED = False
if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
else:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee')
-SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| IS_TF32_SUPPORTED=False | |
| if IS_TF32_SUPPORTED: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') | |
| else: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') | |
| SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32') | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') |
| IS_TF32_SUPPORTED=False | |
| if IS_TF32_SUPPORTED: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') | |
| else: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') | |
| SOLVE_TRIL_DOT_PRECISION= tl.constexpr('tf32') | |
| IS_TF32_SUPPORTED = False | |
| if IS_TF32_SUPPORTED: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') | |
| else: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') |
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_intra.py around lines 9 - 14, The conditional
assignment to SOLVE_TRIL_DOT_PRECISION using IS_TF32_SUPPORTED is currently dead
because it is immediately overwritten by SOLVE_TRIL_DOT_PRECISION =
tl.constexpr('tf32'); either remove the unconditional overwrite and keep the
conditional block (use SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') or 'ieee'
based on IS_TF32_SUPPORTED) or delete the conditional block and retain the
single explicit SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32'); ensure you only
assign SOLVE_TRIL_DOT_PRECISION once and use the tl.constexpr(...) call
consistently.
| def chunk_bwd_kernel_dv_local( | ||
| q, | ||
| k, | ||
| g, | ||
| g_gamma, | ||
| A, | ||
| do, | ||
| dv, | ||
| cu_seqlens, | ||
| chunk_indices, | ||
| scale, | ||
| T, | ||
| H: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BT: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| USE_G: tl.constexpr, | ||
| USE_G_GAMMA: tl.constexpr, | ||
| USE_A: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| ): | ||
| i_t, i_bh = tl.program_id(0), tl.program_id(1) | ||
| i_b, i_h = i_bh // H, i_bh % H | ||
| if IS_VARLEN: | ||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_b * T, i_b * T + T | ||
|
|
||
| # offset calculation | ||
| q += (bos * H + i_h) * K | ||
| k += (bos * H + i_h) * K | ||
| do += (bos * H + i_h) * V | ||
| dv += (bos * H + i_h) * V | ||
|
|
||
| if USE_A: | ||
| p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) | ||
| b_A = tl.load(p_A, boundary_check=(0, 1)) | ||
|
|
||
|
|
||
| o_t = i_t * BT + tl.arange(0, BT) | ||
| m_t = o_t < T | ||
| m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) | ||
| b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) | ||
|
|
||
| for i_v in range(tl.cdiv(V, BV)): | ||
| p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| b_dv = tl.dot(b_A.to(b_do.dtype), b_do) | ||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
Multiple unused kernel arguments in chunk_bwd_kernel_dv_local.
Static analysis correctly identifies that g, g_gamma, scale, BK, USE_G, and USE_G_GAMMA are declared but never used in the kernel body. This suggests either incomplete implementation or arguments reserved for future use.
If these are placeholders for future features, consider adding a comment. If they're dead code, remove them to reduce confusion.
@triton.jit(do_not_specialize=['T'])
def chunk_bwd_kernel_dv_local(
q,
k,
- g,
- g_gamma,
+ g, # Reserved for gating support (currently unused)
+ g_gamma, # Reserved for gamma gating (currently unused)
A,
do,
dv,🧰 Tools
🪛 Ruff (0.14.10)
386-386: Unused function argument: g
(ARG001)
387-387: Unused function argument: g_gamma
(ARG001)
393-393: Unused function argument: scale
(ARG001)
399-399: Unused function argument: BK
(ARG001)
401-401: Unused function argument: USE_G
(ARG001)
402-402: Unused function argument: USE_G_GAMMA
(ARG001)
| if USE_A: | ||
| p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) | ||
| b_A = tl.load(p_A, boundary_check=(0, 1)) | ||
|
|
||
|
|
||
| o_t = i_t * BT + tl.arange(0, BT) | ||
| m_t = o_t < T | ||
| m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) | ||
| b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) |
There was a problem hiding this comment.
b_A used before assignment when USE_A is False.
If USE_A is False (line 421), b_A is never assigned, but it's used on line 429. This will cause a runtime error.
🐛 Initialize b_A before conditional
if USE_A:
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))
-
+ else:
+ # Initialize identity-like matrix when A is not provided
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
o_t = i_t * BT + tl.arange(0, BT)Alternatively, if USE_A=False is not a valid code path, add an assertion in the wrapper.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if USE_A: | |
| p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) | |
| b_A = tl.load(p_A, boundary_check=(0, 1)) | |
| o_t = i_t * BT + tl.arange(0, BT) | |
| m_t = o_t < T | |
| m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) | |
| b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) | |
| if USE_A: | |
| p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) | |
| b_A = tl.load(p_A, boundary_check=(0, 1)) | |
| else: | |
| # Initialize identity-like matrix when A is not provided | |
| b_A = tl.zeros([BT, BT], dtype=tl.float32) | |
| o_t = i_t * BT + tl.arange(0, BT) | |
| m_t = o_t < T | |
| m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) | |
| b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) |
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 421 - 429, The branch
guarding creation of b_A (when USE_A is True) leaves b_A undefined if USE_A is
False; ensure b_A is always initialized before its use by adding an else that
sets b_A to a zero tensor with the same shape and dtype as the loaded value
(matching do.dtype.element_ty and shape of p_A load / m_A result), or assert
that USE_A must be True; update the block around the USE_A check (where p_A, b_A
are created) so b_A is defined for both branches before the tl.where(...) on
line using m_A.
examples/KDA/FLA_KDA/fla_chunk_o.py
Outdated
| if check_shared_mem('hopper', k.device.index): | ||
| CONST_TILING = 128 | ||
| elif check_shared_mem: | ||
| CONST_TILING = 64 | ||
| else: | ||
| CONST_TILING = 32 |
There was a problem hiding this comment.
Bug: check_shared_mem used as boolean instead of being called.
Line 457 uses check_shared_mem (the function object) in a boolean context instead of calling it with check_shared_mem(). A function object is always truthy, so this branch will always execute.
🐛 Fix the missing function call
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
- elif check_shared_mem:
+ elif check_shared_mem('ampere', k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 455 - 460, The conditional
mistakenly uses the function object check_shared_mem instead of calling it,
causing the middle branch to always run; update the second branch to call
check_shared_mem with the same signature as the first (e.g.,
check_shared_mem('hopper', k.device.index)) or the appropriate args, so
CONST_TILING is set to 64 only when the function returns truthy; ensure the
first branch remains check_shared_mem('hopper', k.device.index) and the else
remains CONST_TILING = 32.
examples/KDA/FLA_KDA/fla_utils.py
Outdated
| log = tldevice.fast_logf | ||
| log2 = tldevice.fast_log2f | ||
|
|
||
| IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) |
There was a problem hiding this comment.
Potential crash if no CUDA device is available.
torch.cuda.get_device_name(0) and torch.cuda.get_device_capability() will raise an error if no CUDA device is present. This executes at module import time.
🛡️ Suggested defensive initialization
-IS_NVIDIA_HOPPER = (True and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
+def _detect_hopper():
+ try:
+ return 'NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+ except Exception:
+ return False
+
+IS_NVIDIA_HOPPER = _detect_hopper()🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_utils.py at line 25, The module-level expression
that sets IS_NVIDIA_HOPPER calls torch.cuda.get_device_name(0) and
torch.cuda.get_device_capability() which will raise if no CUDA device exists;
change the initialization of IS_NVIDIA_HOPPER in fla_utils.py to first check
torch.cuda.is_available() (and/or torch.cuda.device_count() > 0) and then safely
call get_device_name/get_device_capability inside a try/except, defaulting
IS_NVIDIA_HOPPER to False on any exception so import-time errors are avoided.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (4)
examples/KDA/FLA_KDA/cumsum.py (4)
56-64: Inconsistent reverse cumsum implementation between scalar and vector kernels.The scalar kernel manually computes reverse cumsum while the vector kernel (line 116) uses
tl.cumsum(b_s, axis=0, reverse=True). Consider using the nativereverse=Trueparameter here for consistency if supported for 1D tensors, or add a comment explaining why the manual approach is necessary.
178-179: Redundant condition:if i_c >= 0:is always true.Since
i_citerates fromrange(NT)starting at 0, this condition is always satisfied. The check can be removed.Suggested fix
b_o += b_z - if i_c >= 0: - b_z += b_ss + b_z += b_ss if HAS_SCALE: b_o *= scale
246-255: Consider using explicitOptionaltype hints for nullable parameters.PEP 484 prohibits implicit
Optional. Parameters likescale,cu_seqlens, andchunk_indicesthat default toNoneshould be typed asOptional[T].Suggested fix
+from typing import Optional + def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: torch.Tensor = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: torch.dtype = torch.float, - chunk_indices: torch.LongTensor = None, + output_dtype: torch.dtype = torch.float, + chunk_indices: Optional[torch.LongTensor] = None, ) -> torch.Tensor:This applies to similar signatures in other wrapper functions as well.
427-438: Unused**kwargsparameter.The
kwargsparameter is accepted but never used. If this is for forward compatibility, consider documenting it. Otherwise, remove to avoid confusion.Option 1: Remove unused kwargs
def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor = None, head_first: bool = False, output_dtype: torch.dtype = torch.float, chunk_indices: torch.LongTensor = None, - **kwargs, ) -> torch.Tensor:Option 2: Document if intentional for API stability
def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor = None, head_first: bool = False, output_dtype: torch.dtype = torch.float, chunk_indices: torch.LongTensor = None, - **kwargs, + **kwargs, # Reserved for future parameters / API compatibility ) -> torch.Tensor:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py
33-33: Unused function argument: B
(ARG001)
88-88: Unused function argument: B
(ARG001)
146-146: Unused function argument: B
(ARG001)
206-206: Unused function argument: B
(ARG001)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-287: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
330-330: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
361-361: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
395-395: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
420-424: Avoid specifying long messages outside the exception class
(TRY003)
432-432: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
437-437: Unused function argument: kwargs
(ARG001)
464-468: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (7)
examples/KDA/FLA_KDA/cumsum.py (7)
1-11: LGTM!Standard imports and constants for Triton kernel development.
67-122: LGTM!Clean implementation using native Triton reverse cumsum support.
185-244: LGTM!Global vector cumsum implementation is correct with proper accumulation logic.
283-323: LGTM!Vector cumsum wrapper correctly sets up the dynamic grid and delegates to the kernel.
325-354: LGTM!Global scalar cumsum wrapper correctly handles both fixed-length and variable-length sequences.
356-388: LGTM!Global vector cumsum wrapper correctly computes block size and grid dimensions.
390-425: LGTM!Dispatcher correctly routes based on tensor dimensionality with appropriate validation.
1ef405a to
ffd4fa5
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @examples/KDA/FLA_KDA/cumsum.py:
- Around line 437-438: The function in examples/KDA/FLA_KDA/cumsum.py whose
signature currently includes an unused **kwargs parameter should either drop
**kwargs or explicitly handle/validate unexpected keyword args; update the
function signature to remove **kwargs if no extra kwargs are intended, or add
code at the start of that function to raise a TypeError listing unexpected keys
(so callers aren’t silently ignored). Locate the function with the signature
that returns torch.Tensor (the one capturing **kwargs) and apply the chosen
change consistently across any internal calls or tests.
- Around line 178-179: The condition "if i_c >= 0" is redundant because the loop
initializes i_c at 0 so it is always true; remove the conditional and directly
perform the accumulation (replace the guarded statement with a direct "b_z +=
b_ss" inside the loop) at the location using variables i_c, b_z, and b_ss to
eliminate dead code.
- Around line 464-468: The error message raised for unsupported input shape when
checking g.shape is inaccurate for the 3D scalar case; update the message in the
cumsum module (the raise that references g.shape) to mirror chunk_global_cumsum
style and list both allowed shapes like "[B, T, H] / [B, T, H, D]" (or
vice-versa depending on head_first) so it correctly documents the 3D scalar and
4D tensor possibilities.
🧹 Nitpick comments (2)
examples/KDA/FLA_KDA/cumsum.py (2)
297-301: Inconsistent assertion ordering compared to scalar version.The power-of-2 assertion is placed after
prepare_chunk_indicesis called withchunk_size, whereas inchunk_local_cumsum_scalar(line 260), the assertion is placed before any usage. Consider moving the assertion before line 298 for consistency and to catch invalid input earlier.Suggested fix
if head_first: B, H, T, S = g.shape else: B, T, H, S = g.shape + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" BT = chunk_size if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
246-255: Consider using explicitOptional[T]type hints.Several parameters across the file use implicit optional types (e.g.,
scale: float = None). PEP 484 recommends using explicitOptional[float]for clarity. This applies toscale,cu_seqlens, andchunk_indicesparameters throughout the file.Example fix for chunk_local_cumsum_scalar
+from typing import Optional + def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: torch.Tensor = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, output_dtype: torch.dtype = torch.float, - chunk_indices: torch.LongTensor = None, + chunk_indices: Optional[torch.LongTensor] = None, ) -> torch.Tensor:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/KDA/FLA_KDA/cumsum.py (2)
examples/KDA/FLA_KDA/fla_utils.py (2)
prepare_chunk_indices(100-105)input_guard(147-178)tilelang/math/__init__.py (1)
next_power_of_2(1-2)
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py
33-33: Unused function argument: B
(ARG001)
88-88: Unused function argument: B
(ARG001)
146-146: Unused function argument: B
(ARG001)
206-206: Unused function argument: B
(ARG001)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-287: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
330-330: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
361-361: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
395-395: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
420-424: Avoid specifying long messages outside the exception class
(TRY003)
432-432: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
437-437: Unused function argument: kwargs
(ARG001)
464-468: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
examples/KDA/FLA_KDA/cumsum.py (5)
13-65: LGTM!The scalar local cumsum kernel logic is correct. The reverse cumsum implementation properly computes the suffix sums, and the pointer arithmetic correctly handles both
HEAD_FIRSTand non-HEAD_FIRST memory layouts. TheBparameter is intentionally included in the autotune key even though it's not used in the kernel body.
67-122: LGTM!The vector local cumsum kernel correctly uses Triton's native
reverse=Trueparameter for the reverse cumsum operation. The 2D block pointer setup properly handles the additional state dimensionS.
185-244: LGTM!The vector global cumsum kernel correctly accumulates state across chunks with proper ordering of the store and accumulator update operations.
325-387: LGTM!Both global cumsum wrappers correctly configure the grid dimensions and kernel parameters. The
BSconstraint in the vector version appropriately limits block size for better performance.
390-424: LGTM!The public API correctly dispatches to scalar or vector implementations based on input dimensionality, with clear error messaging for unsupported shapes.
| if i_c >= 0: | ||
| b_z += b_ss |
There was a problem hiding this comment.
Unnecessary condition: if i_c >= 0 is always true.
Since the loop starts at i_c = 0, this condition will always evaluate to True. This appears to be dead code or a leftover from previous logic.
Suggested fix
b_o += b_z
- if i_c >= 0:
- b_z += b_ss
+ b_z += b_ss
if HAS_SCALE:
b_o *= scale📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if i_c >= 0: | |
| b_z += b_ss | |
| b_o += b_z | |
| b_z += b_ss | |
| if HAS_SCALE: | |
| b_o *= scale |
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 178 - 179, The condition "if i_c
>= 0" is redundant because the loop initializes i_c at 0 so it is always true;
remove the conditional and directly perform the accumulation (replace the
guarded statement with a direct "b_z += b_ss" inside the loop) at the location
using variables i_c, b_z, and b_ss to eliminate dead code.
| **kwargs, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Unused **kwargs may silently ignore invalid arguments.
The **kwargs parameter is captured but never used, which could lead to confusing behavior where callers pass invalid keyword arguments that are silently ignored rather than raising an error. Consider either removing it or documenting its intended purpose.
Suggested fix (remove if not needed)
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: torch.Tensor = None,
head_first: bool = False,
output_dtype: torch.dtype = torch.float,
chunk_indices: torch.LongTensor = None,
- **kwargs,
) -> torch.Tensor:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| **kwargs, | |
| ) -> torch.Tensor: | |
| def chunk_local_cumsum( | |
| g: torch.Tensor, | |
| chunk_size: int, | |
| reverse: bool = False, | |
| scale: float = None, | |
| cu_seqlens: torch.Tensor = None, | |
| head_first: bool = False, | |
| output_dtype: torch.dtype = torch.float, | |
| chunk_indices: torch.LongTensor = None, | |
| ) -> torch.Tensor: |
🧰 Tools
🪛 Ruff (0.14.10)
437-437: Unused function argument: kwargs
(ARG001)
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 437 - 438, The function in
examples/KDA/FLA_KDA/cumsum.py whose signature currently includes an unused
**kwargs parameter should either drop **kwargs or explicitly handle/validate
unexpected keyword args; update the function signature to remove **kwargs if no
extra kwargs are intended, or add code at the start of that function to raise a
TypeError listing unexpected keys (so callers aren’t silently ignored). Locate
the function with the signature that returns torch.Tensor (the one capturing
**kwargs) and apply the chosen change consistently across any internal calls or
tests.
| raise ValueError( | ||
| f"Unsupported input shape {g.shape}, " | ||
| f"which should be (B, T, H, D) if `head_first=False` " | ||
| f"or (B, H, T, D) otherwise", | ||
| ) |
There was a problem hiding this comment.
Error message is inaccurate for the scalar case.
The error message states the expected shape is (B, T, H, D) but the 3D scalar case expects (B, T, H) without the D dimension. Consider updating to match the style used in chunk_global_cumsum which mentions both [B, T, H]/[B, T, H, D].
Suggested fix
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
- f"which should be (B, T, H, D) if `head_first=False` "
- f"or (B, H, T, D) otherwise",
+ f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
+ f"or [B, H, T]/[B, H, T, D] otherwise",
)🧰 Tools
🪛 Ruff (0.14.10)
464-468: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/cumsum.py around lines 464 - 468, The error message
raised for unsupported input shape when checking g.shape is inaccurate for the
3D scalar case; update the message in the cumsum module (the raise that
references g.shape) to mirror chunk_global_cumsum style and list both allowed
shapes like "[B, T, H] / [B, T, H, D]" (or vice-versa depending on head_first)
so it correctly documents the 3D scalar and 4D tensor possibilities.
ffd4fa5 to
801ae9b
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (4)
examples/KDA/FLA_KDA/cumsum.py (4)
59-61: Inconsistent reverse cumsum implementation.The vector kernel uses
tl.cumsum(b_s, axis=0, reverse=True)directly (line 116), but this scalar kernel manually computes the reverse. Consider using the built-in parameter for consistency:♻️ Suggested simplification
- b_o = tl.cumsum(b_s, axis=0) - if REVERSE: - b_z = tl.sum(b_s, axis=0) - b_o = -b_o + b_z[None] + b_s + if REVERSE: + b_o = tl.cumsum(b_s, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_s, axis=0)
178-179: Dead code: condition is always true.
i_cranges from0toNT-1, soif i_c >= 0is alwaysTrue. This branch is redundant.♻️ Remove dead condition
b_o += b_z - if i_c >= 0: - b_z += b_ss + b_z += b_ss if HAS_SCALE:
250-254: Use explicitOptionaltype hints per PEP 484.Parameters with
= Nonedefault should useOptional[T]type annotation for clarity and static analysis compatibility.♻️ Suggested type hint fix
+from typing import Optional + def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: torch.Tensor = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: torch.dtype = torch.float, - chunk_indices: torch.LongTensor = None, + output_dtype: Optional[torch.dtype] = torch.float, + chunk_indices: Optional[torch.LongTensor] = None, ) -> torch.Tensor:Note: This same pattern applies to all wrapper functions (
chunk_local_cumsum_vector,chunk_global_cumsum_scalar,chunk_global_cumsum_vector,chunk_global_cumsum, andchunk_local_cumsum).
437-438: Consider removing or documenting unused**kwargs.The
**kwargsparameter is not used within the function. If it's intended for forward compatibility or interface consistency with other modules, consider adding a docstring to document this intent. Otherwise, removing it would improve clarity.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/KDA/FLA_KDA/cumsum.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
🪛 Ruff (0.14.10)
examples/KDA/FLA_KDA/cumsum.py
33-33: Unused function argument: B
(ARG001)
88-88: Unused function argument: B
(ARG001)
146-146: Unused function argument: B
(ARG001)
206-206: Unused function argument: B
(ARG001)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-287: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
330-330: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
361-361: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
395-395: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
420-424: Avoid specifying long messages outside the exception class
(TRY003)
432-432: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
437-437: Unused function argument: kwargs
(ARG001)
464-468: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (7)
examples/KDA/FLA_KDA/cumsum.py (7)
1-11: LGTM!Imports and constants are appropriate for the Triton kernel implementation.
67-122: LGTM!The vector kernel correctly handles 2D blocks with proper boundary checks and uses the built-in
reverse=Trueparameter for cumulative sum.
185-244: LGTM!The global vector kernel correctly implements the chunked global cumulative sum with proper handling of both forward and reverse directions.
283-323: LGTM with minor note.Implementation is correct. The power-of-2 assertion placement differs from the scalar version (line 301 vs 260), but this doesn't affect functionality. Same
Optionaltype hint improvements apply here as noted above.
325-354: LGTM!The global scalar wrapper correctly handles both regular and variable-length inputs.
356-388: LGTM!The global vector wrapper correctly computes block size and handles the vectorized cumsum.
390-425: LGTM!The dispatcher correctly routes to scalar/vector implementations based on input dimensionality with helpful error messages.
There was a problem hiding this comment.
Actionable comments posted: 18
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 302-304: main() is passing TileLang types (T.float32) into
run_test via parameters input_dtype/gate_dtype/qk_dtype, but run_test expects
string dtypes used with getattr(torch, ...); change the call sites in main() to
pass string names like "float32" (or whatever exact torch dtype attribute name
you need) for input_dtype, gate_dtype and qk_dtype, or alternatively update
run_test to accept TileLang types by converting them to the corresponding torch
dtype before calling getattr; update the parameters consistently so
getattr(torch, ...) receives a string key.
In @examples/KDA/chunk_bwd_gla_dA.py:
- Around line 83-84: The V_shared buffer is allocated with do_dtype but later
populated from V which uses input_dtype, risking implicit casts; change the
allocation for V_shared to use input_dtype (i.e., T.alloc_shared((block_S,
block_DV), dtype=input_dtype)) or explicitly cast V to do_dtype at the copy site
so the types match; update the allocation of V_shared (and any corresponding
uses) to prevent precision loss between do_dtype and input_dtype.
- Line 12: Remove the hardcoded os.environ["CUDA_VISIBLE_DEVICES"] = "7"
assignment and either delete the line entirely or replace it with reading a
configurable value (e.g., via argparse or
os.environ.get("CUDA_VISIBLE_DEVICES")) so device selection is not forced in the
script; ensure any replacement uses a fallback (None or empty) and document that
users should set CUDA_VISIBLE_DEVICES externally if needed.
In @examples/KDA/chunk_bwd_intra_op.py:
- Around line 508-512: The dtype parameters (input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype) are using TileLang type objects
(T.float32) instead of string names; update each to use the matching string
representation (e.g., "float32") wherever these kwargs are set in this module
(replace T.float32 with "float32" for
input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype) to match the
pattern used in the other files.
In @examples/KDA/chunk_delta_bwd.py:
- Around line 267-307: The test computes FLA references (dh_ref, dh0_ref,
dv2_ref) but never verifies TileLang outputs because the compare_tensors calls
are commented out; uncomment and invoke compare_tensors to validate correctness
(compare_tensors("dh", dh_ref, dh_tilelang), compare_tensors("dh0", dh0_ref,
dh0_tilelang), compare_tensors("dv2", dv2_ref, dv2_tilelang")). Ensure these
comparisons only run when the references are actually computed (i.e., wrap them
in the same use_gk conditional or compute refs unconditionally), and import or
define compare_tensors if missing so the comparisons execute after running
chunk_gated_delta_rule_bwd_dhu and tilelang_chunk_gated_delta_rule_bwd_dhu.
In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 307-317: The benchmark invocation passes the wrong keyword to the
reference function: change the do_bench call that references
chunk_gated_delta_rule_fwd_h to use the gk= parameter instead of g=G so the
reference gets the intended gate constant; locate the do_bench call where
chunk_gated_delta_rule_fwd_h is passed and replace the g=G argument with gk=G.
- Around line 336-340: The test is passing TileLang type objects (T.float16,
T.float32) to run_test but run_test expects string dtype names; update the call
site in main() (the parameters input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype) to pass corresponding string names ("float16",
"float32", etc.) instead of T.float16/T.float32 so the dtype formats match
run_test's expectations.
In @examples/KDA/chunk_inter_solve_fused.py:
- Line 13: Remove the hardcoded CUDA device assignment that sets
os.environ["CUDA_VISIBLE_DEVICES"] = "7"; either delete this line or replace it
with a configurable approach (read the device index from an environment variable
or a CLI/config parameter and only set CUDA_VISIBLE_DEVICES if that value is
present/valid). Update the initialization in the script (the line using
os.environ["CUDA_VISIBLE_DEVICES"]) to validate the provided value before
applying it, and document or expose the configuration parameter so users on
systems with different GPU counts can override it.
In @examples/KDA/chunk_intra_token_parallel.py:
- Around line 260-266: Aqk_tilelang and Akk_tilelang are computed but never
validated because the comparison/assertion block is commented out; restore or
add explicit comparisons against the reference tensors (Aqk_ref, Akk_ref) using
existing helpers like assert_similar or compare_tensors from test_utils.py and
ensure test_utils is imported, e.g., call compare_tensors("Aqk", Aqk_ref,
Aqk_tilelang) and compare_tensors("Akk", Akk_ref, Akk_tilelang) (or equivalent
assert_similar checks) where the commented lines (around the existing 291-296
comparison block) currently reside so the kernel outputs are actually validated.
In @examples/KDA/chunk_o.py:
- Around line 132-134: The parallel loop uses indices (i_s, i_v) over (block_S,
block_DV) but then indexes Q_shared and GK_shared which are shaped (block_S,
block_DK); change the loop to iterate over (block_S, block_DK) (e.g., (i_s,
i_k)) or otherwise use the DK-range for the second index, and update uses of
Q_shared and GK_shared to Q_shared[i_s, i_k] and GK_shared[i_s, i_k] (and
GQ_shared[i_s, i_k]) so the second index matches the buffers' block_DK
dimension.
In @examples/KDA/FLA_KDA/fla_chunk_delta.py:
- Around line 487-494: When cu_seqlens is provided you never assign N, NT, or
chunk_offsets, causing UnboundLocalError; update the branch that handles
cu_seqlens (around the prepare_chunk_indices call) to also set N (number of
sequences), NT (number of chunks, e.g., from chunk_indices or using triton.cdiv
on totals), and chunk_offsets (derived from cu_seqlens or prepare_chunk_indices
output) before they are used; make the same fix inside
chunk_gated_delta_rule_bwd_dhu so both code paths initialize N, NT, and
chunk_offsets when cu_seqlens is not None.
In @examples/KDA/FLA_KDA/fla_chunk_o.py:
- Around line 432-437: The elif currently tests the function object instead of
calling it; change "elif check_shared_mem:" to call the function (e.g., "elif
check_shared_mem(\"hopper\", k.device.index):" or the correct argument set for
your use case) so the condition actually invokes check_shared_mem rather than
always evaluating truthiness of the function object, and keep CONST_TILING
assignments as-is.
In @examples/KDA/FLA_KDA/fla_wy_fast.py:
- Around line 269-277: The backward helper prepare_wy_repr_bwd hardcodes BT = 64
which can mismatch the forward recompute_w_u_fwd where BT is derived as
A.shape[-1]; change prepare_wy_repr_bwd to compute BT = A.shape[-1] (or
otherwise derive BT from the same input used in forward, e.g., A or the relevant
tensor shape) instead of the hardcoded 64 so forward and backward use the same
tiling.
In @examples/KDA/test_utils.py:
- Around line 27-30: The non-finite comparison in the block using
x.masked_fill(x_mask, 0) and y.masked_fill(y_mask, 0) has the mask inverted;
currently finite positions are being filled and non-finite values compared.
Change the masked_fill calls to use the inverted masks (~x_mask and ~y_mask) so
you mask out (fill) finite values and compare only non-finite positions (keep
using equal_nan=True and the existing error/raise behavior), referencing the
variables x, y, x_mask, y_mask and the masked_fill calls to locate the change.
In @examples/KDA/wy_fast_bwd.py:
- Around line 364-374: The call to main() in wy_fast_bwd.py is passing TileLang
types (T.float32) for dtype params while run_test expects string names to use
getattr(torch, ...); update the main(...) invocation to pass string dtype names
(e.g., "float32") for input_dtype, output_dtype, accum_dtype, gate_dtype, and
state_dtype so run_test can resolve torch dtypes via getattr; keep all other
numeric/size params (chunk_size, block_DK, block_DV, threads, num_stages)
unchanged.
In @examples/KDA/wy_fast.py:
- Around line 207-208: The variable use_qg is mistakenly assigned as a
one-element tuple (False,) which is truthy; change its assignment to the boolean
False (remove the trailing comma) so conditional checks using use_qg behave
correctly; confirm similar assignments (e.g., use_kg) are booleans and update
any other occurrences of use_qg in the file to expect a boolean value.
- Around line 266-269: The bug is that main() passes TileLang type objects
(T.bfloat16, T.float32) into run_test which expects string dtype names resolved
with getattr(torch, ...); update the call site in main() to pass dtype names as
strings (e.g., "bfloat16" for input_dtype/output_dtype and "float32" for
gate_dtype/accum_dtype) so run_test can successfully call getattr(torch, dtype).
Reference symbols: main(), run_test(), and the dtype parameters
input_dtype/output_dtype/gate_dtype/accum_dtype.
🟡 Minor comments (7)
examples/KDA/FLA_KDA/fla_utils.py-24-24 (1)
24-24: Potential crash on systems without CUDA devices.
torch.cuda.get_device_name(0)will raise an exception if no CUDA device is available. This executes at module import time, making the entire module unimportable on CPU-only systems.Consider guarding with a device availability check:
IS_NVIDIA_HOPPER = ( torch.cuda.is_available() and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) )examples/KDA/chunk_intra_token_parallel.py-8-8 (1)
8-8: Hardcoded GPU device index.
CUDA_VISIBLE_DEVICES="7"is hardcoded, which will fail on systems with fewer GPUs or different configurations. Consider removing this or making it configurable via environment variable fallback.🔧 Suggested fix
-os.environ["CUDA_VISIBLE_DEVICES"] = "7" +# os.environ["CUDA_VISIBLE_DEVICES"] = "7" # Uncomment to pin specific GPUexamples/KDA/FLA_KDA/cumsum.py-176-177 (1)
176-177: Conditioni_c >= 0is always true.Since
i_citerates fromrange(NT), it starts at 0 and is always non-negative. This condition provides no filtering.🐛 If this was meant to skip the first iteration, use:
- if i_c >= 0: - b_z += b_ss + if i_c > 0: + b_z += b_ssOtherwise, remove the conditional entirely:
- if i_c >= 0: - b_z += b_ss + b_z += b_ssexamples/KDA/chunk_bwd_dqkwg.py-271-274 (1)
271-274: Correctness validation is disabled.The
compare_tensorscalls are commented out. Enable these to validate the TileLang kernel against the reference implementation.Suggested fix
- # compare_tensors("dq", dq_ref, dq) - # compare_tensors("dk", dk_ref, dk) - # compare_tensors("dw", dw_ref, dw) - # compare_tensors("dg", dg_ref, dg) + compare_tensors("dq", dq_ref, dq) + compare_tensors("dk", dk_ref, dk) + compare_tensors("dw", dw_ref, dw) + compare_tensors("dg", dg_ref, dg)examples/KDA/wy_fast_bwd.py-334-338 (1)
334-338: Correctness validation is disabled.The
compare_tensorscalls are commented out, meaning the test only benchmarks without validating correctness. Consider enabling these checks or adding a flag to toggle validation.Suggested fix
- # compare_tensors("dA", dA_tilelang, dA_ref) - # compare_tensors("dk", dk_tilelang, dk_ref) - # compare_tensors("dv", dv_tilelang, dv_ref) - # compare_tensors("dbeta", dbeta_tilelang, dbeta_ref) - # compare_tensors("dg", dg_tilelang, dg_ref) + compare_tensors("dA", dA_tilelang, dA_ref) + compare_tensors("dk", dk_tilelang, dk_ref) + compare_tensors("dv", dv_tilelang, dv_ref) + compare_tensors("dbeta", dbeta_tilelang, dbeta_ref) + compare_tensors("dg", dg_tilelang, dg_ref)examples/KDA/FLA_KDA/fla_chunk_intra.py-10-15 (1)
10-15: Dead code:SOLVE_TRIL_DOT_PRECISIONis overwritten.Lines 10-14 define
SOLVE_TRIL_DOT_PRECISIONconditionally, but line 15 unconditionally overwrites it with"tf32". The conditional logic is dead code.If you intend to always use tf32, simplify to:
-IS_TF32_SUPPORTED = False -if IS_TF32_SUPPORTED: - SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") -else: - SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")Otherwise, remove line 15 to enable the conditional logic.
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py-114-168 (1)
114-168: Incorrect return type annotation.The function signature declares
-> Nonebut the function returnsAqk, Akkon line 168. Update the return type annotation to match the actual behavior.Suggested fix
def chunk_kda_fwd_intra_token_parallel( ... -) -> None: +) -> tuple[torch.Tensor, torch.Tensor]:
🧹 Nitpick comments (38)
examples/KDA/chunk_bwd_gla_dA.py (8)
4-4: Remove unusedsysimport.The
sysmodule is imported but never used. Thenoqa: F401directive is also unnecessary since the lint rule isn't triggered for valid reasons.Suggested fix
-import sys # noqa: F401
17-28: Unusedchunk_sizeparameter.The
chunk_sizeparameter is declared but never used in this function. Either remove it or document why it's kept for interface consistency.Suggested fix (if not needed for interface consistency)
def prepare_input( B, S, H, DV, - chunk_size, input_dtype, do_dtype, ):
31-40: UnusedDVparameter.The
DVparameter is unused in this function.Suggested fix
def prepare_output( B, S, H, - DV, chunk_size, d_type, ):
43-51: Consider movingitertoolsimport to module level.The
itertoolsimport inside the function works but is unconventional. For consistency and slight performance benefit on repeated calls, consider moving it to the top-level imports.
86-86: Remove commented-out code and translate comments.The commented-out
dA_sharedallocation and copy statements should be removed if no longer needed. Also, the Chinese comment on line 102 (下三角矩阵) should be translated to English for broader accessibility.Suggested fix
- # dA_shared = T.alloc_shared((block_S, block_S), dtype=da_dtype) ... for i_s1, i_s2 in T.Parallel(block_S, block_S): - dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 - # T.copy(dA_fragment, dA_shared) + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # lower triangular matrixAlso applies to: 102-103
130-130: Consider removing debug print statement.Printing raw timing data on every benchmark call may be noisy. Consider removing it or gating it behind a verbose flag.
Suggested fix
- print(times)
150-162: Wasted tensor allocation on line 150.
dA_tilelangis allocated viaprepare_outputon line 150, but then immediately overwritten by the kernel output on line 162. The allocation serves no purpose. If you intended to passdA_tilelangas an output buffer to the kernel, the kernel invocation pattern would need to change.Suggested fix
- dA_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, da_dtype)) kernel = tilelang_chunk_bwd_kernel_dv_local( B=B, S=S, H=H, DV=DV, scale=scale, input_dtype=input_dtype, da_dtype=da_dtype, do_dtype=do_dtype, chunk_size=chunk_size, ) dA_tilelang = kernel(DO, V_new)
138-138: UnusedDKparameter.The
DKparameter is declared but never used inrun_test.examples/KDA/chunk_delta_bwd.py (3)
3-9: Remove debug artifacts and unused import.The
sysimport is unused, and theprint(tilelang.__file__)statement will execute on every module import, which is not suitable for production/example code.Proposed fix
-import sys # noqa: F401 - import tilelang import tilelang.language as T from tilelang.autotuner import autotune - -print(tilelang.__file__, flush=True)
22-49: Unused function parameters.Parameters
output_dtype,accum_dtype, andstate_dtypeare declared but never used. Consider removing them or prefixing with_if they're placeholders for future use.
263-265: Unnecessary pre-allocation of output tensors.
prepare_outputallocatesdh_tilelang,dh0_tilelang, anddv2_tilelang, but these are immediately overwritten by the kernel call at line 295. The@tilelang.jit(out_idx=[-3, -2, -1])decorator already handles output allocation.Proposed fix
- dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) - ) - # fla refexamples/KDA/test_utils.py (2)
43-45: Unusedatolandrtolparameters.The function signature accepts
atolandrtolbut they are never used in the comparison logic. Either implement tolerance-based checks or remove these parameters to avoid confusion.
49-54: Minor: Chinese comments and fullwidth characters.Comments on lines 49, 54, 64, and 68 use Chinese text and fullwidth parentheses. Consider using English comments for broader accessibility, or at minimum replace fullwidth characters with ASCII equivalents for consistency.
examples/KDA/FLA_KDA/fla_utils.py (3)
24-25: RedundantTrue andprefix.
True and (...)is equivalent to just the condition itself. This pattern suggests a toggle that was left enabled.♻️ Simplification
-IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) -USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" +IS_NVIDIA_HOPPER = "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +USE_CUDA_GRAPH = os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
136-150: Silent exception swallowing may hide configuration issues.Both
try/exceptblocks catch broadExceptionand silently pass. If the Triton API changes or returns unexpected data, failures will be masked, and the function falls back to returning 1 without any indication of why.Consider logging at debug level:
except Exception as e: import logging logging.debug(f"Triton 2.2+ API failed: {e}")
224-230: CatchingBaseExceptionis overly broad.
BaseExceptionincludesKeyboardInterruptandSystemExit, which should typically propagate. UseExceptioninstead.♻️ Proposed fix
try: return [ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] for i in range(device_torch_lib.device_count()) ] - except BaseException: + except Exception: return [-1]examples/KDA/FLA_KDA/cumsum.py (2)
260-261: Power-of-2 check uses non-standard pattern.The check
chunk_size == 2 ** (chunk_size.bit_length() - 1)works but is fragile for edge cases (fails for 0). Consider the canonical form.♻️ More robust alternative
- assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, "chunk_size must be a power of 2"
440-441: Unused**kwargsparameter.The
kwargsparameter is accepted but never used. If it's for future extensibility, document it; otherwise, remove it to avoid confusion.examples/KDA/wy_fast.py (1)
10-10: Avoid hardcodingCUDA_VISIBLE_DEVICES.Hardcoding a specific GPU device index limits portability and may cause issues in different environments. Consider removing this or using an environment variable fallback.
Suggested fix
-os.environ["CUDA_VISIBLE_DEVICES"] = "7" +# Allow override via environment variable, default to all visible devices +# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")examples/KDA/chunk_bwd_dv.py (2)
12-12: Avoid hardcodingCUDA_VISIBLE_DEVICES.Same issue as in
wy_fast.py. Consider removing or making configurable.
133-144: Unused parameters:DKandscale.These parameters are declared but never used in
run_test. Consider removing them or documenting their intended future use.Suggested fix
def run_test( B, S, H, - DK, + DK, # Used for prepare_input DV, - scale, input_dtype, do_dtype, output_dtype, chunk_size, ):examples/KDA/chunk_bwd_dqkwg.py (2)
11-11: Avoid hardcodingCUDA_VISIBLE_DEVICES.Same issue across multiple files.
241-248: Consider removing unused parameters.Multiple parameters (
use_gk,use_initial_state,store_final_state,save_new_value,block_DK,block_DV,threads,num_stages) are declared but never used. This adds confusion.examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
141-155: Unused parameterwand implicitOptionaltype.
- Parameter
wis passed but never used in the function body (only used to createdwwith same shape).- Per PEP 484,
scale: float = Noneshould bescale: Optional[float] = None.Suggested fix for type hint
+from typing import Optional + def chunk_kda_bwd_dqkwg( q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, v: torch.Tensor, h: torch.Tensor, g: torch.Tensor, do: torch.Tensor, dh: torch.Tensor, dv: torch.Tensor, - scale: float = None, + scale: Optional[float] = None, cu_seqlens: torch.LongTensor = None, chunk_size: int = 64, chunk_indices: torch.LongTensor = None, ):examples/KDA/chunk_bwd_intra_op.py (2)
470-470: Remove debug print statement.This debug print should be removed or guarded behind a debug flag.
Suggested fix
- print(db_tilelang.shape)
200-200: Minor typo in comment."i_j is index ofprevious sub_chunks" should be "i_j is index of previous sub-chunks".
Suggested fix
- for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index of previous sub-chunksexamples/KDA/chunk_o.py (3)
4-4: Remove unusednoqadirective.The
sysimport is not used and thenoqa: F401directive is unnecessary since F401 is not enabled.Suggested fix
-import sys # noqa: F401
213-214: Wasteful allocation immediately overwritten.
prepare_outputallocatesO_refon line 213, but it's immediately overwritten by the result ofchunk_gla_fwd_o_gkon line 214. Remove the unnecessary allocation.Suggested fix
- O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) O_ref = chunk_gla_fwd_o_gk(Q, V, G, A, HIDDEN, scale, chunk_size=chunk_size, use_exp2=True)
198-201: Unused function arguments.Parameters
block_DK,block_DV,threads, andnum_stagesare accepted but never used inrun_test. Since autotuning is enabled on the kernel, these could be removed from the function signature, or they should be passed to the kernel if manual configuration is intended.examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
534-534: Use explicitOptionaltype annotation.PEP 484 prohibits implicit
Optional. Thescaleparameter should useOptional[float]for clarity.Suggested fix
+from typing import Optional + def chunk_gated_delta_rule_bwd_dhu( ... - scale: float = None, + scale: Optional[float] = None, ...examples/KDA/chunk_bwd_intra.py (3)
3-3: Remove unusednoqadirective.Suggested fix
-import sys # noqa: F401
200-200: Typo in comment.Minor typo: "ofprevious" should be "of previous".
Suggested fix
- for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index of previous sub_chunks
470-470: Remove debug print statement.This appears to be leftover debug code that should be removed before merging.
Suggested fix
- print(db_tilelang.shape) dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)examples/KDA/chunk_inter_solve_fused.py (2)
3-3: Remove unusednoqadirective.Suggested fix
-import sys # noqa: F401
528-528: Remove debug print statement.The
print(times)insidedo_benchappears to be debug output that should be removed.Suggested fix
- print(times) return times.mean().item()examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
370-370: Variableallshadows Python builtin.Using
allas a variable name shadows the built-inall()function. Consider using a more descriptive name liketotal_tokensorbatch_tokens.Suggested fix
- all = B * T + total_tokens = B * T if IS_VARLEN: ... ... - db += (i_k * all + bos) * H + i_h + db += (i_k * total_tokens + bos) * H + i_hexamples/KDA/FLA_KDA/fla_chunk_o.py (2)
87-88: Conditionif i_k >= 0:is always true.Since
i_kcomes fromrange(tl.cdiv(K, BK)), it is always non-negative. This condition appears to be dead code or a placeholder. If it's intentional for clarity/future use, consider adding a comment.
360-381: Multiple unused kernel arguments.The kernel
chunk_bwd_kernel_dv_localdeclares parametersg,g_gamma,scale,BK,USE_G, andUSE_G_GAMMAbut never uses them. If these are placeholders for future functionality, consider adding a TODO comment. Otherwise, they should be removed.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
.pre-commit-config.yamlexamples/KDA/FLA_KDA/cumsum.pyexamples/KDA/FLA_KDA/fla_chunk_delta.pyexamples/KDA/FLA_KDA/fla_chunk_inter.pyexamples/KDA/FLA_KDA/fla_chunk_intra.pyexamples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.pyexamples/KDA/FLA_KDA/fla_chunk_o.pyexamples/KDA/FLA_KDA/fla_utils.pyexamples/KDA/FLA_KDA/fla_wy_fast.pyexamples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_dv.pyexamples/KDA/chunk_bwd_gla_dA.pyexamples/KDA/chunk_bwd_intra.pyexamples/KDA/chunk_bwd_intra_op.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_inter_solve_fused.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/chunk_o.pyexamples/KDA/test_utils.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.py
✅ Files skipped from review due to trivial changes (1)
- .pre-commit-config.yaml
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
🧬 Code graph analysis (12)
examples/KDA/test_utils.py (1)
tilelang/carver/roller/policy/default.py (1)
sim(290-291)
examples/KDA/FLA_KDA/fla_chunk_delta.py (4)
examples/KDA/FLA_KDA/fla_utils.py (1)
prepare_chunk_indices(107-112)examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
grid(168-169)examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
grid(149-150)examples/KDA/FLA_KDA/fla_chunk_o.py (1)
grid(320-321)
examples/KDA/chunk_delta_bwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
chunk_gated_delta_rule_bwd_dhu(524-579)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
examples/KDA/chunk_o.py (6)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_chunk_o.py (1)
chunk_gla_fwd_o_gk(299-340)tilelang/language/kernel.py (1)
threads(215-219)tilelang/language/annotations.py (1)
annotate_layout(27-40)tilelang/layout/swizzle.py (1)
make_swizzled_layout(62-71)tilelang/language/dtypes.py (2)
bfloat16(407-407)float32(310-310)
examples/KDA/chunk_bwd_dqkwg.py (1)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
chunk_kda_bwd_dqkwg(141-193)
examples/KDA/chunk_inter_solve_fused.py (3)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
chunk_kda_fwd_inter_solve_fused(611-650)examples/KDA/test_utils.py (1)
compare_tensors(43-83)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
examples/KDA/FLA_KDA/fla_utils.py (1)
tilelang/utils/device.py (1)
get_current_device(14-21)
examples/KDA/wy_fast.py (4)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_wy_fast.py (1)
recompute_w_u_fwd(210-254)examples/KDA/test_utils.py (1)
compare_tensors(43-83)tilelang/language/proxy.py (1)
Tensor(233-233)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
examples/KDA/FLA_KDA/fla_utils.py (2)
prepare_chunk_indices(107-112)check_shared_mem(234-240)
examples/KDA/wy_fast_bwd.py (2)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_wy_fast.py (1)
prepare_wy_repr_bwd(257-312)
examples/KDA/FLA_KDA/fla_chunk_intra.py (2)
examples/KDA/FLA_KDA/fla_utils.py (1)
prepare_chunk_indices(107-112)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
examples/KDA/FLA_KDA/fla_chunk_o.py (1)
examples/KDA/FLA_KDA/fla_utils.py (2)
prepare_chunk_indices(107-112)check_shared_mem(234-240)
🪛 Ruff (0.14.10)
examples/KDA/test_utils.py
43-43: Unused function argument: atol
(ARG001)
43-43: Unused function argument: rtol
(ARG001)
54-54: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
54-54: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
examples/KDA/chunk_bwd_intra_op.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
27-27: Unused function argument: output_dtype
(ARG001)
28-28: Unused function argument: accum_dtype
(ARG001)
30-30: Unused function argument: state_dtype
(ARG001)
57-57: Unused function argument: chunk_size
(ARG001)
61-61: Unused function argument: state_dtype
(ARG001)
96-96: Unused function argument: state_dtype
(ARG001)
133-133: Unused function argument: db
(ARG001)
419-419: Unused function argument: threads
(ARG001)
420-420: Unused function argument: num_stages
(ARG001)
421-421: Unused function argument: cu_seqlens
(ARG001)
422-422: Unused function argument: chunk_indices
(ARG001)
examples/KDA/FLA_KDA/fla_chunk_delta.py
534-534: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
examples/KDA/FLA_KDA/cumsum.py
32-32: Unused function argument: B
(ARG001)
85-85: Unused function argument: B
(ARG001)
144-144: Unused function argument: B
(ARG001)
206-206: Unused function argument: B
(ARG001)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-287: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
333-333: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
364-364: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
398-398: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
423-427: Avoid specifying long messages outside the exception class
(TRY003)
435-435: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
440-440: Unused function argument: kwargs
(ARG001)
467-469: Avoid specifying long messages outside the exception class
(TRY003)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
30-30: Unused function argument: output_dtype
(ARG001)
31-31: Unused function argument: accum_dtype
(ARG001)
33-33: Unused function argument: state_dtype
(ARG001)
60-60: Unused function argument: gate_dtype
(ARG001)
130-130: Unused function argument: h0
(ARG001)
244-244: Unused function argument: block_DV
(ARG001)
245-245: Unused function argument: threads
(ARG001)
246-246: Unused function argument: num_stages
(ARG001)
247-247: Unused function argument: use_torch
(ARG001)
270-270: Unpacked variable dh_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
270-270: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
270-270: Unpacked variable dv2_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dh_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dh0_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
295-295: Unpacked variable dv2_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_o.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
22-22: Unused function argument: output_dtype
(ARG001)
23-23: Unused function argument: accum_dtype
(ARG001)
39-39: Unused function argument: DK
(ARG001)
41-41: Unused function argument: chunk_size
(ARG001)
44-44: Ambiguous variable name: O
(E741)
99-99: Ambiguous variable name: O
(E741)
198-198: Unused function argument: block_DK
(ARG001)
199-199: Unused function argument: block_DV
(ARG001)
200-200: Unused function argument: threads
(ARG001)
201-201: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
47-47: Unused function argument: DV
(ARG001)
48-48: Unused function argument: chunk_size
(ARG001)
49-49: Unused function argument: qk_dtype
(ARG001)
241-241: Unused function argument: use_gk
(ARG001)
242-242: Unused function argument: use_initial_state
(ARG001)
243-243: Unused function argument: store_final_state
(ARG001)
244-244: Unused function argument: save_new_value
(ARG001)
245-245: Unused function argument: block_DK
(ARG001)
246-246: Unused function argument: block_DV
(ARG001)
247-247: Unused function argument: threads
(ARG001)
248-248: Unused function argument: num_stages
(ARG001)
252-252: Unpacked variable dq_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
252-252: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
252-252: Unpacked variable dw_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
252-252: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
269-269: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
269-269: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
269-269: Unpacked variable dw is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
269-269: Unpacked variable dg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_inter_solve_fused.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
27-27: Unused function argument: output_dtype
(ARG001)
28-28: Unused function argument: accum_dtype
(ARG001)
49-49: Unused function argument: sub_chunk_size
(ARG001)
examples/KDA/FLA_KDA/fla_utils.py
33-33: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
140-141: try-except-pass detected, consider logging the exception
(S110)
140-140: Do not catch blind exception: Exception
(BLE001)
149-150: try-except-pass detected, consider logging the exception
(S110)
149-149: Do not catch blind exception: Exception
(BLE001)
229-229: Do not catch blind exception: BaseException
(BLE001)
239-239: Do not catch blind exception: Exception
(BLE001)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
17-17: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
41-41: Unused function argument: output_dtype
(ARG001)
42-42: Unused function argument: accum_dtype
(ARG001)
108-108: Unused function argument: block_DK
(ARG001)
248-248: Unused function argument: block_DK
(ARG001)
249-249: Unused function argument: block_DV
(ARG001)
250-250: Unused function argument: threads
(ARG001)
251-251: Unused function argument: num_stages
(ARG001)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
18-18: Unused function argument: output_dtype
(ARG001)
71-71: Unused function argument: use_qg
(ARG001)
97-97: Unused function argument: QG
(ARG001)
202-202: Unused function argument: block_DK
(ARG001)
203-203: Unused function argument: block_DV
(ARG001)
204-204: Unused function argument: threads
(ARG001)
205-205: Unused function argument: num_stages
(ARG001)
212-212: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
213-213: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/FLA_KDA/fla_chunk_inter.py
151-151: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
24-24: Unused function argument: output_dtype
(ARG001)
25-25: Unused function argument: accum_dtype
(ARG001)
27-27: Unused function argument: state_dtype
(ARG001)
49-49: Unused function argument: chunk_size
(ARG001)
52-52: Unused function argument: state_dtype
(ARG001)
91-91: Unused function argument: state_dtype
(ARG001)
282-282: Unused function argument: block_DK
(ARG001)
283-283: Unused function argument: block_DV
(ARG001)
284-284: Unused function argument: threads
(ARG001)
285-285: Unused function argument: num_stages
(ARG001)
306-306: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
306-306: Unpacked variable dv_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
306-306: Unpacked variable dbeta_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
306-306: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
306-306: Unpacked variable dA_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
332-332: Unpacked variable dA_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
332-332: Unpacked variable dk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
332-332: Unpacked variable dv_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
332-332: Unpacked variable dbeta_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
332-332: Unpacked variable dg_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_intra_token_parallel.py
22-22: Unused function argument: output_dtype
(ARG001)
23-23: Unused function argument: accum_dtype
(ARG001)
260-260: Unpacked variable Aqk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
260-260: Unpacked variable Akk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/FLA_KDA/fla_chunk_o.py
363-363: Unused function argument: g
(ARG001)
364-364: Unused function argument: g_gamma
(ARG001)
370-370: Unused function argument: scale
(ARG001)
376-376: Unused function argument: BK
(ARG001)
378-378: Unused function argument: USE_G
(ARG001)
379-379: Unused function argument: USE_G_GAMMA
(ARG001)
422-422: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
examples/KDA/chunk_bwd_dv.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
39-39: Unused function argument: chunk_size
(ARG001)
139-139: Unused function argument: scale
(ARG001)
examples/KDA/chunk_bwd_gla_dA.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
22-22: Unused function argument: chunk_size
(ARG001)
35-35: Unused function argument: DV
(ARG001)
138-138: Unused function argument: DK
(ARG001)
examples/KDA/chunk_bwd_intra.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
27-27: Unused function argument: output_dtype
(ARG001)
28-28: Unused function argument: accum_dtype
(ARG001)
30-30: Unused function argument: state_dtype
(ARG001)
57-57: Unused function argument: chunk_size
(ARG001)
61-61: Unused function argument: state_dtype
(ARG001)
96-96: Unused function argument: state_dtype
(ARG001)
133-133: Unused function argument: db
(ARG001)
419-419: Unused function argument: threads
(ARG001)
420-420: Unused function argument: num_stages
(ARG001)
421-421: Unused function argument: cu_seqlens
(ARG001)
422-422: Unused function argument: chunk_indices
(ARG001)
🔇 Additional comments (19)
examples/KDA/chunk_bwd_gla_dA.py (1)
170-182: LGTM!Test configuration with reasonable parameters for benchmarking. The hardcoded values are appropriate for an example/benchmark script.
examples/KDA/chunk_delta_bwd.py (6)
52-67: LGTM!Output tensor preparation looks correct. The unused
gate_dtypeparameter mirrors the pattern inprepare_inputfor API consistency.
70-79: LGTM!Autotune configuration generation is well-structured, producing 60 combinations for exploration.
130-130: Unusedh0input parameter in kernel.The
h0(initial state) tensor is passed to the kernel but never used in the computation. Comparing with the reference implementationchunk_gated_delta_rule_bwd_dhuinFLA_KDA/fla_chunk_delta.py,h0is also passed to the Triton kernel. Please verify whetherh0should contribute to the backward gradient computation, or if it's only needed for thedh0output derivation that occurs at line 223.
182-224: LGTM!The backward pass kernel logic correctly implements:
- Reverse iteration for gradient accumulation
- Pipelined memory operations with configurable stages
- GEMM-based gradient updates for
dvanddh- Optional GK scaling with
exp2
310-332: LGTM!Standard CUDA event-based benchmarking implementation with proper synchronization.
360-361: LGTM!Standard entry point pattern.
examples/KDA/chunk_intra_token_parallel.py (1)
112-115: Shared memory allocation size mismatch.
Aqk_sharedandAkk_sharedare allocated with shape(block_H, DK)but are used for intermediate reductions that eventually write toSum_Aqk_sharedof shape(block_H, CS). This seems correct for element-wise products, but verify thatDKmatches the expected dimension for the dot product accumulation.examples/KDA/FLA_KDA/fla_wy_fast.py (1)
309-311: Input gradientsdkanddgare overwritten without accumulation.The function receives
dkanddgas inputs (potentially containing existing gradients), but lines 309-310 unconditionally replace them withdk2anddg2. If these inputs are meant to accumulate gradients, this discards prior values.Verify if this is intentional. If accumulation is expected:
- dk = dk2 - dg = dg2 + dk.copy_(dk2) + dg.copy_(dg2)Or if the inputs serve only as workspace, consider renaming them to clarify intent.
examples/KDA/FLA_KDA/cumsum.py (1)
57-60: Inconsistent reverse cumsum implementation between scalar and vector kernels.In
chunk_local_cumsum_scalar_kernel(lines 57-60), reverse cumsum is manually computed as-cumsum + sum + original. Inchunk_local_cumsum_vector_kernel(line 113), the built-intl.cumsum(..., reverse=True)is used. This could lead to subtle numerical differences.Verify that both approaches produce identical results, or consider using the same method for consistency.
examples/KDA/wy_fast.py (2)
57-58: LGTM on autotune and JIT configuration.The autotune decorator with config generation and JIT compilation with fast math enabled follows established TileLang patterns.
115-126: LGTM on swizzled layout annotations.Proper use of swizzled layouts for shared memory buffers to optimize memory access patterns.
examples/KDA/wy_fast_bwd.py (1)
135-241: LGTM on backward kernel implementation.The kernel correctly implements the backward pass with proper shared memory allocation, fragment management, and gradient computations for dA, dk, dv, dbeta, and dg.
examples/KDA/chunk_bwd_dv.py (2)
85-86: Potential type mismatch for A_shared.
A_sharedis allocated withdo_dtypebutAtensor hasinput_dtype. This may cause implicit type conversions. Verify this is intentional.
76-104: LGTM on DV backward kernel.The kernel correctly implements the backward DV computation with proper triangular masking and pipelined execution.
examples/KDA/FLA_KDA/fla_chunk_inter.py (2)
14-29: LGTM on autotune configuration.Good use of Triton's heuristics for IS_VARLEN detection and autotune with multiple block size and warp configurations.
56-90: LGTM on kernel pointer setup and varlen handling.Proper handling of variable-length sequences with chunk indices and boundary conditions.
examples/KDA/chunk_delta_h_fwd.py (1)
168-198: LGTM on forward kernel recurrence implementation.The pipelined loop correctly stores the previous hidden state, computes the recurrence
V_new = U - W * h, applies gating withexp2, and updates the hidden state.examples/KDA/chunk_bwd_intra_op.py (1)
141-382: LGTM on intra-chunk backward kernel.The kernel correctly implements the complex intra-chunk backward pass with:
- Sub-chunk processing with proper indexing
- Upper and lower triangular handling for diagonal blocks
- Proper gradient accumulation for dq, dk, db, dg
| do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) | ||
| V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) |
There was a problem hiding this comment.
Potential dtype mismatch for V_shared.
V_shared is allocated with do_dtype, but it's used to copy from V which has input_dtype. If do_dtype != input_dtype, this could cause implicit type conversion or precision loss.
Suggested fix
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
- V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype)
+ V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) | |
| V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) | |
| do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) | |
| V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) |
🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_gla_dA.py around lines 83 - 84, The V_shared buffer
is allocated with do_dtype but later populated from V which uses input_dtype,
risking implicit casts; change the allocation for V_shared to use input_dtype
(i.e., T.alloc_shared((block_S, block_DV), dtype=input_dtype)) or explicitly
cast V to do_dtype at the copy site so the types match; update the allocation of
V_shared (and any corresponding uses) to prevent precision loss between do_dtype
and input_dtype.
| input_dtype=T.float32, | ||
| output_dtype=T.float32, | ||
| accum_dtype=T.float32, | ||
| gate_dtype=T.float32, | ||
| state_dtype=T.float32, |
There was a problem hiding this comment.
Type mismatch: passing TileLang types instead of strings.
Same pattern as other files.
Proposed fix
- input_dtype=T.float32,
- output_dtype=T.float32,
- accum_dtype=T.float32,
- gate_dtype=T.float32,
- state_dtype=T.float32,
+ input_dtype="float32",
+ output_dtype="float32",
+ accum_dtype="float32",
+ gate_dtype="float32",
+ state_dtype="float32",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_dtype=T.float32, | |
| output_dtype=T.float32, | |
| accum_dtype=T.float32, | |
| gate_dtype=T.float32, | |
| state_dtype=T.float32, | |
| input_dtype="float32", | |
| output_dtype="float32", | |
| accum_dtype="float32", | |
| gate_dtype="float32", | |
| state_dtype="float32", |
🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_intra_op.py around lines 508 - 512, The dtype
parameters (input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype) are
using TileLang type objects (T.float32) instead of string names; update each to
use the matching string representation (e.g., "float32") wherever these kwargs
are set in this module (replace T.float32 with "float32" for
input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype) to match the
pattern used in the other files.
| if check_shared_mem("hopper", k.device.index): | ||
| CONST_TILING = 128 | ||
| elif check_shared_mem: | ||
| CONST_TILING = 64 | ||
| else: | ||
| CONST_TILING = 32 |
There was a problem hiding this comment.
Bug: Missing function call parentheses.
check_shared_mem is a function, but elif check_shared_mem: checks if the function object is truthy (always True) instead of calling it. This means the elif branch will always be taken when the first condition is false.
Suggested fix
if check_shared_mem("hopper", k.device.index):
CONST_TILING = 128
- elif check_shared_mem:
+ elif check_shared_mem():
CONST_TILING = 64
else:
CONST_TILING = 32🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_chunk_o.py around lines 432 - 437, The elif
currently tests the function object instead of calling it; change "elif
check_shared_mem:" to call the function (e.g., "elif
check_shared_mem(\"hopper\", k.device.index):" or the correct argument set for
your use case) so the condition actually invokes check_shared_mem rather than
always evaluating truthiness of the function object, and keep CONST_TILING
assignments as-is.
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = 64 | ||
| if chunk_indices is None and cu_seqlens is not None: | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| CONST_TILING = 64 | ||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) |
There was a problem hiding this comment.
Inconsistent BT derivation between forward and backward.
In recompute_w_u_fwd (line 221), BT = A.shape[-1] is derived from the input. In prepare_wy_repr_bwd (line 271), BT = 64 is hardcoded. If the forward pass uses a different BT, the backward pass will produce incorrect results.
🐛 Proposed fix
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
gk: torch.Tensor,
A: torch.Tensor,
dk: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
dg: torch.Tensor,
cu_seqlens: torch.LongTensor = None,
chunk_indices: torch.LongTensor = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
- BT = 64
+ BT = A.shape[-1]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| B, T, H, K, V = *k.shape, v.shape[-1] | |
| BT = 64 | |
| if chunk_indices is None and cu_seqlens is not None: | |
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | |
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | |
| CONST_TILING = 64 | |
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | |
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| B, T, H, K, V = *k.shape, v.shape[-1] | |
| BT = A.shape[-1] | |
| if chunk_indices is None and cu_seqlens is not None: | |
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | |
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | |
| CONST_TILING = 64 | |
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | |
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) |
🤖 Prompt for AI Agents
In @examples/KDA/FLA_KDA/fla_wy_fast.py around lines 269 - 277, The backward
helper prepare_wy_repr_bwd hardcodes BT = 64 which can mismatch the forward
recompute_w_u_fwd where BT is derived as A.shape[-1]; change prepare_wy_repr_bwd
to compute BT = A.shape[-1] (or otherwise derive BT from the same input used in
forward, e.g., A or the relevant tensor shape) instead of the hardcoded 64 so
forward and backward use the same tiling.
| if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): | ||
| print_red_warning(f"{name} Error: nonfinite value mismatch") | ||
| if raise_assert: | ||
| raise AssertionError |
There was a problem hiding this comment.
Inverted mask logic in non-finite value comparison.
The comparison on line 27 uses masked_fill(x_mask, 0) which fills finite positions with 0, leaving non-finite values for comparison. This appears inverted—you likely want to compare the non-finite values by masking out the finite ones with ~x_mask.
🐛 Proposed fix
- if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
+ if not torch.isclose(x.masked_fill(~x_mask, 0), y.masked_fill(~y_mask, 0), rtol=0, atol=0, equal_nan=True).all():🤖 Prompt for AI Agents
In @examples/KDA/test_utils.py around lines 27 - 30, The non-finite comparison
in the block using x.masked_fill(x_mask, 0) and y.masked_fill(y_mask, 0) has the
mask inverted; currently finite positions are being filled and non-finite values
compared. Change the masked_fill calls to use the inverted masks (~x_mask and
~y_mask) so you mask out (fill) finite values and compare only non-finite
positions (keep using equal_nan=True and the existing error/raise behavior),
referencing the variables x, y, x_mask, y_mask and the masked_fill calls to
locate the change.
| input_dtype=T.float32, | ||
| output_dtype=T.float32, | ||
| accum_dtype=T.float32, | ||
| gate_dtype=T.float32, | ||
| state_dtype=T.float32, | ||
| chunk_size=64, | ||
| block_DK=32, | ||
| block_DV=32, | ||
| threads=128, | ||
| num_stages=0, | ||
| ) |
There was a problem hiding this comment.
Type mismatch: passing TileLang types instead of strings.
Similar to wy_fast.py, main() passes T.float32 but run_test expects string dtype arguments for getattr(torch, ...).
Proposed fix
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
- input_dtype=T.float32,
- output_dtype=T.float32,
- accum_dtype=T.float32,
- gate_dtype=T.float32,
- state_dtype=T.float32,
+ input_dtype="float32",
+ output_dtype="float32",
+ accum_dtype="float32",
+ gate_dtype="float32",
+ state_dtype="float32",
chunk_size=64,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_dtype=T.float32, | |
| output_dtype=T.float32, | |
| accum_dtype=T.float32, | |
| gate_dtype=T.float32, | |
| state_dtype=T.float32, | |
| chunk_size=64, | |
| block_DK=32, | |
| block_DV=32, | |
| threads=128, | |
| num_stages=0, | |
| ) | |
| input_dtype="float32", | |
| output_dtype="float32", | |
| accum_dtype="float32", | |
| gate_dtype="float32", | |
| state_dtype="float32", | |
| chunk_size=64, | |
| block_DK=32, | |
| block_DV=32, | |
| threads=128, | |
| num_stages=0, |
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast_bwd.py around lines 364 - 374, The call to main() in
wy_fast_bwd.py is passing TileLang types (T.float32) for dtype params while
run_test expects string names to use getattr(torch, ...); update the main(...)
invocation to pass string dtype names (e.g., "float32") for input_dtype,
output_dtype, accum_dtype, gate_dtype, and state_dtype so run_test can resolve
torch dtypes via getattr; keep all other numeric/size params (chunk_size,
block_DK, block_DV, threads, num_stages) unchanged.
| input_dtype=T.bfloat16, | ||
| output_dtype=T.bfloat16, | ||
| gate_dtype=T.float32, | ||
| accum_dtype=T.float32, |
There was a problem hiding this comment.
Type mismatch: passing TileLang types instead of strings.
run_test expects string dtype arguments (e.g., "bfloat16") that are resolved via getattr(torch, input_dtype), but main() passes T.bfloat16, T.float32 directly. This will cause getattr(torch, T.bfloat16) to fail.
Proposed fix
run_test(
B=1,
S=8192,
H=64,
DK=128,
DV=128,
chunk_size=64,
- input_dtype=T.bfloat16,
- output_dtype=T.bfloat16,
- gate_dtype=T.float32,
- accum_dtype=T.float32,
+ input_dtype="bfloat16",
+ output_dtype="bfloat16",
+ gate_dtype="float32",
+ accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_dtype=T.bfloat16, | |
| output_dtype=T.bfloat16, | |
| gate_dtype=T.float32, | |
| accum_dtype=T.float32, | |
| run_test( | |
| B=1, | |
| S=8192, | |
| H=64, | |
| DK=128, | |
| DV=128, | |
| chunk_size=64, | |
| input_dtype="bfloat16", | |
| output_dtype="bfloat16", | |
| gate_dtype="float32", | |
| accum_dtype="float32", | |
| block_DK=64, | |
| block_DV=32, | |
| threads=128, | |
| num_stages=3, | |
| ) |
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 266 - 269, The bug is that main()
passes TileLang type objects (T.bfloat16, T.float32) into run_test which expects
string dtype names resolved with getattr(torch, ...); update the call site in
main() to pass dtype names as strings (e.g., "bfloat16" for
input_dtype/output_dtype and "float32" for gate_dtype/accum_dtype) so run_test
can successfully call getattr(torch, dtype). Reference symbols: main(),
run_test(), and the dtype parameters
input_dtype/output_dtype/gate_dtype/accum_dtype.
There was a problem hiding this comment.
@wfloveiu, thank you for confirming! Once you've applied the fix to pass string dtype names ("bfloat16", "float32") instead of TileLang type objects, the code should work correctly. Let me know if you need any assistance with the change.
🧠 Learnings used
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
|
Fixed linting issues, ready for review. |
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 242-245: Re-enable the correctness checks by uncommenting the
compare_tensors calls (compare_tensors("dq", dq_ref, dq), compare_tensors("dk",
dk_ref, dk), compare_tensors("dw", dw_ref, dw), compare_tensors("dg", dg_ref,
dg")) so the TileLang outputs are validated against the reference tensors; if
you expect numerical differences, call compare_tensors with an appropriate
tolerance or document known precision differences near these calls and adjust
the tolerance parameter (or add a comment) to avoid false failures while
preserving verification.
In @examples/KDA/chunk_bwd_intra_op.py:
- Line 16: Remove the debug print setting call
torch.set_printoptions(profile="full") from the top-level of
chunk_bwd_intra_op.py; simply delete that line (and any other ad-hoc
torch.set_printoptions calls) so tensors aren’t forced into verbose output
during normal execution, leaving normal print behavior unchanged.
In @examples/KDA/chunk_delta_bwd.py:
- Around line 301-303: Uncomment the validation calls so TileLang outputs are
compared to the reference: restore the three compare_tensors calls for "dh",
"dh0", and "dv2" (i.e., compare_tensors("dh", dh_ref, dh_tilelang),
compare_tensors("dh0", dh0_ref, dh0_tilelang), compare_tensors("dv2", dv2_ref,
dv2_tilelang")) in chunk_delta_bwd.py so the kernel outputs (dh_tilelang,
dh0_tilelang, dv2_tilelang) are actually validated against
dh_ref/dh0_ref/dv2_ref.
- Around line 121-136: The kernel function declares an unused parameter h0 in
its signature; either remove h0 from the kernel(...) parameter list or implement
the initial state handling referenced by use_initial_state so h0 is consumed. If
removing, update any callers and tests that pass h0 and delete related
h0_shape/h0 dtype defs; if implementing, add logic in kernel (e.g., within the
use_initial_state branch) to incorporate h0 into the backward computation and
ensure dh0 (or dh) is computed correctly from h0. Update any docstrings/comments
to reflect the chosen behavior.
In @examples/KDA/chunk_delta_h_fwd.py:
- Around line 272-282: The benchmark call to do_bench for
chunk_gated_delta_rule_fwd_h is passing the gate parameter as g=G but the
reference implementation expects gk=G, so update the argument name from g to gk
in that do_bench invocation (and verify the chunk_gated_delta_rule_fwd_h
signature accepts gk) to ensure the correct value is forwarded; adjust any other
mismatched calls to use gk consistently.
In @examples/KDA/chunk_intra_token_parallel.py:
- Around line 185-197: The run_test function accepts a scale parameter but it is
only forwarded to the reference implementation
chunk_kda_fwd_intra_token_parallel and not to the TileLang kernel; update the
TileLang kernel's function signature to accept a scale argument (matching the
reference) and pass scale in the kernel invocation inside run_test (the call
currently missing scale); ensure any internal uses of scale in the kernel are
wired through and remove or update the comment "# scale 如何传值" once fixed so both
reference and kernel implementations receive identical parameters for correct
comparison.
In @examples/KDA/wy_fast_bwd.py:
- Around line 310-314: The correctness checks were disabled by commenting out
the compare_tensors calls; re-enable verification by uncommenting the
compare_tensors invocations for dA, dk, dv, dbeta, and dg so dk_tilelang/dk_ref,
dv_tilelang/dv_ref, dbeta_tilelang/dbeta_ref, dg_tilelang/dg_ref, and
dA_tilelang/dA_ref are compared; ensure these reference tensors are populated
before the calls and keep the comparisons after the TileLang outputs are
computed so any mismatches are reported.
In @examples/KDA/wy_fast.py:
- Around line 179-180: The variable use_qg is mistakenly assigned a
single-element tuple (False,) which is truthy; change the assignment of use_qg
in wy_fast.py from the tuple to a boolean (use_qg = False) so conditionals
behave correctly, and scan for any other occurrences of use_qg to ensure they
expect a boolean.
- Line 222: The print statement prints "tritron time" which is a typo; find the
print call that references triton_time (print("tritron time:", triton_time)) and
correct the string to "triton time:" so the output label matches the variable
name and intended wording.
🧹 Nitpick comments (19)
examples/KDA/test_utils.py (3)
43-83: Unusedatolandrtolparameters.The function signature accepts
atolandrtolbut they're never used. Either remove them or integrate them into the comparison logic (e.g., usingtorch.allclosewith these tolerances).♻️ Suggested fix to use or remove unused parameters
def compare_tensors(name, x, y, atol=1e-5, rtol=1e-5): import numpy as np import torch diff = (x - y).abs() # ========= 最大绝对误差 ========= max_abs_err = diff.max().item() abs_flat_idx = diff.argmax() abs_idx = list(np.unravel_index(abs_flat_idx.cpu().numpy(), diff.shape)) ... print("=====================================\n") + + # Return pass/fail based on tolerances + return torch.allclose(x, y, atol=atol, rtol=rtol)
86-107: Duplicatedo_benchimplementation.This function is duplicated in
examples/KDA/chunk_o.py(lines 161-183). Consider consolidating into a single shared utility to avoid maintenance burden.
3-3: Remove unused constantRCP_LN2or document its intended purpose.The constant is defined but never referenced anywhere in the codebase.
examples/KDA/chunk_bwd_dv.py (2)
31-40: Unusedchunk_sizeparameter.The
chunk_sizeparameter is accepted but never used. Consider removing it for clarity.♻️ Suggested fix
def prepare_output( B, S, H, DV, - chunk_size, output_dtype, ): dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() return dv
117-132: Unused variabledv_tilelangfromprepare_output.Line 120 allocates
dv_tilelangviaprepare_output, but it's immediately overwritten by the kernel result on line 131. The pre-allocation is wasted.♻️ Suggested fix - remove unnecessary pre-allocation
- dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype)) kernel = tilelang_chunk_bwd_kernel_dv_local( B=B, S=S, H=H, DV=DV, input_dtype=input_dtype, output_dtype=output_dtype, do_dtype=do_dtype, chunk_size=chunk_size, ) dv_tilelang = kernel(DO, A)examples/KDA/chunk_bwd_gla_dA.py (2)
14-37: Unused parameters in preparation functions.
prepare_input:chunk_size(line 19) is unused.prepare_output:DV(line 32) is unused; the output shape useschunk_sizefor the last dimension.Consider removing unused parameters or documenting the intent for API consistency.
118-135: Pre-allocateddA_tilelangis unused.Similar to chunk_bwd_dv.py, line 122 allocates
dA_tilelangbut it's overwritten on line 134. The allocation is wasted.examples/KDA/chunk_bwd_intra_op.py (2)
199-199: Minor typo in comment."ofprevious" should be "of previous".
📝 Fix typo
- for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index of previous sub_chunks
444-448: Pre-allocated outputs are immediately overwritten.Lines 444-446 prepare output tensors that are immediately overwritten by the kernel call on line 447. The
prepare_outputcall is unnecessary.♻️ Suggested fix
- dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) - ) dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)examples/KDA/chunk_delta_h_fwd.py (2)
3-3: Remove unusednoqadirective.The
sysimport isn't actually used and has an unnecessarynoqa: F401directive. Consider removing the import entirely if not needed, or remove the noqa comment if the import is required.Suggested fix
-import sys # noqa: F401
301-305: Dtype passed as TileLang type instead of string.In
main(), the dtypes are passed asT.float16,T.float32etc., but inrun_test()lines 225-228, they are converted usinggetattr(torch, input_dtype). This works because TileLang types have string representations, but it's inconsistent with how strings are used elsewhere and may be confusing.examples/KDA/wy_fast.py (1)
184-185: Unused variables should use underscore prefix.
QG_refandQG_tilelangare unpacked but never used. Use underscore prefix to indicate intentional discard.Suggested fix
- W_ref, U_ref, QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) - W_tilelang, U_tilelang, QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) + W_ref, U_ref, _QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) + W_tilelang, U_tilelang, _QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))examples/KDA/chunk_bwd_dqkwg.py (2)
48-52: Usetorch.emptyinstead oftorch.randnfor output tensors.Output tensors are initialized with random values using
torch.randn, but they will be overwritten by the kernel. Usingtorch.emptyis more appropriate and slightly more efficient.Suggested fix
- dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() - dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() - dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() - dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + dq = torch.empty(B, S, H, DK, dtype=torch.float32).cuda() + dk = torch.empty(B, S, H, DK, dtype=torch.float32).cuda() + dw = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() + dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda()
223-240: Unused reference output variables.The reference outputs
dq_ref,dk_ref,dw_ref,dg_refand TileLang outputsdq,dk,dw,dgare computed but never used (comparison is commented out). Use underscore prefix for intentional discard or enable the comparison.examples/KDA/chunk_intra_token_parallel.py (1)
137-144: Commented-out swizzled layout annotations.The
T.annotate_layoutblock is commented out. If swizzling improves performance, consider enabling it. If it causes issues, document why it's disabled.examples/KDA/chunk_bwd_intra.py (2)
199-199: Minor typo in comment."ofprevious" should be "of previous".
Suggested fix
- for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index of previous sub_chunks
80-84: TMA disabled - verify necessity.The kernel disables TMA lowering via
TL_DISABLE_TMA_LOWER: True. If this is a workaround for a known issue, consider adding a comment explaining why.examples/KDA/wy_fast_bwd.py (2)
282-292: Unused reference output variables.All reference outputs (
dk_ref,dv_ref,dbeta_ref,dg_ref,dA_ref) are computed but never used since comparison is disabled. Use underscore prefixes or enable the comparison.Suggested fix if keeping disabled
- dk_ref, dv_ref, dbeta_ref, dg_ref, dA_ref = prepare_wy_repr_bwd( + _dk_ref, _dv_ref, _dbeta_ref, _dg_ref, _dA_ref = prepare_wy_repr_bwd(
76-79: Both TMA and warp specialization disabled.The kernel disables both
TL_DISABLE_TMA_LOWERandTL_DISABLE_WARP_SPECIALIZED. Consider adding a comment explaining why these optimizations are disabled for this specific kernel.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
examples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_dv.pyexamples/KDA/chunk_bwd_gla_dA.pyexamples/KDA/chunk_bwd_intra.pyexamples/KDA/chunk_bwd_intra_op.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_inter_solve_fused.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/test_utils.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/KDA/chunk_bwd_intra_op.pyexamples/KDA/wy_fast.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/wy_fast_bwd.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/chunk_bwd_dv.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/KDA/wy_fast.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_intra_token_parallel.py
🧬 Code graph analysis (8)
examples/KDA/wy_fast.py (3)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_wy_fast.py (1)
recompute_w_u_fwd(210-254)examples/KDA/test_utils.py (2)
compare_tensors(43-83)do_bench(86-108)
examples/KDA/chunk_delta_h_fwd.py (1)
examples/KDA/test_utils.py (2)
compare_tensors(43-83)do_bench(86-108)
examples/KDA/chunk_delta_bwd.py (1)
examples/KDA/test_utils.py (1)
do_bench(86-108)
examples/KDA/test_utils.py (3)
tilelang/carver/roller/policy/default.py (1)
sim(290-291)tilelang/language/tir/op.py (1)
all(1913-1930)examples/KDA/chunk_o.py (1)
do_bench(162-184)
examples/KDA/chunk_bwd_intra.py (2)
examples/KDA/FLA_KDA/fla_chunk_intra.py (1)
chunk_kda_bwd_intra(541-608)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
examples/KDA/wy_fast_bwd.py (3)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/KDA/FLA_KDA/fla_wy_fast.py (1)
prepare_wy_repr_bwd(257-312)examples/KDA/test_utils.py (1)
do_bench(86-108)
examples/KDA/chunk_intra_token_parallel.py (2)
examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)examples/KDA/test_utils.py (1)
do_bench(86-108)
examples/KDA/chunk_bwd_dv.py (1)
examples/KDA/test_utils.py (2)
compare_tensors(43-83)do_bench(86-108)
🪛 Ruff (0.14.10)
examples/KDA/chunk_bwd_intra_op.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
29-29: Unused function argument: state_dtype
(ARG001)
56-56: Unused function argument: chunk_size
(ARG001)
60-60: Unused function argument: state_dtype
(ARG001)
95-95: Unused function argument: state_dtype
(ARG001)
132-132: Unused function argument: db
(ARG001)
396-396: Unused function argument: threads
(ARG001)
397-397: Unused function argument: num_stages
(ARG001)
398-398: Unused function argument: cu_seqlens
(ARG001)
399-399: Unused function argument: chunk_indices
(ARG001)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused function argument: output_dtype
(ARG001)
68-68: Unused function argument: use_qg
(ARG001)
94-94: Unused function argument: QG
(ARG001)
174-174: Unused function argument: block_DK
(ARG001)
175-175: Unused function argument: block_DV
(ARG001)
176-176: Unused function argument: threads
(ARG001)
177-177: Unused function argument: num_stages
(ARG001)
184-184: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
185-185: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
31-31: Unused function argument: output_dtype
(ARG001)
32-32: Unused function argument: accum_dtype
(ARG001)
98-98: Unused function argument: block_DK
(ARG001)
213-213: Unused function argument: block_DK
(ARG001)
214-214: Unused function argument: block_DV
(ARG001)
215-215: Unused function argument: threads
(ARG001)
216-216: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
28-28: Unused function argument: output_dtype
(ARG001)
29-29: Unused function argument: accum_dtype
(ARG001)
31-31: Unused function argument: state_dtype
(ARG001)
58-58: Unused function argument: gate_dtype
(ARG001)
128-128: Unused function argument: h0
(ARG001)
242-242: Unused function argument: block_DV
(ARG001)
243-243: Unused function argument: threads
(ARG001)
244-244: Unused function argument: num_stages
(ARG001)
245-245: Unused function argument: use_torch
(ARG001)
268-268: Unpacked variable dh_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
268-268: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
268-268: Unpacked variable dv2_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
291-291: Unpacked variable dh_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
291-291: Unpacked variable dh0_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
291-291: Unpacked variable dv2_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/test_utils.py
43-43: Unused function argument: atol
(ARG001)
43-43: Unused function argument: rtol
(ARG001)
54-54: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
54-54: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
examples/KDA/chunk_bwd_intra.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
29-29: Unused function argument: state_dtype
(ARG001)
56-56: Unused function argument: chunk_size
(ARG001)
60-60: Unused function argument: state_dtype
(ARG001)
95-95: Unused function argument: state_dtype
(ARG001)
132-132: Unused function argument: db
(ARG001)
396-396: Unused function argument: threads
(ARG001)
397-397: Unused function argument: num_stages
(ARG001)
398-398: Unused function argument: cu_seqlens
(ARG001)
399-399: Unused function argument: chunk_indices
(ARG001)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
43-43: Unused function argument: DV
(ARG001)
44-44: Unused function argument: chunk_size
(ARG001)
45-45: Unused function argument: qk_dtype
(ARG001)
212-212: Unused function argument: use_gk
(ARG001)
213-213: Unused function argument: use_initial_state
(ARG001)
214-214: Unused function argument: store_final_state
(ARG001)
215-215: Unused function argument: save_new_value
(ARG001)
216-216: Unused function argument: block_DK
(ARG001)
217-217: Unused function argument: block_DV
(ARG001)
218-218: Unused function argument: threads
(ARG001)
219-219: Unused function argument: num_stages
(ARG001)
223-223: Unpacked variable dq_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dw_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dw is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_inter_solve_fused.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
24-24: Unused function argument: output_dtype
(ARG001)
25-25: Unused function argument: accum_dtype
(ARG001)
46-46: Unused function argument: sub_chunk_size
(ARG001)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
25-25: Unused function argument: output_dtype
(ARG001)
26-26: Unused function argument: accum_dtype
(ARG001)
28-28: Unused function argument: state_dtype
(ARG001)
50-50: Unused function argument: chunk_size
(ARG001)
53-53: Unused function argument: state_dtype
(ARG001)
92-92: Unused function argument: state_dtype
(ARG001)
258-258: Unused function argument: block_DK
(ARG001)
259-259: Unused function argument: block_DV
(ARG001)
260-260: Unused function argument: threads
(ARG001)
261-261: Unused function argument: num_stages
(ARG001)
282-282: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dv_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dbeta_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dA_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dA_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dv_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dbeta_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dg_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_intra_token_parallel.py
20-20: Unused function argument: output_dtype
(ARG001)
21-21: Unused function argument: accum_dtype
(ARG001)
233-233: Unpacked variable Aqk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
233-233: Unpacked variable Akk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_bwd_dv.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
36-36: Unused function argument: chunk_size
(ARG001)
111-111: Unused function argument: scale
(ARG001)
examples/KDA/chunk_bwd_gla_dA.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
19-19: Unused function argument: chunk_size
(ARG001)
32-32: Unused function argument: DV
(ARG001)
110-110: Unused function argument: DK
(ARG001)
🔇 Additional comments (12)
examples/KDA/test_utils.py (1)
20-40: Inverted mask logic in non-finite value comparison.Line 27 uses
masked_fill(x_mask, 0)which fills finite values with 0, keeping non-finite values. This is correct for comparing non-finite values. However, the variable naming and comments could be clearer, and the logic appears intentional for checking that non-finite values (NaN, Inf) match in both tensors.Lines 31-32 correctly zero out non-finite values for the similarity calculation.
examples/KDA/chunk_bwd_dv.py (1)
54-102: TileLang kernel implementation looks correct.The kernel properly:
- Uses swizzled layouts for shared memory
- Applies lower-triangular masking on A
- Uses pipelined loops for DV dimension
- Performs GEMM with A transposed
examples/KDA/chunk_bwd_gla_dA.py (1)
51-103: Kernel implementation is well-structured.The kernel correctly:
- Accumulates dA via pipelined GEMM across V blocks
- Applies lower-triangular masking with scale
- Uses appropriate swizzled layouts
examples/KDA/chunk_delta_bwd.py (1)
80-103: Kernel factory implementation follows expected patterns.The kernel properly handles:
- Multiple tensor outputs via
out_idx=[-3, -2, -1]- Autotuning with configurable parameters
- Proper gradient accumulation in reverse order (line 182)
Note: The
use_initial_stateflag conditionally storesdh0(line 220-221), buth0input is still unused.examples/KDA/chunk_bwd_intra_op.py (2)
121-382: Complex intra-chunk backward kernel implementation.The kernel handles:
- Sub-chunk processing with proper indexing
- Inter-sub-chunk gradient flow via
kg_fragmentand gating- Lower triangular diagonal processing
- Proper accumulation of
dq2,dk2,db,dg2The implementation follows the expected pattern from the FLA reference and correctly handles boundary conditions.
130-133: Input tensordbdeclared but unused in kernel.The kernel signature includes
db: T.Tensor(db_shape, dtype=input_dtype)but it's never read inside the kernel. The actualdbaddition happens post-kernel inrun_test(line 448:db_tilelang.sum(0).add_(db)).This is correct behavior since the kernel computes partial
db2values that are summed and added to the originaldbafterward. Consider adding a brief comment to clarify this design.examples/KDA/chunk_delta_h_fwd.py (1)
158-160: Verify h store ordering in pipelined loop.The hidden state
his stored at the beginning of each iteration before the recurrence update is computed. This meansh[i_s]receives the value computed in the previous iteration (or the initial state for i_s=0). Confirm this matches the intended semantics whereh[i]should contain the state before processing chunki.examples/KDA/chunk_inter_solve_fused.py (3)
558-559: LGTM: Correctness verification is enabled.Unlike some other files in this PR, this file properly enables
compare_tensorscalls to verify the TileLang kernel output matches the reference implementation.
326-336: Verify forward substitution loop bounds.The pipelined loop starts at index 2 (
T.Pipelined(2, T.min(BC, S - i_tc0), ...)). Verify this is intentional and correctly handles the first two rows of the block matrix during forward substitution. A comment explaining why the loop starts at 2 would improve clarity.
89-100: Helpful docstring for complex kernel.Good documentation of the kernel's purpose and the six-step algorithm. This makes the long kernel much more maintainable.
examples/KDA/chunk_bwd_intra.py (1)
455-458: LGTM: Correctness verification is enabled.The
compare_tensorscalls are properly enabled to verifydq,dk,db, anddgoutputs against the reference implementation.examples/KDA/chunk_intra_token_parallel.py (1)
264-269: Correctness verification is disabled.The
compare_tensorscalls and debug prints are commented out. Enable these to verify the TileLang implementation matches the reference.⛔ Skipped due to learnings
Learnt from: silentCoder-dev Repo: tile-ai/tilelang PR: 1464 File: testing/python/language/test_tilelang_language_rand.py:14-14 Timestamp: 2025-12-18T04:50:00.512Z Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
| @T.prim_func | ||
| def kernel( | ||
| # Input | ||
| Q: T.Tensor(Q_shape, dtype=input_dtype), | ||
| K: T.Tensor(K_shape, dtype=input_dtype), | ||
| W: T.Tensor(W_shape, dtype=input_dtype), | ||
| GK: T.Tensor(G_shape, dtype=gate_dtype), | ||
| h0: T.Tensor(h0_shape, dtype=input_dtype), | ||
| dht: T.Tensor(dht_shape, dtype=input_dtype), | ||
| dO: T.Tensor(dO_shape, dtype=input_dtype), | ||
| dv: T.Tensor(dv_shape, dtype=input_dtype), | ||
| # Output | ||
| dh: T.Tensor(dh_shape, dtype=output_dtype), | ||
| dh0: T.Tensor(dh0_shape, dtype=state_dtype), | ||
| dv2: T.Tensor(dv2_shape, dtype=output_dtype), | ||
| ): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n examples/KDA/chunk_delta_bwd.py | head -180 | tail -70Repository: tile-ai/tilelang
Length of output: 3903
🏁 Script executed:
rg -n '\bh0\b' examples/KDA/chunk_delta_bwd.pyRepository: tile-ai/tilelang
Length of output: 653
🏁 Script executed:
wc -l examples/KDA/chunk_delta_bwd.pyRepository: tile-ai/tilelang
Length of output: 95
🏁 Script executed:
cat -n examples/KDA/chunk_delta_bwd.py | sed -n '180,240p'Repository: tile-ai/tilelang
Length of output: 3286
Remove unused parameter h0 from kernel signature or implement initial state handling logic.
The kernel declares h0: T.Tensor(h0_shape, dtype=input_dtype) at line 128 but never uses it in the kernel body. Given the use_initial_state conditional at line 220, h0 may have been intended for initialization logic that is not yet implemented. Either remove the parameter if it's unused, or implement the logic to incorporate the initial hidden state in the backward pass.
🧰 Tools
🪛 Ruff (0.14.10)
128-128: Unused function argument: h0
(ARG001)
🤖 Prompt for AI Agents
In @examples/KDA/chunk_delta_bwd.py around lines 121 - 136, The kernel function
declares an unused parameter h0 in its signature; either remove h0 from the
kernel(...) parameter list or implement the initial state handling referenced by
use_initial_state so h0 is consumed. If removing, update any callers and tests
that pass h0 and delete related h0_shape/h0 dtype defs; if implementing, add
logic in kernel (e.g., within the use_initial_state branch) to incorporate h0
into the backward computation and ensure dh0 (or dh) is computed correctly from
h0. Update any docstrings/comments to reflect the chosen behavior.
| def run_test( | ||
| B, | ||
| S, | ||
| H, | ||
| DK, | ||
| scale, | ||
| input_dtype, | ||
| output_dtype, | ||
| accum_dtype, | ||
| gate_dtype, | ||
| chunk_size, | ||
| sub_chunk_size, | ||
| ): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat examples/KDA/chunk_intra_token_parallel.py | head -300 | tail -250Repository: tile-ai/tilelang
Length of output: 8024
🏁 Script executed:
wc -l examples/KDA/chunk_intra_token_parallel.pyRepository: tile-ai/tilelang
Length of output: 106
🏁 Script executed:
# First, let's look at the file structure - check lines around the kernel definition and run_test
sed -n '55,80p' examples/KDA/chunk_intra_token_parallel.pyRepository: tile-ai/tilelang
Length of output: 618
🏁 Script executed:
# Check the run_test function signature (lines 185-197)
sed -n '185,197p' examples/KDA/chunk_intra_token_parallel.pyRepository: tile-ai/tilelang
Length of output: 220
🏁 Script executed:
# Check how scale is used in the kernel call and reference implementation
sed -n '200,270p' examples/KDA/chunk_intra_token_parallel.pyRepository: tile-ai/tilelang
Length of output: 1954
Unused scale parameter in run_test creates inconsistency between reference and kernel implementations.
The scale parameter is passed to run_test (line 185) and used when calling the reference implementation chunk_kda_fwd_intra_token_parallel (line 249), but is not passed to the TileLang kernel (line 252). The kernel function signature (lines 60-73) lacks a scale parameter entirely. Additionally, a code comment (# scale 如何传值 on line 243) indicates the developer was aware this parameter needs to be handled. The TileLang kernel and reference implementation must receive the same parameters for correctness comparison.
🤖 Prompt for AI Agents
In @examples/KDA/chunk_intra_token_parallel.py around lines 185 - 197, The
run_test function accepts a scale parameter but it is only forwarded to the
reference implementation chunk_kda_fwd_intra_token_parallel and not to the
TileLang kernel; update the TileLang kernel's function signature to accept a
scale argument (matching the reference) and pass scale in the kernel invocation
inside run_test (the call currently missing scale); ensure any internal uses of
scale in the kernel are wired through and remove or update the comment "# scale
如何传值" once fixed so both reference and kernel implementations receive identical
parameters for correct comparison.
| tilelang_time = do_bench(kernel, K, V, Beta, G, A) | ||
| triton_time = do_bench(recompute_w_u_fwd, k=K, v=V, beta=Beta, gk=G, A=A) | ||
| print("tilelang time:", tilelang_time) | ||
| print("tritron time:", triton_time) |
There was a problem hiding this comment.
Typo: "tritron" should be "triton".
Suggested fix
print("tilelang time:", tilelang_time)
- print("tritron time:", triton_time)
+ print("triton time:", triton_time)🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py at line 222, The print statement prints "tritron
time" which is a typo; find the print call that references triton_time
(print("tritron time:", triton_time)) and correct the string to "triton time:"
so the output label matches the variable name and intended wording.
|
Hi! Thank you for contribution! |
There was a problem hiding this comment.
Actionable comments posted: 16
🤖 Fix all issues with AI agents
In @examples/KDA/chunk_bwd_dqkwg.py:
- Around line 242-245: The tensor validation calls have been commented out so
correctness isn't verified; re-enable the comparisons by restoring the
compare_tensors invocations for "dq", "dk", "dw", and "dg" (using dq_ref/dq,
dk_ref/dk, dw_ref/dw, dg_ref/dg) in chunk_bwd_dqkwg.py (around the current
commented block) or wrap them behind a debug/validation flag so they run during
test/CI; ensure that compare_tensors is imported/defined and invoked with the
same argument order and any necessary tolerance parameters so validation
succeeds.
- Around line 38-52: The prepare_output function ignores the qk_dtype parameter
by hardcoding dq and dk to torch.float32; update dq and dk allocations in
prepare_output to use dtype=qk_dtype (or convert them to qk_dtype after
creation) so the passed qk_dtype controls their dtype, leaving dw and dg using
gate_dtype as before; locate prepare_output and replace the dtype=torch.float32
occurrences for dq and dk with dtype=qk_dtype.
- Around line 273-275: main() is passing TileLang type objects (T.float32) into
run_test which expects string names and uses getattr(torch, input_dtype); change
the three arguments passed (input_dtype, gate_dtype, qk_dtype) from T.float32
etc. to string names like "float32" (or "float16" as appropriate) so
getattr(torch, ...) works, or alternatively change run_test to accept dtype
objects and skip getattr(torch, ...) — update either the call sites in main() or
the implementation of run_test (referenced by the run_test function name) so the
types and usage match.
In @examples/KDA/chunk_delta_bwd.py:
- Around line 298-300: The three commented-out validation calls disable
correctness checks; restore or justify them by uncommenting the compare_tensors
calls for "dh", "dh0", and "dv2" (i.e., compare_tensors("dh", dh_ref,
dh_tilelang); compare_tensors("dh0", dh0_ref, dh0_tilelang);
compare_tensors("dv2", dv2_ref, dv2_tilelang)) so the test actually validates
outputs, or if they must remain disabled, add a clear inline comment explaining
why (e.g., non-determinism, known numerical drift) and add an alternative
lightweight assertion (shape/dtype checks) to ensure the test still provides
basic validation.
- Around line 311-315: main() is passing TileLang types (e.g., T.bfloat16,
T.float32) into run_test where run_test calls getattr(torch, dtype) expecting
dtype to be a string; change the arguments passed to run_test for input_dtype,
output_dtype, accum_dtype, gate_dtype, and state_dtype to be dtype name strings
(e.g., "bfloat16", "float32") instead of TileLang types so getattr(torch, dtype)
works correctly, or alternatively convert TileLang types to their name strings
before calling getattr inside run_test (refer to the run_test call site and the
parameter names input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype).
In @examples/KDA/chunk_delta_h_fwd.py:
- Line 58: BS is computed with floor division (BS = S // chunk_size) which can
mismatch the kernel's ceil-div iteration (T.ceildiv(S, block_S)) and cause
out-of-bounds writes; change the computation to use ceiling division (e.g., BS =
T.ceildiv(S, chunk_size)) or add an explicit validation that S % chunk_size == 0
and raise/handle the error; update references where BS is used (and any code
expecting exact divisibility) to match the chosen approach.
- Around line 294-315: The bug is a dtype type mismatch: main() passes TileLang
dtype objects (e.g., T.float16/T.float32) to run_test, but run_test expects
string names and uses getattr(torch, input_dtype); fix by changing the dtype
arguments in main() to string names (e.g., "float16", "float32") for
input_dtype, output_dtype, accum_dtype, gate_dtype, and state_dtype so
getattr(torch, ...) works as intended; alternatively update run_test to accept
TileLang dtype objects by converting them to strings before calling getattr
(detect non-str and map to their name) but the quickest fix is to replace
T.float16/T.float32 in main() with the corresponding string names.
- Around line 272-282: The benchmark call to do_bench uses the scalar gate
argument g=G but the correctness test uses the vector gate gk=G; update the
do_bench invocation for chunk_gated_delta_rule_fwd_h to pass gk=G (instead of
g=G) so the benchmark exercises the same gate configuration as the correctness
test (also scan for other do_bench calls referencing g vs gk and make them
consistent with the reference function signature).
In @examples/KDA/chunk_o.py:
- Around line 132-134: The loop uses the wrong inner index variable and shape:
in the DK loop the parallel iterator is "for i_s, i_v in T.Parallel(block_S,
block_DV)" but Q_shared and GQ_shared are shaped (block_S, block_DK) and the DK
index is named i_k; change the parallel iterator to iterate over block_DK (e.g.,
T.Parallel(block_S, block_DK)) and rename the inner loop variable to i_k, then
replace occurrences of i_v inside the body (Q_shared[i_s, i_v], GQ_shared[i_s,
i_v], GK_shared[i_s, i_v]) with i_k so indexing uses the DK dimension
consistently.
- Around line 247-250: The call to run_test from main() passes TileLang type
objects (e.g., T.bfloat16, T.float32) into parameters
input_dtype/output_dtype/accum_dtype/gate_dtype while run_test expects strings
(it calls getattr(torch, dtype)); change the arguments in main() to pass dtype
names as strings (e.g., "bfloat16", "float32") for input_dtype, output_dtype,
accum_dtype and gate_dtype so getattr(torch, dtype) works as intended, or
alternatively update run_test to accept TileLang types and convert them to the
corresponding torch dtype names before calling getattr.
In @examples/KDA/wy_fast_bwd.py:
- Around line 310-314: The correctness checks are disabled because the
compare_tensors calls are commented out; either re-enable those calls (uncomment
compare_tensors for dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang,
dg_tilelang to validate against dA_ref, dk_ref, dv_ref, dbeta_ref, dg_ref) or
add a runtime flag (e.g., validate=True/--validate) that conditionally runs
compare_tensors so CI can enable validation while benchmarks can disable it;
alternatively move these compare_tensors calls into a dedicated test that
imports the same variables and runs validation.
- Around line 341-345: main() is passing TileLang type objects like T.float32
into run_test, but run_test uses getattr(torch, input_dtype) expecting a string;
change either main() or run_test so run_test receives a torch dtype or a string:
in main(), pass a string like "float32" (e.g., input_dtype="float32") or pass
torch.float32 directly, or update run_test to detect TileLang types (e.g.,
T.float32) and convert them to torch dtypes by using the TileLang type's name or
mapping (e.g., input_dtype = getattr(torch, input_dtype.name) or via a small
dict), making sure to apply the same fix for output_dtype, accum_dtype,
gate_dtype, and state_dtype and to locate the conversion logic near the run_test
signature and any getattr(torch, ...) calls.
- Around line 295-307: The run_test call passes raw string dtype names into
tilelang_wy_fast_bwd, but tilelang_wy_fast_bwd expects TileLang dtype objects
(e.g., T.float32); before calling tilelang_wy_fast_bwd in run_test, convert each
string dtype parameter to the corresponding TileLang dtype using getattr(T,
input_dtype) and likewise for output_dtype, accum_dtype, gate_dtype, and
state_dtype so the arguments passed to tilelang_wy_fast_bwd are TileLang dtype
objects instead of strings.
In @examples/KDA/wy_fast.py:
- Around line 54-76: The function tilelang_recompute_w_u_fwd currently accepts
the flag use_qg and declares an output QG but never populates it; either remove
the unused parameter and QG output from the function signature and from the
@tilelang.jit out_idx list, or implement the QG computation/write-path when
use_qg is true (e.g., produce and store QG inside the kernel) and keep the flag;
if you intentionally defer implementation, add a clear TODO comment in
tilelang_recompute_w_u_fwd explaining that QG is unimplemented and ensure the
@tilelang.jit out_idx and any downstream callers are kept consistent.
- Around line 179-180: The variable use_qg is mistakenly assigned as a
one-element tuple (False,) rather than a boolean, causing wrong truthy checks;
change the assignment of use_qg to the boolean False (remove the trailing comma)
so any conditionals or kernel parameters expecting a bool receive a proper
boolean value, and verify usages of use_qg (and related flags like use_kg)
continue to treat it as a boolean.
- Around line 237-240: The test is passing TileLang dtype constants (e.g.,
T.bfloat16) to run_test which expects string dtype names (it uses getattr(torch,
input_dtype)), causing an AttributeError; update the main() invocation to pass
string dtype names like "bfloat16", "float32" (e.g., replace
input_dtype=T.bfloat16 with input_dtype="bfloat16", output_dtype="bfloat16",
gate_dtype="float32", accum_dtype="float32") so getattr(torch, ...) resolves
correctly to torch dtypes for the run_test call.
🧹 Nitpick comments (25)
examples/KDA/chunk_bwd_gla_dA.py (7)
4-4: Remove unusedsysimport.The
sysmodule is imported but never used. Thenoqa: F401directive is also unnecessary.Suggested fix
-import sys # noqa: F401
14-25: Unusedchunk_sizeparameter.The
chunk_sizeparameter is declared but never used. If it's kept for API consistency with other functions, consider prefixing with underscore (_chunk_size) to indicate it's intentionally unused.Suggested fix
def prepare_input( B, S, H, DV, - chunk_size, + _chunk_size, input_dtype, do_dtype, ):
28-37: UnusedDVparameter.Similar to
prepare_input, theDVparameter is unused. Consider removing or prefixing with underscore.Suggested fix
def prepare_output( B, S, H, - DV, + _DV, chunk_size, d_type, ):
83-91: Remove commented-out code.Dead code (commented allocation and layout annotations) should be removed to improve maintainability. If these are needed for reference, consider documenting them elsewhere or in a commit message.
Suggested fix
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=T.float32) - # dA_shared = T.alloc_shared((block_S, block_S), dtype=da_dtype) - - # T.annotate_layout( - # { - # do_shared: tilelang.layout.make_swizzled_layout(do_shared), - # V_shared: tilelang.layout.make_swizzled_layout(V_shared), - # } - # ) - # T.use_swizzle(10) T.clear(dA_fragment)
99-101: Translate comment to English and remove commented-out code.The comment "下三角矩阵" should be translated to "lower triangular mask" for consistency. Also remove the commented-out copy on line 100.
Suggested fix
for i_s1, i_s2 in T.Parallel(block_S, block_S): - dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 - # T.copy(dA_fragment, dA_shared) + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # lower triangular mask T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, 0:block_S])
106-117: UnusedDKparameter.The
DKparameter is declared but never used inrun_test. Remove it or prefix with underscore if kept for signature consistency.Suggested fix
def run_test( B, S, H, - DK, DV, scale, input_dtype, do_dtype, da_dtype, chunk_size, ):And update the call in
main():run_test( B=1, S=1024 * 8, # 32768 H=64, - DK=128, DV=128,
119-119: Remove debug print statement.This debug print should be removed before merging, or converted to proper logging if needed for diagnostics.
Suggested fix
- print(DO.dtype, V_new.dtype)examples/KDA/chunk_delta_h_fwd.py (4)
3-3: Remove unusedsysimport.The
sysimport andnoqadirective are no longer needed since thesys.path.inserton line 10 is commented out. This is also flagged by static analysis (RUF100).Suggested fix
-import sys # noqa: F401 import tilelang
23-45: Unused parametersoutput_dtypeandaccum_dtype.These parameters are defined but never used in the function body. If they're placeholders for future use, consider documenting this intent; otherwise, remove them to avoid confusion.
Suggested fix (if not needed)
def prepare_input( B, S, H, DK, DV, chunk_size, input_dtype, - output_dtype, - accum_dtype, gate_dtype, ):
98-98: Unusedblock_DKparameter in kernel configuration.The
block_DKparameter is included in the autotuning configs but never used in the kernel. The kernel allocates tensors with fullDKdimension (e.g.,b_h_sharedat line 129,K_sharedat line 137) rather than tiling over it.If tiling over
DKis intended for performance or memory optimization, the kernel logic needs updating. Otherwise, consider removingblock_DKfromget_configs()and the function signature to reduce the autotuning search space (currently 54 configs, would be 18 withoutblock_DK).
213-216: Unused kernel config parameters inrun_test.The parameters
block_DK,block_DV,threads, andnum_stagesare accepted but never used since the@autotunedecorator automatically selects configurations.If manual config selection is needed for debugging, consider adding support to bypass autotuning. Otherwise, remove these parameters to avoid confusion.
examples/KDA/wy_fast_bwd.py (5)
3-3: Remove unusedsysimport.The
sysmodule is imported but never used in this file. Thenoqa: F401directive is also unnecessary since there's no valid reason to suppress the warning.Suggested fix
-import sys # noqa: F401 -
17-42: Unused function parameters and minor efficiency improvement.Parameters
output_dtype,accum_dtype, andstate_dtypeare declared but never used. Consider prefixing with underscore or removing if they serve no purpose.Also, creating tensors directly on CUDA device is slightly more efficient than creating on CPU and then moving:
Suggested improvement
def prepare_input( B, S, H, DK, DV, chunk_size, input_dtype, - output_dtype, - accum_dtype, + _output_dtype, + _accum_dtype, gate_dtype, - state_dtype, + _state_dtype, ): BS = chunk_size - K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = torch.randn(B, S, H, DK, dtype=input_dtype, device="cuda")
44-61: Unused parameters inprepare_output.Similar to
prepare_input, parameterschunk_sizeandstate_dtypeare declared but unused. Consider prefixing with underscore for consistency.
154-170: Consider removing commented-out code.Lines 154, 165, and 170 contain commented-out allocations and a disabled swizzle call. If these are no longer needed, consider removing them to keep the code clean.
277-279: Unused output tensors fromprepare_output.The tensors allocated here (
dk_tilelang,dv_tilelang, etc.) are immediately overwritten by the kernel call at line 308. This allocation serves no purpose since the autotuned kernel returns new tensors viaout_idx.Suggested fix
- dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang, dA_tilelang = prepare_output( - B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) - ) -examples/KDA/wy_fast.py (2)
15-22: Unusedoutput_dtypeparameter.The
output_dtypeparameter is accepted but never used in this function. If it's for API consistency with other functions, consider prefixing with underscore or adding a comment.
174-177: Unused function parameters inrun_test.The parameters
block_DK,block_DV,threads, andnum_stagesare accepted but never passed to the kernel. The kernel relies on autotuning to select these values. Either remove these parameters or pass them to override autotuning.examples/KDA/chunk_delta_bwd.py (1)
265-267: Reference results are computed but never used.
dh_ref,dh0_ref, anddv2_refare unpacked but never compared against TileLang outputs since the compare_tensors calls are commented out. Consider prefixing with underscore if intentionally unused, or enable the validation.examples/KDA/chunk_bwd_dv.py (2)
31-40: Unusedchunk_sizeparameter inprepare_output.The
chunk_sizeparameter is accepted but not used. Consider removing it or prefixing with underscore if kept for API consistency.
105-116: Unusedscaleparameter inrun_test.The
scaleparameter is accepted but never used in the function. It's passed inmain()but has no effect.♻️ Proposed fix - remove unused parameter
def run_test( B, S, H, DK, DV, - scale, input_dtype, do_dtype, output_dtype, chunk_size, ):And in
main():run_test( B=1, S=1024 * 8, # 32768 H=64, DK=128, DV=128, - scale=1.0, input_dtype="bfloat16",examples/KDA/chunk_bwd_dqkwg.py (1)
212-219: Many unused parameters inrun_test.Parameters
use_gk,use_initial_state,store_final_state,save_new_value,block_DK,block_DV,threads, andnum_stagesare accepted but never used. Consider removing them or using them to configure the kernel/test behavior.examples/KDA/chunk_o.py (3)
142-146: Consider removing or translating Chinese comment.The comment "改成下面的代码为什么就错了" (translated: "Why does it fail when changed to the code below?") should either be removed, translated to English, or replaced with a proper explanation for maintainability.
161-183: Duplicatedo_benchimplementation.This file defines its own
do_benchfunction while other files (wy_fast.py,chunk_delta_bwd.py, etc.) import it fromtest_utils. Consider using the shared implementation for consistency.♻️ Proposed fix
from FLA_KDA.fla_chunk_o import chunk_gla_fwd_o_gk -from test_utils import compare_tensors +from test_utils import compare_tensors, do_bench ... -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item()
197-200: Unused parameters inrun_test.Parameters
block_DK,block_DV,threads, andnum_stagesare accepted but not used. The kernel relies on autotuning. Consider removing these or passing them to override autotuning behavior.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_dv.pyexamples/KDA/chunk_bwd_gla_dA.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_o.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/wy_fast.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/wy_fast_bwd.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/KDA/chunk_delta_bwd.pyexamples/KDA/wy_fast.pyexamples/KDA/chunk_delta_h_fwd.py
🧬 Code graph analysis (2)
examples/KDA/wy_fast.py (3)
examples/KDA/FLA_KDA/fla_wy_fast.py (1)
recompute_w_u_fwd(210-254)examples/KDA/test_utils.py (1)
compare_tensors(43-83)tilelang/language/allocate.py (2)
alloc_shared(39-54)alloc_fragment(71-82)
examples/KDA/chunk_delta_h_fwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
chunk_gated_delta_rule_fwd_h(470-521)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
🪛 Ruff (0.14.10)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
43-43: Unused function argument: DV
(ARG001)
44-44: Unused function argument: chunk_size
(ARG001)
45-45: Unused function argument: qk_dtype
(ARG001)
212-212: Unused function argument: use_gk
(ARG001)
213-213: Unused function argument: use_initial_state
(ARG001)
214-214: Unused function argument: store_final_state
(ARG001)
215-215: Unused function argument: save_new_value
(ARG001)
216-216: Unused function argument: block_DK
(ARG001)
217-217: Unused function argument: block_DV
(ARG001)
218-218: Unused function argument: threads
(ARG001)
219-219: Unused function argument: num_stages
(ARG001)
223-223: Unpacked variable dq_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dw_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
223-223: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dw is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
240-240: Unpacked variable dg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
28-28: Unused function argument: output_dtype
(ARG001)
29-29: Unused function argument: accum_dtype
(ARG001)
31-31: Unused function argument: state_dtype
(ARG001)
58-58: Unused function argument: gate_dtype
(ARG001)
128-128: Unused function argument: h0
(ARG001)
239-239: Unused function argument: block_DV
(ARG001)
240-240: Unused function argument: threads
(ARG001)
241-241: Unused function argument: num_stages
(ARG001)
242-242: Unused function argument: use_torch
(ARG001)
265-265: Unpacked variable dh_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
265-265: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
265-265: Unpacked variable dv2_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
288-288: Unpacked variable dh_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
288-288: Unpacked variable dh0_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
288-288: Unpacked variable dv2_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused function argument: output_dtype
(ARG001)
68-68: Unused function argument: use_qg
(ARG001)
94-94: Unused function argument: QG
(ARG001)
174-174: Unused function argument: block_DK
(ARG001)
175-175: Unused function argument: block_DV
(ARG001)
176-176: Unused function argument: threads
(ARG001)
177-177: Unused function argument: num_stages
(ARG001)
184-184: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
185-185: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
31-31: Unused function argument: output_dtype
(ARG001)
32-32: Unused function argument: accum_dtype
(ARG001)
98-98: Unused function argument: block_DK
(ARG001)
213-213: Unused function argument: block_DK
(ARG001)
214-214: Unused function argument: block_DV
(ARG001)
215-215: Unused function argument: threads
(ARG001)
216-216: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_o.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
22-22: Unused function argument: output_dtype
(ARG001)
23-23: Unused function argument: accum_dtype
(ARG001)
39-39: Unused function argument: DK
(ARG001)
41-41: Unused function argument: chunk_size
(ARG001)
44-44: Ambiguous variable name: O
(E741)
99-99: Ambiguous variable name: O
(E741)
197-197: Unused function argument: block_DK
(ARG001)
198-198: Unused function argument: block_DV
(ARG001)
199-199: Unused function argument: threads
(ARG001)
200-200: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_bwd_gla_dA.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
19-19: Unused function argument: chunk_size
(ARG001)
32-32: Unused function argument: DV
(ARG001)
110-110: Unused function argument: DK
(ARG001)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
25-25: Unused function argument: output_dtype
(ARG001)
26-26: Unused function argument: accum_dtype
(ARG001)
28-28: Unused function argument: state_dtype
(ARG001)
50-50: Unused function argument: chunk_size
(ARG001)
53-53: Unused function argument: state_dtype
(ARG001)
92-92: Unused function argument: state_dtype
(ARG001)
258-258: Unused function argument: block_DK
(ARG001)
259-259: Unused function argument: block_DV
(ARG001)
260-260: Unused function argument: threads
(ARG001)
261-261: Unused function argument: num_stages
(ARG001)
282-282: Unpacked variable dk_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dv_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dbeta_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dg_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
282-282: Unpacked variable dA_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dA_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dv_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dbeta_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
308-308: Unpacked variable dg_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_bwd_dv.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
36-36: Unused function argument: chunk_size
(ARG001)
111-111: Unused function argument: scale
(ARG001)
🔇 Additional comments (19)
examples/KDA/chunk_bwd_gla_dA.py (2)
80-81: Verify dtype forV_sharedallocation.
do_sharedusesdo_dtype(correct), butV_sharedalso usesdo_dtypewhile the source tensorVhasinput_dtype. Ifinput_dtypeanddo_dtypediffer, an implicit cast occurs during the copy. This may be intentional for computation precision alignment, but consider usinginput_dtypeor adding a comment explaining the design choice.Suggested fix if intentional cast
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) - V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) # Cast V to do_dtype for GEMM precision
142-158: LGTM!The main entry point is well-structured with reasonable default parameters for testing.
examples/KDA/chunk_delta_h_fwd.py (2)
65-75: LGTM!The configuration generation for autotuning is well-structured with reasonable parameter ranges.
158-192: Kernel logic appears correct.The pipelined loop correctly implements the chunked forward pass:
- Stores the previous state to
hbefore computing the current chunk- Applies the gated recurrence with
exp2(matchinguse_exp2=Truein reference)- Accumulates via GEMM operations
The commented-out swizzle annotations (lines 140-150) are appropriately disabled per the commit message "Remove redundant swizzle as it can be done automatically."
examples/KDA/wy_fast_bwd.py (2)
63-74: LGTM!The autotuning configuration space is well-defined with reasonable parameter ranges.
220-220: This concern is based on an incorrect assumption about the default value ofclear_accum. According to the T.gemm definition intilelang/language/gemm_op.py, the default value ofclear_accumis False, not True. Therefore, line 220 does not inadvertently cleardA_fragment. Both line 195 (DK loop) and line 220 (DV loop) accumulate contributions intodA_fragmentas intended, with no bug present.Likely an incorrect or invalid review comment.
examples/KDA/wy_fast.py (2)
1-12: LGTM - Imports and setup are appropriate.The imports and random seed initialization follow the pattern established across the KDA examples.
143-158: Kernel logic for W and KG computation is correct.The pipelined loop correctly computes W with gating (
T.exp2) and conditionally computes KG whenuse_kgis enabled. The use of shared memory and fragments follows TileLang patterns.examples/KDA/chunk_delta_bwd.py (4)
1-17: LGTM - Imports and setup are appropriate.Imports include necessary TileLang, PyTorch, and reference implementations. Random seed is set for reproducibility.
171-176: Good handling of final state gradient initialization.The conditional initialization of
b_dh_fragmentfromdhtor clearing it based onuse_final_state_gradientis correct.
177-218: Backward pass iteration logic is well-structured.The reverse iteration (
i_s_inv) correctly processes chunks in reverse order for the backward pass. The gradient updates fordh,dv, and finaldh0follow the expected pattern.
125-136: Review comment is incorrect.The reviewer mistakenly referenced the forward kernel. In the FLA implementation,
h0is used in the forward pass (chunk_gated_delta_rule_fwd_kernel_h_blockdim64, lines 97-107) to initialize hidden state, not in the backward pass. The backward kernel (chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64) does not accepth0as a parameter—it only computes and outputsdh0(the gradient with respect to the initial state).The tilelang implementation correctly matches this behavior:
h0is accepted as a parameter (for API consistency) but is not used during backward computation. This is the correct approach since the backward pass computes gradients, not reads the forward state.examples/KDA/chunk_bwd_dv.py (4)
1-11: LGTM - Imports and setup follow established patterns.
94-95: Lower triangular mask implementation is correct.The masking
T.if_then_else(i_s1 >= i_s2, A_shared[i_s1, i_s2], 0.0)correctly implements a lower triangular mask where elements above the diagonal are zeroed.
96-100: Pipelined DV computation with proper staging.The pipelined loop over DV dimension with GEMM (transpose_A) and copy operations follows correct TileLang patterns.
140-152: Correct dtype passing pattern.Unlike other files, this one correctly passes string dtypes (
"bfloat16","float32") torun_test, which is compatible with thegetattr(torch, dtype)usage.examples/KDA/chunk_bwd_dqkwg.py (2)
1-11: LGTM - Imports and setup follow established patterns.
106-198: Complex kernel logic with proper gradient computation.The kernel correctly handles:
- Chunk-wise gradient computation with proper indexing
- Gating with
T.exp2for GK- Reduction operations for
dgkn_fragment- Conditional accumulation for the last token in chunk
The mathematical operations appear to follow the expected backward pass pattern.
examples/KDA/chunk_o.py (1)
1-11: LGTM - Imports and setup follow established patterns.
| def prepare_output( | ||
| B, | ||
| S, | ||
| H, | ||
| DK, | ||
| DV, | ||
| chunk_size, | ||
| qk_dtype, | ||
| gate_dtype, | ||
| ): | ||
| dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() | ||
| dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() | ||
| dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | ||
| dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | ||
| return dq, dk, dw, dg |
There was a problem hiding this comment.
qk_dtype parameter is ignored - outputs use hardcoded torch.float32.
The qk_dtype parameter is accepted but dq and dk are always allocated with dtype=torch.float32 regardless of the passed value. Either use the parameter or remove it.
♻️ Proposed fix
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
qk_dtype,
gate_dtype,
):
- dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
- dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda()
+ dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
+ dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda()
dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda()
return dq, dk, dw, dg📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def prepare_output( | |
| B, | |
| S, | |
| H, | |
| DK, | |
| DV, | |
| chunk_size, | |
| qk_dtype, | |
| gate_dtype, | |
| ): | |
| dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() | |
| dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() | |
| dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | |
| dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | |
| return dq, dk, dw, dg | |
| def prepare_output( | |
| B, | |
| S, | |
| H, | |
| DK, | |
| DV, | |
| chunk_size, | |
| qk_dtype, | |
| gate_dtype, | |
| ): | |
| dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda() | |
| dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda() | |
| dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | |
| dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() | |
| return dq, dk, dw, dg |
🧰 Tools
🪛 Ruff (0.14.10)
43-43: Unused function argument: DV
(ARG001)
44-44: Unused function argument: chunk_size
(ARG001)
45-45: Unused function argument: qk_dtype
(ARG001)
🤖 Prompt for AI Agents
In @examples/KDA/chunk_bwd_dqkwg.py around lines 38 - 52, The prepare_output
function ignores the qk_dtype parameter by hardcoding dq and dk to
torch.float32; update dq and dk allocations in prepare_output to use
dtype=qk_dtype (or convert them to qk_dtype after creation) so the passed
qk_dtype controls their dtype, leaving dw and dg using gate_dtype as before;
locate prepare_output and replace the dtype=torch.float32 occurrences for dq and
dk with dtype=qk_dtype.
| input_dtype=T.float32, | ||
| output_dtype=T.float32, | ||
| accum_dtype=T.float32, | ||
| gate_dtype=T.float32, | ||
| state_dtype=T.float32, |
There was a problem hiding this comment.
Type mismatch: T.float32 is not a valid argument for getattr(torch, ...).
The main() function passes TileLang types (e.g., T.float32) to run_test, but run_test uses getattr(torch, input_dtype) which expects a string like "float32". This will cause an AttributeError at runtime.
Suggested fix
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
- input_dtype=T.float32,
- output_dtype=T.float32,
- accum_dtype=T.float32,
- gate_dtype=T.float32,
- state_dtype=T.float32,
+ input_dtype="float32",
+ output_dtype="float32",
+ accum_dtype="float32",
+ gate_dtype="float32",
+ state_dtype="float32",
chunk_size=64,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_dtype=T.float32, | |
| output_dtype=T.float32, | |
| accum_dtype=T.float32, | |
| gate_dtype=T.float32, | |
| state_dtype=T.float32, | |
| run_test( | |
| B=1, | |
| S=32768, | |
| H=8, | |
| DK=DK, | |
| DV=DV, | |
| input_dtype="float32", | |
| output_dtype="float32", | |
| accum_dtype="float32", | |
| gate_dtype="float32", | |
| state_dtype="float32", | |
| chunk_size=64, | |
| ) |
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast_bwd.py around lines 341 - 345, main() is passing
TileLang type objects like T.float32 into run_test, but run_test uses
getattr(torch, input_dtype) expecting a string; change either main() or run_test
so run_test receives a torch dtype or a string: in main(), pass a string like
"float32" (e.g., input_dtype="float32") or pass torch.float32 directly, or
update run_test to detect TileLang types (e.g., T.float32) and convert them to
torch dtypes by using the TileLang type's name or mapping (e.g., input_dtype =
getattr(torch, input_dtype.name) or via a small dict), making sure to apply the
same fix for output_dtype, accum_dtype, gate_dtype, and state_dtype and to
locate the conversion logic near the run_test signature and any getattr(torch,
...) calls.
| @autotune(configs=get_configs(), warmup=3, rep=5) | ||
| @tilelang.jit(out_idx=[-4, -3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def tilelang_recompute_w_u_fwd( | ||
| # task config | ||
| B, | ||
| S, | ||
| H, | ||
| DK, | ||
| DV, | ||
| input_dtype, | ||
| output_dtype, | ||
| gate_dtype, | ||
| accum_dtype, | ||
| chunk_size, | ||
| use_qg, | ||
| use_kg, | ||
| # kernel config | ||
| block_S=64, | ||
| block_DK=32, | ||
| block_DV=32, | ||
| threads=128, | ||
| num_stages=0, | ||
| ): |
There was a problem hiding this comment.
Unused use_qg parameter - QG tensor is never populated.
The use_qg parameter and QG tensor are declared but the kernel never writes to QG. If QG computation is intentionally omitted, consider either:
- Removing the parameter and output tensor
- Adding a TODO comment explaining future implementation
The kernel signature declares QG as output but the body never computes or writes to it.
🧰 Tools
🪛 Ruff (0.14.10)
68-68: Unused function argument: use_qg
(ARG001)
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 54 - 76, The function
tilelang_recompute_w_u_fwd currently accepts the flag use_qg and declares an
output QG but never populates it; either remove the unused parameter and QG
output from the function signature and from the @tilelang.jit out_idx list, or
implement the QG computation/write-path when use_qg is true (e.g., produce and
store QG inside the kernel) and keep the flag; if you intentionally defer
implementation, add a clear TODO comment in tilelang_recompute_w_u_fwd
explaining that QG is unimplemented and ensure the @tilelang.jit out_idx and any
downstream callers are kept consistent.
| input_dtype=T.bfloat16, | ||
| output_dtype=T.bfloat16, | ||
| gate_dtype=T.float32, | ||
| accum_dtype=T.float32, |
There was a problem hiding this comment.
Type mismatch: passing TileLang types where string dtypes are expected.
run_test expects string dtype parameters (e.g., "bfloat16") based on getattr(torch, input_dtype) usage, but main() passes TileLang types like T.bfloat16. This will cause an AttributeError since getattr(torch, T.bfloat16) is invalid.
🐛 Proposed fix - use string dtypes
run_test(
B=1,
S=8192,
H=64,
DK=128,
DV=128,
chunk_size=64,
- input_dtype=T.bfloat16,
- output_dtype=T.bfloat16,
- gate_dtype=T.float32,
- accum_dtype=T.float32,
+ input_dtype="bfloat16",
+ output_dtype="bfloat16",
+ gate_dtype="float32",
+ accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_dtype=T.bfloat16, | |
| output_dtype=T.bfloat16, | |
| gate_dtype=T.float32, | |
| accum_dtype=T.float32, | |
| run_test( | |
| B=1, | |
| S=8192, | |
| H=64, | |
| DK=128, | |
| DV=128, | |
| chunk_size=64, | |
| input_dtype="bfloat16", | |
| output_dtype="bfloat16", | |
| gate_dtype="float32", | |
| accum_dtype="float32", | |
| block_DK=64, | |
| block_DV=32, | |
| threads=128, | |
| num_stages=3, | |
| ) |
🤖 Prompt for AI Agents
In @examples/KDA/wy_fast.py around lines 237 - 240, The test is passing TileLang
dtype constants (e.g., T.bfloat16) to run_test which expects string dtype names
(it uses getattr(torch, input_dtype)), causing an AttributeError; update the
main() invocation to pass string dtype names like "bfloat16", "float32" (e.g.,
replace input_dtype=T.bfloat16 with input_dtype="bfloat16",
output_dtype="bfloat16", gate_dtype="float32", accum_dtype="float32") so
getattr(torch, ...) resolves correctly to torch dtypes for the run_test call.
… and do some fix form coderabbitai
Thank you for your review, i will check and fix them |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@examples/KDA/chunk_bwd_intra.py`:
- Around line 478-493: main() is passing TileLang types (e.g. T.float32) into
run_test(), but run_test() (which uses getattr(torch, input_dtype) inside the
run_test function) expects string names like "float32", causing an
AttributeError; fix by changing the arguments in main() to string dtype names
(e.g. "float32" for input_dtype/output_dtype/accum_dtype/gate_dtype/state_dtype)
so run_test() can resolve them via getattr(torch, ...). Ensure the update
references the main() call site and the run_test(...) parameter names so all
dtype args are passed as strings.
In `@examples/KDA/chunk_delta_h_fwd.py`:
- Around line 178-182: The access GK[bb, (i_s + 1) * block_S - 1, bh, :] can go
out of bounds when S is not divisible by block_S; either clamp the computed
index or enforce divisibility. Fix option A: replace the raw index with a
clamped value (e.g., end_idx = min((i_s + 1) * block_S - 1, S - 1)) and use
GK[bb, end_idx, bh, :] when setting GK_last_shared in the use_gk branch; or
option B: add a runtime assertion in run_test or prepare_input that S % block_S
== 0 to guarantee no partial final block. Ensure references to use_gk,
GK_last_shared, i_s, block_S and S are adjusted accordingly.
♻️ Duplicate comments (7)
examples/KDA/wy_fast.py (3)
68-68: Unuseduse_qgparameter andQGoutput tensor.The
use_qgparameter is accepted but never used, and theQGtensor is declared as output but never written to. Either implement the QG computation path or remove these unused declarations.Also applies to: 94-94
222-222: Typo: "tritron" should be "triton".- print("tritron time:", triton_time) + print("triton time:", triton_time)
237-240: Type mismatch: passing TileLang types where string dtypes are expected.
run_testusesgetattr(torch, input_dtype)which expects string arguments like"bfloat16", butmain()passes TileLang type objects (T.bfloat16,T.float32). This will cause anAttributeErrorat runtime.run_test( B=1, S=8192, H=64, DK=128, DV=128, chunk_size=64, - input_dtype=T.bfloat16, - output_dtype=T.bfloat16, - gate_dtype=T.float32, - accum_dtype=T.float32, + input_dtype="bfloat16", + output_dtype="bfloat16", + gate_dtype="float32", + accum_dtype="float32", block_DK=64, block_DV=32, threads=128, num_stages=3, )examples/KDA/chunk_delta_bwd.py (1)
121-136: Unusedh0parameter in kernel signature.The kernel declares
h0but never uses it. Per theuse_initial_statelogic at line 217-218, onlydh0(the gradient w.r.t. initial state) is written. Ifh0is not needed for this backward computation, consider removing it from the kernel signature to avoid confusion.examples/KDA/chunk_bwd_dqkwg.py (2)
47-48:dqanddkalways usetorch.float32regardless of parameters.The
prepare_outputfunction ignores its parameters and hardcodestorch.float32fordqanddk. While this may be intentional (gradients often need higher precision), it contradicts the function signature which acceptsgate_dtype.
122-123: Gate buffersG_sharedandGn_sharedallocate withinput_dtypeinstead ofgate_dtype.These buffers store gate values (sourced from the
Gtensor at line 95 which usesgate_dtype), but are allocated withinput_dtype. Wheninput_dtypediffers fromgate_dtype(e.g.,bfloat16vsfloat32), this causes precision loss and dtype mismatch.Suggested fix
- G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G - Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # chunk last token Gexamples/KDA/wy_fast_bwd.py (1)
332-351: Critical:main()passes TileLang types butrun_testexpects strings.
main()passes TileLang type objects (T.float32) torun_test, butrun_testusesgetattr(torch, input_dtype)at lines 270-274 which expects string arguments like"float32". This will cause anAttributeErrorat runtime.🐛 Proposed fix
def main(): DK = 128 DV = 128 run_test( B=1, S=32768, H=8, DK=DK, DV=DV, - input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", chunk_size=64, block_DK=32, block_DV=32, threads=128, num_stages=0, )
🧹 Nitpick comments (16)
examples/KDA/wy_fast.py (5)
6-6: Remove unusedsysimport.The
sysmodule is imported but never used, and thenoqadirective is unnecessary.-import sys # noqa: F401
15-22: Unusedoutput_dtypeparameter.The
output_dtypeparameter is declared but never used in the function body. Consider removing it or prefixing with underscore if reserved for future use.-def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32): +def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, gate_dtype=torch.float32):
112-123: Remove or document commented-out layout annotations.This block of commented code for swizzled layouts appears to be either dead code or work-in-progress. Consider removing it to reduce clutter, or add a TODO/NOTE explaining why it's retained.
174-177: Unused kernel configuration parameters.The
block_DK,block_DV,threads, andnum_stagesparameters are accepted but never passed to the kernel (lines 202-216). If these are intended for manual configuration (bypassing autotune), pass them to the kernel call. Otherwise, remove them from the function signature.def run_test( B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype, accum_dtype, - block_DK, - block_DV, - threads, - num_stages, ):
184-185: Prefix unused unpacked variables with underscore.
QG_refandQG_tilelangare unpacked but never used. Prefix them with underscore to indicate intentional disuse.- W_ref, U_ref, QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) - W_tilelang, U_tilelang, QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) + W_ref, U_ref, _QG_ref, KG_ref = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype)) + W_tilelang, U_tilelang, _QG_tilelang, KG_tilelang = prepare_output(B, S, H, DK, DV, use_qg, use_kg, getattr(torch, output_dtype))examples/KDA/chunk_delta_h_fwd.py (4)
3-3: Remove unusedsysimport.The
sysmodule is imported but never used (thesys.path.inserton line 10 is commented out). The# noqa: F401directive is also unnecessary.Suggested fix
-import sys # noqa: F401 import tilelang
23-45: Unused parametersoutput_dtypeandaccum_dtype.These parameters are declared but never used in the function body. Consider removing them if not needed for API consistency, or prefixing with underscore to indicate intentional non-use.
Suggested fix
def prepare_input( B, S, H, DK, DV, chunk_size, input_dtype, - output_dtype, - accum_dtype, + _output_dtype, + _accum_dtype, gate_dtype, ):
98-98: Unusedblock_DKparameter in autotuned kernel.The
block_DKparameter is included in autotune configs (line 68) but is never used in the kernel—DKis used directly for all allocations and operations. Either removeblock_DKfrom configs, or use it to tile the DK dimension for potential performance gains.
213-216: Unused kernel configuration parameters.The parameters
block_DK,block_DV,threads, andnum_stagesare accepted but never used—the kernel relies on autotuning. Consider removing them from the function signature or passing them to the kernel call if manual config override is desired.Suggested fix
def run_test( B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype, chunk_size, use_gk=True, use_initial_state=True, store_final_state=True, save_new_value=True, - block_DK=64, - block_DV=32, - threads=128, - num_stages=0, ):examples/KDA/chunk_o.py (2)
197-231: Unused kernel configuration parameters inrun_test.The parameters
block_DK,block_DV,threads, andnum_stagesare accepted byrun_testbut never passed totilelang_chunk_fwd_o. The kernel is invoked at line 217-230 without these configuration options, meaning the autotuner will select them instead. If this is intentional (relying on autotuning), consider removing these parameters fromrun_test's signature to avoid confusion. If manual configuration was intended, these should be passed to the kernel.♻️ Option A: Remove unused parameters
def run_test( B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, accum_dtype, gate_dtype, - block_DK, - block_DV, - threads, - num_stages, ):♻️ Option B: Pass parameters to kernel
kernel = tilelang_chunk_fwd_o( B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, chunk_size, scale, block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, )
3-4: Remove unusedsysimport.The
sysmodule is imported but never used, and thenoqadirective is unnecessary.♻️ Proposed fix
-import sys # noqa: F401examples/KDA/chunk_delta_bwd.py (1)
3-3: Remove unusedsysimport.♻️ Proposed fix
-import sys # noqa: F401examples/KDA/chunk_bwd_dqkwg.py (1)
4-4: Remove unusedsysimport.♻️ Proposed fix
-import sys # noqa: F401examples/KDA/wy_fast_bwd.py (1)
3-3: Remove unusedsysimport.♻️ Proposed fix
-import sys # noqa: F401examples/KDA/chunk_bwd_intra.py (2)
130-133: Input parameterdbis declared but unused in kernel.The kernel signature includes
db: T.Tensor(db_shape, dtype=input_dtype)but it's never read in the kernel body. The kernel writes todb2(output) but doesn't use the inputdb. Ifdbis meant to be accumulated into the output (similar to howdqanddkare added at lines 248-250 and 362-376), ensure it's added to the final result.Note: Looking at line 448 in
run_test,db_tilelang.sum(0).add_(db)is called after the kernel, which does adddb. However, this is inconsistent with howdqanddkare handled inside the kernel.
3-3: Remove unusedsysimport.♻️ Proposed fix
-import sys # noqa: F401
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_intra.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_o.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_o.pyexamples/KDA/chunk_delta_bwd.py
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_delta_bwd.py
🧬 Code graph analysis (2)
examples/KDA/chunk_delta_h_fwd.py (3)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
chunk_gated_delta_rule_fwd_h(470-521)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)examples/KDA/test_utils.py (2)
compare_tensors(43-83)do_bench(86-108)
examples/KDA/chunk_bwd_dqkwg.py (4)
tilelang/language/dtypes.py (1)
float32(310-310)tilelang/language/allocate.py (2)
alloc_fragment(71-82)alloc_shared(39-54)tilelang/language/loop.py (2)
Pipelined(97-134)Parallel(13-72)tilelang/language/reduce_op.py (1)
reduce_sum(144-166)
🪛 Ruff (0.14.11)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused function argument: output_dtype
(ARG001)
68-68: Unused function argument: use_qg
(ARG001)
94-94: Unused function argument: QG
(ARG001)
174-174: Unused function argument: block_DK
(ARG001)
175-175: Unused function argument: block_DV
(ARG001)
176-176: Unused function argument: threads
(ARG001)
177-177: Unused function argument: num_stages
(ARG001)
184-184: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
185-185: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
25-25: Unused function argument: output_dtype
(ARG001)
26-26: Unused function argument: accum_dtype
(ARG001)
28-28: Unused function argument: state_dtype
(ARG001)
50-50: Unused function argument: chunk_size
(ARG001)
53-53: Unused function argument: state_dtype
(ARG001)
92-92: Unused function argument: state_dtype
(ARG001)
258-258: Unused function argument: block_DK
(ARG001)
259-259: Unused function argument: block_DV
(ARG001)
260-260: Unused function argument: threads
(ARG001)
261-261: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_bwd_intra.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
29-29: Unused function argument: state_dtype
(ARG001)
56-56: Unused function argument: chunk_size
(ARG001)
60-60: Unused function argument: state_dtype
(ARG001)
95-95: Unused function argument: state_dtype
(ARG001)
132-132: Unused function argument: db
(ARG001)
396-396: Unused function argument: threads
(ARG001)
397-397: Unused function argument: num_stages
(ARG001)
398-398: Unused function argument: cu_seqlens
(ARG001)
399-399: Unused function argument: chunk_indices
(ARG001)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
31-31: Unused function argument: output_dtype
(ARG001)
32-32: Unused function argument: accum_dtype
(ARG001)
98-98: Unused function argument: block_DK
(ARG001)
213-213: Unused function argument: block_DK
(ARG001)
214-214: Unused function argument: block_DV
(ARG001)
215-215: Unused function argument: threads
(ARG001)
216-216: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
43-43: Unused function argument: DV
(ARG001)
44-44: Unused function argument: chunk_size
(ARG001)
209-209: Unused function argument: qk_dtype
(ARG001)
211-211: Unused function argument: use_gk
(ARG001)
212-212: Unused function argument: use_initial_state
(ARG001)
213-213: Unused function argument: store_final_state
(ARG001)
214-214: Unused function argument: save_new_value
(ARG001)
215-215: Unused function argument: block_DK
(ARG001)
216-216: Unused function argument: block_DV
(ARG001)
217-217: Unused function argument: threads
(ARG001)
218-218: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_o.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
22-22: Unused function argument: output_dtype
(ARG001)
23-23: Unused function argument: accum_dtype
(ARG001)
39-39: Unused function argument: DK
(ARG001)
41-41: Unused function argument: chunk_size
(ARG001)
44-44: Ambiguous variable name: O
(E741)
99-99: Ambiguous variable name: O
(E741)
197-197: Unused function argument: block_DK
(ARG001)
198-198: Unused function argument: block_DV
(ARG001)
199-199: Unused function argument: threads
(ARG001)
200-200: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
28-28: Unused function argument: output_dtype
(ARG001)
29-29: Unused function argument: accum_dtype
(ARG001)
31-31: Unused function argument: state_dtype
(ARG001)
58-58: Unused function argument: gate_dtype
(ARG001)
128-128: Unused function argument: h0
(ARG001)
239-239: Unused function argument: block_DV
(ARG001)
240-240: Unused function argument: threads
(ARG001)
241-241: Unused function argument: num_stages
(ARG001)
242-242: Unused function argument: use_torch
(ARG001)
🔇 Additional comments (14)
examples/KDA/wy_fast.py (2)
25-39: LGTM!The conditional tensor allocation based on
use_qganduse_kgflags is clean and appropriate.
42-51: LGTM!The configuration generator for autotuning is well-structured.
examples/KDA/chunk_delta_h_fwd.py (3)
48-62: LGTM!The ceiling division on line 58 correctly matches the kernel iteration bounds, and the tensor allocations are appropriate.
65-75: LGTM!Configuration generation for autotuning is correctly implemented.
295-316: LGTM!The dtype values are correctly passed as strings. Note that the current test configuration (
S=8192,chunk_size=64) ensuresSis evenly divisible, avoiding the potential out-of-bounds issue identified in the kernel. Consider adding a test case with non-divisible values to verify the fix once applied.examples/KDA/chunk_o.py (1)
1-11: LGTM on the core kernel implementation.The indexing bug (using
block_DVinstead ofblock_DKfor Q_shared/GK_shared/GQ_shared) from past reviews has been correctly fixed - line 132 now properly iterates overT.Parallel(block_S, block_DK)withi_k2. The dtype mismatch inmain()has also been addressed with string arguments.examples/KDA/chunk_delta_bwd.py (2)
298-300: LGTM - Validation is now enabled.The
compare_tensorscalls are now active, addressing the previous review concern about disabled validation. The test will now properly verify correctness against the reference implementation.
303-325: LGTM - Dtype issue has been fixed.The
main()function now correctly passes string dtype names ("bfloat16","float32") torun_test, which usesgetattr(torch, ...)to resolve them. This addresses the previous review concern about type mismatch.examples/KDA/chunk_bwd_dqkwg.py (2)
241-244: LGTM - Validation is now enabled.The
compare_tensorscalls are active, properly validatingdq,dk,dw, anddgagainst the reference implementation. This addresses the previous review concern.
264-284: LGTM - Dtype handling inmain()is correct.The
main()function now passes string dtype names ("float32") torun_test, correctly aligning with thegetattr(torch, ...)usage.examples/KDA/wy_fast_bwd.py (2)
310-314: LGTM - Validation is now enabled.The
compare_tensorscalls are active for all outputs (dA,dk,dv,dbeta,dg), properly verifying correctness against the reference implementation.
295-307: TileLang dtype handling is correct in the kernel. Thetilelang_wy_fast_bwdfunction properly accepts dtype parameters and uses them consistently inT.Tensor,T.alloc_shared, andT.alloc_fragmentdeclarations. The API signatures in TileLang's proxy layer support both string dtype names and TVM dtype objects, making the current implementation valid.examples/KDA/chunk_bwd_intra.py (2)
455-458: LGTM - Validation is enabled and comprehensive.The
compare_tensorscalls validate all outputs (dq,dk,db,dg) against the reference implementation. The post-processing at lines 448-453 (summingdb_tilelangand applying cumsum todg_tilelang) is correctly applied before comparison.
140-146: Complex grid decomposition looks correct.The grid decomposition
(i_k * NC + i_i, i_t, i_bh)properly handles the Cartesian product of K-blocks and sub-chunks, with correct index extraction. The sub-chunk index calculation at line 145 (i_ti = i_t * BT + i_i * BC) is mathematically sound.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
I have commit the code. However, the pre-commit.ci show an error, I check it local which show no error. |
…d corresponding test case (tile-ai#1654) * Add unroll loop functionality and corresponding test case - Introduced a new `UnrollLoop` function in the transform module to unroll loops based on various configuration options. - Added a test case in `test_tilelang_language_unroll.py` to validate the behavior of `T.unroll` with only the extent parameter, ensuring correct kernel generation with unroll pragmas. * Refactor unroll kernel implementation and update test case - Changed the kernel function in `test_tilelang_language_unroll.py` to use a new `unroll_kernel` function that compiles and returns the output tensor, improving clarity and structure. - Updated the `OptimizeForTarget` function in `phase.py` to ensure the `UnrollLoop` transformation is applied correctly, maintaining consistency in optimization phases. * lint fix * lint fix
remove unnecessary debug print
…nsitive LetStmt dependencies (tile-ai#1657) * [Enhancement] Update global load/store functions for CUDA compatibility (tile-ai#1652) Refactor the `ld_global_256` and `st_global_256` functions to support both CUDA versions above 12.9 and earlier versions. This change ensures that 256-bit loads and stores are handled correctly across different CUDA versions, improving performance and compatibility. The implementation now uses two 128-bit loads/stores for older versions, enhancing the robustness of the codebase. * Update comments in global load/store functions for CUDA compatibility Clarified comments in `ld_global_256` and `st_global_256` functions to indicate that the fallback for CUDA versions below 12.9 may have performance regressions. This change enhances code readability and provides better context for developers working with different CUDA versions. * Update submodule and enhance LetStmt handling in inject_pipeline.cc - Updated the TVM submodule to the latest commit. - Improved the handling of LetStmt in the inject_pipeline.cc file to account for transitive dependencies on loop variables, ensuring correct variable substitution in rewritten blocks. - Adjusted test_tilelang_issue_1263.py to remove unnecessary jit decorator and updated the kernel compilation process with specific pass configurations. * lint fix * revert tvm * remove unused test * test fix
…s operations (tile-ai#1663) * [Enhancement] Update CallNode handling to include annotations in various operations - Modified CallNode invocations in multiple files to ensure that annotations are passed correctly, enhancing the consistency and functionality of the codebase. - Removed the "use_tma" annotation from AtomicAddNode and adjusted related calls to maintain expected behavior. - Updated CUDA intrinsic dispatch functions to include annotations, improving compatibility and correctness in CUDA operations. * lint fix
…#1664) * [Fix] Refactor type hint extraction logic in DSLMutator for better clarity and handling of annotations * [Refactor] Remove redundant tensor creation in loop layout tests and update kernel compilation parameters
* [Feat] Add tilelang autodd for delta debugging * fix typos * fix lint error * fix typos * fix lint error * fix bugs * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix codeview comments * [Refactor] Move AutoDD detection to env module and update import logic * Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation. * Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage. * Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability. * [Documentation] Add AutoDD section to debug_tools_for_tilelang.md * Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs. * Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity. * Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects. * Included tips for effective usage and a reference to a complete example in the documentation. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Latest PR of TileLang has fixed this lint error. I also pull from the main branch and fix it. |
* support cp.reduce.async.bulk.tensor and add test * Refactor flash attention example by removing unnecessary layout annotations * support swizzle layout for tma reduce * auto swizzle for non-1d tma atomic add * upd example and test * lint * typo * add constraint for test * Refactor CUDA data type mapping by moving the to_CUtensorMapDataType function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc. * lint
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
340-343: Inconsistent copy pattern:dkshould be copied viadk_shared.Line 343 copies directly from
dkfragment to global memory, butdk_sharedwas populated at line 342 and not used. This is inconsistent with thedvpattern at lines 340-341 and differs from the varlen implementation atexample_gqa_bwd_tma_reduce_varlen.pyline 481 which correctly usesdk_shared.Proposed fix
T.copy(dv, dv_shared) T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) T.copy(dk, dk_shared) - T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :]) + T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
🤖 Fix all issues with AI agents
In `@examples/KDA/chunk_delta_bwd.py`:
- Around line 262-268: The reference tensors dh_ref, dh0_ref, dv2_ref are only
produced when use_gk is True but compare_tensors is called unconditionally;
either compute the reference via chunk_gated_delta_rule_bwd_dhu when use_gk is
False (call the same function path or its non-gk equivalent) or wrap the
compare_tensors calls in an if use_gk block so they only run when
dh_ref/dh0_ref/dv2_ref exist; update the logic around the
chunk_gated_delta_rule_bwd_dhu call and the subsequent compare_tensors calls to
ensure the reference variables are always defined before use or comparisons are
conditional on use_gk.
In `@examples/KDA/chunk_inter_solve_fused.py`:
- Around line 81-82: Add a runtime check at the start of the kernel function to
enforce the relationship between chunk_size and sub_chunk_size used elsewhere:
assert chunk_size == 4 * sub_chunk_size (or raise a ValueError with a clear
message). Locate where block_S = BS = chunk_size and BC = sub_chunk_size are set
(symbols block_S, BS, BC, chunk_size, sub_chunk_size) and insert the assertion
immediately after those assignments to prevent silent incorrect computation when
the kernel assumes exactly 4 sub-chunks.
In `@src/op/atomic_add.cc`:
- Around line 373-379: The code dereferences as_const_int(...) results for
mat_stride and mat_continuous without checking for nullptr, which can crash if
shared_tensor->shape elements are not compile-time constants; update the block
using as_const_int(shared_tensor->shape[dim - 2]) and
as_const_int(shared_tensor->shape[dim - 1]) to first capture each result into a
local pointer, verify they are non-null (and handle the non-constant case by
returning an error, throwing, or using a safe fallback), and only then assign
mat_stride/mat_continuous and call makeGemmABLayoutHopper; ensure error handling
is consistent with surrounding code paths that construct Layout when shape
values are unknown.
♻️ Duplicate comments (15)
examples/KDA/test_utils_kda.py (2)
10-17: Fix tensor-to-Python control flow incalc_sim().This function has several issues that were flagged in past reviews but remain unaddressed:
.datais deprecated; use.detach()insteaddenominator == 0compares a tensor to a scalar, which works but is fragile- Returns a tensor instead of a Python float, causing issues in
assert_similarProposed fix
def calc_sim(x, y, name="tensor"): - x, y = x.data.double(), y.data.double() + x, y = x.detach().double(), y.detach().double() denominator = (x * x + y * y).sum() - if denominator == 0: + if denominator.item() == 0: print_red_warning(f"{name} all zero") - return 1 + return 1.0 sim = 2 * (x * y).sum() / denominator - return sim + return float(sim.item())
27-30: Inverted mask logic in non-finite value comparison.The comparison fills finite positions with 0 (since
x_mask = torch.isfinite(x)), leaving non-finite values. To compare non-finite values while masking out finite ones, the masks should be inverted.Proposed fix
- if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all(): + if not torch.isclose(x.masked_fill(~x_mask, 0), y.masked_fill(~y_mask, 0), rtol=0, atol=0, equal_nan=True).all():examples/KDA/chunk_bwd_dqkwg.py (2)
47-51:qk_dtypeparameter is still ignored.The
prepare_outputfunction acceptsqk_dtypebutdqanddkare hardcoded totorch.float32. This was flagged in a previous review. Either use the parameter or remove it.Proposed fix
def prepare_output( B, S, H, DK, DV, chunk_size, gate_dtype, ): - dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() - dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + # Note: Reference implementation uses float32 for dq/dk accumulation + dq = torch.empty(B, S, H, DK, dtype=torch.float32).cuda() + dk = torch.empty(B, S, H, DK, dtype=torch.float32).cuda() dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() return dq, dk, dw, dgNote: Also consider using
torch.emptyinstead oftorch.randnfor output tensors that will be completely overwritten.
122-123: Gate dtype mismatch forG_shared/Gn_shared.These buffers store values copied from
G(which hasgate_dtype) but are allocated withinput_dtype. This can silently change precision. A previous review marked this as addressed, but the code still showsinput_dtype.Proposed fix
- G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G - Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # chunk last token Gexamples/KDA/wy_fast.py (3)
237-240: Type mismatch: passing TileLang types instead of strings.
run_testexpects string dtype arguments (resolved viagetattr(torch, input_dtype)), butmain()passesT.bfloat16,T.float32. This causes anAttributeErrorat runtime. This was flagged in a previous review but remains unfixed.Proposed fix
run_test( B=1, S=8192, H=64, DK=128, DV=128, chunk_size=64, - input_dtype=T.bfloat16, - output_dtype=T.bfloat16, - gate_dtype=T.float32, - accum_dtype=T.float32, + input_dtype="bfloat16", + output_dtype="bfloat16", + gate_dtype="float32", + accum_dtype="float32", block_DK=64, block_DV=32, threads=128, num_stages=3, )
222-222: Typo: "tritron" should be "triton".Proposed fix
- print("tritron time:", triton_time) + print("triton time:", triton_time)
68-69:use_qgparameter declared butQGtensor never populated.The
use_qgparameter andQGoutput tensor are declared, but the kernel never computes or writes toQG. If QG computation is intentionally deferred, add a TODO comment. Otherwise, remove the unused parameter and output.Also applies to: 94-94
examples/KDA/chunk_delta_h_fwd.py (1)
178-182: Potential out-of-bounds access whenSis not divisible byblock_S.At line 180,
GK[bb, (i_s + 1) * block_S - 1, bh, :]can exceed bounds whenSis not evenly divisible byblock_S. For example, ifS=100andblock_S=64, wheni_s=1, the index becomes127, exceeding the valid range[0, 99].This was flagged in a previous review but appears unaddressed.
Proposed fix using index clamping
if use_gk: - T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared) # block last token + last_idx = T.min((i_s + 1) * block_S - 1, S - 1) + T.copy(GK[bb, last_idx, bh, :], GK_last_shared) # block last tokenAlternatively, add a runtime assertion that
S % chunk_size == 0.examples/KDA/chunk_intra_token_parallel.py (2)
60-73: Kernel lacksscaleparameter while reference implementation uses it.The
tilelang_chunk_kda_fwd_intra_token_parallelkernel signature doesn't include ascaleparameter, but the reference implementationchunk_kda_fwd_intra_token_parallel(called at line 212-214) receivesscale. This creates an inconsistency where the reference and TileLang kernel operate with different scaling behaviors, potentially causing validation failures when comparisons are enabled.
233-239: Kernel outputs are computed but never validated against reference.
Aqk_tilelangandAkk_tilelangare computed at line 233 but never compared againstAqk_refandAkk_ref. The comparison code at lines 264-269 is commented out. For a test/validation file, this defeats the purpose of correctness verification mentioned in the PR objectives.Also applies to: 264-269
examples/KDA/chunk_bwd_gla_dA.py (1)
80-81: Potential dtype mismatch forV_shared.
V_sharedis allocated withdo_dtypeat line 81, but it's used to copy fromVwhich hasinput_dtype(line 96). Ifdo_dtype != input_dtype, this could cause implicit type conversion or precision loss.🐛 Suggested fix
do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) - V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)examples/KDA/chunk_delta_bwd.py (1)
121-136: Unusedh0parameter in kernel signature.The kernel declares
h0: T.Tensor(h0_shape, dtype=input_dtype)at line 128 but never uses it in the kernel body. Given theuse_initial_stateflag at line 217-218 only writes todh0, theh0input may have been intended for initialization logic that's not yet implemented.examples/KDA/wy_fast_bwd.py (2)
341-351: Critical type mismatch:T.float32passed where string expected.
main()passes TileLang type objects (e.g.,T.float32) torun_test, butrun_testusesgetattr(torch, input_dtype)at lines 270-274, which expects string arguments like"float32". This will cause anAttributeErrorat runtime.🐛 Proposed fix
def main(): DK = 128 DV = 128 run_test( B=1, S=32768, H=8, DK=DK, DV=DV, - input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", chunk_size=64,
295-307: String dtypes passed to TileLang kernel expecting TileLang types.
run_testpasses string dtype parameters directly totilelang_wy_fast_bwd, but the kernel'sT.Tensor(..., dtype=input_dtype)declarations expect TileLang dtype objects. This creates an inconsistency - either convert strings to TileLang types before the kernel call, or have the kernel handle string conversion internally.🔧 Proposed fix - convert dtypes before kernel call
kernel = tilelang_wy_fast_bwd( B, S, H, DK, DV, - input_dtype, - output_dtype, - accum_dtype, - gate_dtype, - state_dtype, + getattr(T, input_dtype), + getattr(T, output_dtype), + getattr(T, accum_dtype), + getattr(T, gate_dtype), + getattr(T, state_dtype), chunk_size, )examples/KDA/chunk_bwd_intra.py (1)
478-493: Critical:main()passes TileLang types butrun_testexpects strings.This issue was flagged in a previous review and remains unfixed.
main()passesT.float32(a TileLang type object) torun_test, butrun_testusesgetattr(torch, input_dtype)at lines 407-411, which expects string arguments like"float32". This will cause anAttributeErrorat runtime.🐛 Proposed fix
def main(): DK = 128 run_test( B=1, S=8192, H=8, DK=DK, - input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", chunk_size=64, threads=128, num_stages=0, )
🧹 Nitpick comments (17)
src/tl_templates/cuda/copy_sm90.h (1)
265-331: Implementation looks correct and follows established patterns.The new
tma_store_addoverloads correctly implement descriptor-based TMA atomic add operations for 1D through 5D tensors. The inline assembly uses the proper PTX instruction format forcp.reduce.async.bulk.tensor.Xd.global.shared::cta.add.bulk_group.One minor inconsistency: the
tma_storeoverloads (lines 184-253) include aCacheHintSm90template parameter with.L2::cache_hintin the assembly, while thesetma_store_addoverloads omit cache hint support. The PTX ISA does support L2 cache hints for reduce operations, so adding this for API consistency could be beneficial.♻️ Optional: Add cache hint support for consistency with tma_store
+template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, void const *const smem_ptr, int32_t const &crd0) { uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); asm volatile( - "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group " - "[%0, {%2}], [%1];" + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group" + ".L2::cache_hint [%0, {%2}], [%1], %3;" : - : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0) + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "l"(cache_hint) : "memory"); }Apply similar changes to the 2D-5D overloads if cache hint control is desired.
src/op/atomic_add.cc (1)
317-332: Consider extracting to a shared utility function.This implementation is identical to
CopyNode::ComputeLinearLayoutinsrc/op/copy.cc(lines 269-284). Consider extracting this to a shared utility function to reduce duplication and ensure consistent behavior.testing/python/language/test_tilelang_language_atomic_add.py (1)
389-401: Consider also running the explicit_swizzle variant to verify correctness.The test compiles with
explicit_swizzle=Trueand verifies kernel source equivalence, but doesn't actually execute the explicit swizzle variant to confirm numerical correctness.♻️ Suggested enhancement
kernel_with_explicit_swizzle = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32], explicit_swizzle=True) # Ensure auto swizzled layout is applied assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source() + + # Also verify explicit_swizzle variant produces correct results + out_explicit = torch.zeros((16, 16), dtype=torch.float32, device="cuda") + tma_atomic_add_program(out_explicit, explicit_swizzle=True) + torch.testing.assert_close(out_explicit, torch.ones((16, 16), dtype=torch.float32, device="cuda") * 16)examples/KDA/test_utils_kda.py (1)
43-43: Unusedatolandrtolparameters.These parameters are accepted but never used in the function body. Either use them for tolerance-based pass/fail logic or remove them to avoid confusion.
examples/KDA/chunk_bwd_dv.py (1)
31-40: Unusedchunk_sizeparameter inprepare_output.The
chunk_sizeparameter is accepted but never used. Consider removing it or documenting why it's reserved for future use.examples/KDA/chunk_bwd_dqkwg.py (1)
200-219: Multiple unused parameters inrun_test.Parameters
qk_dtype,use_gk,use_initial_state,store_final_state,save_new_value,block_DK,block_DV,threads, andnum_stagesare accepted but never used. Consider removing unused parameters or documenting they're reserved for future use.examples/KDA/chunk_intra_token_parallel.py (1)
13-30: Unused function parameters inprepare_input.The
output_dtypeandaccum_dtypeparameters are declared but never used within this function. Consider removing them if they serve no purpose, or prefix with underscore to indicate intentional non-use.♻️ Suggested fix
def prepare_input( B, S, H, DK, chunk_size, input_dtype, - output_dtype, - accum_dtype, + _output_dtype, + _accum_dtype, gate_dtype, ):examples/KDA/chunk_bwd_gla_dA.py (1)
106-117: UnusedDKparameter inrun_test.The
DKparameter is declared at line 110 but never used in the function body. This appears to be a copy-paste artifact from other similar test functions.♻️ Suggested fix
def run_test( B, S, H, - DK, DV, scale, input_dtype, do_dtype, da_dtype, chunk_size, ):And update the call in
main():run_test( B=1, S=1024 * 8, # 32768 H=64, - DK=128, DV=128,examples/KDA/chunk_delta_bwd.py (1)
236-242: Unused tuning parameters inrun_test.The parameters
block_DV,threads,num_stages, anduse_torchare declared but never used. These appear to be intended for manual kernel configuration override but aren't wired through totilelang_chunk_gated_delta_rule_bwd_dhu.♻️ Either remove unused parameters or wire them through
Option 1 - Remove:
def run_test( ... use_final_state_gradient=True, - block_DV=64, - threads=256, - num_stages=0, - use_torch=False, ):Option 2 - Wire through to kernel:
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( ... use_final_state_gradient, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, )examples/KDA/chunk_inter_solve_fused.py (1)
326-398: Forward substitution loops have off-by-one potential in start index.The
T.Pipelinedloops for forward substitution use unusual start indices:
- Line 326:
T.Pipelined(2, T.min(BC, S - i_tc0), ...)- Line 340:
T.Pipelined(BC + 2, T.min(2 * BC, S - i_tc0), ...)- Line 362:
T.Pipelined(2 * BC + 2, T.min(3 * BC, S - i_tc0), ...)- Line 380:
T.Pipelined(3 * BC + 2, T.min(4 * BC, S - i_tc0), ...)The
+ 2offset in start indices appears intentional to skip the first two rows of each diagonal block, but this pattern is non-obvious and should be documented.Consider adding a comment explaining why the loop starts at
i*BC + 2rather thani*BC:# Start at offset 2 to skip the first two rows already handled by identity initialization for i_i in T.Pipelined(2, T.min(BC, S - i_tc0), num_stages=num_stages):examples/KDA/wy_fast_bwd.py (1)
14-14: Debug print settings left enabled.
torch.set_printoptions(profile="full")is set at line 14, which will produce verbose output for all tensor prints. This is typically a debugging artifact that should be removed before merging.♻️ Remove debug setting
torch.random.manual_seed(0) -torch.set_printoptions(profile="full")examples/KDA/chunk_bwd_intra.py (6)
3-3: Remove unusedsysimport.The
sysmodule is imported but never used. The# noqa: F401directive is also flagged as unnecessary by static analysis.🧹 Proposed fix
-import sys # noqa: F401 -
19-49: Consider removing or documenting unused parameters.The parameters
output_dtype,accum_dtype, andstate_dtypeare never used in this function. If they're kept for API consistency with other functions, consider adding a brief comment explaining this. Otherwise, remove them to reduce confusion.
199-199: Minor typo in comment."ofprevious" should be "of previous".
📝 Proposed fix
- for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index of previous sub_chunks
274-274: Remove commented-out code.Lines 274 and 331 contain commented-out code (
T.use_swizzle(10)anddkt_lower_tempallocation). If these are no longer needed, consider removing them to keep the code clean.Also applies to: 331-331
444-447: Remove redundant tensor allocation.The tensors allocated by
prepare_outputon line 444-446 are immediately overwritten by the kernel call on line 447. Theprepare_outputcall is unnecessary here.♻️ Proposed fix
- dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output( - B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) - ) dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg)
385-400: Consider removing or documenting unused parameters.The parameters
threads,num_stages,cu_seqlens, andchunk_indicesare declared but never used in this function. If they're placeholders for future functionality, consider adding a TODO comment or removing them.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (22)
examples/KDA/chunk_bwd_dqkwg.pyexamples/KDA/chunk_bwd_dv.pyexamples/KDA/chunk_bwd_gla_dA.pyexamples/KDA/chunk_bwd_intra.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_delta_h_fwd.pyexamples/KDA/chunk_inter_solve_fused.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/test_utils_kda.pyexamples/KDA/wy_fast.pyexamples/KDA/wy_fast_bwd.pyexamples/autodd/tilelang_buggy.pyexamples/autodd/tilelang_minimized_expected.pyexamples/flash_attention/example_gqa_bwd_tma_reduce.pyexamples/flash_attention/example_gqa_bwd_tma_reduce_varlen.pysrc/op/atomic_add.ccsrc/op/atomic_add.hsrc/op/copy.ccsrc/op/utils.ccsrc/op/utils.hsrc/tl_templates/cuda/copy_sm90.htesting/python/language/test_tilelang_language_atomic_add.py
💤 Files with no reviewable changes (1)
- src/op/copy.cc
✅ Files skipped from review due to trivial changes (2)
- examples/autodd/tilelang_minimized_expected.py
- examples/autodd/tilelang_buggy.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/KDA/chunk_bwd_dv.pyexamples/KDA/wy_fast.pyexamples/KDA/chunk_delta_bwd.pyexamples/KDA/chunk_bwd_dqkwg.pytesting/python/language/test_tilelang_language_atomic_add.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/wy_fast_bwd.pyexamples/KDA/chunk_delta_h_fwd.py
📚 Learning: 2025-12-15T07:23:50.065Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:50.065Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.
Applied to files:
src/tl_templates/cuda/copy_sm90.h
📚 Learning: 2026-01-12T07:25:35.591Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:35.591Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.
Applied to files:
src/op/atomic_add.cc
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/KDA/wy_fast.pyexamples/KDA/chunk_delta_bwd.pytesting/python/language/test_tilelang_language_atomic_add.pyexamples/KDA/chunk_intra_token_parallel.pyexamples/KDA/wy_fast_bwd.pyexamples/KDA/chunk_delta_h_fwd.py
🧬 Code graph analysis (12)
src/op/atomic_add.h (2)
src/op/atomic_add.cc (2)
ComputeLinearLayout(317-332)ComputeLinearLayout(317-317)src/op/copy.cc (2)
ComputeLinearLayout(270-285)ComputeLinearLayout(270-270)
src/op/utils.h (1)
src/op/utils.cc (2)
to_CUtensorMapDataType(96-157)to_CUtensorMapDataType(96-96)
examples/KDA/chunk_bwd_dv.py (2)
examples/KDA/test_utils_kda.py (2)
compare_tensors(43-83)do_bench(86-108)examples/KDA/chunk_o.py (2)
do_bench(161-183)kernel(93-156)
examples/KDA/test_utils_kda.py (2)
tilelang/carver/roller/policy/default.py (1)
sim(290-291)tilelang/language/tir/op.py (1)
all(1913-1930)
examples/KDA/wy_fast.py (2)
examples/KDA/test_utils_kda.py (1)
compare_tensors(43-83)tilelang/language/allocate.py (1)
alloc_shared(39-54)
examples/KDA/chunk_delta_bwd.py (2)
examples/KDA/FLA_KDA/fla_chunk_delta.py (1)
chunk_gated_delta_rule_bwd_dhu(524-579)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)
examples/KDA/chunk_bwd_dqkwg.py (2)
examples/KDA/FLA_KDA/fla_chunk_inter.py (1)
chunk_kda_bwd_dqkwg(141-193)examples/KDA/test_utils_kda.py (2)
do_bench(86-108)compare_tensors(43-83)
examples/KDA/chunk_inter_solve_fused.py (2)
examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)examples/KDA/test_utils_kda.py (2)
compare_tensors(43-83)do_bench(86-108)
examples/KDA/chunk_bwd_gla_dA.py (1)
examples/KDA/test_utils_kda.py (2)
compare_tensors(43-83)do_bench(86-108)
examples/KDA/chunk_intra_token_parallel.py (3)
examples/KDA/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
chunk_kda_fwd_intra_token_parallel(114-168)examples/KDA/FLA_KDA/cumsum.py (1)
chunk_local_cumsum(431-469)examples/KDA/test_utils_kda.py (1)
do_bench(86-108)
examples/KDA/wy_fast_bwd.py (1)
examples/KDA/test_utils_kda.py (2)
do_bench(86-108)compare_tensors(43-83)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
flashattn_bwd_postprocess(198-228)
🪛 Ruff (0.14.11)
examples/KDA/chunk_bwd_dv.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
36-36: Unused function argument: chunk_size
(ARG001)
111-111: Unused function argument: scale
(ARG001)
examples/KDA/test_utils_kda.py
43-43: Unused function argument: atol
(ARG001)
43-43: Unused function argument: rtol
(ARG001)
54-54: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
54-54: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
examples/KDA/wy_fast.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
15-15: Unused function argument: output_dtype
(ARG001)
68-68: Unused function argument: use_qg
(ARG001)
94-94: Unused function argument: QG
(ARG001)
174-174: Unused function argument: block_DK
(ARG001)
175-175: Unused function argument: block_DV
(ARG001)
176-176: Unused function argument: threads
(ARG001)
177-177: Unused function argument: num_stages
(ARG001)
184-184: Unpacked variable QG_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
185-185: Unpacked variable QG_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/chunk_delta_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
28-28: Unused function argument: output_dtype
(ARG001)
29-29: Unused function argument: accum_dtype
(ARG001)
31-31: Unused function argument: state_dtype
(ARG001)
58-58: Unused function argument: gate_dtype
(ARG001)
128-128: Unused function argument: h0
(ARG001)
239-239: Unused function argument: block_DV
(ARG001)
240-240: Unused function argument: threads
(ARG001)
241-241: Unused function argument: num_stages
(ARG001)
242-242: Unused function argument: use_torch
(ARG001)
examples/KDA/chunk_bwd_dqkwg.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
43-43: Unused function argument: DV
(ARG001)
44-44: Unused function argument: chunk_size
(ARG001)
209-209: Unused function argument: qk_dtype
(ARG001)
211-211: Unused function argument: use_gk
(ARG001)
212-212: Unused function argument: use_initial_state
(ARG001)
213-213: Unused function argument: store_final_state
(ARG001)
214-214: Unused function argument: save_new_value
(ARG001)
215-215: Unused function argument: block_DK
(ARG001)
216-216: Unused function argument: block_DV
(ARG001)
217-217: Unused function argument: threads
(ARG001)
218-218: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_inter_solve_fused.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
24-24: Unused function argument: output_dtype
(ARG001)
25-25: Unused function argument: accum_dtype
(ARG001)
46-46: Unused function argument: sub_chunk_size
(ARG001)
examples/KDA/chunk_bwd_intra.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
26-26: Unused function argument: output_dtype
(ARG001)
27-27: Unused function argument: accum_dtype
(ARG001)
29-29: Unused function argument: state_dtype
(ARG001)
56-56: Unused function argument: chunk_size
(ARG001)
60-60: Unused function argument: state_dtype
(ARG001)
95-95: Unused function argument: state_dtype
(ARG001)
132-132: Unused function argument: db
(ARG001)
396-396: Unused function argument: threads
(ARG001)
397-397: Unused function argument: num_stages
(ARG001)
398-398: Unused function argument: cu_seqlens
(ARG001)
399-399: Unused function argument: chunk_indices
(ARG001)
examples/KDA/chunk_bwd_gla_dA.py
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
19-19: Unused function argument: chunk_size
(ARG001)
32-32: Unused function argument: DV
(ARG001)
110-110: Unused function argument: DK
(ARG001)
examples/KDA/chunk_intra_token_parallel.py
20-20: Unused function argument: output_dtype
(ARG001)
21-21: Unused function argument: accum_dtype
(ARG001)
233-233: Unpacked variable Aqk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
233-233: Unpacked variable Akk_tilelang is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/KDA/wy_fast_bwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
25-25: Unused function argument: output_dtype
(ARG001)
26-26: Unused function argument: accum_dtype
(ARG001)
28-28: Unused function argument: state_dtype
(ARG001)
50-50: Unused function argument: chunk_size
(ARG001)
53-53: Unused function argument: state_dtype
(ARG001)
92-92: Unused function argument: state_dtype
(ARG001)
258-258: Unused function argument: block_DK
(ARG001)
259-259: Unused function argument: block_DV
(ARG001)
260-260: Unused function argument: threads
(ARG001)
261-261: Unused function argument: num_stages
(ARG001)
examples/KDA/chunk_delta_h_fwd.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
31-31: Unused function argument: output_dtype
(ARG001)
32-32: Unused function argument: accum_dtype
(ARG001)
98-98: Unused function argument: block_DK
(ARG001)
213-213: Unused function argument: block_DK
(ARG001)
214-214: Unused function argument: block_DV
(ARG001)
215-215: Unused function argument: threads
(ARG001)
216-216: Unused function argument: num_stages
(ARG001)
🔇 Additional comments (32)
src/op/utils.h (2)
9-9: LGTM: Include added for CUDA stub types.The include is necessary to access
CUtensorMapDataTypeenum used by the newto_CUtensorMapDataTypefunction.
20-26: LGTM: Utility declarations are well-documented and correctly placed.The
to_CUtensorMapDataTypedeclaration matches its implementation in utils.cc, and theReverseArraytemplate is a clean, correct implementation using reverse iterators.src/op/utils.cc (1)
95-157: LGTM: Comprehensive data type mapping with proper error handling.The function correctly handles all common TVM data types. The mapping of
int16toCU_TENSOR_MAP_DATA_TYPE_UINT16is correct—CUDA's TensorMap API defines only UINT16 for 16-bit integers and lacks a signed INT16 type.examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
576-582: LGTM!The
mod_postinvocation correctly matches the function signature, and the pattern of applying postprocessing only todqwhile handlingdk/dvvia reduction is appropriate since they're stored in a different shape(groups, *k.shape).examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
399-408: LGTM!The
mod_postinvocation correctly matches the function signature(batch, heads, head_kv, seq_len, dim_qk, dim_v), and the usage pattern aligns with the varlen implementation.src/op/atomic_add.h (1)
80-81: LGTM!The method declaration is consistent with the implementation pattern used by other operators (e.g.,
CopyNode::ComputeLinearLayout).src/op/atomic_add.cc (3)
8-8: LGTM!The include is required for accessing layout utility functions (
makeGemmABLayoutHopper,makeLinearLayout, swizzle layout functions) used in the TMA atomic add implementation.
538-545: Warnings for unsupported layouts should consider failing compilation.When an unsupported swizzle layout is detected (padded or unknown), the code logs a warning but proceeds with
SWIZZLE_NONE. This could lead to incorrect results if the shared memory was actually written using a swizzled pattern.Consider whether this should be an error that fails compilation, or at minimum document this behavior so users understand the potential for silent correctness issues.
616-657: LGTM!The TMA reduce implementation correctly handles both single-step and multi-step (split) invocation paths. The loop unrolling approach for split operations is appropriate for TMA, and the thread predicate guard ensures only one thread per block executes the TMA operation.
testing/python/language/test_tilelang_language_atomic_add.py (3)
2-4: LGTM!The new imports are required for the TMA atomic add test functionality.
355-368: LGTM!The test program correctly exercises both automatic and explicit swizzle layout paths for TMA atomic add operations.
370-374: Good addition of half-precision type coverage.Testing
float16andbfloat16alongsidefloat32ensures the atomic operations work correctly across common data types.examples/KDA/test_utils_kda.py (1)
86-107: LGTM!The benchmarking function correctly uses CUDA events for timing with proper warmup and synchronization. The debug print at line 107 may be intentional but could be removed or made optional if noisy output is undesirable.
examples/KDA/chunk_bwd_dv.py (2)
54-102: LGTM!The TileLang kernel correctly implements the chunked backward computation for
dv. The lower triangular mask application and pipelined GEMM operations are properly structured.
105-137: LGTM with minor note.The test flow correctly compares the TileLang kernel against the reference implementation and benchmarks both. The
scaleparameter at line 111 is unused but may be kept for API consistency with related test functions.examples/KDA/chunk_bwd_dqkwg.py (1)
264-288: LGTM!The
main()function correctly passes string dtype names, and the validation calls are properly enabled.examples/KDA/wy_fast.py (1)
97-159: LGTM for kernel implementation.The kernel correctly implements the forward computation for
W,U, and optionalKGwith proper shared memory management and pipelining.examples/KDA/chunk_delta_h_fwd.py (3)
272-283: LGTM!The benchmark now correctly uses
gk=Ganduse_exp2=True, matching the correctness test parameters.
115-194: LGTM for kernel implementation.The kernel correctly implements the forward chunked gated delta rule with proper handling of initial/final states, pipelined iteration, and optional V_new output. The ceildiv usage at line 104 ensures proper tensor allocation.
295-316: LGTM!The
main()function correctly passes string dtype names, addressing the previous type mismatch issue.examples/KDA/chunk_intra_token_parallel.py (2)
93-99: Index computation logic appears correct.The chunk and sub-chunk index calculations (
i_c,i_s,i_tc,i_ts,loops) correctly derive sequence positions for the token-parallel iteration pattern.
161-180: Pipelined loop and GEMM-like reduction logic looks correct.The inner loop properly handles the Q-K gated interactions with exponential gating factors and accumulates results into
Sum_AqkandSum_Akk. The conditional write at line 176-177 correctly avoids writing Akk for the diagonal element (j < bs).examples/KDA/chunk_bwd_gla_dA.py (2)
93-101: Lower-triangular masking and GEMM logic is correct.The kernel correctly computes
dAby accumulating the GEMM ofDOandV^T, then applies a lower-triangular mask (i_s1 >= i_s2) with scaling. The logic aligns with the backward pass semantics.
135-139: Correctness validation is enabled - good practice.Unlike some other files in this PR, this file properly calls
compare_tensorsto validate the TileLang kernel output against the reference.examples/KDA/chunk_delta_bwd.py (1)
177-218: Backward pass loop structure and gradient updates look correct.The kernel correctly processes chunks in reverse order (
i_s_inv), updatesdvby accumulating with the existing value, and computesdhupdates through GEMMs with Q, W, and gating factors. The conditionaluse_gkanduse_initial_statebranches are properly handled.examples/KDA/chunk_inter_solve_fused.py (2)
400-404: Identity matrix addition for diagonal blocks.The code correctly adds identity to diagonal blocks after forward substitution, which is standard for computing
(I - L)^{-1}whereLis strictly lower triangular.
558-559: Validation is enabled - good.The file properly validates both
AqkandAkkoutputs against the reference implementation.examples/KDA/wy_fast_bwd.py (3)
183-212: dk update loop logic appears correct.The kernel properly computes dk updates by:
- Loading K, dk_old, dg_old, and GK
- Computing gated K (
K_shared_beta_g = K * Beta * exp2(GK))- Accumulating GEMM results for dA and dk
- Computing dbeta contribution via reduction
- Updating dg
The accumulation pattern and gating logic align with the WY representation backward pass.
233-241: dA correction logic is correct.The dA correction computes
-(A^T @ tril(dA) @ A)which matches the mathematical formulation for the WY backward pass. The lower-triangular masking (i_s1 > i_s2) correctly excludes the diagonal.
310-314: Validation is now enabled - good.The file properly calls
compare_tensorsfor all five outputs (dA, dk, dv, dbeta, dg).examples/KDA/chunk_bwd_intra.py (2)
51-67: LGTM with minor note.The function correctly prepares output tensors. Same note as
prepare_inputregarding unused parameters (chunk_size,state_dtype) - consider documenting if kept for API consistency.
69-78: LGTM!Clean configuration generation for autotuning.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| # fla ref | ||
| print("fla running...", flush=True) | ||
| if use_gk: | ||
| dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu( | ||
| q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, use_exp2=True | ||
| ) | ||
|
|
There was a problem hiding this comment.
Reference computation only runs when use_gk=True.
The reference variables dh_ref, dh0_ref, dv2_ref are only assigned when use_gk=True (lines 264-267), but compare_tensors at lines 298-300 unconditionally uses them. If use_gk=False, this will raise a NameError.
🐛 Proposed fix
compare_tensors("dh", dh_ref, dh_tilelang)
compare_tensors("dh0", dh0_ref, dh0_tilelang)
compare_tensors("dv2", dv2_ref, dv2_tilelang)
+ # Note: These comparisons assume use_gk=True; wrap in conditional if use_gk can be FalseOr wrap in conditional:
+ if use_gk:
compare_tensors("dh", dh_ref, dh_tilelang)
compare_tensors("dh0", dh0_ref, dh0_tilelang)
compare_tensors("dv2", dv2_ref, dv2_tilelang)🤖 Prompt for AI Agents
In `@examples/KDA/chunk_delta_bwd.py` around lines 262 - 268, The reference
tensors dh_ref, dh0_ref, dv2_ref are only produced when use_gk is True but
compare_tensors is called unconditionally; either compute the reference via
chunk_gated_delta_rule_bwd_dhu when use_gk is False (call the same function path
or its non-gk equivalent) or wrap the compare_tensors calls in an if use_gk
block so they only run when dh_ref/dh0_ref/dv2_ref exist; update the logic
around the chunk_gated_delta_rule_bwd_dhu call and the subsequent
compare_tensors calls to ensure the reference variables are always defined
before use or comparisons are conditional on use_gk.
| block_S = BS = chunk_size | ||
| BC = sub_chunk_size |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if chunk_size / sub_chunk_size ratio is documented or enforced elsewhere
rg -n "chunk_size.*sub_chunk_size|BC.*4|4.*BC" examples/KDA/Repository: tile-ai/tilelang
Length of output: 1998
🏁 Script executed:
#!/bin/bash
# Read the file to understand the kernel logic and find any constraints
wc -l examples/KDA/chunk_inter_solve_fused.pyRepository: tile-ai/tilelang
Length of output: 103
🏁 Script executed:
#!/bin/bash
# Examine the file around lines 81-82 and look for constraints
sed -n '70,150p' examples/KDA/chunk_inter_solve_fused.py | cat -nRepository: tile-ai/tilelang
Length of output: 4457
🏁 Script executed:
#!/bin/bash
# Search for assertions or validation related to chunk_size and sub_chunk_size
rg -n "assert|chunk_size|sub_chunk_size" examples/KDA/chunk_inter_solve_fused.py | head -40Repository: tile-ai/tilelang
Length of output: 989
🏁 Script executed:
#!/bin/bash
# Look for the test or main section to see what chunk sizes are used
rg -n "chunk_size.*=|sub_chunk_size.*=" examples/KDA/chunk_inter_solve_fused.pyRepository: tile-ai/tilelang
Length of output: 344
🏁 Script executed:
#!/bin/bash
# Look at the kernel's main loop structure around line 380 and 503 for context
sed -n '375,385p' examples/KDA/chunk_inter_solve_fused.py | cat -nRepository: tile-ai/tilelang
Length of output: 916
🏁 Script executed:
#!/bin/bash
# Check line 503 context
sed -n '500,510p' examples/KDA/chunk_inter_solve_fused.py | cat -nRepository: tile-ai/tilelang
Length of output: 502
🏁 Script executed:
#!/bin/bash
# Look for any documentation comments in the file about constraints
rg -n "chunk.*4|4.*chunk|sub.chunk.*ratio" examples/KDA/chunk_inter_solve_fused.pyRepository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
#!/bin/bash
# Check if there are similar patterns in other KDA files to understand if this is a known constraint
rg -n "chunk_size.*4|4.*sub_chunk" examples/KDA/ -A 2 -B 2Repository: tile-ai/tilelang
Length of output: 7682
Add assertion to enforce chunk_size = 4 * sub_chunk_size constraint.
The kernel hardcodes logic for exactly 4 sub-chunks (see line 380's T.min(4 * BC, ...) and line 503's 3 * BC : 4 * BC indexing). This requires chunk_size = 4 * sub_chunk_size, but there is no validation to enforce it. Passing incompatible sizes (e.g., chunk_size=64, sub_chunk_size=32) would silently produce incorrect results. Add an assertion at the start of the kernel function to validate this constraint.
🤖 Prompt for AI Agents
In `@examples/KDA/chunk_inter_solve_fused.py` around lines 81 - 82, Add a runtime
check at the start of the kernel function to enforce the relationship between
chunk_size and sub_chunk_size used elsewhere: assert chunk_size == 4 *
sub_chunk_size (or raise a ValueError with a clear message). Locate where
block_S = BS = chunk_size and BC = sub_chunk_size are set (symbols block_S, BS,
BC, chunk_size, sub_chunk_size) and insert the assertion immediately after those
assignments to prevent silent incorrect computation when the kernel assumes
exactly 4 sub-chunks.
| int dim = shared_tensor->shape.size(); | ||
| const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]); | ||
| const int64_t mat_continuous = | ||
| *as_const_int(shared_tensor->shape[dim - 1]); | ||
| Layout swizzle_layout = | ||
| makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous, | ||
| shared_tensor->dtype.bits(), /*k_inner=*/true); |
There was a problem hiding this comment.
Potential null pointer dereference if tensor dimensions are not constants.
as_const_int returns nullptr when the expression is not a compile-time constant. Dereferencing without a null check could cause a segmentation fault.
🐛 Proposed fix with null checks
// TMA atomic add is similar to TMA Store - we should perform swizzle if
// possible Use the last two dimensions to analyze swizzling
int dim = shared_tensor->shape.size();
- const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]);
- const int64_t mat_continuous =
- *as_const_int(shared_tensor->shape[dim - 1]);
+ auto mat_stride_ptr = as_const_int(shared_tensor->shape[dim - 2]);
+ auto mat_continuous_ptr = as_const_int(shared_tensor->shape[dim - 1]);
+ if (mat_stride_ptr == nullptr || mat_continuous_ptr == nullptr) {
+ // Cannot determine swizzle layout for non-constant dimensions
+ result_map.Set(shared_tensor, ComputeLinearLayout(shared_tensor));
+ return result_map;
+ }
+ const int64_t mat_stride = *mat_stride_ptr;
+ const int64_t mat_continuous = *mat_continuous_ptr;
Layout swizzle_layout =
makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous,
shared_tensor->dtype.bits(), /*k_inner=*/true);🤖 Prompt for AI Agents
In `@src/op/atomic_add.cc` around lines 373 - 379, The code dereferences
as_const_int(...) results for mat_stride and mat_continuous without checking for
nullptr, which can crash if shared_tensor->shape elements are not compile-time
constants; update the block using as_const_int(shared_tensor->shape[dim - 2])
and as_const_int(shared_tensor->shape[dim - 1]) to first capture each result
into a local pointer, verify they are non-null (and handle the non-constant case
by returning an error, throwing, or using a safe fallback), and only then assign
mat_stride/mat_continuous and call makeGemmABLayoutHopper; ensure error handling
is consistent with surrounding code paths that construct Layout when shape
values are unknown.
- Updated the TVM submodule to commit 354eef9a. - Removed several outdated KDA example files and utility scripts that are no longer in use, including chunk_bwd_dqkwg.py, chunk_bwd_dv.py, chunk_bwd_gla_dA.py, chunk_bwd_intra.py, chunk_delta_bwd.py, chunk_delta_h_fwd.py, chunk_inter_solve_fused.py, chunk_intra_token_parallel.py, chunk_o.py, README.md, test_utils_kda.py, wy_fast_bwd.py, wy_fast.py, and various FLA_KDA implementations.
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Fix all issues with AI agents
In `@examples/kda/chunk_bwd_dv.py`:
- Line 137: The inline comment for the constant S is incorrect: the expression S
= 1024 * 8 evaluates to 8192, not 32768; update either the numeric expression or
the comment. Locate the S definition (symbol "S") and either change the
multiplication to match 32768 (e.g., 1024 * 32) if 32768 was intended, or update
the trailing comment to "# 8192" if the current value is correct; make sure the
code and comment are consistent.
In `@examples/kda/chunk_bwd_gla_dA.py`:
- Line 134: The inline comment for the constant S is incorrect: the expression S
= 1024 * 8 evaluates to 8192, not 32768; update the comment next to the S
assignment in examples/kda/chunk_bwd_gla_dA.py (the S variable declaration) to
reflect the correct value (8192) or remove the misleading comment, and ensure
consistency with the same fix applied to the S comment in chunk_bwd_dv.py if
present.
In `@examples/kda/FLA_KDA/fla_chunk_delta.py`:
- Around line 536-543: The function currently hard-codes BT = 64 while accepting
a chunk_size parameter, causing mismatches when callers pass a different
chunk_size; either bind BT to the parameter or enforce chunk_size==64. Fix by
replacing the hard-coded BT assignment with BT = chunk_size (or add an assertion
like assert chunk_size == 64 with a clear message) and ensure any use of
chunk_indices and downstream offsets relies on BT (the same value) so the kernel
and Python-side chunking remain consistent; update references to BT/chunk_size
in this function (and any callers using chunk_indices) accordingly.
- Around line 524-535: The parameter annotation for scale in
chunk_gated_delta_rule_bwd_dhu is incorrect: it has a default of None but is
typed as float; change the annotation to Optional[float] and add the
corresponding import from typing (e.g., from typing import Optional) at the top
of the module so the function signature reads scale: Optional[float] = None;
keep other parameters unchanged and ensure any linters/type checkers accept the
Optional import.
In `@examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py`:
- Around line 114-125: The function chunk_kda_fwd_intra_token_parallel currently
declares a None return but actually returns Aqk and Akk; update the signature to
reflect the real return type (e.g., -> Tuple[torch.Tensor, torch.Tensor]) and
add the necessary typing import (from typing import Tuple) or use a
forward-compatible annotation (e.g., -> tuple[torch.Tensor, torch.Tensor] for
Py3.9+), ensuring the function name chunk_kda_fwd_intra_token_parallel and its
callers are consistent with the corrected return type.
In `@examples/kda/FLA_KDA/fla_chunk_intra.py`:
- Around line 10-16: The code currently sets SOLVE_TRIL_DOT_PRECISION based on
IS_TF32_SUPPORTED (the conditional block setting tl.constexpr("tf32x3") or
tl.constexpr("ieee")) and then immediately overwrites it with
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32"), which defeats the conditional;
remove the unconditional override line so SOLVE_TRIL_DOT_PRECISION respects
IS_TF32_SUPPORTED (or conversely delete the conditional and keep the hardcoded
tl.constexpr("tf32") line if tf32 is intentionally fixed)—update references to
SOLVE_TRIL_DOT_PRECISION used by tl.dot operations accordingly.
In `@examples/kda/FLA_KDA/fla_utils.py`:
- Around line 16-25: The module currently calls CUDA APIs at import time which
can raise on CPU-only hosts; update the initialization of device,
device_torch_lib, IS_NVIDIA_HOPPER, and USE_CUDA_GRAPH to first guard with
torch.cuda.is_available() (or wrap torch.cuda calls in try/except) and provide
safe CPU fallbacks (e.g., device="cpu", device_torch_lib=torch) and default
IS_NVIDIA_HOPPER=False when CUDA is unavailable or the queries fail; ensure
USE_CUDA_GRAPH still reads the env var but only enables when CUDA is available.
In `@examples/kda/wy_fast_bwd.py`:
- Around line 55-60: The preallocated dA uses DK but the kernel writes (B, S, H,
chunk_size); change the dA allocation in prepare_output (the line creating dA)
to use chunk_size instead of DK so dA is torch.empty(B, S, H, chunk_size,
dtype=output_dtype).cuda(); keep the other tensors and dtypes unchanged and
ensure any downstream code expecting dA uses that chunk_size shape.
♻️ Duplicate comments (22)
examples/kda/test_utils_kda.py (2)
10-17: Fix tensor-to-Python control flow issues incalc_sim().This function has multiple issues that were flagged in a previous review:
.datais deprecated; use.detach()insteaddenominator == 0comparison on a tensor is ambiguous- Returns a tensor instead of a Python float, causing issues in downstream comparisons
27-30: Inverted mask logic when comparing non-finite values.The mask logic is inverted:
x_maskisTruefor finite values, somasked_fill(x_mask, 0)zeros out the finite positions and leaves non-finite values for comparison. This should use~x_maskto mask out the non-finite values and compare only finite ones (or vice versa depending on intent).examples/kda/chunk_delta_bwd.py (2)
117-132: Unusedh0parameter in kernel signature.The kernel declares
h0: T.Tensor(h0_shape, dtype=input_dtype)but never uses it. This was flagged in a previous review. Either implement the initial state handling logic or remove the parameter.
244-280:NameErrorwhenuse_gk=False: reference variables undefined.The reference tensors
dh_ref,dh0_ref,dv2_refare only assigned whenuse_gk=True(lines 244-247), butcompare_tensorsat lines 278-280 uses them unconditionally. Ifuse_gk=False, this will raise aNameError.examples/kda/chunk_bwd_gla_dA.py (1)
79-80: Potential dtype mismatch forV_shared.
V_sharedis allocated withdo_dtypebut is used to copy fromVwhich hasinput_dtype. This was flagged in a previous review.examples/kda/FLA_KDA/cumsum.py (2)
176-177: Remove redundantif i_c >= 0.
i_cstarts at 0 and is always non‑negative, so the branch is dead code.♻️ Suggested refactor
- if i_c >= 0: - b_z += b_ss + b_z += b_ss
430-440: Tightenchunk_local_cumsumAPI and error message.
**kwargsis unused (silently ignores unexpected args) and the error message omits the 3D scalar shape.🧹 Suggested fix
def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, scale: float = None, cu_seqlens: torch.Tensor = None, head_first: bool = False, output_dtype: torch.dtype = torch.float, chunk_indices: torch.LongTensor = None, - **kwargs, ) -> torch.Tensor: @@ else: raise ValueError( f"Unsupported input shape {g.shape}, " - f"which should be (B, T, H, D) if `head_first=False` or (B, H, T, D) otherwise", + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise", )Also applies to: 467-469
examples/kda/wy_fast.py (3)
53-94:use_qg/QGare declared but never written.
QGremains uninitialized even when returned. Either remove the flag/output or implement the QG write path.
203-204: Fix benchmark label typo.✅ Suggested fix
- print("tritron time:", triton_time) + print("triton time:", triton_time)
219-222: Pass dtype names as strings inmain().
run_testusesgetattr(torch, input_dtype)soT.bfloat16/T.float32will raiseAttributeError.🐛 Suggested fix
- input_dtype=T.bfloat16, - output_dtype=T.bfloat16, - gate_dtype=T.float32, - accum_dtype=T.float32, + input_dtype="bfloat16", + output_dtype="bfloat16", + gate_dtype="float32", + accum_dtype="float32",#!/bin/bash # Verify run_test expects string dtype names rg -n "getattr\\(torch, input_dtype\\)|input_dtype=T" examples/kda/wy_fast.pyexamples/kda/FLA_KDA/fla_chunk_o.py (2)
398-405: Initializeb_Afor theUSE_A=Falsepath.
b_Ais used unconditionally but only assigned whenUSE_Ais true, which will crash at runtime on the false path.🐛 Suggested fix
if USE_A: p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) b_A = tl.load(p_A, boundary_check=(0, 1)) + else: + b_A = tl.zeros([BT, BT], dtype=tl.float32)
432-435: Callcheck_shared_memin theelifbranch.
The current branch tests the function object instead of invoking it, so theelifalways evaluates truthy when the first condition fails.🔧 Suggested fix
- elif check_shared_mem: + elif check_shared_mem("ampere", k.device.index): CONST_TILING = 64examples/kda/chunk_intra_token_parallel.py (1)
176-254: TileLang path ignoresscaleand outputs aren’t validated.
scaleis passed to the reference but never to the TileLang kernel, and the TileLang outputs are not compared against references.✅ Suggested fix (validation)
-from test_utils_kda import do_bench +from test_utils_kda import do_bench, compare_tensors @@ Aqk_tilelang, Akk_tilelang = kernel( q, k, gk, beta, ) @@ + compare_tensors("Aqk", Aqk_ref, Aqk_tilelang) + compare_tensors("Akk", Akk_ref, Akk_tilelang)Please also thread
scaleintotilelang_chunk_kda_fwd_intra_token_paralleland apply it consistently with the reference kernel.examples/kda/chunk_delta_h_fwd.py (1)
167-169: Clamp GK last-token index for partial chunks.
WhenSisn’t divisible byblock_S,(i_s + 1) * block_S - 1can exceedS - 1.🐛 Suggested fix
- if use_gk: - T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared) # block last token + if use_gk: + last_idx = T.min((i_s + 1) * block_S - 1, S - 1) + T.copy(GK[bb, last_idx, bh, :], GK_last_shared) # block last tokenexamples/kda/chunk_inter_solve_fused.py (1)
69-70: Enforce kernel size invariants.
The kernel hardcodes 4 sub‑chunks per chunk and later loads diagonals unconditionally; without explicit checks, incompatible sizes can cause incorrect results or OOB access.✅ Suggested fix
block_S = BS = chunk_size BC = sub_chunk_size + assert chunk_size == 4 * sub_chunk_size, "chunk_size must be 4 * sub_chunk_size" + assert S % chunk_size == 0, "S must be divisible by chunk_size"examples/kda/wy_fast_bwd.py (2)
289-302: Convert dtype strings to TileLang dtypes before kernel construction.
run_testuses string dtypes fortorchconversion, but those strings are passed directly totilelang_wy_fast_bwd, which expects TileLang dtypes inT.Tensor(..., dtype=...). This mismatch will fail unless TileLang explicitly accepts strings here.🐛 Suggested fix
+ tl_input_dtype = getattr(T, input_dtype) if isinstance(input_dtype, str) else input_dtype + tl_output_dtype = getattr(T, output_dtype) if isinstance(output_dtype, str) else output_dtype + tl_accum_dtype = getattr(T, accum_dtype) if isinstance(accum_dtype, str) else accum_dtype + tl_gate_dtype = getattr(T, gate_dtype) if isinstance(gate_dtype, str) else gate_dtype + tl_state_dtype = getattr(T, state_dtype) if isinstance(state_dtype, str) else state_dtype kernel = tilelang_wy_fast_bwd( B, S, H, DK, DV, - input_dtype, - output_dtype, - accum_dtype, - gate_dtype, - state_dtype, + tl_input_dtype, + tl_output_dtype, + tl_accum_dtype, + tl_gate_dtype, + tl_state_dtype, chunk_size, )
327-340: Fix dtype arguments inmain()to matchrun_test.
run_testusesgetattr(torch, input_dtype)and expects strings like"float32". PassingT.float32will raiseAttributeError.🐛 Suggested fix
- input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32",examples/kda/FLA_KDA/fla_wy_fast.py (1)
269-272: Keep backwardBTconsistent with forward tiling.
prepare_wy_repr_bwdhardcodesBT = 64while the forward path usesBT = A.shape[-1]. If chunk size changes, backward will be incorrect.🐛 Suggested fix
- BT = 64 + BT = A.shape[-1]examples/kda/chunk_bwd_dqkwg.py (2)
121-122: Usegate_dtypefor gate buffers.
G_sharedandGn_sharedstore gate values but are allocated withinput_dtype, which can silently reduce precision.🐛 Suggested fix
- G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G - Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # chunk last token G
37-50: Honorqk_dtypefor dq/dk outputs.
qk_dtypeis accepted byrun_testbut dq/dk are always float32. Either remove the parameter or use it consistently.♻️ Suggested fix
-def prepare_output( +def prepare_output( B, S, H, DK, DV, chunk_size, - gate_dtype, + qk_dtype, + gate_dtype, ): - dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() - dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dq = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda() + dk = torch.randn(B, S, H, DK, dtype=qk_dtype).cuda() dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() return dq, dk, dw, dg- dq, dk, dw, dg = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype)) + dq, dk, dw, dg = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, qk_dtype), getattr(torch, gate_dtype) + )examples/kda/chunk_bwd_intra.py (1)
474-486: Fix dtype arguments inmain()to matchrun_test.
run_testexpects strings forgetattr(torch, ...); passingT.float32raisesAttributeError.🐛 Suggested fix
- input_dtype=T.float32, - output_dtype=T.float32, - accum_dtype=T.float32, - gate_dtype=T.float32, - state_dtype=T.float32, + input_dtype="float32", + output_dtype="float32", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32",examples/kda/FLA_KDA/fla_chunk_delta.py (1)
487-492: InitializeN/NT/chunk_offsetsfor varlen inputs (still uninitialized).When
cu_seqlensis provided, thecu_seqlens is Nonebranch is skipped, soN,NT, andchunk_offsetsare never set before allocation or kernel launch. This still triggersUnboundLocalErrorand/or invalid kernel arguments.🐛 Suggested fix pattern (apply in both forward and backward)
if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N = len(cu_seqlens) - 1 + NT = len(chunk_indices) + # derive per-sequence chunk offsets from chunk_indices + counts = torch.bincount(chunk_indices[:, 0], minlength=N) + chunk_offsets = torch.zeros(N, device=k.device, dtype=torch.int32) + chunk_offsets[1:] = torch.cumsum(counts, 0)[:-1]Also applies to: 545-548
🧹 Nitpick comments (18)
examples/kda/FLA_KDA/fla_utils.py (4)
33-55: Replace the fullwidth comma in the comment.Ruff flags the
,character; use a standard comma to avoid lint churn.✏️ Lint-friendly comment
-# error check,copy from +# error check, copy from
125-152: Don’t swallow driver errors silently inget_multiprocessor_count.Catching
Exceptionandpasscan hide real driver/API failures and make debugging tough. Prefer narrower exceptions and/or a warning.🛠️ Example with warning
- except Exception: - pass + except Exception as e: + warnings.warn(f"Failed to read multiprocessor count (new API): {e}", stacklevel=2)- except Exception: - pass + except Exception as e: + warnings.warn(f"Failed to read multiprocessor count (legacy API): {e}", stacklevel=2)
155-186: Guardcustom_device_ctxusage to CUDA tensors only.
input_guardcurrently uses a CUDA device context for any tensor type; if a CPU tensor slips in, this can create confusing failures. Consider checkingtensor.is_cudabefore entering a CUDA context.🧭 Safer device-context selection
- if tensor is not None: - ctx = custom_device_ctx(tensor.device.index) - else: - ctx = contextlib.nullcontext() + if tensor is not None and tensor.is_cuda: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext()
224-240: Handle shared-memory query failures more explicitly.
BaseException/Exceptionwith silent fallbacks can mask real issues (driver misconfig, bad indices). Prefer narrower exceptions and a warning.🛠️ Narrow exceptions + warning
- except BaseException: - return [-1] + except (RuntimeError, AttributeError, KeyError, IndexError) as e: + warnings.warn(f"Failed to query max shared mem: {e}", stacklevel=2) + return [-1]- except Exception: - return False + except (RuntimeError, AttributeError, KeyError, IndexError) as e: + warnings.warn(f"Failed to check shared mem: {e}", stacklevel=2) + return Falseexamples/kda/test_utils_kda.py (2)
43-84: Remove or use theatolandrtolparameters.The
compare_tensorsfunction acceptsatolandrtolparameters but never uses them. Either implement tolerance-based comparison or remove these parameters to avoid misleading callers.♻️ Proposed fix to use the tolerance parameters
def compare_tensors(name, x, y, atol=1e-5, rtol=1e-5): import numpy as np import torch diff = (x - y).abs() + + # Check if tensors are close within tolerance + is_close = torch.allclose(x, y, atol=atol, rtol=rtol) + print(f"Tensors close (atol={atol}, rtol={rtol}): {is_close}") # ========= Max Absolute Error =========
107-108: Remove debugprint(times)statement.This debug statement prints the raw timing tensor before returning the mean. Consider removing it or gating it behind a
verboseparameter to reduce noise in production benchmarks.♻️ Proposed fix
- print(times) return times.mean().item()examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py (1)
70-71: Translate Chinese comments for consistency.Consider translating these comments to English for broader accessibility:
- "chunk 首坐标" → "chunk start coordinate"
- "subchunk 首坐标" → "subchunk start coordinate"
♻️ Proposed fix
- i_tc = i_c * BT # chunk 首坐标 - i_ts = i_tc + i_s * BC # subchunk 首坐标 + i_tc = i_c * BT # chunk start coordinate + i_ts = i_tc + i_s * BC # subchunk start coordinateexamples/kda/chunk_bwd_dv.py (3)
4-4: Remove unusedsysimport.The
sysmodule is imported but never used. Thenoqa: F401directive suppresses the warning but doesn't address the underlying issue.♻️ Proposed fix
-import sys # noqa: F401
114-125: Dead code:prepare_outputresult is immediately overwritten.The output tensor
dv_tilelangallocated at line 114 is never used because the kernel at line 125 returns a new tensor that overwrites it. Either remove theprepare_outputcall or use it for in-place kernel execution.♻️ Proposed fix
- dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype)) kernel = tilelang_chunk_bwd_kernel_dv_local( B=B, S=S, H=H, DV=DV, input_dtype=input_dtype, output_dtype=output_dtype, do_dtype=do_dtype, chunk_size=chunk_size, ) dv_tilelang = kernel(DO, A)
31-40: Remove unusedchunk_sizeparameter fromprepare_output.The
chunk_sizeparameter is accepted but never used in the function body. Remove it to clean up the API.♻️ Proposed fix
def prepare_output( B, S, H, DV, - chunk_size, output_dtype, ): dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() return dvexamples/kda/chunk_delta_bwd.py (1)
219-222: Unused parameters inrun_testfunction.The parameters
block_DV,threads,num_stages, anduse_torchare accepted but never used. These appear to be remnants of manual tuning that's now handled by autotune. Consider removing them or passing them to override autotuning.examples/kda/chunk_bwd_gla_dA.py (2)
89-89: Translate Chinese comment for consistency.Consider translating "下三角矩阵" to "lower triangular matrix" for broader accessibility.
♻️ Proposed fix
- dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # lower triangular matrix
107-108: Remove debugThis debug print for dtypes should be removed or gated behind a verbose flag.
♻️ Proposed fix
DO, V_new = prepare_input(B, S, H, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) - print(DO.dtype, V_new.dtype) dA_ref = chunk_gla_bwd_dA(v=V_new, do=DO, scale=scale)examples/kda/FLA_KDA/fla_chunk_inter.py (2)
151-154: Use explicitOptionaltype annotation for nullable parameters.Per PEP 484,
scale: float = Noneshould bescale: Optional[float] = None(orscale: float | None = Nonein Python 3.10+) to explicitly indicate the parameter can beNone.♻️ Proposed fix
+from typing import Optional + def chunk_kda_bwd_dqkwg( q: torch.Tensor, k: torch.Tensor, w: torch.Tensor, v: torch.Tensor, h: torch.Tensor, g: torch.Tensor, do: torch.Tensor, dh: torch.Tensor, dv: torch.Tensor, - scale: float = None, - cu_seqlens: torch.LongTensor = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, chunk_size: int = 64, - chunk_indices: torch.LongTensor = None, + chunk_indices: Optional[torch.LongTensor] = None, ):
163-166: Intentional dtype difference for gradient outputs.The code allocates
dqanddkwithdtype=torch.floatwhiledwanddguseempty_like(inheriting the input dtype). This appears intentional for numerical stability in gradient accumulation, but consider adding a comment to clarify the rationale.♻️ Proposed documentation
+ # dq, dk need float32 for gradient accumulation stability dq = torch.empty_like(q, dtype=torch.float) dk = torch.empty_like(k, dtype=torch.float) dw = torch.empty_like(w) dg = torch.empty_like(g)examples/kda/wy_fast_bwd.py (1)
241-257: Kernel config args are accepted but never used.
block_DK,block_DV,threads, andnum_stagesare passed intorun_testbut not forwarded to the kernel, so caller intent is ignored.♻️ Suggested fix
kernel = tilelang_wy_fast_bwd( B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, state_dtype, chunk_size, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, )examples/kda/chunk_o.py (1)
169-213: Forward kernel config params totilelang_chunk_fwd_o.
block_DK,block_DV,threads, andnum_stagesare accepted but never used, so the benchmark ignores caller-provided configs.♻️ Suggested fix
kernel = tilelang_chunk_fwd_o( B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, gate_dtype, chunk_size, scale, block_S, + block_DK=block_DK, + block_DV=block_DV, + threads=threads, + num_stages=num_stages, )examples/kda/chunk_bwd_intra.py (1)
381-438: Forward kernel config args totilelang_chunk_bwd_intra.
threadsandnum_stagesare accepted but unused, so the test ignores the caller’s tuning choices.♻️ Suggested fix
kernel = tilelang_chunk_bwd_intra( B=B, S=S, H=H, DK=DK, input_dtype=input_dtype, output_dtype=output_dtype, accum_dtype=accum_dtype, gate_dtype=gate_dtype, state_dtype=state_dtype, chunk_size=chunk_size, block_DK=block_DK, + threads=threads, + num_stages=num_stages, )
| def main(): | ||
| run_test( | ||
| B=1, | ||
| S=1024 * 8, # 32768 |
There was a problem hiding this comment.
Incorrect comment: 1024 × 8 = 8192, not 32768.
The inline comment states # 32768 but 1024 * 8 = 8192. Either fix the comment or update the value if 32768 was intended.
🐛 Proposed fix (if 8192 is intended)
- S=1024 * 8, # 32768
+ S=1024 * 8, # 8192🐛 Proposed fix (if 32768 is intended)
- S=1024 * 8, # 32768
+ S=1024 * 32, # 32768📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| S=1024 * 8, # 32768 | |
| S=1024 * 8, # 8192 |
| S=1024 * 8, # 32768 | |
| S=1024 * 32, # 32768 |
🤖 Prompt for AI Agents
In `@examples/kda/chunk_bwd_dv.py` at line 137, The inline comment for the
constant S is incorrect: the expression S = 1024 * 8 evaluates to 8192, not
32768; update either the numeric expression or the comment. Locate the S
definition (symbol "S") and either change the multiplication to match 32768
(e.g., 1024 * 32) if 32768 was intended, or update the trailing comment to "#
8192" if the current value is correct; make sure the code and comment are
consistent.
| def main(): | ||
| run_test( | ||
| B=1, | ||
| S=1024 * 8, # 32768 |
There was a problem hiding this comment.
Incorrect comment: 1024 × 8 = 8192, not 32768.
Same issue as in chunk_bwd_dv.py: the comment states # 32768 but the expression evaluates to 8192.
🐛 Proposed fix
- S=1024 * 8, # 32768
+ S=1024 * 8, # 8192📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| S=1024 * 8, # 32768 | |
| S=1024 * 8, # 8192 |
🤖 Prompt for AI Agents
In `@examples/kda/chunk_bwd_gla_dA.py` at line 134, The inline comment for the
constant S is incorrect: the expression S = 1024 * 8 evaluates to 8192, not
32768; update the comment next to the S assignment in
examples/kda/chunk_bwd_gla_dA.py (the S variable declaration) to reflect the
correct value (8192) or remove the misleading comment, and ensure consistency
with the same fix applied to the S comment in chunk_bwd_dv.py if present.
| def chunk_gated_delta_rule_bwd_dhu( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| w: torch.Tensor, | ||
| do: torch.Tensor, | ||
| dv: torch.Tensor, | ||
| g: torch.Tensor = None, | ||
| gk: torch.Tensor = None, | ||
| h0: torch.Tensor = None, | ||
| dht: torch.Tensor = None, | ||
| scale: float = None, | ||
| cu_seqlens: torch.LongTensor = None, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "fla_chunk_delta.py" -type fRepository: tile-ai/tilelang
Length of output: 101
🏁 Script executed:
# Once found, let's read the relevant lines
head -n 550 examples/kda/FLA_KDA/fla_chunk_delta.py | tail -n 50Repository: tile-ai/tilelang
Length of output: 1493
🏁 Script executed:
# Check the imports at the top of the file
head -n 50 examples/kda/FLA_KDA/fla_chunk_delta.pyRepository: tile-ai/tilelang
Length of output: 1318
Use Optional[float] for scale's default None (RUF013).
The scale parameter has a default value of None but is typed as float. Update the type annotation to Optional[float] and add the import from typing.
Proposed fix
-import torch
+import torch
+from typing import Optional
@@
- scale: float = None,
+ scale: Optional[float] = None,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def chunk_gated_delta_rule_bwd_dhu( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| w: torch.Tensor, | |
| do: torch.Tensor, | |
| dv: torch.Tensor, | |
| g: torch.Tensor = None, | |
| gk: torch.Tensor = None, | |
| h0: torch.Tensor = None, | |
| dht: torch.Tensor = None, | |
| scale: float = None, | |
| cu_seqlens: torch.LongTensor = None, | |
| from typing import Optional | |
| import torch | |
| def chunk_gated_delta_rule_bwd_dhu( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| w: torch.Tensor, | |
| do: torch.Tensor, | |
| dv: torch.Tensor, | |
| g: torch.Tensor = None, | |
| gk: torch.Tensor = None, | |
| h0: torch.Tensor = None, | |
| dht: torch.Tensor = None, | |
| scale: Optional[float] = None, | |
| cu_seqlens: torch.LongTensor = None, |
🧰 Tools
🪛 Ruff (0.14.13)
534-534: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_delta.py` around lines 524 - 535, The
parameter annotation for scale in chunk_gated_delta_rule_bwd_dhu is incorrect:
it has a default of None but is typed as float; change the annotation to
Optional[float] and add the corresponding import from typing (e.g., from typing
import Optional) at the top of the module so the function signature reads scale:
Optional[float] = None; keep other parameters unchanged and ensure any
linters/type checkers accept the Optional import.
| chunk_size: int = 64, # SY: remove this argument and force chunk size 64? | ||
| chunk_indices: torch.LongTensor = None, | ||
| use_exp2: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *q.shape, do.shape[-1] | ||
| # N: the actual number of sequences in the batch with either equal or variable lengths | ||
| BT = 64 | ||
| assert K <= 256, "current kernel does not support head dimension being larger than 256." |
There was a problem hiding this comment.
Backward chunk size mismatch (BT hard-coded to 64).
chunk_size is still accepted (and used for chunk_indices), but BT is forced to 64. If callers pass a non-64 chunk_size, chunking and offsets will diverge from the kernel’s BT, producing incorrect results. Either enforce 64 explicitly or wire BT to chunk_size.
🔧 Option: enforce 64 to avoid silent mismatch
- BT = 64
+ assert chunk_size == 64, "chunk_size must be 64 for blockdim64 kernels"
+ BT = chunk_size🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_delta.py` around lines 536 - 543, The function
currently hard-codes BT = 64 while accepting a chunk_size parameter, causing
mismatches when callers pass a different chunk_size; either bind BT to the
parameter or enforce chunk_size==64. Fix by replacing the hard-coded BT
assignment with BT = chunk_size (or add an assertion like assert chunk_size ==
64 with a clear message) and ensure any use of chunk_indices and downstream
offsets relies on BT (the same value) so the kernel and Python-side chunking
remain consistent; update references to BT/chunk_size in this function (and any
callers using chunk_indices) accordingly.
| def chunk_kda_fwd_intra_token_parallel( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| gk: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| Aqk: torch.Tensor, | ||
| Akk: torch.Tensor, | ||
| scale: float, | ||
| cu_seqlens: torch.LongTensor = None, | ||
| chunk_size: int = 64, | ||
| sub_chunk_size: int = 16, | ||
| ) -> None: |
There was a problem hiding this comment.
Return type annotation contradicts actual behavior.
The function signature declares -> None but the function returns Aqk, Akk at line 168. This inconsistency can mislead callers and type checkers.
🐛 Proposed fix
-) -> None:
+) -> tuple[torch.Tensor, torch.Tensor]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def chunk_kda_fwd_intra_token_parallel( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| gk: torch.Tensor, | |
| beta: torch.Tensor, | |
| Aqk: torch.Tensor, | |
| Akk: torch.Tensor, | |
| scale: float, | |
| cu_seqlens: torch.LongTensor = None, | |
| chunk_size: int = 64, | |
| sub_chunk_size: int = 16, | |
| ) -> None: | |
| def chunk_kda_fwd_intra_token_parallel( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| gk: torch.Tensor, | |
| beta: torch.Tensor, | |
| Aqk: torch.Tensor, | |
| Akk: torch.Tensor, | |
| scale: float, | |
| cu_seqlens: torch.LongTensor = None, | |
| chunk_size: int = 64, | |
| sub_chunk_size: int = 16, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py` around lines 114 -
125, The function chunk_kda_fwd_intra_token_parallel currently declares a None
return but actually returns Aqk and Akk; update the signature to reflect the
real return type (e.g., -> Tuple[torch.Tensor, torch.Tensor]) and add the
necessary typing import (from typing import Tuple) or use a forward-compatible
annotation (e.g., -> tuple[torch.Tensor, torch.Tensor] for Py3.9+), ensuring the
function name chunk_kda_fwd_intra_token_parallel and its callers are consistent
with the corrected return type.
| IS_TF32_SUPPORTED = False | ||
| if IS_TF32_SUPPORTED: | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") | ||
| else: | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") | ||
| # ============================================================================ |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's verify the file exists and read the relevant lines
if [ -f "examples/kda/FLA_KDA/fla_chunk_intra.py" ]; then
echo "=== File content (lines 1-30) ==="
head -30 examples/kda/FLA_KDA/fla_chunk_intra.py | cat -n
else
echo "File not found"
fiRepository: tile-ai/tilelang
Length of output: 1258
🏁 Script executed:
# Let's also check if this file exists with a different path pattern
find . -name "fla_chunk_intra.py" -type f 2>/dev/nullRepository: tile-ai/tilelang
Length of output: 101
🏁 Script executed:
# Search for usages of SOLVE_TRIL_DOT_PRECISION in the file
rg "SOLVE_TRIL_DOT_PRECISION" examples/kda/FLA_KDA/fla_chunk_intra.py -nRepository: tile-ai/tilelang
Length of output: 1367
🏁 Script executed:
# Check more context around lines 10-15 to see if there are comments explaining this
sed -n '5,20p' examples/kda/FLA_KDA/fla_chunk_intra.py | cat -nRepository: tile-ai/tilelang
Length of output: 765
Remove the unconditional precision override.
SOLVE_TRIL_DOT_PRECISION is set based on IS_TF32_SUPPORTED on lines 11-14, then immediately overwritten to "tf32" on line 15, which defeats the conditional and renders the flag meaningless. The precision is used extensively in downstream tl.dot operations (lines 278+), so the effective behavior is always "tf32" regardless of the IS_TF32_SUPPORTED check.
Either remove the conditional logic if "tf32" is the intended hardcoded value, or remove line 15 to respect the conditional logic:
Suggested fix
IS_TF32_SUPPORTED = False
-if IS_TF32_SUPPORTED:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
-else:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
-SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")
+SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") if IS_TF32_SUPPORTED else tl.constexpr("ieee")or simply:
IS_TF32_SUPPORTED = False
-if IS_TF32_SUPPORTED:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3")
-else:
- SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee")
SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| IS_TF32_SUPPORTED = False | |
| if IS_TF32_SUPPORTED: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") | |
| else: | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") | |
| # ============================================================================ | |
| IS_TF32_SUPPORTED = False | |
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") if IS_TF32_SUPPORTED else tl.constexpr("ieee") | |
| # ============================================================================ |
🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_chunk_intra.py` around lines 10 - 16, The code
currently sets SOLVE_TRIL_DOT_PRECISION based on IS_TF32_SUPPORTED (the
conditional block setting tl.constexpr("tf32x3") or tl.constexpr("ieee")) and
then immediately overwrites it with SOLVE_TRIL_DOT_PRECISION =
tl.constexpr("tf32"), which defeats the conditional; remove the unconditional
override line so SOLVE_TRIL_DOT_PRECISION respects IS_TF32_SUPPORTED (or
conversely delete the conditional and keep the hardcoded tl.constexpr("tf32")
line if tf32 is intentionally fixed)—update references to
SOLVE_TRIL_DOT_PRECISION used by tl.dot operations accordingly.
| device = "cuda" | ||
| device_torch_lib = getattr(torch, device) | ||
|
|
||
| exp = tldevice.fast_expf | ||
| exp2 = tldevice.exp2 | ||
| log = tldevice.fast_logf | ||
| log2 = tldevice.fast_log2f | ||
|
|
||
| IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) | ||
| USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and read it
find . -name "fla_utils.py" -path "*/FLA_KDA/*"Repository: tile-ai/tilelang
Length of output: 95
🏁 Script executed:
# Read the specific file to examine lines 16-25
cat -n examples/kda/FLA_KDA/fla_utils.py | head -30Repository: tile-ai/tilelang
Length of output: 1143
🏁 Script executed:
# Check the full context of the file, especially around imports and the device initialization
wc -l examples/kda/FLA_KDA/fla_utils.py
cat -n examples/kda/FLA_KDA/fla_utils.pyRepository: tile-ai/tilelang
Length of output: 9304
Guard CUDA device queries at import time.
Line 24 calls CUDA APIs unconditionally at module import, which will raise RuntimeError on CPU-only hosts or when no GPU is visible, breaking imports entirely. Defer evaluation or guard with torch.cuda.is_available() and a safe fallback.
🔧 Suggested guard for safe import
-IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)
+def _is_nvidia_hopper() -> bool:
+ if not torch.cuda.is_available():
+ return False
+ try:
+ name = torch.cuda.get_device_name(0)
+ major, _ = torch.cuda.get_device_capability(0)
+ return ("NVIDIA H" in name) or (major >= 9)
+ except Exception:
+ return False
+
+IS_NVIDIA_HOPPER = _is_nvidia_hopper()🤖 Prompt for AI Agents
In `@examples/kda/FLA_KDA/fla_utils.py` around lines 16 - 25, The module currently
calls CUDA APIs at import time which can raise on CPU-only hosts; update the
initialization of device, device_torch_lib, IS_NVIDIA_HOPPER, and USE_CUDA_GRAPH
to first guard with torch.cuda.is_available() (or wrap torch.cuda calls in
try/except) and provide safe CPU fallbacks (e.g., device="cpu",
device_torch_lib=torch) and default IS_NVIDIA_HOPPER=False when CUDA is
unavailable or the queries fail; ensure USE_CUDA_GRAPH still reads the env var
but only enables when CUDA is available.
| dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() | ||
| dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() | ||
| dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda() | ||
| dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() | ||
| dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() | ||
| return dk, dv, dbeta, dg, dA |
There was a problem hiding this comment.
Fix dA output shape to match kernel output.
dA is produced with shape (B, S, H, chunk_size) in the kernel, but prepare_output allocates (B, S, H, DK). This will misalign if the preallocated output is ever used.
🐛 Suggested fix
- dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
+ dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() | |
| dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() | |
| dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda() | |
| dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() | |
| dA = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() | |
| return dk, dv, dbeta, dg, dA | |
| dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() | |
| dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() | |
| dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda() | |
| dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() | |
| dA = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda() | |
| return dk, dv, dbeta, dg, dA |
🤖 Prompt for AI Agents
In `@examples/kda/wy_fast_bwd.py` around lines 55 - 60, The preallocated dA uses
DK but the kernel writes (B, S, H, chunk_size); change the dA allocation in
prepare_output (the line creating dA) to use chunk_size instead of DK so dA is
torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda(); keep the other
tensors and dtypes unchanged and ensure any downstream code expecting dA uses
that chunk_size shape.
LeiWang1999
left a comment
There was a problem hiding this comment.
LGTM, and I refactored some of the code, I think now we can let this commit in.
The KDA algorithm, originally implemented with Triton in the flash-linear-attention repo, is fully reimplemented based on TileLang (0.1.6.post2+cuda.git729e66ca). Precision alignment and performance testing have been finished for every individual operator.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.