Skip to content

Add KDA#621

Merged
zhiyuan1i merged 90 commits intomainfrom
kda
Oct 27, 2025
Merged

Add KDA#621
zhiyuan1i merged 90 commits intomainfrom
kda

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Oct 27, 2025

Summary by CodeRabbit

  • New Features

    • Added a KDA model family (config, model, and causal LM) and a new KDA layer with configurable attention modes, gating, short-conv options, and generation caching.
    • Exposed KDA compute kernels and optimized ops with chunked and fused recurrent paths, including varlen support and autotuning.
    • Added a Triton-based benchmark for KDA and related attention variants.
  • Tests

    • Added comprehensive unit tests covering KDA models, ops, fused gate kernels, and end-to-end validation.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Oct 27, 2025

Warning

Rate limit exceeded

@zhiyuan1i has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 5 minutes and 32 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between eccea7e and 0511a50.

📒 Files selected for processing (4)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/kda.py (1 hunks)
  • fla/models/kda/configuration_kda.py (1 hunks)
  • fla/models/kda/modeling_kda.py (1 hunks)

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a complete KDA feature: new KDA layer, model family (config, model, causal LM), Triton-backed KDA ops (chunk/fused-recurrent/intra/inter/wy/gate), autograd wrappers, tests, a benchmark, and package exports.

Changes

Cohort / File(s) Summary
KDA Layer
fla/layers/kda.py, fla/layers/__init__.py
New KDA layer class with multi-head gating, optional short convs, chunk/fused_recurrent modes, state caching; exported via package all.
Model family
fla/models/kda/configuration_kda.py, fla/models/kda/modeling_kda.py, fla/models/kda/__init__.py, fla/models/__init__.py
Adds KDAConfig, KDAModel, KDAForCausalLM, KDABlock, KDAPreTrainedModel; registers with Auto* and exposes API.
Chunk ops (autograd + kernels)
fla/ops/kda/chunk.py, fla/ops/kda/chunk_inter.py, fla/ops/kda/chunk_intra.py
Implements chunk_kda autograd Function and wrapper; adds Triton-backed intra/inter kernels and Python launchers for forward/backward with varlen and autotuning.
Fused recurrent & references
fla/ops/kda/fused_recurrent.py, fla/ops/kda/naive.py
Adds fused_recurrent_kda wrapper (calls fused gated delta rule) and naive reference implementations (recurrent & chunked).
WY / recompute kernels
fla/ops/kda/wy_fast.py
Adds recompute/prepare WY Triton kernels and Python wrappers for fast W/U recompute and backward scaffolding.
Gate mechanism
fla/ops/kda/gate.py
Adds gated KDA kernels, reference gate, autograd Function (KDAGateFunction), and fused_kda_gate API.
Ops exports
fla/ops/kda/__init__.py, fla/ops/__init__.py
Exposes chunk_kda, fused_recurrent_kda, fused_kda_gate, and related KDA ops at package level.
GLA chunk autotune update
fla/ops/gla/chunk.py
Expands autotune configs (adds num_stages), small chunk-size calculation tweaks, and docstring/example adjustments.
Minor syntactic
fla/layers/gated_deltanet.py
Trailing comma added to argument lists (no semantic change).
Tests
tests/ops/test_kda.py, tests/models/test_modeling_kda.py
New unit tests covering naive vs optimized paths, varlen, forward/backward, gate numerics, model forward/backward and generation.
Benchmark
benchmarks/ops/benchmark_kda.py
New Triton-based benchmark script exercising providers (gdn,comba,kda,dplr,attn) across sequence lengths and reporting quantile timings.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant KDA_Layer
    participant Projections
    participant KDA_Op
    participant Triton_Kernels
    participant Output

    Client->>KDA_Layer: forward(hidden_states, attention_mask?, past_key_values?)
    KDA_Layer->>Projections: compute q, k, v, g, beta
    Projections-->>KDA_Layer: q,k,v,g,beta
    alt fused_recurrent path
        KDA_Layer->>KDA_Op: fused_recurrent_kda(q,k,v,g,beta,...)
    else chunked path
        KDA_Layer->>KDA_Op: chunk_kda(q,k,v,g,beta,...)
    end
    KDA_Op->>Triton_Kernels: launch kernels (intra/inter/wy/gate as needed)
    Triton_Kernels-->>KDA_Op: o (and optional final_state)
    KDA_Op-->>KDA_Layer: o, final_state?
    KDA_Layer->>KDA_Layer: output projection, norm, update past_key_values
    KDA_Layer-->>Output: (o, None, updated_past_key_values)
    Output-->>Client: return
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Focus areas:
    • Triton kernels (chunk_intra/inter, wy_fast, gate): indexing, memory bounds, autotune params, varlen branches.
    • Autograd wrappers (ChunkKDAFunction, KDAGateFunction): ctx saves, dtypes, backward math and memory lifetimes.
    • Model integration (modeling_kda.py): past_key_values layout, generate() compatibility, fused loss/prenorm_residual logic.
    • Tests/benchmarks: determinism, tolerances, env-gated behavior (TMA, varlen).

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰
I hopped through kernels, triton-lit and bright,
Counting betas, gates, and chunks by night.
From naive loops to fused recurrent song,
I nibbled bugs and hopped the code along.
Now benchmarks hum — the rabbit dances light!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "Add KDA" is fully aligned with the changeset's primary objective. The PR introduces comprehensive KDA (Kimi Delta Attention) support across the entire codebase, including a new layer implementation (fla/layers/kda.py), complete model family (KDAConfig, KDAModel, KDAForCausalLM), operations (chunk, fused_recurrent, gate, naive, wy_fast), benchmarking utilities, and test suites. The title is concise, specific, and avoids vague terminology, making it immediately clear to developers reviewing the history what the primary change entails.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzhangcs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the Kernelized Delta Attention (KDA) mechanism, a new approach to attention that promises improved efficiency. It includes the full implementation of KDA as a PyTorch layer, along with highly optimized Triton kernels for its core operations. The changes also integrate KDA into the Hugging Face Transformers library, making it readily available for model building and experimentation. Comprehensive tests ensure the correctness and performance of the new components.

