Skip to content

[GDN] Add fused gate kernel with use_gate_in_kernel support#813

Merged
yzhangcs merged 6 commits intomainfrom
fused-gdn-gate
Apr 6, 2026
Merged

[GDN] Add fused gate kernel with use_gate_in_kernel support#813
yzhangcs merged 6 commits intomainfrom
fused-gdn-gate

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 5, 2026

Summary

  • Create fla/ops/gated_delta_rule/gate.py with fused Triton kernels:
    • gdn_gate_chunk_cumsum: fuses gate activation (-exp(A_log) * softplus(g + dt_bias)) with chunk-local cumsum in a single kernel pass, eliminating an HBM round-trip.
    • fused_gdn_gate: standalone gate with autograd support, used by the fused_recurrent inference path.
    • naive_gdn_gate: PyTorch reference implementation for testing.
  • Add use_gate_in_kernel / A_log / dt_bias parameters to chunk_gated_delta_rule API and wire through the autograd function.
  • Update GatedDeltaNet layer to pass raw gate input + parameters instead of pre-computed gate values in both chunk and fused_recurrent modes.
  • Clean up layer: remove unused elu_p1/sum_norm, update docstring.
  • Add test_gate to verify fused_gdn_gate against naive_gdn_gate (forward + backward).

Test plan

  • pytest tests/ops/test_gated_delta.py::test_gate
  • pytest tests/ops/test_gated_delta.py::test_chunk_gate_in_kernel
  • pytest tests/ops/test_gated_delta.py::test_chunk
  • pytest tests/ops/test_gated_delta.py::test_fused_recurrent

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added an optional fused gate computation path for faster gated updates and improved performance.
  • Refactor

    • Gate computation module reorganized to support both fused (kernel) and unfused (reference) processing with clearer parameter behavior.
  • Tests

    • Expanded tests to validate outputs and gradients across standard, grouped-attention, and variable-length sequences.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 5, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e8e6cc09-fae9-44c7-9f5e-612623c3521b

📥 Commits

Reviewing files that changed from the base of the PR and between 3a0e923 and 38a3014.

📒 Files selected for processing (1)
  • fla/ops/gated_delta_product/chunk.py

Walkthrough

This change adds an in-kernel GDN gate implementation (Triton + autograd), threads A_log/dt_bias/use_gate_in_kernel through chunked gated-delta forward/backward, updates GatedDeltaNet to call the fused gate for both chunked and fused-recurrent paths, and adds tests validating outputs and gradients.

Changes

Cohort / File(s) Summary
GDN Gate Operations (New)
fla/ops/gated_delta_rule/gate.py
New Triton + reference implementations for GDN gating: gdn_gate_chunk_cumsum, gdn_gate_fwd, gdn_gate_bwd, autograd GDNGateFunction, and fused_gdn_gate. Supports varlen sequences, optional dt_bias, and exposes fused gate for kernels.
Chunk Gated Delta Rule
fla/ops/gated_delta_rule/chunk.py
Added use_gate_in_kernel, A_log, dt_bias params to fwd/bwd and ChunkGatedDeltaRuleFunction; forward now optionally fuses gate activation into chunk cumsum via gdn_gate_chunk_cumsum; forward returns g_input; backward returns extra gradients dA_log, ddt_bias. Revised shape/divisibility checks to GVA-style (HV % H == 0).
Gated DeltaNet Layer
fla/layers/gated_deltanet.py
Removed two local helper functions and refactored gating: chunk path passes raw g plus use_gate_in_kernel=True, A_log, dt_bias into chunk_gated_delta_rule; fused_recurrent uses fused_gdn_gate. Minor docstring and default-doc updates.
Gated Delta Product Backward
fla/ops/gated_delta_product/chunk.py
Adjusted unpacking of chunk_gated_delta_rule_bwd results to accept the two additional backward outputs (captured as placeholders).
Tests
tests/ops/test_gated_delta.py
Added tests test_chunk_gate_in_kernel, test_chunk_gate_in_kernel_gqa, and test_chunk_gate_in_kernel_varlen comparing in-kernel fused gate vs reference naive_gdn_gate, including gradient checks for A_log and optional dt_bias.

