[GDN] Add fused gate kernel with use_gate_in_kernel support#813
[GDN] Add fused gate kernel with use_gate_in_kernel support#813
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
WalkthroughThis change adds an in-kernel GDN gate implementation (Triton + autograd), threads Changes
Sequence Diagram(s)sequenceDiagram
participant Layer as GatedDeltaNet Layer
participant Chunk as chunk_gated_delta_rule
participant Gate as gdn_gate_chunk_cumsum
participant Triton as Triton Gate Kernel
participant Back as gdn_gate_bwd
Layer->>Chunk: forward(..., g_raw, A_log, dt_bias, use_gate_in_kernel=True)
Chunk->>Gate: gdn_gate_chunk_cumsum(g_raw, A_log, chunk_size, dt_bias)
Gate->>Triton: launch gdn_gate_chunk_cumsum_scalar_kernel
Triton-->>Gate: g_gated, g_input
Gate-->>Chunk: g_gated, g_input
Chunk->>Layer: (o, ht, A, final_state, g_input)
Layer->>Back: backward(...)
Back->>Triton: launch gdn_gate_bwd_kernel(dyg, g_raw, A_log, dt_bias)
Triton-->>Back: dg, dA_log, ddt_bias
Back->>Chunk: chunk_gated_delta_rule_bwd(..., dA_log, ddt_bias)
Chunk-->>Layer: gradients for q,k,v,g_raw,A_log,dt_bias,h0
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Grouped Value Attention (GVA) in Gated Delta Networks (GDN) by decoupling the number of query/key heads from value heads, and adds a fused gate activation kernel to improve performance. My feedback suggests simplifying the head divisibility check logic in the chunk_gated_delta_rule function to reduce redundancy.
| if q.shape[2] != k.shape[2]: | ||
| raise ValueError( | ||
| f"q and k must have the same number of heads, " | ||
| f"but got q.shape[2]={q.shape[2]} and k.shape[2]={k.shape[2]}" | ||
| ) | ||
| H, HV = q.shape[2], v.shape[2] | ||
| if HV % H != 0: | ||
| raise ValueError( | ||
| f"For GQA, num_heads (H={H}) must be evenly divisible by " | ||
| f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}" | ||
| f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by " | ||
| f"num_heads (H={H}), but got HV % H = {HV % H}" | ||
| ) |
There was a problem hiding this comment.
The head divisibility check logic is redundant and can be simplified. q.shape[2] != k.shape[2] is already covered by the subsequent check HV % H != 0 if H is defined as q.shape[2] and HV as v.shape[2]. Additionally, the error message for HV % H != 0 should be clear about the requirement.
| if q.shape[2] != k.shape[2]: | |
| raise ValueError( | |
| f"q and k must have the same number of heads, " | |
| f"but got q.shape[2]={q.shape[2]} and k.shape[2]={k.shape[2]}" | |
| ) | |
| H, HV = q.shape[2], v.shape[2] | |
| if HV % H != 0: | |
| raise ValueError( | |
| f"For GQA, num_heads (H={H}) must be evenly divisible by " | |
| f"num_kv_heads (Hq={Hq}), but got H % Hq = {H % Hq}" | |
| f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by " | |
| f"num_heads (H={H}), but got HV % H = {HV % H}" | |
| ) | |
| H, HV = q.shape[2], v.shape[2] | |
| if HV % H != 0: | |
| raise ValueError( | |
| f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by " | |
| f"num_heads (H={H}), but got HV % H = {HV % H}" | |
| ) |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
181-199:⚠️ Potential issue | 🔴 Critical
dkis uninitialized on the ungated backward path.
tl.store(p_dk, ...)is unconditional, butb_dkis only assigned insideif USE_G:.prepare_wy_repr_bwd(..., g=None)will therefore write garbage or fail before the laterH != HVreduction even runs.🔧 Suggested fix
b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) b_dkbg = tl.dot(b_A, b_dw) if USE_G: b_dk = b_dkbg * (b_g_exp * b_b)[:, None] b_db += tl.sum(b_dkbg * b_k * b_g_exp[:, None], 1) b_dg += tl.sum(b_dkbg * b_kbg, 1) + else: + b_dk = b_dkbg * b_b[:, None] + b_db += tl.sum(b_dkbg * b_k, 1) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/wy_fast.py` around lines 181 - 199, The backward loop unconditionally stores b_dk via tl.store(p_dk, ...) but b_dk is only set inside the if USE_G branch, so when prepare_wy_repr_bwd(..., g=None) the store writes uninitialized data; fix by assigning b_dk in the ungated path (else) before the tl.store call — e.g., compute b_dk from b_dkbg and the bias term (use b_dk = b_dkbg * b_b[:, None]) so both gated (USE_G) and ungated branches define b_dk prior to tl.store in the loop.
🧹 Nitpick comments (2)
tests/context_parallel/test_cp_bwd_gk_offset.py (1)
112-113: Add oneHV > Hcase here too.Both launches at Line 112 and Line 120 still set
HV=H, so the new grouping pathi_h // (HV // H)never runs in this regression test. That leaves the GVA-specific offset logic uncovered.Also applies to: 120-121
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/context_parallel/test_cp_bwd_gk_offset.py` around lines 112 - 113, The test currently only exercises HV == H so the GVA-specific grouping branch (i_h // (HV // H)) is never taken; update the two test launches that pass HV=H (the calls passing cu_seqlens=..., scale=1.0, T=T, H=H, HV=H, K=K, V=V, ...) to include at least one variant where HV > H (for example set HV = H * 2 or HV = H + 1) so the grouping path and GVA offset logic are exercised in this regression test.fla/ops/gated_delta_rule/chunk.py (1)
487-491: Consider extracting kwargs values earlier to avoid potential confusion.The code extracts
use_gate_in_kernel,A_log, anddt_biasfromkwargsat lines 487-489, butuse_gate_in_kernelis also a named parameter in the public API (line 364 states it's in**kwargs). This works correctly, but the pattern of re-extracting from kwargs after the function signature could be clearer.The assertion at line 491 correctly enforces that
A_logmust be provided whenuse_gate_in_kernel=True.Consider documenting the kwargs contract more explicitly
The
use_gate_in_kernel,A_log, anddt_biasparameters are currently passed through**kwargsand extracted twice (once implicitly through kwargs destructuring for the docstring, once explicitly at lines 487-489). This works but could be clearer if these were promoted to explicit keyword arguments with defaults in the function signature.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 487 - 491, The kwargs extraction for use_gate_in_kernel, A_log, and dt_bias is done late and duplicates the implicit contract; change the function signature to accept explicit keyword arguments (use_gate_in_kernel=False, A_log=None, dt_bias=None) instead of pulling them out of **kwargs so their intent is clear, then remove the kwargs.get(...) lines (relating to use_gate_in_kernel, A_log, dt_bias) and keep the existing assertion that A_log is provided when use_gate_in_kernel is True (the assertion around A_log), updating any call sites that relied on passing these via **kwargs to use explicit named parameters.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/layers/gated_deltanet.py`:
- Around line 55-58: Add an input validation in the module's initializer (where
num_v_heads and num_heads are processed) to immediately reject cases where
num_v_heads is not None and num_v_heads < num_heads by raising a ValueError;
keep the existing divisibility check for num_v_heads > num_heads (i.e., ensure
num_v_heads % num_heads == 0) but precede it with this new check so the native
GVA path never receives HV < H and avoids invalid HV // H math inside the
kernels.
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 309-310: The wrapper for GVA in fused_recurrent.py does not honor
the new contract: when beta is None it builds beta from q[..., 0] producing
shape [B,T,H], which is incorrect for HV>H and allows invalid HV values; update
the wrapper logic (the code path that constructs beta when beta is None) to
produce beta shaped for grouped value attention (match HV dimensions) or require
beta to be provided when HV>H, validate that HV is divisible by H (raise/handle
if not), and ensure any indexing into beta in the kernel uses the new HV-shaped
beta; reference the symbols beta, q, H, HV, and the GVA fallback code in
fused_recurrent.py to locate and change the default-construction and validation
logic accordingly.
In `@fla/ops/gated_delta_rule/gate.py`:
- Around line 195-212: The backward kernel currently casts dg from fp32 to
g.dtype before computing dbias, causing timestep contributions to be quantized
for bf16/fp16; compute dbias from the fp32 dg buffer produced by
gdn_gate_bwd_kernel (use the dg fp32 tensor created at the top) before you call
dg = dg.view_as(g).type_as(g), e.g. compute dbias = dg.view(-1,
H).sum(0).to(dt_bias) if dt_bias is not None, then proceed to cast dg and reduce
dA—this preserves fp32 accumulation for dt_bias while keeping the rest of
outputs converted to g.dtype as before.
---
Outside diff comments:
In `@fla/ops/gated_delta_rule/wy_fast.py`:
- Around line 181-199: The backward loop unconditionally stores b_dk via
tl.store(p_dk, ...) but b_dk is only set inside the if USE_G branch, so when
prepare_wy_repr_bwd(..., g=None) the store writes uninitialized data; fix by
assigning b_dk in the ungated path (else) before the tl.store call — e.g.,
compute b_dk from b_dkbg and the bias term (use b_dk = b_dkbg * b_b[:, None]) so
both gated (USE_G) and ungated branches define b_dk prior to tl.store in the
loop.
---
Nitpick comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 487-491: The kwargs extraction for use_gate_in_kernel, A_log, and
dt_bias is done late and duplicates the implicit contract; change the function
signature to accept explicit keyword arguments (use_gate_in_kernel=False,
A_log=None, dt_bias=None) instead of pulling them out of **kwargs so their
intent is clear, then remove the kwargs.get(...) lines (relating to
use_gate_in_kernel, A_log, dt_bias) and keep the existing assertion that A_log
is provided when use_gate_in_kernel is True (the assertion around A_log),
updating any call sites that relied on passing these via **kwargs to use
explicit named parameters.
In `@tests/context_parallel/test_cp_bwd_gk_offset.py`:
- Around line 112-113: The test currently only exercises HV == H so the
GVA-specific grouping branch (i_h // (HV // H)) is never taken; update the two
test launches that pass HV=H (the calls passing cu_seqlens=..., scale=1.0, T=T,
H=H, HV=H, K=K, V=V, ...) to include at least one variant where HV > H (for
example set HV = H * 2 or HV = H + 1) so the grouping path and GVA offset logic
are exercised in this regression test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5dd20087-27fe-48b1-9bd1-ea8b966e9a5d
📒 Files selected for processing (13)
fla/layers/gated_deltanet.pyfla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/chunk_scaled_dot_kkt.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/chunk_fwd.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gated_delta_rule/gate.pyfla/ops/gated_delta_rule/wy_fast.pytests/context_parallel/test_cp_bwd_gk_offset.pytests/ops/test_gated_delta.py
| num_v_heads (int, Optional): | ||
| The number of heads for the value projection, equal to `num_heads` if `None`. | ||
| GVA is applied if `num_v_heads` > `num_heads`. Default: `None`. | ||
| GVA (Grouped Value Attention) is applied if `num_v_heads` > `num_heads`, | ||
| where `num_v_heads` must be divisible by `num_heads`. Default: `None`. |
There was a problem hiding this comment.
Reject num_v_heads < num_heads up front.
The new native GVA path only supports HV >= H, but the module currently only checks divisibility when num_v_heads > num_heads. Smaller values still get through and later hit invalid HV // H math inside the kernels.
🔧 Suggested fix
- if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
+ if self.num_v_heads < self.num_heads or self.num_v_heads % self.num_heads != 0:
raise ValueError(
- f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.",
+ f"num_v_heads={self.num_v_heads} must be >= num_heads={self.num_heads} "
+ f"and divisible by it.",
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/layers/gated_deltanet.py` around lines 55 - 58, Add an input validation
in the module's initializer (where num_v_heads and num_heads are processed) to
immediately reject cases where num_v_heads is not None and num_v_heads <
num_heads by raising a ValueError; keep the existing divisibility check for
num_v_heads > num_heads (i.e., ensure num_v_heads % num_heads == 0) but precede
it with this new check so the native GVA path never receives HV < H and avoids
invalid HV // H math inside the kernels.
| values of shape `[B, T, HV, V]`. | ||
| GVA is applied if `HV > H`. | ||
| GVA (Grouped Value Attention) is applied if `HV > H`, where `HV` must be divisible by `H`. |
There was a problem hiding this comment.
The wrapper doesn't fully honor the new GVA contract.
With beta=None, the fallback below still creates beta from q[..., 0], i.e. shape [B, T, H]. Once HV > H, the kernel indexes beta with HV stride, so this default is wrong, and invalid HV < H / non-divisible ratios still make it to HV // H.
🔧 Suggested fix
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
@@
) -> tuple[torch.Tensor, torch.Tensor]:
+ HV = v.shape[2]
+ H = q.shape[2]
+ if HV < H or HV % H != 0:
+ raise ValueError(
+ f"`v.shape[2]` ({HV}) must be >= `q.shape[2]` ({H}) and divisible by it for GVA."
+ )
if cu_seqlens is not None:
...
if scale is None:
scale = k.shape[-1] ** -0.5
if beta is None:
- beta = torch.ones_like(q[..., 0])
+ beta = torch.ones_like(v[..., 0])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 309 - 310, The
wrapper for GVA in fused_recurrent.py does not honor the new contract: when beta
is None it builds beta from q[..., 0] producing shape [B,T,H], which is
incorrect for HV>H and allows invalid HV values; update the wrapper logic (the
code path that constructs beta when beta is None) to produce beta shaped for
grouped value attention (match HV dimensions) or require beta to be provided
when HV>H, validate that HV is divisible by H (raise/handle if not), and ensure
any indexing into beta in the kernel uses the new HV-shaped beta; reference the
symbols beta, q, H, HV, and the GVA fallback code in fused_recurrent.py to
locate and change the default-construction and validation logic accordingly.
| dg = torch.empty_like(g, dtype=torch.float32) | ||
| dA = A_log.new_empty(NT, H, dtype=torch.float32) | ||
|
|
||
| gdn_gate_bwd_kernel[(NT, H)]( | ||
| g=g, | ||
| A_log=A_log, | ||
| dt_bias=dt_bias, | ||
| dyg=dyg, | ||
| dg=dg, | ||
| dA=dA, | ||
| T=T, | ||
| H=H, | ||
| BT=BT, | ||
| ) | ||
|
|
||
| dg = dg.view_as(g).type_as(g) | ||
| dA = dA.sum(0).view_as(A_log).type_as(A_log) | ||
| dbias = dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None |
There was a problem hiding this comment.
Keep dt_bias accumulation in fp32.
dbias is computed from dg after dg has been cast back to g.dtype. Under bf16/fp16 that quantizes every timestep contribution before the reduction, so dt_bias.grad is less accurate than the fp32 buffer you already produced.
🔧 Suggested fix
- dg = dg.view_as(g).type_as(g)
- dA = dA.sum(0).view_as(A_log).type_as(A_log)
- dbias = dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None
+ dg_fp32 = dg.view_as(g)
+ dA = dA.sum(0).view_as(A_log).type_as(A_log)
+ dbias = dg_fp32.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None else None
+ dg = dg_fp32.type_as(g)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/gate.py` around lines 195 - 212, The backward kernel
currently casts dg from fp32 to g.dtype before computing dbias, causing timestep
contributions to be quantized for bf16/fp16; compute dbias from the fp32 dg
buffer produced by gdn_gate_bwd_kernel (use the dg fp32 tensor created at the
top) before you call dg = dg.view_as(g).type_as(g), e.g. compute dbias =
dg.view(-1, H).sum(0).to(dt_bias) if dt_bias is not None, then proceed to cast
dg and reduce dA—this preserves fp32 accumulation for dt_bias while keeping the
rest of outputs converted to g.dtype as before.
# Conflicts: # fla/layers/gated_deltanet.py # fla/ops/gated_delta_rule/chunk.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
fla/layers/gated_deltanet.py (1)
128-130:⚠️ Potential issue | 🟠 MajorReject
num_v_heads < num_headsup front.The new gate-in-kernel and
fused_recurrentpaths still assumeHV >= H, so configs likenum_heads=6, num_v_heads=4still instantiate here and then fail inside the kernels.🔧 Suggested fix
- if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0: + if self.num_v_heads < self.num_heads or self.num_v_heads % self.num_heads != 0: raise ValueError( - f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.", + f"num_v_heads={self.num_v_heads} must be >= num_heads={self.num_heads} " + f"and divisible by it.", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/layers/gated_deltanet.py` around lines 128 - 130, The constructor validation currently only rejects cases where num_v_heads > num_heads but not divisible; you must also reject configurations where num_v_heads < num_heads since downstream paths (gate-in-kernel and fused_recurrent) assume HV >= H. In the same validation block (referencing self.num_v_heads and self.num_heads in gated_deltanet.py, e.g., in the class initializer), add a check that raises a ValueError when self.num_v_heads < self.num_heads (or combine into a single condition requiring self.num_v_heads >= self.num_heads and, if greater, divisible by self.num_heads) with a clear error message mentioning both num_v_heads and num_heads.
🧹 Nitpick comments (1)
fla/layers/gated_deltanet.py (1)
215-217: Add a layer-levelfused_recurrentregression before merge.This branch is now taken automatically for every eval call with
q_len <= 64, so the standalone gate-op test does not cover the cache/padding/GVA integration exercised here.Also applies to: 276-288
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@fla/layers/gated_deltanet.py`:
- Around line 128-130: The constructor validation currently only rejects cases
where num_v_heads > num_heads but not divisible; you must also reject
configurations where num_v_heads < num_heads since downstream paths
(gate-in-kernel and fused_recurrent) assume HV >= H. In the same validation
block (referencing self.num_v_heads and self.num_heads in gated_deltanet.py,
e.g., in the class initializer), add a check that raises a ValueError when
self.num_v_heads < self.num_heads (or combine into a single condition requiring
self.num_v_heads >= self.num_heads and, if greater, divisible by self.num_heads)
with a clear error message mentioning both num_v_heads and num_heads.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ed515f48-fb0f-4239-8507-39de1da183a3
📒 Files selected for processing (1)
fla/layers/gated_deltanet.py
|
| GPU | NVIDIA H200 |
| CUDA | 12.8 |
| PyTorch | 2.7.1+cu128 |
| Base | ceacb10dd2 |
| Head | c0c191482c |
| Threshold | 5.0% |
📊 View Details (1 significant changes)
| Op | Mode | B | T | H | D | Base (ms) | Head (ms) | Change |
|---|---|---|---|---|---|---|---|---|
| chunk_gdn | fwd | 8 | 1024 | 8 | 64 | 0.534 | 0.563 | +5.5% 🔴 |
This comment is automatically updated with the latest benchmark results.
Summary
fla/ops/gated_delta_rule/gate.pywith fused Triton kernels:gdn_gate_chunk_cumsum: fuses gate activation (-exp(A_log) * softplus(g + dt_bias)) with chunk-local cumsum in a single kernel pass, eliminating an HBM round-trip.fused_gdn_gate: standalone gate with autograd support, used by the fused_recurrent inference path.naive_gdn_gate: PyTorch reference implementation for testing.use_gate_in_kernel/A_log/dt_biasparameters tochunk_gated_delta_ruleAPI and wire through the autograd function.elu_p1/sum_norm, update docstring.test_gateto verifyfused_gdn_gateagainstnaive_gdn_gate(forward + backward).Test plan
pytest tests/ops/test_gated_delta.py::test_gatepytest tests/ops/test_gated_delta.py::test_chunk_gate_in_kernelpytest tests/ops/test_gated_delta.py::test_chunkpytest tests/ops/test_gated_delta.py::test_fused_recurrent🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Refactor
Tests