Highlights

  • New Attention Mechanism: KDA: Introduced Kernelized Delta Attention (KDA) as a novel attention mechanism, designed for efficient and scalable sequence processing.
  • Optimized Implementations: Provided highly optimized Triton kernels for KDA's chunk and fused_recurrent modes, covering intra-chunk, inter-chunk, gating, and WY representation computations for enhanced performance.
  • Hugging Face Transformers Integration: Integrated KDA into the Hugging Face Transformers ecosystem with dedicated KDAConfig, KDAModel, and KDAForCausalLM classes, allowing for seamless use within existing model architectures.
  • Comprehensive Testing: Added extensive unit tests for KDA operations and the KDA model, including comparisons against naive implementations and validation for various scenarios like variable-length sequences.
  • Layer and Operation Exposure: Exposed the new KDA layer (fla.layers.kda.KDA) and its core operations (fla.ops.kda.chunk_kda, fla.ops.kda.fused_recurrent_kda) within the framework's API.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the KDA (Kernelized Delta Attention) model, including its layers, operations, model configuration, and benchmarks. The implementation is comprehensive, adding a new attention mechanism to the library.

My review focuses on correctness, consistency, and clarity. I've found a few issues:

  • A logical bug in the benchmark script that would lead to incorrect results.
  • Several inconsistencies in docstrings, type hints, and parameter initializations which could cause confusion.
  • Minor issues like typos and incorrect examples in docstrings.

Overall, the core implementation of KDA seems solid, but the surrounding code and documentation need some polishing. Please see my detailed comments for suggestions.

Comment thread benchmarks/ops/benchmark_kda.py
Comment thread benchmarks/ops/benchmark_kda.py Outdated
Comment thread fla/layers/kda.py
Comment thread fla/layers/kda.py Outdated
Comment thread fla/models/kda/configuration_kda.py Outdated
):
if isinstance(module, KDA) and next(module.parameters()).device.type != 'meta':
with torch.no_grad():
module.A.copy_(nn.init.uniform_(module.A, a=1, b=16).log())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The A parameter of the KDA module is being initialized here, but it's already initialized within the KDA class's __init__ method. This redundant initialization can be confusing and should be removed to keep the initialization logic in one place.