Sequence Diagram(s)

sequenceDiagram
    participant Layer as GatedDeltaNet Layer
    participant Chunk as chunk_gated_delta_rule
    participant Gate as gdn_gate_chunk_cumsum
    participant Triton as Triton Gate Kernel
    participant Back as gdn_gate_bwd

    Layer->>Chunk: forward(..., g_raw, A_log, dt_bias, use_gate_in_kernel=True)
    Chunk->>Gate: gdn_gate_chunk_cumsum(g_raw, A_log, chunk_size, dt_bias)
    Gate->>Triton: launch gdn_gate_chunk_cumsum_scalar_kernel
    Triton-->>Gate: g_gated, g_input
    Gate-->>Chunk: g_gated, g_input
    Chunk->>Layer: (o, ht, A, final_state, g_input)

    Layer->>Back: backward(...)
    Back->>Triton: launch gdn_gate_bwd_kernel(dyg, g_raw, A_log, dt_bias)
    Triton-->>Back: dg, dA_log, ddt_bias
    Back->>Chunk: chunk_gated_delta_rule_bwd(..., dA_log, ddt_bias)
    Chunk-->>Layer: gradients for q,k,v,g_raw,A_log,dt_bias,h0
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • Nathancgy

Poem

🐇 I hopped through kernels, quiet and bright,

Gates stitched in triton, glowing light,
A_log and bias, gradients hum,
Chunked magic fused — the tensors come,
A rabbit cheers: fast forward, run! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.19% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly and concisely describes the main change: adding fused gate kernels with use_gate_in_kernel support, which aligns with the core modifications across multiple files in the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fused-gdn-gate

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzhangcs yzhangcs changed the title [GDN] Native GVA support and fused gate kernel [GDN] Add fused gate kernel with use_gate_in_kernel support Apr 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Grouped Value Attention (GVA) in Gated Delta Networks (GDN) by decoupling the number of query/key heads from value heads, and adds a fused gate activation kernel to improve performance. My feedback suggests simplifying the head divisibility check logic in the chunk_gated_delta_rule function to reduce redundancy.

Comment on lines +450 to 460
if q.shape[2] != k.shape[2]:
raise ValueError(
f"q and k must have the same number of heads, "
f"but got q.shape[2]={q.shape[2]} and k.shape[2]={k.shape[2]}"
)
H, HV = q.shape[2], v.shape[2]
if HV % H != 0:
raise ValueError(
f"For GQA, num_heads (H={H}) must be evenly divisible by "
f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by "
f"num_heads (H={H}), but got HV % H = {HV % H}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head divisibility check logic is redundant and can be simplified. q.shape[2] != k.shape[2] is already covered by the subsequent check HV % H != 0 if H is defined as q.shape[2] and HV as v.shape[2]. Additionally, the error message for HV % H != 0 should be clear about the requirement.

Suggested change
if q.shape[2] != k.shape[2]:
raise ValueError(
f"q and k must have the same number of heads, "
f"but got q.shape[2]={q.shape[2]} and k.shape[2]={k.shape[2]}"
)
H, HV = q.shape[2], v.shape[2]
if HV % H != 0:
raise ValueError(
f"For GQA, num_heads (H={H}) must be evenly divisible by "
f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}"
f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by "
f"num_heads (H={H}), but got HV % H = {HV % H}"
)
H, HV = q.shape[2], v.shape[2]
if HV % H != 0:
raise ValueError(
f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by "
f"num_heads (H={H}), but got HV % H = {HV % H}"
)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)

181-199: ⚠️ Potential issue | 🔴 Critical

dk is uninitialized on the ungated backward path.

