Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions fla/layers/gated_deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,23 @@
from fla.layers.utils import get_layer_cache, get_unpad_data, index_first_axis, pad_input, update_layer_cache
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from fla.ops.gated_delta_rule.gate import fused_gdn_gate

if TYPE_CHECKING:
from transformers.processing_utils import Unpack

from fla.models.utils import Cache


@torch.compile
def elu_p1(x):
return (F.elu(x, 1., False) + 1.).to(x)


@torch.compile
def sum_norm(x):
return (x / x.sum(-1, keepdim=True)).to(x)


class GatedDeltaNet(nn.Module):
"""
The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
Gated Delta Networks (GDN) layer implementation.

Reference: `Gated Delta Networks: Improving Mamba2 with Delta Rule <https://arxiv.org/abs/2412.06464>`_

Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.

Parameter alloation when use_gate=True:
Parameter allocation when use_gate=True:
- 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
- 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
- Others are ignorably small.
Expand All @@ -54,11 +47,11 @@ class GatedDeltaNet(nn.Module):
hidden_size (int, Optional):
The hidden size of the input. Default: 2048.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 2.0.
The expansion ratio for the value dimension. Default: 2.0.
head_dim (int, Optional):
The dimension of each head. Default: 256.
num_heads (int, Optional):
The number of heads. Default: 4.
The number of heads. Default: 6.
num_v_heads (int, Optional):
The number of heads for the value projection, equal to `num_heads` if `None`.
GVA (Grouped Value Attention) is applied if `num_v_heads` > `num_heads`,
Expand All @@ -69,15 +62,14 @@ class GatedDeltaNet(nn.Module):
Which Gated DeltaNet kernel to use.
Currently available: `chunk` and `fused_recurrent`.
Default: `chunk`.
use_beta (bool, Optional):
Whether to use beta. Default: `True`.
use_gate (bool, Optional):
Whether to use output gate. Default: `True`.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `True`.
allow_neg_eigval (bool, Optional):
Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
See reference:
`Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues <https://arxiv.org/abs/2411.12537>`_
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Expand Down Expand Up @@ -265,22 +257,24 @@ def forward(
if self.allow_neg_eigval:
beta = beta * 2.

g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)

recurrent_state = last_state['recurrent_state'] if last_state is not None else None
if mode == 'chunk':
o, recurrent_state = chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
g=self.a_proj(hidden_states),
beta=beta,
initial_state=recurrent_state,
output_final_state=use_cache,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_gate_in_kernel=True,
A_log=self.A_log,
dt_bias=self.dt_bias,
)
elif mode == 'fused_recurrent':
g = fused_gdn_gate(g=self.a_proj(hidden_states), A_log=self.A_log, dt_bias=self.dt_bias)
o, recurrent_state = fused_recurrent_gated_delta_rule(
q=q,
k=k,
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/gated_delta_product/chunk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang

import torch
from einops import rearrange
Expand Down Expand Up @@ -174,7 +174,7 @@ def backward(
# call the gated deltanet kernel for now.
# TODO: optimize the backward pass like the forward pass.
if g is not None:
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
dq, dk, dv, db, dg, dh0, _, _ = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
Expand Down
96 changes: 81 additions & 15 deletions fla/ops/gated_delta_rule/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
expand_h0,
)
from fla.ops.gated_delta_rule.chunk_fwd import chunk_gated_delta_rule_fwd_intra
from fla.ops.gated_delta_rule.gate import gdn_gate_bwd, gdn_gate_chunk_cumsum
from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
from fla.ops.utils import chunk_local_cumsum
from fla.ops.utils.constant import RCP_LN2
Expand All @@ -36,14 +37,29 @@ def chunk_gated_delta_rule_fwd(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = True,
transpose_state_layout: bool = False,
use_gate_in_kernel: bool = False,
A_log: torch.Tensor | None = None,
dt_bias: torch.Tensor | None = None,
):
g = chunk_local_cumsum(
g,
chunk_size=64,
scale=RCP_LN2 if use_exp2 else None,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
g_input = g if use_gate_in_kernel else None
if use_gate_in_kernel:
g = gdn_gate_chunk_cumsum(
g=g,
A_log=A_log,
chunk_size=64,
scale=RCP_LN2 if use_exp2 else None,
dt_bias=dt_bias,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
else:
g = chunk_local_cumsum(
g,
chunk_size=64,
scale=RCP_LN2 if use_exp2 else None,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
# obtain WY representation. u is actually the new v.
# fused kkt + solve_tril + recompute_w_u
w, u, A = chunk_gated_delta_rule_fwd_intra(
Expand Down Expand Up @@ -97,7 +113,7 @@ def chunk_gated_delta_rule_fwd(
use_exp2=use_exp2,
transpose_state_layout=transpose_state_layout,
)
return g, o, A, final_state, initial_state
return g, o, A, final_state, initial_state, g_input


def chunk_gated_delta_rule_bwd(
Expand All @@ -116,6 +132,10 @@ def chunk_gated_delta_rule_bwd(
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = True,
transpose_state_layout: bool = False,
use_gate_in_kernel: bool = False,
g_input: torch.Tensor | None = None,
A_log: torch.Tensor | None = None,
dt_bias: torch.Tensor | None = None,
):
w, u = recompute_w_u_fwd(
k=k,
Expand Down Expand Up @@ -219,7 +239,10 @@ def chunk_gated_delta_rule_bwd(
dk.add_(dk2)
dg.add_(dg2)
dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices)
return dq, dk, dv, db, dg, dh0
dA_log, ddt_bias = None, None
if use_gate_in_kernel:
dg, dA_log, ddt_bias = gdn_gate_bwd(g=g_input, A_log=A_log, dt_bias=dt_bias, dyg=dg)
return dq, dk, dv, db, dg, dh0, dA_log, ddt_bias


class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
Expand All @@ -242,6 +265,9 @@ def forward(
use_qk_l2norm_in_kernel: bool = False,
cp_context: FLACPContext | None = None,
transpose_state_layout: bool = False,
use_gate_in_kernel: bool = False,
A_log: torch.Tensor | None = None,
dt_bias: torch.Tensor | None = None,
):
q_rstd, k_rstd = None, None
if use_qk_l2norm_in_kernel:
Expand All @@ -250,7 +276,7 @@ def forward(

chunk_indices = prepare_chunk_indices(
cu_seqlens, 64, cu_seqlens_cpu=cu_seqlens_cpu) if cu_seqlens is not None else None
g, o, A, final_state, initial_state = chunk_gated_delta_rule_fwd(
g, o, A, final_state, initial_state, g_input = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
Expand All @@ -263,12 +289,20 @@ def forward(
cp_context=cp_context,
chunk_indices=chunk_indices,
transpose_state_layout=transpose_state_layout,
use_gate_in_kernel=use_gate_in_kernel,
A_log=A_log,
dt_bias=dt_bias,
)
ctx.save_for_backward(
q, q_rstd, k, k_rstd, v, g, beta, A,
initial_state, cu_seqlens, chunk_indices,
g_input, A_log, dt_bias,
)
ctx.save_for_backward(q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens, chunk_indices)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
ctx.cp_context = cp_context
ctx.transpose_state_layout = transpose_state_layout
ctx.use_gate_in_kernel = use_gate_in_kernel
return o.to(q.dtype), final_state

@staticmethod
Expand All @@ -279,8 +313,10 @@ def backward(
do: torch.Tensor,
dht: torch.Tensor,
):
q, q_rstd, k, k_rstd, v, g, beta, A, initial_state, cu_seqlens, chunk_indices = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
(q, q_rstd, k, k_rstd, v, g, beta, A,
initial_state, cu_seqlens, chunk_indices,
g_input, A_log, dt_bias) = ctx.saved_tensors
dq, dk, dv, db, dg, dh0, dA_log, ddt_bias = chunk_gated_delta_rule_bwd(
q=q,
k=k,
v=v,
Expand All @@ -295,11 +331,19 @@ def backward(
cp_context=ctx.cp_context,
chunk_indices=chunk_indices,
transpose_state_layout=ctx.transpose_state_layout,
use_gate_in_kernel=ctx.use_gate_in_kernel,
g_input=g_input,
A_log=A_log,
dt_bias=dt_bias,
)
if ctx.use_qk_l2norm_in_kernel:
dq = l2norm_bwd(q, q_rstd, dq)
dk = l2norm_bwd(k, k_rstd, dk)
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None
return (
dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta),
None, dh0, None, None, None, None, None, None,
None, dA_log, ddt_bias,
)


@torch.compiler.disable
Expand Down Expand Up @@ -329,7 +373,10 @@ def chunk_gated_delta_rule(
values of shape `[B, T, HV, V]`.
GVA (Grouped Value Attention) is applied if `HV > H`, where `HV` must be divisible by `H`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, HV]`.
(forget) gating tensor of shape `[B, T, HV]`.
When `use_gate_in_kernel=False` (default), `g` should be in log space (pre-computed decay).
When `use_gate_in_kernel=True`, `g` is the raw input before gate activation;
the kernel fuses `-exp(A_log) * softplus(g + dt_bias)` + chunk cumsum internally.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[float]):
Expand All @@ -353,6 +400,16 @@ def chunk_gated_delta_rule(
transpose_state_layout (Optional[bool]):
Whether to use the transposed state layout for the hidden state.
Default: `False`.
use_gate_in_kernel (bool):
Whether to compute the log-space GDN decay internally.
When `True`, the passed `g` is the raw input, and `A_log` must be provided.
The kernel fuses gate activation + chunk cumsum in a single pass.
Default: `False`.
A_log (Optional[torch.Tensor]):
Decay parameter of shape `[HV]`. Required when `use_gate_in_kernel=True`.
dt_bias (Optional[torch.Tensor]):
Bias added to `g` before activation, of shape `[HV]`.
Only used when `use_gate_in_kernel=True`.

Returns:
o (torch.Tensor):
Expand Down Expand Up @@ -427,6 +484,12 @@ def chunk_gated_delta_rule(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
)
use_gate_in_kernel = kwargs.get('use_gate_in_kernel', False)
A_log = kwargs.get('A_log')
dt_bias = kwargs.get('dt_bias')
if use_gate_in_kernel:
assert A_log is not None, "A_log must be provided when use_gate_in_kernel=True."

if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
Expand All @@ -443,6 +506,9 @@ def chunk_gated_delta_rule(
use_qk_l2norm_in_kernel,
cp_context,
transpose_state_layout,
use_gate_in_kernel,
A_log,
dt_bias,
)
return o, final_state

Expand Down
Loading
Loading