Comment thread fla/ops/kda/fused_recurrent.py Outdated
Comment thread tests/ops/test_kda.py
g = F.logsigmoid(torch.randn(1, T, H, D, dtype=torch.float))
g = g * (torch.rand_like(g) > mask_p)
beta = torch.rand(1, T, H, dtype=dtype).sigmoid()
h0 = torch.randn((N, H, D, D), dtype=dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initial state h0 is created with dtype=dtype, which can be torch.float16. However, recurrent states are typically expected to be in torch.float32 for precision. In other tests like test_chunk, h0 is explicitly created with torch.float32. Please ensure consistency and use torch.float32 for h0 to avoid potential precision issues.

Suggested change
h0 = torch.randn((N, H, D, D), dtype=dtype)
h0 = torch.randn((N, H, D, D), dtype=torch.float32)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (24)
fla/ops/kda/chunk_inter.py (1)

161-165: Gradient dtype choices — align with input dtypes

Allocating dq/dk as float32 may surprise downstream if bf16/half gradients are expected. Recommend matching input dtypes (kernel still accumulates in fp32 and casts on store).

-    dq = torch.empty_like(q, dtype=torch.float)
-    dk = torch.empty_like(k, dtype=torch.float)
-    dw = torch.empty_like(w)
-    dg = torch.empty_like(g)
+    dq = torch.empty_like(q)
+    dk = torch.empty_like(k)
+    dw = torch.empty_like(w)
+    dg = torch.empty_like(g)

Confirm tests expect dq/dk to match q/k dtypes; if a global policy prefers fp32 grads, we can keep current behavior and cast at call sites.

fla/models/kda/configuration_kda.py (3)

9-11: Type class attributes as ClassVar

Improves clarity and satisfies linters.

-from typing import Dict, Optional
+from typing import Dict, Optional, ClassVar, List
@@
-class KDAConfig(PretrainedConfig):
-    model_type = 'kda'
-    keys_to_ignore_at_inference = ['past_key_values']
+class KDAConfig(PretrainedConfig):
+    model_type: ClassVar[str] = 'kda'
+    keys_to_ignore_at_inference: ClassVar[List[str]] = ['past_key_values']

29-36: Fix Optional typing and defaults

Avoid implicit Optional; also align annotation with default value.

-        hidden_ratio: Optional[int] = 4,
+        hidden_ratio: Optional[int] = 4,
@@
-        pad_token_id: int = None,
+        pad_token_id: Optional[int] = None,

Optionally prefer PEP 604:

-        pad_token_id: Optional[int] = None,
+        pad_token_id: int | None = None,

69-80: Avoid mutating user-supplied attn dict in-place

Copy before filling defaults.

-        if attn is not None:
+        if attn is not None:
+            attn = dict(attn)  # defensive copy
             if not isinstance(attn, Dict):
                 raise ValueError("attn must be a dictionary")
fla/ops/kda/chunk_intra.py (2)

431-437: Optional: tighten BK tiling for consistency

Not a blocker. Consider keeping BK within [16, 256] and power-of-two.

-    BK = max(triton.next_power_of_2(K), 16)
+    BK = max(16, min(256, triton.next_power_of_2(K)))

482-540: Return values and in-place semantics

dq/dk/db/dg are updated in-place but dq/dk are reassigned to *_2 buffers for return. Clarify this in the docstring and ensure callers don’t rely on preallocated tensors being filled.

tests/models/test_modeling_kda.py (1)

39-56: Generation test: add a couple more toggles

Optional: add cases for allow_neg_eigval=True and num_v_heads > num_heads (GVA) to exercise those branches.

fla/ops/kda/fused_recurrent.py (3)

11-23: Type hints: avoid implicit Optional and remove unused kwargs

  • Replace implicit optionals (T = None) with union types (T | None) to satisfy RUF013.
  • Drop the unused **kwargs (ARG001) or rename to **_ if intentional.
-def fused_recurrent_kda(
-    q: torch.Tensor,
-    k: torch.Tensor,
-    v: torch.Tensor,
-    g: torch.Tensor,
-    beta: torch.Tensor = None,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
-    output_final_state: bool = False,
-    use_qk_l2norm_in_kernel: bool = False,
-    cu_seqlens: Optional[torch.LongTensor] = None,
-    **kwargs
-) -> Tuple[torch.Tensor, torch.Tensor]:
+def fused_recurrent_kda(
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    g: torch.Tensor,
+    beta: torch.Tensor | None = None,
+    scale: float | None = None,
+    initial_state: torch.Tensor | None = None,
+    output_final_state: bool = False,
+    use_qk_l2norm_in_kernel: bool = False,
+    cu_seqlens: torch.LongTensor | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:

58-86: Docs: fix example function name and g shape

Examples call a non‑existent fused_gated_recurrent_delta_rule. It should reference this wrapper or the actual callee. Also ensure g matches [B, T, HV, K] per this wrapper.

-        >>> g = F.logsigmoid(torch.rand(B, T, HV, K, device='cuda'))
+        >>> g = F.logsigmoid(torch.rand(B, T, HV, K, device='cuda'))
         >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
         >>> h0 = torch.randn(B, HV, K, V, device='cuda')
-        >>> o, ht = fused_gated_recurrent_delta_rule(
-            q, k, v, g, beta,
-            initial_state=h0,
-            output_final_state=True
-        )
+        >>> o, ht = fused_recurrent_kda(
+        ...     q, k, v, g, beta,
+        ...     initial_state=h0,
+        ...     output_final_state=True,
+        ... )
@@
-        >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
-            q, k, v, g, beta,
-            initial_state=h0,
-            output_final_state=True,
-            cu_seqlens=cu_seqlens
-        )
+        >>> o_var, ht_var = fused_recurrent_kda(
+        ...     q, k, v, g, beta,
+        ...     initial_state=h0,
+        ...     output_final_state=True,
+        ...     cu_seqlens=cu_seqlens,
+        ... )

88-101: Minor lint: long exception messages (TRY003)

Optional: store long f-strings in variables before raising to appease TRY003, or keep as-is if TRY rules aren’t enforced CI-wide.

tests/ops/test_kda.py (1)

340-376: Minor test hygiene: avoid retain_graph=True where unnecessary

Both backward passes in test_kda_gate re-run separate graphs; retain_graph=True isn’t needed and can slightly slow tests. Safe to drop it.

benchmarks/ops/benchmark_kda.py (1)

42-51: Benchmark hygiene: fix TMA restore, include 'attn', remove unused var, sort quantiles

  • Preserve original FLA_USE_TMA presence (unset vs set).
  • Add 'attn' to provider lists so its path is exercised.
  • Remove unused p in kda branch.
  • Use ascending quantiles for clarity.
@@
-    original_tma_env = os.environ.get('FLA_USE_TMA', '0')
+    original_tma_env = os.environ.get('FLA_USE_TMA')
@@
-    quantiles = [0.5, 0.2, 0.8]
+    quantiles = [0.2, 0.5, 0.8]
@@
-        line_vals=['gdn', 'comba', 'kda', 'dplr'],
+        line_vals=['gdn', 'comba', 'kda', 'dplr', 'attn'],
@@
-        line_names=['gdn', 'comba', 'kda', 'dplr'],
+        line_names=['gdn', 'comba', 'kda', 'dplr', 'attn'],
@@
-        styles=[('blue', '-'), ('red', '-.'), ('green', '-'), ('orange', '-.'),
-                ('purple', '-'), ('brown', '-.'), ('pink', '-'), ('gray', '-.')],
+        styles=[('blue', '-'), ('red', '-.'), ('green', '-'), ('orange', '-.'),
+                ('purple', '-')],
@@
-        p = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
         v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
@@
-    os.environ['FLA_USE_TMA'] = original_tma_env
+    if original_tma_env is None:
+        os.environ.pop('FLA_USE_TMA', None)
+    else:
+        os.environ['FLA_USE_TMA'] = original_tma_env

Also applies to: 52-55, 86-105, 105-123, 143-145

fla/ops/kda/naive.py (1)

64-78: Reduce transient allocations in naive_chunk_kda

Cloning and allocating A inside loops is heavy. Using in-place ops where safe and preallocating outside inner loops can reduce overhead without changing semantics.

Also applies to: 82-100

fla/ops/kda/chunk.py (1)

249-262: Type hints and unused kwargs cleanup (Ruff RUF013, ARG001)

Prefer explicit Optional typing and drop unused kwargs to quiet linters and clarify API.

 @torch.compiler.disable
 def chunk_kda(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
     beta: torch.Tensor,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
     use_qk_l2norm_in_kernel: bool = False,
     cu_seqlens: Optional[torch.LongTensor] = None,
-    **kwargs
 ):
@@
-    if scale is None:
+    if scale is None:
         scale = k.shape[-1] ** -0.5

Also applies to: 326-339

fla/layers/kda.py (1)

77-93: Constructor typing and message polish

Use Optional typing for None defaults and fix assertion message typo.

-        num_v_heads: int = None,
+        num_v_heads: Optional[int] = None,
@@
-        layer_idx: int = None,
+        layer_idx: Optional[int] = None,
@@
-        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
+        assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."

Also applies to: 114-130

fla/ops/kda/wy_fast.py (1)

227-234: Varlen NT calculation with chunk_indices length: confirm shape

len(chunk_indices) assumes shape [NT, 2] from prepare_chunk_indices; OK if stable. Add a brief assert to guard against accidental shape regressions.

-    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+    NT = triton.cdiv(T, BT) if cu_seqlens is None else chunk_indices.shape[0]
fla/ops/kda/gate.py (2)

237-247: Keep output dtype consistent with input g (avoid forced fp32 allocation)

y is allocated in fp32, which may force extra casts and memory. Prefer g.dtype for output tensor and let kernels upcast internally as needed.

-    y = torch.empty_like(g, dtype=torch.float32)
+    y = torch.empty_like(g)

If you want fp32 compute but original dtype output, keep current internal casts and only change the allocation.


18-55: Reference path: ensure output dtype mirrors input

kda_gate_ref casts to float32 for math and returns float32, which can diverge from fused path if you adopt the dtype change above. Return to g.dtype for parity.

-    A_exp = -A.float().exp().unsqueeze(-1)  # [H, 1]
-    g_softplus = F.softplus(g.float(), beta, threshold)      # [..., H, D]
-
-    return A_exp * g_softplus
+    A_exp = -A.float().exp().unsqueeze(-1)                   # [H, 1]
+    g_softplus = F.softplus(g.float(), beta, threshold)      # [..., H, D]
+    out = A_exp * g_softplus
+    return out.to(g.dtype)
fla/models/kda/modeling_kda.py (6)

115-121: Annotate mutable class attributes with ClassVar

Silence RUF012 and clarify intent.

-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, ClassVar
...
-    config_class = KDAConfig
-    base_model_prefix = 'model'
-    supports_gradient_checkpointing = True
-    _no_split_modules = ['KDABlock']
-    _supports_cache_class = True
+    config_class: ClassVar = KDAConfig
+    base_model_prefix: ClassVar[str] = 'model'
+    supports_gradient_checkpointing: ClassVar[bool] = True
+    _no_split_modules: ClassVar[List[str]] = ['KDABlock']
+    _supports_cache_class: ClassVar[bool] = True

Also apply similarly for KDAForCausalLM._tied_weights_keys.


210-214: warnings.warn without stacklevel

Add stacklevel=2 to point at the caller.

-            warnings.warn("`KDAModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`KDAModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2)

218-223: Long ValueError messages: acceptable but consider brevity

Static analyzer flags TRY003. Optional: shorten or extract to constant.


312-326: Prefer 'raise ... from None' and bare 'raise' to preserve trace

In generate(), chain the custom AttributeError with from None; re-raise with bare raise.

-        except AttributeError as exception:
+        except AttributeError as exception:
             if 'past_key_values' in str(exception):
-                raise AttributeError(
+                raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from None
             else:
-                raise exception
+                raise

280-289: _tied_weights_keys should be ClassVar

Align with RUF012.

-    _tied_weights_keys = ["lm_head.weight"]
+    _tied_weights_keys: ClassVar[List[str]] = ["lm_head.weight"]

384-386: Prefer tuple splat to concatenation

Minor ergonomics per RUF005.

-            output = (logits,) + outputs[1:]
-            return (loss,) + output if loss is not None else output
+            output = (logits, *outputs[1:])
+            return (loss, *output) if loss is not None else output
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 38cc619 and 26fe007.

📒 Files selected for processing (20)
  • benchmarks/ops/benchmark_kda.py (1 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/gated_deltanet.py (2 hunks)
  • fla/layers/kda.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/kda/__init__.py (1 hunks)
  • fla/models/kda/configuration_kda.py (1 hunks)
  • fla/models/kda/modeling_kda.py (1 hunks)
  • fla/ops/__init__.py (2 hunks)
  • fla/ops/gla/chunk.py (7 hunks)
  • fla/ops/kda/__init__.py (1 hunks)
  • fla/ops/kda/chunk.py (1 hunks)
  • fla/ops/kda/chunk_inter.py (1 hunks)
  • fla/ops/kda/chunk_intra.py (1 hunks)
  • fla/ops/kda/fused_recurrent.py (1 hunks)
  • fla/ops/kda/gate.py (1 hunks)
  • fla/ops/kda/naive.py (1 hunks)
  • fla/ops/kda/wy_fast.py (1 hunks)
  • tests/models/test_modeling_kda.py (1 hunks)
  • tests/ops/test_kda.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (16)
fla/models/__init__.py (2)
fla/models/kda/configuration_kda.py (1)
  • KDAConfig (8-87)
fla/models/kda/modeling_kda.py (2)
  • KDAForCausalLM (280-394)
  • KDAModel (177-277)
fla/layers/__init__.py (1)
fla/layers/kda.py (1)
  • KDA (25-283)
fla/ops/kda/__init__.py (2)
fla/ops/kda/chunk.py (1)
  • chunk_kda (250-351)
fla/ops/kda/fused_recurrent.py (1)
  • fused_recurrent_kda (11-114)
fla/ops/kda/fused_recurrent.py (2)
fla/ops/common/fused_recurrent.py (1)
  • fused_recurrent (540-569)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (242-353)
fla/models/kda/__init__.py (2)
fla/models/kda/configuration_kda.py (1)
  • KDAConfig (8-87)
fla/models/kda/modeling_kda.py (2)
  • KDAForCausalLM (280-394)
  • KDAModel (177-277)
fla/ops/kda/chunk_intra.py (3)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (432-469)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (116-121)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (337-384)
fla/layers/kda.py (6)
fla/layers/utils.py (2)
  • get_unpad_data (75-98)
  • pad_input (176-197)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormGated (997-1058)
fla/modules/convolution.py (1)
  • ShortConvolution (794-1009)
fla/ops/kda/chunk.py (2)
  • chunk_kda (250-351)
  • forward (186-218)
fla/ops/kda/fused_recurrent.py (1)
  • fused_recurrent_kda (11-114)
fla/ops/kda/gate.py (2)
  • fused_kda_gate (329-349)
  • forward (304-314)
fla/ops/kda/chunk.py (9)
fla/modules/l2norm.py (3)
  • l2norm (266-271)
  • l2norm_bwd (195-239)
  • l2norm_fwd (149-192)
fla/ops/common/chunk_delta_h.py (2)
  • chunk_gated_delta_rule_bwd_dhu (484-535)
  • chunk_gated_delta_rule_fwd_h (435-481)
fla/ops/common/chunk_o.py (1)
  • chunk_bwd_dv_local (582-628)
fla/ops/gla/chunk.py (2)
  • chunk_gla_bwd_dA (879-908)
  • chunk_gla_fwd_o_gk (842-876)
fla/ops/kda/chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg (141-188)
fla/ops/kda/chunk_intra.py (2)
  • chunk_kda_bwd_intra (482-540)
  • chunk_kda_fwd_intra (389-479)
fla/ops/kda/wy_fast.py (2)
  • prepare_wy_repr_bwd (258-304)
  • recompute_w_u_fwd (213-255)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (432-469)
fla/utils.py (1)
  • input_guard (135-166)
fla/ops/kda/gate.py (3)
fla/utils.py (1)
  • input_guard (135-166)
fla/layers/kda.py (1)
  • forward (170-283)
fla/ops/kda/chunk.py (2)
  • forward (186-218)
  • backward (223-246)
tests/models/test_modeling_kda.py (2)
fla/models/kda/configuration_kda.py (1)
  • KDAConfig (8-87)
tests/models/test_modeling_base.py (2)
  • run_test_generation (67-126)
  • run_test_model_forward_backward (27-61)
tests/ops/test_kda.py (5)
fla/ops/kda/chunk.py (2)
  • chunk_kda (250-351)
  • backward (223-246)
fla/ops/kda/fused_recurrent.py (1)
  • fused_recurrent_kda (11-114)
fla/ops/kda/gate.py (3)
  • fused_kda_gate (329-349)
  • kda_gate_ref (18-54)
  • backward (318-326)
fla/ops/kda/naive.py (2)
  • naive_chunk_kda (41-102)
  • naive_recurrent_kda (9-38)
fla/utils.py (1)
  • assert_close (82-93)
fla/ops/kda/chunk_inter.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (116-121)
fla/utils.py (1)
  • check_shared_mem (445-451)
fla/ops/__init__.py (2)
fla/ops/kda/chunk.py (1)
  • chunk_kda (250-351)
fla/ops/kda/fused_recurrent.py (1)
  • fused_recurrent_kda (11-114)
benchmarks/ops/benchmark_kda.py (4)
fla/ops/comba/chunk.py (1)
  • chunk_comba (235-342)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (219-325)
fla/ops/generalized_delta_rule/dplr/chunk.py (1)
  • chunk_dplr_delta_rule (267-359)
fla/ops/kda/chunk.py (2)
  • chunk_kda (250-351)
  • backward (223-246)
fla/ops/kda/wy_fast.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (116-121)
fla/utils.py (1)
  • check_shared_mem (445-451)
fla/models/kda/modeling_kda.py (7)
fla/layers/attn.py (1)
  • Attention (33-174)
fla/layers/kda.py (2)
  • KDA (25-283)
  • forward (170-283)
fla/models/kda/configuration_kda.py (1)
  • KDAConfig (8-87)
fla/modules/fused_cross_entropy.py (1)
  • FusedCrossEntropyLoss (344-419)
fla/modules/fused_linear_cross_entropy.py (1)
  • FusedLinearCrossEntropyLoss (493-567)
fla/modules/mlp.py (1)
  • GatedMLP (26-69)
fla/modules/layernorm.py (1)
  • RMSNorm (1064-1111)
🪛 Ruff (0.14.1)
fla/ops/kda/fused_recurrent.py

17-17: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


22-22: Unused function argument: kwargs

(ARG001)


90-93: Avoid specifying long messages outside the exception class

(TRY003)


95-98: Avoid specifying long messages outside the exception class

(TRY003)

fla/layers/kda.py

83-83: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


89-89: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


91-91: Unused method argument: kwargs

(ARG002)


116-119: Avoid specifying long messages outside the exception class

(TRY003)


121-123: Avoid specifying long messages outside the exception class

(TRY003)


126-129: Avoid specifying long messages outside the exception class

(TRY003)


176-176: Unused method argument: output_attentions

(ARG002)

fla/ops/kda/chunk.py

256-256: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


261-261: Unused function argument: kwargs

(ARG001)


328-331: Avoid specifying long messages outside the exception class

(TRY003)


333-336: Avoid specifying long messages outside the exception class

(TRY003)

fla/models/kda/configuration_kda.py

10-10: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


31-31: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


71-71: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Avoid specifying long messages outside the exception class

(TRY003)


75-75: Avoid specifying long messages outside the exception class

(TRY003)

fla/models/kda/modeling_kda.py

118-118: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


174-174: Avoid specifying long messages outside the exception class

(TRY003)


201-201: Unused blanket noqa directive

Remove unused noqa directive

(RUF100)


211-211: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


222-222: Avoid specifying long messages outside the exception class

(TRY003)


282-282: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


317-323: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


317-323: Avoid specifying long messages outside the exception class

(TRY003)


325-325: Use raise without specifying exception name

Remove exception name

(TRY201)


385-385: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


386-386: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (20)
fla/ops/gla/chunk.py (7)

100-102: LGTM: Autotune configuration expansion.

The addition of num_stages parameter to the autotune configuration appropriately expands the search space for performance tuning. The range [2, 3] is conservative and suitable for this kernel.


287-291: LGTM: Autotune configuration expansion.

The addition of num_stages in [2, 3, 4] appropriately expands the autotune search space. The broader range is justified given this kernel's complexity with multiple block size parameters (BK, BV).


367-369: LGTM: Autotune configuration expansion.

The autotune configuration expansion with num_stages in [2, 3, 4] is consistent with similar kernels and appropriate for this backward pass kernel.


503-505: LGTM: Autotune configuration expansion.

The autotune configuration expansion is consistent with other backward pass kernels in this file.


554-558: LGTM: Autotune configuration expansion.

The autotune configuration expansion is appropriate for this kernel. The use of BK_LIST and BV_LIST with num_stages in [2, 3, 4] provides comprehensive tuning coverage.


1192-1192: LGTM: Dynamic chunk size calculation.

The chunk size calculation now adapts to the sequence length while maintaining safe bounds (16-64). The logic correctly handles edge cases:

  • Small sequences: bounded by minimum of 16
  • Large sequences: bounded by maximum of 64
  • Variable-length inputs: uses total concatenated length appropriately

1301-1301: LGTM: Documentation consistency fix.

The docstring example now correctly uses the return variable names that match the actual function signature.

fla/ops/kda/chunk_inter.py (2)

98-106: Clarify expected shape/stride for h

Kernel treats h as (V, K) with strides (1, V) after offsetting by (i_tg, i_h). Please document the expected layout (e.g., h shaped [NT, H, V, K] and contiguous) to prevent subtle misalignment.


118-118: Sign of dw

dw stores -b_dw. If w is defined as exp(g)·v-projected weights, the negative may be intentional. Please confirm sign convention against the forward definition to avoid flipped gradients.

fla/ops/kda/chunk_intra.py (1)

224-231: Varlen path indexing sanity check

Kernel uses all = B*T (global T) for db layout, then overrides T per sequence. Likely intentional given db2 shape [NK, B, T, H], but please confirm no overflow for mixed-length batches.

fla/ops/__init__.py (1)

19-19: Public API surfacing for KDA — verified

Both chunk_kda and fused_recurrent_kda are properly defined, re-exported from fla/ops/kda/__init__.py, and correctly imported into the main fla/ops namespace. The public API is properly exposed with no issues.

fla/layers/gated_deltanet.py (1)

284-285: Trailing comma additions confirmed valid

Both chunk_gated_delta_rule (fla/ops/gated_delta_rule/chunk.py:228) and fused_recurrent_gated_delta_rule (fla/ops/gated_delta_rule/fused_recurrent.py:253) define use_qk_l2norm_in_kernel: bool = False, so the kwarg is properly supported at the call sites. No runtime issues.

fla/layers/__init__.py (1)

18-18: Expose KDA in layers: LGTM

Import and all update look correct; consistent with other layer exports.

Also applies to: 49-49

fla/models/__init__.py (1)

19-19: Public model API updated: LGTM

KDAConfig/KDAModel/KDAForCausalLM are properly re-exported.

Also applies to: 50-50

fla/ops/kda/__init__.py (1)

1-7: Clean ops surface: LGTM

Exporting chunk_kda and fused_recurrent_kda via all is straightforward and consistent.

tests/models/test_modeling_kda.py (1)

14-34: Forward/backward test parameters: LGTM

Covers bf16 with/without l2warp and validates varlen path via run_test_model_forward_backward.

fla/ops/kda/fused_recurrent.py (1)

102-113: Shape contract: confirm g produced upstream is [B, T, HV, K]

You pass gk=g. Ensure fused_kda_gate in fla/ops/kda/gate.py returns [B, T, HV, K] to match fused_recurrent_gated_delta_rule’s gk requirement. Otherwise, add a lightweight assertion on g.dim() and trailing sizes.

fla/models/kda/__init__.py (1)

8-10: Review comment requires no action

The exist_ok parameter is fully supported in transformers 4.53.0 across AutoConfig.register(), AutoModel.register(), and AutoModelForCausalLM.register(). The code at lines 8-10 is correct and will not produce import-time errors.

fla/ops/kda/chunk.py (1)

183-219: Autograd wiring LGTM

Forward/backward save/restore and dtype handling are correct; L2-norm reversal is applied conditionally and safely.

Also applies to: 220-246

fla/ops/kda/wy_fast.py (1)

21-26: ****

The is_tf32_supported variable in fla/utils.py (line 399) is a boolean value, not a function. It is computed once at module load time as:

is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)

The usage in fla/ops/kda/wy_fast.py (line 25) is correct—it uses this boolean directly in the conditional expression without calling it. No refactoring is needed.

Likely an incorrect or invalid review comment.

Comment thread fla/layers/kda.py
Comment on lines +156 to +160
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))

chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix BT/indices mismatch and missing default scale (runtime bugs)

  • When cu_seqlens is used, chunk_indices must be built with BT, not chunk_size.
  • scale can be None but is multiplied in-kernel, causing a failure.

Apply:

-    BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
-
-    chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
+    BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
+    if scale is None:
+        scale = k.shape[-1] ** -0.5
+    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
     NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_inter.py around lines 156 to 160, there are two runtime
bugs: when cu_seqlens is provided you must compute BT first and pass BT (not
chunk_size) into prepare_chunk_indices so chunk_indices aligns with the actual
block size, and scale may be None but is later multiplied inside the kernel —
set a default scale = 1.0 (or cast to the kernel expected dtype) before
launching the kernel so multiplication never receives None; update the call to
prepare_chunk_indices to use BT and ensure scale is normalized to a numeric
value prior to kernel launch.

Comment thread fla/ops/kda/chunk_intra.py
Comment on lines +425 to +434
B, T, H, K = k.shape
assert K <= 256
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

BC = min(16, BT)
NC = triton.cdiv(BT, BC)
BK = max(triton.next_power_of_2(K), 16)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Default scale to 1/sqrt(K) when None

Prevent None multiplication in kernels and match other KDA paths.

-    BC = min(16, BT)
+    if scale is None:
+        scale = k.shape[-1] ** -0.5
+    BC = min(16, BT)
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra.py around lines 425 to 434, the variable scale can be
None and later gets multiplied in kernels which causes None propagation; set a
default numeric scale equal to 1/sqrt(K) when scale is None (compute using K
from k.shape) so subsequent math and kernel calls always receive a float;
implement this normalization immediately after K is derived to avoid None
multiplication and to match other KDA code paths.

Comment on lines +25 to +51
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[float]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
use_qk_l2norm_in_kernel (Optional[bool]):
Whether to use L2 normalization in the kernel. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Doc param shapes: clarify that g here is passed as gk (per‑K decay)

This wrapper forwards g as gk=..., implying shape [B, T, HV, K], not [B, T, HV]. Please update the arg docs to reflect this and avoid confusion with the base API where g (no K) and gk are distinct.

Comment thread fla/ops/kda/gate.py
Comment thread fla/ops/kda/gate.py
Comment on lines +316 to +326
@input_guard
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
g, A = ctx.saved_tensors
head_k_dim = ctx.head_k_dim
beta = ctx.beta
threshold = ctx.threshold
g_bias = ctx.g_bias

grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold)
return grad_g, grad_A, None, grad_gbias, None, None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same decorator-order bug in backward

Swap to keep backward a true staticmethod.

-    @input_guard
-    @staticmethod
+    @staticmethod
+    @input_guard
     def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@input_guard
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
g, A = ctx.saved_tensors
head_k_dim = ctx.head_k_dim
beta = ctx.beta
threshold = ctx.threshold
g_bias = ctx.g_bias
grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold)
return grad_g, grad_A, None, grad_gbias, None, None
@staticmethod
@input_guard
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None, None, None]:
g, A = ctx.saved_tensors
head_k_dim = ctx.head_k_dim
beta = ctx.beta
threshold = ctx.threshold
g_bias = ctx.g_bias
grad_g, grad_A, grad_gbias = kda_gate_bwd(grad_output, g, A, head_k_dim, g_bias, beta, threshold)
return grad_g, grad_A, None, grad_gbias, None, None
🤖 Prompt for AI Agents
In fla/ops/kda/gate.py around lines 316 to 326, the decorators for backward are
in the wrong order (input_guard is applied before staticmethod) which prevents
backward from being a true staticmethod; swap the two decorators so
@staticmethod appears first and @input_guard second, preserving the existing
method signature and return types.

Comment thread fla/ops/kda/wy_fast.py
Comment thread fla/ops/kda/wy_fast.py
Comment on lines +269 to +275
BT = 64
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64 if check_shared_mem() else 32
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Hard-coded BT=64 in backward wrapper will desync from forward BT