tl.store(p_dk, ...) is unconditional, but b_dk is only assigned inside if USE_G:. prepare_wy_repr_bwd(..., g=None) will therefore write garbage or fail before the later H != HV reduction even runs.

🔧 Suggested fix
         b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype))
         b_dkbg = tl.dot(b_A, b_dw)
         if USE_G:
             b_dk = b_dkbg * (b_g_exp * b_b)[:, None]
             b_db += tl.sum(b_dkbg * b_k * b_g_exp[:, None], 1)
             b_dg += tl.sum(b_dkbg * b_kbg, 1)
+        else:
+            b_dk = b_dkbg * b_b[:, None]
+            b_db += tl.sum(b_dkbg * b_k, 1)
         tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 181 - 199, The backward
loop unconditionally stores b_dk via tl.store(p_dk, ...) but b_dk is only set
inside the if USE_G branch, so when prepare_wy_repr_bwd(..., g=None) the store
writes uninitialized data; fix by assigning b_dk in the ungated path (else)
before the tl.store call — e.g., compute b_dk from b_dkbg and the bias term (use
b_dk = b_dkbg * b_b[:, None]) so both gated (USE_G) and ungated branches define
b_dk prior to tl.store in the loop.
🧹 Nitpick comments (2)
tests/context_parallel/test_cp_bwd_gk_offset.py (1)

112-113: Add one HV > H case here too.

Both launches at Line 112 and Line 120 still set HV=H, so the new grouping path i_h // (HV // H) never runs in this regression test. That leaves the GVA-specific offset logic uncovered.

Also applies to: 120-121

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/context_parallel/test_cp_bwd_gk_offset.py` around lines 112 - 113, The
test currently only exercises HV == H so the GVA-specific grouping branch (i_h
// (HV // H)) is never taken; update the two test launches that pass HV=H (the
calls passing cu_seqlens=..., scale=1.0, T=T, H=H, HV=H, K=K, V=V, ...) to
include at least one variant where HV > H (for example set HV = H * 2 or HV = H
+ 1) so the grouping path and GVA offset logic are exercised in this regression
test.
fla/ops/gated_delta_rule/chunk.py (1)

487-491: Consider extracting kwargs values earlier to avoid potential confusion.

The code extracts use_gate_in_kernel, A_log, and dt_bias from kwargs at lines 487-489, but use_gate_in_kernel is also a named parameter in the public API (line 364 states it's in **kwargs). This works correctly, but the pattern of re-extracting from kwargs after the function signature could be clearer.

The assertion at line 491 correctly enforces that A_log must be provided when use_gate_in_kernel=True.

Consider documenting the kwargs contract more explicitly

The use_gate_in_kernel, A_log, and dt_bias parameters are currently passed through **kwargs and extracted twice (once implicitly through kwargs destructuring for the docstring, once explicitly at lines 487-489). This works but could be clearer if these were promoted to explicit keyword arguments with defaults in the function signature.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 487 - 491, The kwargs
extraction for use_gate_in_kernel, A_log, and dt_bias is done late and
duplicates the implicit contract; change the function signature to accept
explicit keyword arguments (use_gate_in_kernel=False, A_log=None, dt_bias=None)
instead of pulling them out of **kwargs so their intent is clear, then remove
the kwargs.get(...) lines (relating to use_gate_in_kernel, A_log, dt_bias) and
keep the existing assertion that A_log is provided when use_gate_in_kernel is
True (the assertion around A_log), updating any call sites that relied on
passing these via **kwargs to use explicit named parameters.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/layers/gated_deltanet.py`:
- Around line 55-58: Add an input validation in the module's initializer (where
num_v_heads and num_heads are processed) to immediately reject cases where
num_v_heads is not None and num_v_heads < num_heads by raising a ValueError;
keep the existing divisibility check for num_v_heads > num_heads (i.e., ensure
num_v_heads % num_heads == 0) but precede it with this new check so the native
GVA path never receives HV < H and avoids invalid HV // H math inside the
kernels.

In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 309-310: The wrapper for GVA in fused_recurrent.py does not honor
the new contract: when beta is None it builds beta from q[..., 0] producing
shape [B,T,H], which is incorrect for HV>H and allows invalid HV values; update
the wrapper logic (the code path that constructs beta when beta is None) to
produce beta shaped for grouped value attention (match HV dimensions) or require
beta to be provided when HV>H, validate that HV is divisible by H (raise/handle
if not), and ensure any indexing into beta in the kernel uses the new HV-shaped
beta; reference the symbols beta, q, H, HV, and the GVA fallback code in
fused_recurrent.py to locate and change the default-construction and validation
logic accordingly.

In `@fla/ops/gated_delta_rule/gate.py`:
- Around line 195-212: The backward kernel currently casts dg from fp32 to
g.dtype before computing dbias, causing timestep contributions to be quantized
for bf16/fp16; compute dbias from the fp32 dg buffer produced by
gdn_gate_bwd_kernel (use the dg fp32 tensor created at the top) before you call
dg = dg.view_as(g).type_as(g), e.g. compute dbias = dg.view(-1,
H).sum(0).to(dt_bias) if dt_bias is not None, then proceed to cast dg and reduce
dA—this preserves fp32 accumulation for dt_bias while keeping the rest of
outputs converted to g.dtype as before.

---

Outside diff comments:
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 181-199: The backward loop unconditionally stores b_dk via
tl.store(p_dk, ...) but b_dk is only set inside the if USE_G branch, so when
prepare_wy_repr_bwd(..., g=None) the store writes uninitialized data; fix by
assigning b_dk in the ungated path (else) before the tl.store call — e.g.,
compute b_dk from b_dkbg and the bias term (use b_dk = b_dkbg * b_b[:, None]) so
both gated (USE_G) and ungated branches define b_dk prior to tl.store in the
loop.

---

Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 487-491: The kwargs extraction for use_gate_in_kernel, A_log, and
dt_bias is done late and duplicates the implicit contract; change the function
signature to accept explicit keyword arguments (use_gate_in_kernel=False,
A_log=None, dt_bias=None) instead of pulling them out of **kwargs so their
intent is clear, then remove the kwargs.get(...) lines (relating to
use_gate_in_kernel, A_log, dt_bias) and keep the existing assertion that A_log
is provided when use_gate_in_kernel is True (the assertion around A_log),
updating any call sites that relied on passing these via **kwargs to use
explicit named parameters.

In `@tests/context_parallel/test_cp_bwd_gk_offset.py`:
- Around line 112-113: The test currently only exercises HV == H so the
GVA-specific grouping branch (i_h // (HV // H)) is never taken; update the two
test launches that pass HV=H (the calls passing cu_seqlens=..., scale=1.0, T=T,
H=H, HV=H, K=K, V=V, ...) to include at least one variant where HV > H (for
example set HV = H * 2 or HV = H + 1) so the grouping path and GVA offset logic
are exercised in this regression test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5dd20087-27fe-48b1-9bd1-ea8b966e9a5d

📥 Commits

Reviewing files that changed from the base of the PR and between 87ff243 and 8491c16.

📒 Files selected for processing (13)
  • fla/layers/gated_deltanet.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/chunk_scaled_dot_kkt.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/chunk_fwd.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gated_delta_rule/gate.py
  • fla/ops/gated_delta_rule/wy_fast.py
  • tests/context_parallel/test_cp_bwd_gk_offset.py
  • tests/ops/test_gated_delta.py

Comment thread fla/layers/gated_deltanet.py Outdated
Comment on lines +55 to +58
num_v_heads (int, Optional):
The number of heads for the value projection, equal to `num_heads` if `None`.
GVA is applied if `num_v_heads` > `num_heads`. Default: `None`.
GVA (Grouped Value Attention) is applied if `num_v_heads` > `num_heads`,
where `num_v_heads` must be divisible by `num_heads`. Default: `None`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject num_v_heads < num_heads up front.

The new native GVA path only supports HV >= H, but the module currently only checks divisibility when num_v_heads > num_heads. Smaller values still get through and later hit invalid HV // H math inside the kernels.

🔧 Suggested fix
-        if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
+        if self.num_v_heads < self.num_heads or self.num_v_heads % self.num_heads != 0:
             raise ValueError(
-                f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
+                f"num_v_heads={self.num_v_heads} must be >= num_heads={self.num_heads} "
+                f"and divisible by it.",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/gated_deltanet.py` around lines 55 - 58, Add an input validation
in the module's initializer (where num_v_heads and num_heads are processed) to
immediately reject cases where num_v_heads is not None and num_v_heads <
num_heads by raising a ValueError; keep the existing divisibility check for
num_v_heads > num_heads (i.e., ensure num_v_heads % num_heads == 0) but precede
it with this new check so the native GVA path never receives HV < H and avoids
invalid HV // H math inside the kernels.

Comment on lines 309 to +310
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
GVA (Grouped Value Attention) is applied if `HV > H`, where `HV` must be divisible by `H`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The wrapper doesn't fully honor the new GVA contract.

With beta=None, the fallback below still creates beta from q[..., 0], i.e. shape [B, T, H]. Once HV > H, the kernel indexes beta with HV stride, so this default is wrong, and invalid HV < H / non-divisible ratios still make it to HV // H.

🔧 Suggested fix
 def fused_recurrent_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
@@
 ) -> tuple[torch.Tensor, torch.Tensor]:
+    HV = v.shape[2]
+    H = q.shape[2]
+    if HV < H or HV % H != 0:
+        raise ValueError(
+            f"`v.shape[2]` ({HV}) must be >= `q.shape[2]` ({H}) and divisible by it for GVA."
+        )
     if cu_seqlens is not None:
         ...
     if scale is None:
         scale = k.shape[-1] ** -0.5
     if beta is None:
-        beta = torch.ones_like(q[..., 0])
+        beta = torch.ones_like(v[..., 0])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 309 - 310, The
wrapper for GVA in fused_recurrent.py does not honor the new contract: when beta
is None it builds beta from q[..., 0] producing shape [B,T,H], which is
incorrect for HV>H and allows invalid HV values; update the wrapper logic (the
code path that constructs beta when beta is None) to produce beta shaped for
grouped value attention (match HV dimensions) or require beta to be provided
when HV>H, validate that HV is divisible by H (raise/handle if not), and ensure
any indexing into beta in the kernel uses the new HV-shaped beta; reference the
symbols beta, q, H, HV, and the GVA fallback code in fused_recurrent.py to
locate and change the default-construction and validation logic accordingly.

Comment on lines +195 to +212
dg = torch.empty_like(g, dtype=torch.float32)
dA = A_log.new_empty(NT, H, dtype=torch.float32)

gdn_gate_bwd_kernel[(NT, H)](
g=g,
A_log=A_log,
dt_bias=dt_bias,
dyg=dyg,
dg=dg,
dA=dA,
T=T,
H=H,
BT=BT,
)

dg = dg.view_as(g).type_as(g)
dA = dA.sum(0).view_as(A_log).type_as(A_log)
dbias = dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep dt_bias accumulation in fp32.

dbias is computed from dg after dg has been cast back to g.dtype. Under bf16/fp16 that quantizes every timestep contribution before the reduction, so dt_bias.grad is less accurate than the fp32 buffer you already produced.

🔧 Suggested fix
-    dg = dg.view_as(g).type_as(g)
-    dA = dA.sum(0).view_as(A_log).type_as(A_log)
-    dbias = dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None
+    dg_fp32 = dg.view_as(g)
+    dA = dA.sum(0).view_as(A_log).type_as(A_log)
+    dbias = dg_fp32.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None
+    dg = dg_fp32.type_as(g)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/gate.py` around lines 195 - 212, The backward kernel
currently casts dg from fp32 to g.dtype before computing dbias, causing timestep
contributions to be quantized for bf16/fp16; compute dbias from the fp32 dg
buffer produced by gdn_gate_bwd_kernel (use the dg fp32 tensor created at the
top) before you call dg = dg.view_as(g).type_as(g), e.g. compute dbias =
dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None, then proceed to cast
dg and reduce dA—this preserves fp32 accumulation for dt_bias while keeping the
rest of outputs converted to g.dtype as before.

# Conflicts:
#	fla/layers/gated_deltanet.py
#	fla/ops/gated_delta_rule/chunk.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
fla/layers/gated_deltanet.py (1)

128-130: ⚠️ Potential issue | 🟠 Major

Reject num_v_heads < num_heads up front.

The new gate-in-kernel and fused_recurrent paths still assume HV >= H, so configs like num_heads=6, num_v_heads=4 still instantiate here and then fail inside the kernels.

🔧 Suggested fix
-        if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
+        if self.num_v_heads < self.num_heads or self.num_v_heads % self.num_heads != 0:
             raise ValueError(
-                f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
+                f"num_v_heads={self.num_v_heads} must be >= num_heads={self.num_heads} "
+                f"and divisible by it.",
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/layers/gated_deltanet.py` around lines 128 - 130, The constructor
validation currently only rejects cases where num_v_heads > num_heads but not
divisible; you must also reject configurations where num_v_heads < num_heads
since downstream paths (gate-in-kernel and fused_recurrent) assume HV >= H. In
the same validation block (referencing self.num_v_heads and self.num_heads in
gated_deltanet.py, e.g., in the class initializer), add a check that raises a
ValueError when self.num_v_heads < self.num_heads (or combine into a single
condition requiring self.num_v_heads >= self.num_heads and, if greater,
divisible by self.num_heads) with a clear error message mentioning both
num_v_heads and num_heads.
🧹 Nitpick comments (1)
fla/layers/gated_deltanet.py (1)

215-217: Add a layer-level fused_recurrent regression before merge.

This branch is now taken automatically for every eval call with q_len <= 64, so the standalone gate-op test does not cover the cache/padding/GVA integration exercised here.

Also applies to: 276-288

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@fla/layers/gated_deltanet.py`:
- Around line 128-130: The constructor validation currently only rejects cases
where num_v_heads > num_heads but not divisible; you must also reject
configurations where num_v_heads < num_heads since downstream paths
(gate-in-kernel and fused_recurrent) assume HV >= H. In the same validation
block (referencing self.num_v_heads and self.num_heads in gated_deltanet.py,
e.g., in the class initializer), add a check that raises a ValueError when
self.num_v_heads < self.num_heads (or combine into a single condition requiring
self.num_v_heads >= self.num_heads and, if greater, divisible by self.num_heads)
with a clear error message mentioning both num_v_heads and num_heads.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ed515f48-fb0f-4239-8507-39de1da183a3

📥 Commits

Reviewing files that changed from the base of the PR and between 8491c16 and 3a0e923.

📒 Files selected for processing (1)
  • fla/layers/gated_deltanet.py

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 6, 2026

⚠️ Benchmark Results (NVIDIA-H100-PT2-7)

Status: 1 regression(s) detected

GPU NVIDIA H200
CUDA 12.8
PyTorch 2.7.1+cu128
Base ceacb10dd2
Head c0c191482c
Threshold 5.0%
📊 View Details (1 significant changes)
Op Mode B T H D Base (ms) Head (ms) Change
chunk_gdn fwd 8 1024 8 64 0.534 0.563 +5.5% 🔴

This comment is automatically updated with the latest benchmark results.

@yzhangcs yzhangcs merged commit e3c8896 into main Apr 6, 2026
6 checks passed
@yzhangcs yzhangcs deleted the fused-gdn-gate branch April 6, 2026 04:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant