Conversation
|
Warning Rate limit exceeded@zhiyuan1i has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 5 minutes and 32 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (4)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a complete KDA feature: new KDA layer, model family (config, model, causal LM), Triton-backed KDA ops (chunk/fused-recurrent/intra/inter/wy/gate), autograd wrappers, tests, a benchmark, and package exports. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant KDA_Layer
participant Projections
participant KDA_Op
participant Triton_Kernels
participant Output
Client->>KDA_Layer: forward(hidden_states, attention_mask?, past_key_values?)
KDA_Layer->>Projections: compute q, k, v, g, beta
Projections-->>KDA_Layer: q,k,v,g,beta
alt fused_recurrent path
KDA_Layer->>KDA_Op: fused_recurrent_kda(q,k,v,g,beta,...)
else chunked path
KDA_Layer->>KDA_Op: chunk_kda(q,k,v,g,beta,...)
end
KDA_Op->>Triton_Kernels: launch kernels (intra/inter/wy/gate as needed)
Triton_Kernels-->>KDA_Op: o (and optional final_state)
KDA_Op-->>KDA_Layer: o, final_state?
KDA_Layer->>KDA_Layer: output projection, norm, update past_key_values
KDA_Layer-->>Output: (o, None, updated_past_key_values)
Output-->>Client: return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzhangcs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Kernelized Delta Attention (KDA) mechanism, a new approach to attention that promises improved efficiency. It includes the full implementation of KDA as a PyTorch layer, along with highly optimized Triton kernels for its core operations. The changes also integrate KDA into the Hugging Face Transformers library, making it readily available for model building and experimentation. Comprehensive tests ensure the correctness and performance of the new components. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the KDA (Kernelized Delta Attention) model, including its layers, operations, model configuration, and benchmarks. The implementation is comprehensive, adding a new attention mechanism to the library.
My review focuses on correctness, consistency, and clarity. I've found a few issues:
- A logical bug in the benchmark script that would lead to incorrect results.
- Several inconsistencies in docstrings, type hints, and parameter initializations which could cause confusion.
- Minor issues like typos and incorrect examples in docstrings.
Overall, the core implementation of KDA seems solid, but the surrounding code and documentation need some polishing. Please see my detailed comments for suggestions.
| ): | ||
| if isinstance(module, KDA) and next(module.parameters()).device.type != 'meta': | ||
| with torch.no_grad(): | ||
| module.A.copy_(nn.init.uniform_(module.A, a=1, b=16).log()) |
| g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float)) | ||
| g = g * (torch.rand_like(g) > mask_p) | ||
| beta = torch.rand(1, T, H, dtype=dtype).sigmoid() | ||
| h0 = torch.randn((N, H, D, D), dtype=dtype) |
There was a problem hiding this comment.
The initial state h0 is created with dtype=dtype, which can be torch.float16. However, recurrent states are typically expected to be in torch.float32 for precision. In other tests like test_chunk, h0 is explicitly created with torch.float32. Please ensure consistency and use torch.float32 for h0 to avoid potential precision issues.
| h0 = torch.randn((N, H, D, D), dtype=dtype) | |
| h0 = torch.randn((N, H, D, D), dtype=torch.float32) |
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (24)
fla/ops/kda/chunk_inter.py (1)
161-165: Gradient dtype choices — align with input dtypesAllocating dq/dk as float32 may surprise downstream if bf16/half gradients are expected. Recommend matching input dtypes (kernel still accumulates in fp32 and casts on store).
- dq = torch.empty_like(q, dtype=torch.float) - dk = torch.empty_like(k, dtype=torch.float) - dw = torch.empty_like(w) - dg = torch.empty_like(g) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + dg = torch.empty_like(g)Confirm tests expect dq/dk to match q/k dtypes; if a global policy prefers fp32 grads, we can keep current behavior and cast at call sites.
fla/models/kda/configuration_kda.py (3)
9-11: Type class attributes as ClassVarImproves clarity and satisfies linters.
-from typing import Dict, Optional +from typing import Dict, Optional, ClassVar, List @@ -class KDAConfig(PretrainedConfig): - model_type = 'kda' - keys_to_ignore_at_inference = ['past_key_values'] +class KDAConfig(PretrainedConfig): + model_type: ClassVar[str] = 'kda' + keys_to_ignore_at_inference: ClassVar[List[str]] = ['past_key_values']
29-36: Fix Optional typing and defaultsAvoid implicit Optional; also align annotation with default value.
- hidden_ratio: Optional[int] = 4, + hidden_ratio: Optional[int] = 4, @@ - pad_token_id: int = None, + pad_token_id: Optional[int] = None,Optionally prefer PEP 604:
- pad_token_id: Optional[int] = None, + pad_token_id: int | None = None,
69-80: Avoid mutating user-supplied attn dict in-placeCopy before filling defaults.
- if attn is not None: + if attn is not None: + attn = dict(attn) # defensive copy if not isinstance(attn, Dict): raise ValueError("attn must be a dictionary")fla/ops/kda/chunk_intra.py (2)
431-437: Optional: tighten BK tiling for consistencyNot a blocker. Consider keeping BK within [16, 256] and power-of-two.
- BK = max(triton.next_power_of_2(K), 16) + BK = max(16, min(256, triton.next_power_of_2(K)))
482-540: Return values and in-place semanticsdq/dk/db/dg are updated in-place but dq/dk are reassigned to *_2 buffers for return. Clarify this in the docstring and ensure callers don’t rely on preallocated tensors being filled.
tests/models/test_modeling_kda.py (1)
39-56: Generation test: add a couple more togglesOptional: add cases for allow_neg_eigval=True and num_v_heads > num_heads (GVA) to exercise those branches.
fla/ops/kda/fused_recurrent.py (3)
11-23: Type hints: avoid implicit Optional and remove unused kwargs
- Replace implicit optionals (T = None) with union types (T | None) to satisfy RUF013.
- Drop the unused **kwargs (ARG001) or rename to **_ if intentional.
-def fused_recurrent_kda( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor = None, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - use_qk_l2norm_in_kernel: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - **kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: +def fused_recurrent_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor | None = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.LongTensor | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]:
58-86: Docs: fix example function name andgshapeExamples call a non‑existent
fused_gated_recurrent_delta_rule. It should reference this wrapper or the actual callee. Also ensuregmatches[B, T, HV, K]per this wrapper.- >>> g = F.logsigmoid(torch.rand(B, T, HV, K, device='cuda')) + >>> g = F.logsigmoid(torch.rand(B, T, HV, K, device='cuda')) >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() >>> h0 = torch.randn(B, HV, K, V, device='cuda') - >>> o, ht = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True - ) + >>> o, ht = fused_recurrent_kda( + ... q, k, v, g, beta, + ... initial_state=h0, + ... output_final_state=True, + ... ) @@ - >>> o_var, ht_var = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens - ) + >>> o_var, ht_var = fused_recurrent_kda( + ... q, k, v, g, beta, + ... initial_state=h0, + ... output_final_state=True, + ... cu_seqlens=cu_seqlens, + ... )
88-101: Minor lint: long exception messages (TRY003)Optional: store long f-strings in variables before raising to appease TRY003, or keep as-is if TRY rules aren’t enforced CI-wide.
tests/ops/test_kda.py (1)
340-376: Minor test hygiene: avoid retain_graph=True where unnecessaryBoth backward passes in test_kda_gate re-run separate graphs; retain_graph=True isn’t needed and can slightly slow tests. Safe to drop it.
benchmarks/ops/benchmark_kda.py (1)
42-51: Benchmark hygiene: fix TMA restore, include 'attn', remove unused var, sort quantiles
- Preserve original FLA_USE_TMA presence (unset vs set).
- Add 'attn' to provider lists so its path is exercised.
- Remove unused p in kda branch.
- Use ascending quantiles for clarity.
@@ - original_tma_env = os.environ.get('FLA_USE_TMA', '0') + original_tma_env = os.environ.get('FLA_USE_TMA') @@ - quantiles = [0.5, 0.2, 0.8] + quantiles = [0.2, 0.5, 0.8] @@ - line_vals=['gdn', 'comba', 'kda', 'dplr'], + line_vals=['gdn', 'comba', 'kda', 'dplr', 'attn'], @@ - line_names=['gdn', 'comba', 'kda', 'dplr'], + line_names=['gdn', 'comba', 'kda', 'dplr', 'attn'], @@ - styles=[('blue', '-'), ('red', '-.'), ('green', '-'), ('orange', '-.'), - ('purple', '-'), ('brown', '-.'), ('pink', '-'), ('gray', '-.')], + styles=[('blue', '-'), ('red', '-.'), ('green', '-'), ('orange', '-.'), + ('purple', '-')], @@ - p = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) @@ - os.environ['FLA_USE_TMA'] = original_tma_env + if original_tma_env is None: + os.environ.pop('FLA_USE_TMA', None) + else: + os.environ['FLA_USE_TMA'] = original_tma_envAlso applies to: 52-55, 86-105, 105-123, 143-145
fla/ops/kda/naive.py (1)
64-78: Reduce transient allocations in naive_chunk_kdaCloning and allocating A inside loops is heavy. Using in-place ops where safe and preallocating outside inner loops can reduce overhead without changing semantics.
Also applies to: 82-100
fla/ops/kda/chunk.py (1)
249-262: Type hints and unused kwargs cleanup (Ruff RUF013, ARG001)Prefer explicit Optional typing and drop unused kwargs to quiet linters and clarify API.
@torch.compiler.disable def chunk_kda( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, - **kwargs ): @@ - if scale is None: + if scale is None: scale = k.shape[-1] ** -0.5Also applies to: 326-339
fla/layers/kda.py (1)
77-93: Constructor typing and message polishUse Optional typing for None defaults and fix assertion message typo.
- num_v_heads: int = None, + num_v_heads: Optional[int] = None, @@ - layer_idx: int = None, + layer_idx: Optional[int] = None, @@ - assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."Also applies to: 114-130
fla/ops/kda/wy_fast.py (1)
227-234: Varlen NT calculation with chunk_indices length: confirm shapelen(chunk_indices) assumes shape [NT, 2] from prepare_chunk_indices; OK if stable. Add a brief assert to guard against accidental shape regressions.
- NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NT = triton.cdiv(T, BT) if cu_seqlens is None else chunk_indices.shape[0]fla/ops/kda/gate.py (2)
237-247: Keep output dtype consistent with input g (avoid forced fp32 allocation)y is allocated in fp32, which may force extra casts and memory. Prefer g.dtype for output tensor and let kernels upcast internally as needed.
- y = torch.empty_like(g, dtype=torch.float32) + y = torch.empty_like(g)If you want fp32 compute but original dtype output, keep current internal casts and only change the allocation.
18-55: Reference path: ensure output dtype mirrors inputkda_gate_ref casts to float32 for math and returns float32, which can diverge from fused path if you adopt the dtype change above. Return to g.dtype for parity.
- A_exp = -A.float().exp().unsqueeze(-1) # [H, 1] - g_softplus = F.softplus(g.float(), beta, threshold) # [..., H, D] - - return A_exp * g_softplus + A_exp = -A.float().exp().unsqueeze(-1) # [H, 1] + g_softplus = F.softplus(g.float(), beta, threshold) # [..., H, D] + out = A_exp * g_softplus + return out.to(g.dtype)fla/models/kda/modeling_kda.py (6)
115-121: Annotate mutable class attributes with ClassVarSilence RUF012 and clarify intent.
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, ClassVar ... - config_class = KDAConfig - base_model_prefix = 'model' - supports_gradient_checkpointing = True - _no_split_modules = ['KDABlock'] - _supports_cache_class = True + config_class: ClassVar = KDAConfig + base_model_prefix: ClassVar[str] = 'model' + supports_gradient_checkpointing: ClassVar[bool] = True + _no_split_modules: ClassVar[List[str]] = ['KDABlock'] + _supports_cache_class: ClassVar[bool] = TrueAlso apply similarly for KDAForCausalLM._tied_weights_keys.
210-214: warnings.warn without stacklevelAdd stacklevel=2 to point at the caller.
- warnings.warn("`KDAModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`KDAModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2)
218-223: Long ValueError messages: acceptable but consider brevityStatic analyzer flags TRY003. Optional: shorten or extract to constant.
312-326: Prefer 'raise ... from None' and bare 'raise' to preserve traceIn generate(), chain the custom AttributeError with from None; re-raise with bare raise.
- except AttributeError as exception: + except AttributeError as exception: if 'past_key_values' in str(exception): - raise AttributeError( + raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from None else: - raise exception + raise
280-289: _tied_weights_keys should be ClassVarAlign with RUF012.
- _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys: ClassVar[List[str]] = ["lm_head.weight"]
384-386: Prefer tuple splat to concatenationMinor ergonomics per RUF005.
- output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (logits, *outputs[1:]) + return (loss, *output) if loss is not None else output
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (20)
benchmarks/ops/benchmark_kda.py(1 hunks)fla/layers/__init__.py(2 hunks)fla/layers/gated_deltanet.py(2 hunks)fla/layers/kda.py(1 hunks)fla/models/__init__.py(2 hunks)fla/models/kda/__init__.py(1 hunks)fla/models/kda/configuration_kda.py(1 hunks)fla/models/kda/modeling_kda.py(1 hunks)fla/ops/__init__.py(2 hunks)fla/ops/gla/chunk.py(7 hunks)fla/ops/kda/__init__.py(1 hunks)fla/ops/kda/chunk.py(1 hunks)fla/ops/kda/chunk_inter.py(1 hunks)fla/ops/kda/chunk_intra.py(1 hunks)fla/ops/kda/fused_recurrent.py(1 hunks)fla/ops/kda/gate.py(1 hunks)fla/ops/kda/naive.py(1 hunks)fla/ops/kda/wy_fast.py(1 hunks)tests/models/test_modeling_kda.py(1 hunks)tests/ops/test_kda.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (16)
fla/models/__init__.py (2)
fla/models/kda/configuration_kda.py (1)
KDAConfig(8-87)fla/models/kda/modeling_kda.py (2)
KDAForCausalLM(280-394)KDAModel(177-277)
fla/layers/__init__.py (1)
fla/layers/kda.py (1)
KDA(25-283)
fla/ops/kda/__init__.py (2)
fla/ops/kda/chunk.py (1)
chunk_kda(250-351)fla/ops/kda/fused_recurrent.py (1)
fused_recurrent_kda(11-114)
fla/ops/kda/fused_recurrent.py (2)
fla/ops/common/fused_recurrent.py (1)
fused_recurrent(540-569)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule(242-353)
fla/models/kda/__init__.py (2)
fla/models/kda/configuration_kda.py (1)
KDAConfig(8-87)fla/models/kda/modeling_kda.py (2)
KDAForCausalLM(280-394)KDAModel(177-277)
fla/ops/kda/chunk_intra.py (3)
fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(432-469)fla/ops/utils/index.py (1)
prepare_chunk_indices(116-121)fla/ops/utils/solve_tril.py (1)
solve_tril(337-384)
fla/layers/kda.py (6)
fla/layers/utils.py (2)
get_unpad_data(75-98)pad_input(176-197)fla/modules/fused_norm_gate.py (1)
FusedRMSNormGated(997-1058)fla/modules/convolution.py (1)
ShortConvolution(794-1009)fla/ops/kda/chunk.py (2)
chunk_kda(250-351)forward(186-218)fla/ops/kda/fused_recurrent.py (1)
fused_recurrent_kda(11-114)fla/ops/kda/gate.py (2)
fused_kda_gate(329-349)forward(304-314)
fla/ops/kda/chunk.py (9)
fla/modules/l2norm.py (3)
l2norm(266-271)l2norm_bwd(195-239)l2norm_fwd(149-192)fla/ops/common/chunk_delta_h.py (2)
chunk_gated_delta_rule_bwd_dhu(484-535)chunk_gated_delta_rule_fwd_h(435-481)fla/ops/common/chunk_o.py (1)
chunk_bwd_dv_local(582-628)fla/ops/gla/chunk.py (2)
chunk_gla_bwd_dA(879-908)chunk_gla_fwd_o_gk(842-876)fla/ops/kda/chunk_inter.py (1)
chunk_kda_bwd_dqkwg(141-188)fla/ops/kda/chunk_intra.py (2)
chunk_kda_bwd_intra(482-540)chunk_kda_fwd_intra(389-479)fla/ops/kda/wy_fast.py (2)
prepare_wy_repr_bwd(258-304)recompute_w_u_fwd(213-255)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(432-469)fla/utils.py (1)
input_guard(135-166)
fla/ops/kda/gate.py (3)
fla/utils.py (1)
input_guard(135-166)fla/layers/kda.py (1)
forward(170-283)fla/ops/kda/chunk.py (2)
forward(186-218)backward(223-246)
tests/models/test_modeling_kda.py (2)
fla/models/kda/configuration_kda.py (1)
KDAConfig(8-87)tests/models/test_modeling_base.py (2)
run_test_generation(67-126)run_test_model_forward_backward(27-61)
tests/ops/test_kda.py (5)
fla/ops/kda/chunk.py (2)
chunk_kda(250-351)backward(223-246)fla/ops/kda/fused_recurrent.py (1)
fused_recurrent_kda(11-114)fla/ops/kda/gate.py (3)
fused_kda_gate(329-349)kda_gate_ref(18-54)backward(318-326)fla/ops/kda/naive.py (2)
naive_chunk_kda(41-102)naive_recurrent_kda(9-38)fla/utils.py (1)
assert_close(82-93)
fla/ops/kda/chunk_inter.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(116-121)fla/utils.py (1)
check_shared_mem(445-451)
fla/ops/__init__.py (2)
fla/ops/kda/chunk.py (1)
chunk_kda(250-351)fla/ops/kda/fused_recurrent.py (1)
fused_recurrent_kda(11-114)
benchmarks/ops/benchmark_kda.py (4)
fla/ops/comba/chunk.py (1)
chunk_comba(235-342)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule(219-325)fla/ops/generalized_delta_rule/dplr/chunk.py (1)
chunk_dplr_delta_rule(267-359)fla/ops/kda/chunk.py (2)
chunk_kda(250-351)backward(223-246)
fla/ops/kda/wy_fast.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(116-121)fla/utils.py (1)
check_shared_mem(445-451)
fla/models/kda/modeling_kda.py (7)
fla/layers/attn.py (1)
Attention(33-174)fla/layers/kda.py (2)
KDA(25-283)forward(170-283)fla/models/kda/configuration_kda.py (1)
KDAConfig(8-87)fla/modules/fused_cross_entropy.py (1)
FusedCrossEntropyLoss(344-419)fla/modules/fused_linear_cross_entropy.py (1)
FusedLinearCrossEntropyLoss(493-567)fla/modules/mlp.py (1)
GatedMLP(26-69)fla/modules/layernorm.py (1)
RMSNorm(1064-1111)
🪛 Ruff (0.14.1)
fla/ops/kda/fused_recurrent.py
17-17: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
22-22: Unused function argument: kwargs
(ARG001)
90-93: Avoid specifying long messages outside the exception class
(TRY003)
95-98: Avoid specifying long messages outside the exception class
(TRY003)
fla/layers/kda.py
83-83: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
89-89: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
91-91: Unused method argument: kwargs
(ARG002)
116-119: Avoid specifying long messages outside the exception class
(TRY003)
121-123: Avoid specifying long messages outside the exception class
(TRY003)
126-129: Avoid specifying long messages outside the exception class
(TRY003)
176-176: Unused method argument: output_attentions
(ARG002)
fla/ops/kda/chunk.py
256-256: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
261-261: Unused function argument: kwargs
(ARG001)
328-331: Avoid specifying long messages outside the exception class
(TRY003)
333-336: Avoid specifying long messages outside the exception class
(TRY003)
fla/models/kda/configuration_kda.py
10-10: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
31-31: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
71-71: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Avoid specifying long messages outside the exception class
(TRY003)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
fla/models/kda/modeling_kda.py
118-118: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
174-174: Avoid specifying long messages outside the exception class
(TRY003)
201-201: Unused blanket noqa directive
Remove unused noqa directive
(RUF100)
211-211: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
222-222: Avoid specifying long messages outside the exception class
(TRY003)
282-282: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
317-323: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
317-323: Avoid specifying long messages outside the exception class
(TRY003)
325-325: Use raise without specifying exception name
Remove exception name
(TRY201)
385-385: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
386-386: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (20)
fla/ops/gla/chunk.py (7)
100-102: LGTM: Autotune configuration expansion.The addition of
num_stagesparameter to the autotune configuration appropriately expands the search space for performance tuning. The range[2, 3]is conservative and suitable for this kernel.
287-291: LGTM: Autotune configuration expansion.The addition of
num_stages in [2, 3, 4]appropriately expands the autotune search space. The broader range is justified given this kernel's complexity with multiple block size parameters (BK, BV).
367-369: LGTM: Autotune configuration expansion.The autotune configuration expansion with
num_stages in [2, 3, 4]is consistent with similar kernels and appropriate for this backward pass kernel.
503-505: LGTM: Autotune configuration expansion.The autotune configuration expansion is consistent with other backward pass kernels in this file.
554-558: LGTM: Autotune configuration expansion.The autotune configuration expansion is appropriate for this kernel. The use of
BK_LISTandBV_LISTwithnum_stages in [2, 3, 4]provides comprehensive tuning coverage.
1192-1192: LGTM: Dynamic chunk size calculation.The chunk size calculation now adapts to the sequence length while maintaining safe bounds (16-64). The logic correctly handles edge cases:
- Small sequences: bounded by minimum of 16
- Large sequences: bounded by maximum of 64
- Variable-length inputs: uses total concatenated length appropriately
1301-1301: LGTM: Documentation consistency fix.The docstring example now correctly uses the return variable names that match the actual function signature.
fla/ops/kda/chunk_inter.py (2)
98-106: Clarify expected shape/stride for hKernel treats h as (V, K) with strides (1, V) after offsetting by (i_tg, i_h). Please document the expected layout (e.g., h shaped [NT, H, V, K] and contiguous) to prevent subtle misalignment.
118-118: Sign of dwdw stores -b_dw. If w is defined as exp(g)·v-projected weights, the negative may be intentional. Please confirm sign convention against the forward definition to avoid flipped gradients.
fla/ops/kda/chunk_intra.py (1)
224-231: Varlen path indexing sanity checkKernel uses all = B*T (global T) for db layout, then overrides T per sequence. Likely intentional given db2 shape [NK, B, T, H], but please confirm no overflow for mixed-length batches.
fla/ops/__init__.py (1)
19-19: Public API surfacing for KDA — verifiedBoth
chunk_kdaandfused_recurrent_kdaare properly defined, re-exported fromfla/ops/kda/__init__.py, and correctly imported into the mainfla/opsnamespace. The public API is properly exposed with no issues.fla/layers/gated_deltanet.py (1)
284-285: Trailing comma additions confirmed validBoth
chunk_gated_delta_rule(fla/ops/gated_delta_rule/chunk.py:228) andfused_recurrent_gated_delta_rule(fla/ops/gated_delta_rule/fused_recurrent.py:253) defineuse_qk_l2norm_in_kernel: bool = False, so the kwarg is properly supported at the call sites. No runtime issues.fla/layers/__init__.py (1)
18-18: Expose KDA in layers: LGTMImport and all update look correct; consistent with other layer exports.
Also applies to: 49-49
fla/models/__init__.py (1)
19-19: Public model API updated: LGTMKDAConfig/KDAModel/KDAForCausalLM are properly re-exported.
Also applies to: 50-50
fla/ops/kda/__init__.py (1)
1-7: Clean ops surface: LGTMExporting chunk_kda and fused_recurrent_kda via all is straightforward and consistent.
tests/models/test_modeling_kda.py (1)
14-34: Forward/backward test parameters: LGTMCovers bf16 with/without l2warp and validates varlen path via run_test_model_forward_backward.
fla/ops/kda/fused_recurrent.py (1)
102-113: Shape contract: confirmgproduced upstream is[B, T, HV, K]You pass
gk=g. Ensurefused_kda_gatein fla/ops/kda/gate.py returns[B, T, HV, K]to matchfused_recurrent_gated_delta_rule’sgkrequirement. Otherwise, add a lightweight assertion ong.dim()and trailing sizes.fla/models/kda/__init__.py (1)
8-10: Review comment requires no actionThe
exist_okparameter is fully supported in transformers 4.53.0 across AutoConfig.register(), AutoModel.register(), and AutoModelForCausalLM.register(). The code at lines 8-10 is correct and will not produce import-time errors.fla/ops/kda/chunk.py (1)
183-219: Autograd wiring LGTMForward/backward save/restore and dtype handling are correct; L2-norm reversal is applied conditionally and safely.
Also applies to: 220-246
fla/ops/kda/wy_fast.py (1)
21-26: ****The
is_tf32_supportedvariable infla/utils.py(line 399) is a boolean value, not a function. It is computed once at module load time as:is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)The usage in
fla/ops/kda/wy_fast.py(line 25) is correct—it uses this boolean directly in the conditional expression without calling it. No refactoring is needed.Likely an incorrect or invalid review comment.
| BT = min(chunk_size, max(16, triton.next_power_of_2(T))) | ||
|
|
||
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
There was a problem hiding this comment.
Fix BT/indices mismatch and missing default scale (runtime bugs)
- When cu_seqlens is used, chunk_indices must be built with BT, not chunk_size.
- scale can be None but is multiplied in-kernel, causing a failure.
Apply:
- BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
-
- chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)🤖 Prompt for AI Agents
In fla/ops/kda/chunk_inter.py around lines 156 to 160, there are two runtime
bugs: when cu_seqlens is provided you must compute BT first and pass BT (not
chunk_size) into prepare_chunk_indices so chunk_indices aligns with the actual
block size, and scale may be None but is later multiplied inside the kernel —
set a default scale = 1.0 (or cast to the kernel expected dtype) before
launching the kernel so multiplication never receives None; update the call to
prepare_chunk_indices to use BT and ensure scale is normalized to a numeric
value prior to kernel launch.
| B, T, H, K = k.shape | ||
| assert K <= 256 | ||
| BT = chunk_size | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
||
| BC = min(16, BT) | ||
| NC = triton.cdiv(BT, BC) | ||
| BK = max(triton.next_power_of_2(K), 16) | ||
|
|
There was a problem hiding this comment.
Default scale to 1/sqrt(K) when None
Prevent None multiplication in kernels and match other KDA paths.
- BC = min(16, BT)
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+ BC = min(16, BT)🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra.py around lines 425 to 434, the variable scale can be
None and later gets multiplied in kernels which causes None propagation; set a
default numeric scale equal to 1/sqrt(K) when scale is None (compute using K
from k.shape) so subsequent math and kernel calls always receive a float;
implement this normalization immediately after K is derived to avoid None
multiplication and to match other KDA code paths.
| Args: | ||
| q (torch.Tensor): | ||
| queries of shape `[B, T, H, K]`. | ||
| k (torch.Tensor): | ||
| keys of shape `[B, T, H, K]`. | ||
| v (torch.Tensor): | ||
| values of shape `[B, T, HV, V]`. | ||
| GVA is applied if `HV > H`. | ||
| g (torch.Tensor): | ||
| g (decays) of shape `[B, T, HV]`. | ||
| beta (torch.Tensor): | ||
| betas of shape `[B, T, HV]`. | ||
| scale (Optional[float]): | ||
| Scale factor for the RetNet attention scores. | ||
| If not provided, it will default to `1 / sqrt(K)`. Default: `None`. | ||
| initial_state (Optional[torch.Tensor]): | ||
| Initial state of shape `[N, HV, K, V]` for `N` input sequences. | ||
| For equal-length input sequences, `N` equals the batch size `B`. | ||
| Default: `None`. | ||
| output_final_state (Optional[bool]): | ||
| Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. | ||
| use_qk_l2norm_in_kernel (Optional[bool]): | ||
| Whether to use L2 normalization in the kernel. Default: `False`. | ||
| cu_seqlens (torch.LongTensor): | ||
| Cumulative sequence lengths of shape `[N+1]` used for variable-length training, | ||
| consistent with the FlashAttention API. | ||
|
|
There was a problem hiding this comment.
Doc param shapes: clarify that g here is passed as gk (per‑K decay)
This wrapper forwards g as gk=..., implying shape [B, T, HV, K], not [B, T, HV]. Please update the arg docs to reflect this and avoid confusion with the base API where g (no K) and gk are distinct.
| @input_guard | ||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: | ||
| g, A = ctx.saved_tensors | ||
| head_k_dim = ctx.head_k_dim | ||
| beta = ctx.beta | ||
| threshold = ctx.threshold | ||
| g_bias = ctx.g_bias | ||
|
|
||
| grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold) | ||
| return grad_g, grad_A, None, grad_gbias, None, None |
There was a problem hiding this comment.
Same decorator-order bug in backward
Swap to keep backward a true staticmethod.
- @input_guard
- @staticmethod
+ @staticmethod
+ @input_guard
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @input_guard | |
| @staticmethod | |
| def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: | |
| g, A = ctx.saved_tensors | |
| head_k_dim = ctx.head_k_dim | |
| beta = ctx.beta | |
| threshold = ctx.threshold | |
| g_bias = ctx.g_bias | |
| grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold) | |
| return grad_g, grad_A, None, grad_gbias, None, None | |
| @staticmethod | |
| @input_guard | |
| def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]: | |
| g, A = ctx.saved_tensors | |
| head_k_dim = ctx.head_k_dim | |
| beta = ctx.beta | |
| threshold = ctx.threshold | |
| g_bias = ctx.g_bias | |
| grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold) | |
| return grad_g, grad_A, None, grad_gbias, None, None |
🤖 Prompt for AI Agents
In fla/ops/kda/gate.py around lines 316 to 326, the decorators for backward are
in the wrong order (input_guard is applied before staticmethod) which prevents
backward from being a true staticmethod; swap the two decorators so
@staticmethod appears first and @input_guard second, preserving the existing
method signature and return types.
| BT = 64 | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| CONST_TILING = 64 if check_shared_mem() else 32 | ||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | ||
|
|
There was a problem hiding this comment.
Hard-coded BT=64 in backward wrapper will desync from forward BT
prepare_wy_repr_bwd sets BT=64 regardless of A.shape[-1]. If forward used a different BT (e.g., 32/128), backward tiling and block_ptr strides become inconsistent, causing wrong results or OOB.
- BT = 64
+ BT = A.shape[-1]🤖 Prompt for AI Agents
In fla/ops/kda/wy_fast.py around lines 269 to 275, the backward wrapper
hard-codes BT = 64 which can mismatch the forward tiling and lead to incorrect
strides/OOB; change the logic to derive BT from the actual block size used by
forward (e.g., accept BT as an argument to the backward function or compute BT =
A.shape[-1] or the forward-provided tiling value), then use that BT for
prepare_chunk_indices and NT computation (NT = triton.cdiv(T, BT) if cu_seqlens
is None else len(chunk_indices)); ensure subsequent BK/BV computations still use
CONST_TILING but not overwrite BT so backward tiling and block_ptr strides match
forward.
There was a problem hiding this comment.
Actionable comments posted: 4
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/ops/benchmark_kda.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/ops/benchmark_kda.py (7)
fla/ops/comba/chunk.py (1)
chunk_comba(235-342)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule(219-325)fla/ops/generalized_delta_rule/dplr/chunk.py (1)
chunk_dplr_delta_rule(267-359)fla/ops/kda/chunk.py (2)
chunk_kda(250-351)backward(223-246)fla/modules/activations.py (1)
logsigmoid(201-202)fla/ops/utils/matmul.py (1)
sigmoid(160-162)fla/ops/kda/gate.py (1)
backward(318-326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
benchmarks/ops/benchmark_kda.py (2)
1-36: LGTM!The imports and benchmark configuration are well-structured and include all necessary dependencies for the KDA and related attention mechanism benchmarks.
56-139: Past review comments appear resolved.The conditional logic now uses a proper if/elif chain for provider selection, and the 'attn' provider is correctly listed in
line_valsat line 25. The previous issues mentioned in the past review comments have been addressed.
| # Set TMA environment variable based on provider | ||
| original_tma_env = os.environ.get('FLA_USE_TMA', '0') | ||
|
|
||
| if provider.endswith('_no_tma'): | ||
| os.environ['FLA_USE_TMA'] = '0' | ||
| provider_base = provider.replace('_no_tma', '') | ||
| else: | ||
| os.environ['FLA_USE_TMA'] = '1' | ||
| provider_base = provider |
There was a problem hiding this comment.
Wrap environment variable restoration in try-finally to ensure cleanup.
The TMA environment variable restoration at line 142 won't execute if an exception occurs during benchmarking. This could leave the environment in an inconsistent state for subsequent runs.
Apply this diff to ensure proper cleanup:
quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
- do = torch.randn(B, T, H, D, dtype=dtype, device=device)
- if provider_base == 'gdn':
+ try:
+ do = torch.randn(B, T, H, D, dtype=dtype, device=device)
+ if provider_base == 'gdn':
+ ...
+ elif provider_base == 'dplr':
+ ...
+ finally:
+ # Restore original TMA environment variable
+ os.environ['FLA_USE_TMA'] = original_tma_env
+
- # Restore original TMA environment variable
- os.environ['FLA_USE_TMA'] = original_tma_env
return resultsAlso applies to: 141-142
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 42 to 50, the code sets
FLA_USE_TMA based on the provider but does not guarantee restoration if an
exception occurs during benchmarking; wrap the benchmarking logic that follows
these lines in a try...finally block and move the restoration of
os.environ['FLA_USE_TMA'] (using the saved original_tma_env) into the finally so
the environment variable is always reset regardless of errors.
| elif provider_base == 'kda': | ||
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True) | ||
| beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True) | ||
| results = triton.testing.do_bench( | ||
| lambda: chunk_kda( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| beta=beta, | ||
| use_qk_l2norm_in_kernel=True, | ||
| )[0].backward(do), | ||
| quantiles=quantiles | ||
| ) |
There was a problem hiding this comment.
Fix critical shape mismatch for g tensor in KDA provider.
Line 108 creates g with shape [B, T, H, D], but chunk_kda expects g to have shape [B, T, H] according to its signature. This mismatch will cause runtime errors or incorrect computation.
Apply this diff to fix the shape:
- g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
+ g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)).requires_grad_(True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif provider_base == 'kda': | |
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True) | |
| beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True) | |
| results = triton.testing.do_bench( | |
| lambda: chunk_kda( | |
| q=q, | |
| k=k, | |
| v=v, | |
| g=g, | |
| beta=beta, | |
| use_qk_l2norm_in_kernel=True, | |
| )[0].backward(do), | |
| quantiles=quantiles | |
| ) | |
| elif provider_base == 'kda': | |
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)).requires_grad_(True) | |
| beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True) | |
| results = triton.testing.do_bench( | |
| lambda: chunk_kda( | |
| q=q, | |
| k=k, | |
| v=v, | |
| g=g, | |
| beta=beta, | |
| use_qk_l2norm_in_kernel=True, | |
| )[0].backward(do), | |
| quantiles=quantiles | |
| ) |
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 104 to 120, the tensor g is
created with shape [B, T, H, D] but chunk_kda expects g of shape [B, T, H];
change the g initializer to create a tensor of shape [B, T, H] (preserving
dtype, device and requires_grad_) and then apply F.logsigmoid to that tensor so
g has the correct shape and gradient behavior; no other changes to the
surrounding benchmark call are needed.
| elif provider_base == 'dplr': | ||
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | ||
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True) | ||
| beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True) | ||
| results = triton.testing.do_bench( | ||
| lambda: chunk_dplr_delta_rule( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| a=a, | ||
| b=b, | ||
| gk=g, | ||
| )[0].backward(do), | ||
| quantiles=quantiles | ||
| ) |
There was a problem hiding this comment.
Remove unused beta variable in DPLR provider.
Line 128 creates a beta tensor that is never used in the chunk_dplr_delta_rule call. The function signature doesn't accept a beta parameter, making this unnecessary computation that wastes resources during benchmarking.
Apply this diff to remove the unused variable:
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
- beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif provider_base == 'dplr': | |
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True) | |
| beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True) | |
| results = triton.testing.do_bench( | |
| lambda: chunk_dplr_delta_rule( | |
| q=q, | |
| k=k, | |
| v=v, | |
| a=a, | |
| b=b, | |
| gk=g, | |
| )[0].backward(do), | |
| quantiles=quantiles | |
| ) | |
| elif provider_base == 'dplr': | |
| q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True) | |
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True) | |
| results = triton.testing.do_bench( | |
| lambda: chunk_dplr_delta_rule( | |
| q=q, | |
| k=k, | |
| v=v, | |
| a=a, | |
| b=b, | |
| gk=g, | |
| )[0].backward(do), | |
| quantiles=quantiles | |
| ) |
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 121 to 139, the DPLR branch
creates an unused tensor `beta` at line 128 which is not passed to
chunk_dplr_delta_rule and wastes compute during benchmarks; remove the creation
of `beta` (and any associated .requires_grad_(True) call) and ensure the call to
chunk_dplr_delta_rule remains unchanged, then run tests/benchmarks to confirm no
references to `beta` remain.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
fla/ops/kda/fused_recurrent.py (1)
33-37: Doc/API mismatch for g (actually gk). Also accept 3D g via broadcast.This wrapper passes g to the underlying op as gk (per‑K decays), so the documented shape should be [B, T, HV, K], not [B, T, HV]. The example already uses [B, T, HV, K]. To avoid user error and support legacy callers, broadcast a 3D g to 4D.
Apply:
@@ - g (torch.Tensor): - g (decays) of shape `[B, T, HV]`. + g (torch.Tensor): + gk (per‑K decays) of shape `[B, T, HV, K]`. Note: if a 3D `[B, T, HV]` + tensor is provided, it will be broadcast across K. @@ if scale is None: scale = k.shape[-1] ** -0.5 + # Backward-compat: allow per‑HV decays by broadcasting across K + if g.dim() == 3: + g = g.unsqueeze(-1).expand(*g.shape, k.shape[-1]) @@ o, final_state = fused_recurrent_gated_delta_rule( q=q, k=k, v=v, - gk=g, + gk=g, beta=beta,Also applies to: 102-113
fla/layers/kda.py (1)
189-192: Guard past_key_values access when layer_idx is None.len(past_key_values) > self.layer_idx raises if layer_idx is None. Add an explicit guard.
- last_state = None - if past_key_values is not None and len(past_key_values) > self.layer_idx: - last_state = past_key_values[self.layer_idx] + last_state = None + if ( + past_key_values is not None + and self.layer_idx is not None + and isinstance(self.layer_idx, int) + and self.layer_idx >= 0 + and len(past_key_values) > self.layer_idx + ): + last_state = past_key_values[self.layer_idx]
🧹 Nitpick comments (8)
fla/models/kda/configuration_kda.py (2)
10-10: Annotate mutable class attribute withClassVar.The
keys_to_ignore_at_inferencelist is a mutable class attribute and should be annotated withtyping.ClassVarto make the intent explicit and improve type safety.Apply this diff:
+from typing import ClassVar, Dict, Optional -from typing import Dict, Optional - keys_to_ignore_at_inference = ['past_key_values'] + keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']
69-79: Consider extracting error messages to constants (optional).The validation logic is correct. For improved maintainability, you could extract the error messages to module-level constants, though this is a minor style point.
Example refactor (optional):
_ERR_ATTN_NOT_DICT = "attn must be a dictionary" _ERR_ATTN_NO_LAYERS = "Layer indices must be provided to initialize hybrid attention layers" _ERR_ATTN_NO_HEADS = "Number of heads must be provided to initialize hybrid attention layers" # Then in __init__: if not isinstance(attn, Dict): raise ValueError(_ERR_ATTN_NOT_DICT)fla/ops/kda/fused_recurrent.py (3)
21-23: Drop unused kwargs from signature.kwargs is not consumed. Removing it tightens the public API.
- cu_seqlens: Optional[torch.LongTensor] = None, - **kwargs + cu_seqlens: Optional[torch.LongTensor] = None,
4-7: Modernize type hints (PEP 604) and ensure runtime safety.Use T | None and add the future import to keep runtime compatibility.
+# from __future__ import annotations @@ -from typing import Optional, Tuple +from typing import Tuple @@ - beta: torch.Tensor = None, - scale: float = None, - initial_state: torch.Tensor = None, + beta: torch.Tensor | None = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None, @@ - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None,Also applies to: 15-21
88-98: Condense exception messages (TRY003).Shorten or factor long f-strings to reduce noise.
- if q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing." - ) + if q.shape[0] != 1: + raise ValueError(f"B must be 1 when using cu_seqlens; got {q.shape[0]}.") @@ - if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError( - f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." - ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"initial_state.shape[0] ({initial_state.shape[0]}) must equal len(cu_seqlens)-1 ({len(cu_seqlens)-1})." + )fla/layers/kda.py (3)
80-87: Use PEP 604 optionals in init signature.Avoid implicit Optional and match Ruff guidance.
- num_v_heads: int = None, + num_v_heads: int | None = None, @@ - layer_idx: int = None, + layer_idx: int | None = None,
174-175: kwargs typing and usage.Unpack[Dict] is unusual here and triggers tooling noise; kwargs is used only for cu_seqlens. Prefer a simpler annotation or none.
- **kwargs: Unpack[Dict] + **kwargs, @@ - cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')Also applies to: 193-196
185-188: Make the decode threshold a constant (optional).q_len <= 64 is embedded magic. Consider a class attribute (e.g., self.decode_threshold = 64) to make tuning easier.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/layers/kda.py(1 hunks)fla/models/kda/configuration_kda.py(1 hunks)fla/ops/kda/fused_recurrent.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/fused_recurrent.py (2)
fla/ops/common/fused_recurrent.py (1)
fused_recurrent(540-569)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule(242-353)
fla/layers/kda.py (6)
fla/layers/utils.py (2)
get_unpad_data(75-98)pad_input(176-197)fla/modules/fused_norm_gate.py (1)
FusedRMSNormGated(997-1058)fla/modules/convolution.py (1)
ShortConvolution(794-1009)fla/ops/kda/chunk.py (2)
chunk_kda(250-351)forward(186-218)fla/ops/kda/fused_recurrent.py (1)
fused_recurrent_kda(11-114)fla/ops/kda/gate.py (2)
fused_kda_gate(329-349)forward(304-314)
🪛 Ruff (0.14.1)
fla/ops/kda/fused_recurrent.py
17-17: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
22-22: Unused function argument: kwargs
(ARG001)
90-93: Avoid specifying long messages outside the exception class
(TRY003)
95-98: Avoid specifying long messages outside the exception class
(TRY003)
fla/layers/kda.py
79-79: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
85-85: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
87-87: Unused method argument: kwargs
(ARG002)
112-115: Avoid specifying long messages outside the exception class
(TRY003)
117-119: Avoid specifying long messages outside the exception class
(TRY003)
122-125: Avoid specifying long messages outside the exception class
(TRY003)
172-172: Unused method argument: output_attentions
(ARG002)
fla/models/kda/configuration_kda.py
10-10: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
31-31: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
71-71: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Avoid specifying long messages outside the exception class
(TRY003)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (2)
fla/models/kda/configuration_kda.py (2)
16-16: LGTM! Type hint corrected.The
expand_vparameter is now correctly typed asfloat, addressing the previous review feedback.
81-87: LGTM!Parent class initialization is correct and passes the appropriate parameters.
| # Consistency check: Ensure expand_v produces integer values | ||
| if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5): | ||
| raise ValueError( | ||
| f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. " | ||
| f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear." | ||
| ) |
There was a problem hiding this comment.
Error message wording: refers to key_dim but checks value_dim.
Clarify message to match the check.
- raise ValueError(
- f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
- f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
- )
+ raise ValueError(
+ f"expand_v={expand_v} must make num_v_heads*head_dim*expand_v an integer "
+ f"(got {self.num_v_heads * self.head_dim * expand_v}); invalid for nn.Linear."
+ )🧰 Tools
🪛 Ruff (0.14.1)
112-115: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In fla/layers/kda.py around lines 111 to 116, the raised ValueError message
incorrectly references key_dim while the check is validating the computed
value_dim; update the error text to reference value_dim (or resulting value
dimension) and include the computed numeric result and the expand_v input for
clarity. Change the f-string to mention value_dim and the actual computed value
(self.num_v_heads * self.head_dim * expand_v) and that it must be an integer
suitable for nn.Linear.
| if self.num_v_heads > self.num_heads: | ||
| q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k)) | ||
|
|
||
| if self.allow_neg_eigval: | ||
| beta = beta * 2. | ||
|
|
There was a problem hiding this comment.
Bug: g and beta are not expanded when num_v_heads > num_heads (HV>H).
q,k are repeated to HV but g ([B,T,H,K]) and beta ([B,T,H]) are not, causing shape mismatch for fused ops expecting per‑HV decays/betas. Broadcast them to HV.
if self.num_v_heads > self.num_heads:
- q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k))
+ expand = self.num_v_heads // self.num_heads
+ q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=expand), (q, k))
+ # match per‑HV shapes expected by ops
+ g = repeat(g, '... h k -> ... (h g_) k', g_=expand)
+ beta = repeat(beta, '... h -> ... (h g_)', g_=expand)Also applies to: 251-262
|
LGTM |
Summary by CodeRabbit
New Features
Tests