prepare_wy_repr_bwd sets BT=64 regardless of A.shape[-1]. If forward used a different BT (e.g., 32/128), backward tiling and block_ptr strides become inconsistent, causing wrong results or OOB.

-    BT = 64
+    BT = A.shape[-1]
🤖 Prompt for AI Agents
In fla/ops/kda/wy_fast.py around lines 269 to 275, the backward wrapper
hard-codes BT = 64 which can mismatch the forward tiling and lead to incorrect
strides/OOB; change the logic to derive BT from the actual block size used by
forward (e.g., accept BT as an argument to the backward function or compute BT =
A.shape[-1] or the forward-provided tiling value), then use that BT for
prepare_chunk_indices and NT computation (NT = triton.cdiv(T, BT) if cu_seqlens
is None else len(chunk_indices)); ensure subsequent BK/BV computations still use
CONST_TILING but not overwrite BT so backward tiling and block_ptr strides match
forward.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26fe007 and 00b1783.

📒 Files selected for processing (1)
  • benchmarks/ops/benchmark_kda.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/ops/benchmark_kda.py (7)
fla/ops/comba/chunk.py (1)
  • chunk_comba (235-342)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (219-325)
fla/ops/generalized_delta_rule/dplr/chunk.py (1)
  • chunk_dplr_delta_rule (267-359)
fla/ops/kda/chunk.py (2)
  • chunk_kda (250-351)
  • backward (223-246)
fla/modules/activations.py (1)
  • logsigmoid (201-202)
fla/ops/utils/matmul.py (1)
  • sigmoid (160-162)
fla/ops/kda/gate.py (1)
  • backward (318-326)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
benchmarks/ops/benchmark_kda.py (2)

1-36: LGTM!

The imports and benchmark configuration are well-structured and include all necessary dependencies for the KDA and related attention mechanism benchmarks.


56-139: Past review comments appear resolved.

The conditional logic now uses a proper if/elif chain for provider selection, and the 'attn' provider is correctly listed in line_vals at line 25. The previous issues mentioned in the past review comments have been addressed.

Comment on lines +42 to +50
# Set TMA environment variable based on provider
original_tma_env = os.environ.get('FLA_USE_TMA', '0')

if provider.endswith('_no_tma'):
os.environ['FLA_USE_TMA'] = '0'
provider_base = provider.replace('_no_tma', '')
else:
os.environ['FLA_USE_TMA'] = '1'
provider_base = provider
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Wrap environment variable restoration in try-finally to ensure cleanup.

The TMA environment variable restoration at line 142 won't execute if an exception occurs during benchmarking. This could leave the environment in an inconsistent state for subsequent runs.

Apply this diff to ensure proper cleanup:

     quantiles = [0.5, 0.2, 0.8]
     results = 0, 0, 0
 
-    do = torch.randn(B, T, H, D, dtype=dtype, device=device)
-    if provider_base == 'gdn':
+    try:
+        do = torch.randn(B, T, H, D, dtype=dtype, device=device)
+        if provider_base == 'gdn':
+            ...
+        elif provider_base == 'dplr':
+            ...
+    finally:
+        # Restore original TMA environment variable
+        os.environ['FLA_USE_TMA'] = original_tma_env
+    
-    # Restore original TMA environment variable
-    os.environ['FLA_USE_TMA'] = original_tma_env
     return results

Also applies to: 141-142

🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 42 to 50, the code sets
FLA_USE_TMA based on the provider but does not guarantee restoration if an
exception occurs during benchmarking; wrap the benchmarking logic that follows
these lines in a try...finally block and move the restoration of
os.environ['FLA_USE_TMA'] (using the saved original_tma_env) into the finally so
the environment variable is always reset regardless of errors.

Comment thread benchmarks/ops/benchmark_kda.py
Comment on lines +104 to +120
elif provider_base == 'kda':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
)[0].backward(do),
quantiles=quantiles
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix critical shape mismatch for g tensor in KDA provider.

Line 108 creates g with shape [B, T, H, D], but chunk_kda expects g to have shape [B, T, H] according to its signature. This mismatch will cause runtime errors or incorrect computation.

Apply this diff to fix the shape:

-        g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
+        g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)).requires_grad_(True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif provider_base == 'kda':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
)[0].backward(do),
quantiles=quantiles
)
elif provider_base == 'kda':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)).requires_grad_(True)
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
use_qk_l2norm_in_kernel=True,
)[0].backward(do),
quantiles=quantiles
)
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 104 to 120, the tensor g is
created with shape [B, T, H, D] but chunk_kda expects g of shape [B, T, H];
change the g initializer to create a tensor of shape [B, T, H] (preserving
dtype, device and requires_grad_) and then apply F.logsigmoid to that tensor so
g has the correct shape and gradient behavior; no other changes to the
surrounding benchmark call are needed.

Comment on lines +121 to +139
elif provider_base == 'dplr':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_dplr_delta_rule(
q=q,
k=k,
v=v,
a=a,
b=b,
gk=g,
)[0].backward(do),
quantiles=quantiles
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove unused beta variable in DPLR provider.

Line 128 creates a beta tensor that is never used in the chunk_dplr_delta_rule call. The function signature doesn't accept a beta parameter, making this unnecessary computation that wastes resources during benchmarking.

Apply this diff to remove the unused variable:

         v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
         g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
-        beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
         results = triton.testing.do_bench(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif provider_base == 'dplr':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_dplr_delta_rule(
q=q,
k=k,
v=v,
a=a,
b=b,
gk=g,
)[0].backward(do),
quantiles=quantiles
)
elif provider_base == 'dplr':
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
results = triton.testing.do_bench(
lambda: chunk_dplr_delta_rule(
q=q,
k=k,
v=v,
a=a,
b=b,
gk=g,
)[0].backward(do),
quantiles=quantiles
)
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda.py around lines 121 to 139, the DPLR branch
creates an unused tensor `beta` at line 128 which is not passed to
chunk_dplr_delta_rule and wastes compute during benchmarks; remove the creation
of `beta` (and any associated .requires_grad_(True) call) and ensure the call to
chunk_dplr_delta_rule remains unchanged, then run tests/benchmarks to confirm no
references to `beta` remain.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
fla/ops/kda/fused_recurrent.py (1)

33-37: Doc/API mismatch for g (actually gk). Also accept 3D g via broadcast.

This wrapper passes g to the underlying op as gk (per‑K decays), so the documented shape should be [B, T, HV, K], not [B, T, HV]. The example already uses [B, T, HV, K]. To avoid user error and support legacy callers, broadcast a 3D g to 4D.

Apply:

@@
-        g (torch.Tensor):
-            g (decays) of shape `[B, T, HV]`.
+        g (torch.Tensor):
+            gk (per‑K decays) of shape `[B, T, HV, K]`.  Note: if a 3D `[B, T, HV]`
+            tensor is provided, it will be broadcast across K.
@@
     if scale is None:
         scale = k.shape[-1] ** -0.5
+    # Backward-compat: allow per‑HV decays by broadcasting across K
+    if g.dim() == 3:
+        g = g.unsqueeze(-1).expand(*g.shape, k.shape[-1])
@@
     o, final_state = fused_recurrent_gated_delta_rule(
         q=q,
         k=k,
         v=v,
-        gk=g,
+        gk=g,
         beta=beta,

Also applies to: 102-113

fla/layers/kda.py (1)

189-192: Guard past_key_values access when layer_idx is None.

len(past_key_values) > self.layer_idx raises if layer_idx is None. Add an explicit guard.

-        last_state = None
-        if past_key_values is not None and len(past_key_values) > self.layer_idx:
-            last_state = past_key_values[self.layer_idx]
+        last_state = None
+        if (
+            past_key_values is not None
+            and self.layer_idx is not None
+            and isinstance(self.layer_idx, int)
+            and self.layer_idx >= 0
+            and len(past_key_values) > self.layer_idx
+        ):
+            last_state = past_key_values[self.layer_idx]
🧹 Nitpick comments (8)
fla/models/kda/configuration_kda.py (2)

10-10: Annotate mutable class attribute with ClassVar.

The keys_to_ignore_at_inference list is a mutable class attribute and should be annotated with typing.ClassVar to make the intent explicit and improve type safety.

Apply this diff:

+from typing import ClassVar, Dict, Optional
-from typing import Dict, Optional

-    keys_to_ignore_at_inference = ['past_key_values']
+    keys_to_ignore_at_inference: ClassVar[list[str]] = ['past_key_values']

69-79: Consider extracting error messages to constants (optional).

The validation logic is correct. For improved maintainability, you could extract the error messages to module-level constants, though this is a minor style point.

Example refactor (optional):

_ERR_ATTN_NOT_DICT = "attn must be a dictionary"
_ERR_ATTN_NO_LAYERS = "Layer indices must be provided to initialize hybrid attention layers"
_ERR_ATTN_NO_HEADS = "Number of heads must be provided to initialize hybrid attention layers"

# Then in __init__:
if not isinstance(attn, Dict):
    raise ValueError(_ERR_ATTN_NOT_DICT)
fla/ops/kda/fused_recurrent.py (3)

21-23: Drop unused kwargs from signature.

kwargs is not consumed. Removing it tightens the public API.

-    cu_seqlens: Optional[torch.LongTensor] = None,
-    **kwargs
+    cu_seqlens: Optional[torch.LongTensor] = None,

4-7: Modernize type hints (PEP 604) and ensure runtime safety.

Use T | None and add the future import to keep runtime compatibility.

+# from __future__ import annotations
@@
-from typing import Optional, Tuple
+from typing import Tuple
@@
-    beta: torch.Tensor = None,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    beta: torch.Tensor | None = None,
+    scale: float | None = None,
+    initial_state: torch.Tensor | None = None,
@@
-    cu_seqlens: Optional[torch.LongTensor] = None,
+    cu_seqlens: torch.LongTensor | None = None,

Also applies to: 15-21


88-98: Condense exception messages (TRY003).

Shorten or factor long f-strings to reduce noise.

-        if q.shape[0] != 1:
-            raise ValueError(
-                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
-                f"Please flatten variable-length inputs before processing."
-            )
+        if q.shape[0] != 1:
+            raise ValueError(f"B must be 1 when using cu_seqlens; got {q.shape[0]}.")
@@
-        if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
-            raise ValueError(
-                f"The number of initial states is expected to be equal to the number of input sequences, "
-                f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
-            )
+        if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
+            raise ValueError(
+                f"initial_state.shape[0] ({initial_state.shape[0]}) must equal len(cu_seqlens)-1 ({len(cu_seqlens)-1})."
+            )
fla/layers/kda.py (3)

80-87: Use PEP 604 optionals in init signature.

Avoid implicit Optional and match Ruff guidance.

-        num_v_heads: int = None,
+        num_v_heads: int | None = None,
@@
-        layer_idx: int = None,
+        layer_idx: int | None = None,

174-175: kwargs typing and usage.

Unpack[Dict] is unusual here and triggers tooling noise; kwargs is used only for cu_seqlens. Prefer a simpler annotation or none.

-        **kwargs: Unpack[Dict]
+        **kwargs,
@@
-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

Also applies to: 193-196


185-188: Make the decode threshold a constant (optional).

q_len <= 64 is embedded magic. Consider a class attribute (e.g., self.decode_threshold = 64) to make tuning easier.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 00b1783 and eccea7e.

📒 Files selected for processing (3)
  • fla/layers/kda.py (1 hunks)
  • fla/models/kda/configuration_kda.py (1 hunks)
  • fla/ops/kda/fused_recurrent.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/fused_recurrent.py (2)
fla/ops/common/fused_recurrent.py (1)
  • fused_recurrent (540-569)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (242-353)
fla/layers/kda.py (6)
fla/layers/utils.py (2)
  • get_unpad_data (75-98)
  • pad_input (176-197)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormGated (997-1058)
fla/modules/convolution.py (1)
  • ShortConvolution (794-1009)
fla/ops/kda/chunk.py (2)
  • chunk_kda (250-351)
  • forward (186-218)
fla/ops/kda/fused_recurrent.py (1)
  • fused_recurrent_kda (11-114)
fla/ops/kda/gate.py (2)
  • fused_kda_gate (329-349)
  • forward (304-314)
🪛 Ruff (0.14.1)
fla/ops/kda/fused_recurrent.py

17-17: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


22-22: Unused function argument: kwargs

(ARG001)


90-93: Avoid specifying long messages outside the exception class

(TRY003)


95-98: Avoid specifying long messages outside the exception class

(TRY003)

fla/layers/kda.py

79-79: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


85-85: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


87-87: Unused method argument: kwargs

(ARG002)


112-115: Avoid specifying long messages outside the exception class

(TRY003)


117-119: Avoid specifying long messages outside the exception class

(TRY003)


122-125: Avoid specifying long messages outside the exception class

(TRY003)


172-172: Unused method argument: output_attentions

(ARG002)

fla/models/kda/configuration_kda.py

10-10: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


31-31: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


71-71: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Avoid specifying long messages outside the exception class

(TRY003)


75-75: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (2)
fla/models/kda/configuration_kda.py (2)

16-16: LGTM! Type hint corrected.

The expand_v parameter is now correctly typed as float, addressing the previous review feedback.


81-87: LGTM!

Parent class initialization is correct and passes the appropriate parameters.

Comment thread fla/layers/kda.py
Comment on lines +111 to +116
# Consistency check: Ensure expand_v produces integer values
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
raise ValueError(
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Error message wording: refers to key_dim but checks value_dim.

Clarify message to match the check.

-            raise ValueError(
-                f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
-                f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
-            )
+            raise ValueError(
+                f"expand_v={expand_v} must make num_v_heads*head_dim*expand_v an integer "
+                f"(got {self.num_v_heads * self.head_dim * expand_v}); invalid for nn.Linear."
+            )
🧰 Tools
🪛 Ruff (0.14.1)

112-115: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In fla/layers/kda.py around lines 111 to 116, the raised ValueError message
incorrectly references key_dim while the check is validating the computed
value_dim; update the error text to reference value_dim (or resulting value
dimension) and include the computed numeric result and the expand_v input for
clarity. Change the f-string to mention value_dim and the actual computed value
(self.num_v_heads * self.head_dim * expand_v) and that it must be an integer
suitable for nn.Linear.

Comment thread fla/layers/kda.py
Comment on lines +232 to +237
if self.num_v_heads > self.num_heads:
q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k))

if self.allow_neg_eigval:
beta = beta * 2.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: g and beta are not expanded when num_v_heads > num_heads (HV>H).

q,k are repeated to HV but g ([B,T,H,K]) and beta ([B,T,H]) are not, causing shape mismatch for fused ops expecting per‑HV decays/betas. Broadcast them to HV.

         if self.num_v_heads > self.num_heads:
-            q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=self.num_v_heads // self.num_heads), (q, k))
+            expand = self.num_v_heads // self.num_heads
+            q, k = map(lambda x: repeat(x, '... h d -> ... (h g) d', g=expand), (q, k))
+            # match per‑HV shapes expected by ops
+            g = repeat(g, '... h k -> ... (h g_) k', g_=expand)
+            beta = repeat(beta, '... h -> ... (h g_)', g_=expand)

Also applies to: 251-262

Comment thread fla/models/kda/configuration_kda.py Outdated
@zhiyuan1i
Copy link
Copy Markdown
Collaborator

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants