From 1a219c45692da48f61b853af019a054d136282b6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 18 Nov 2025 15:21:19 +0000 Subject: [PATCH 01/21] wip --- fast_llm/models/auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 7830c69a..3f67fe71 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -4,6 +4,7 @@ from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip From 5242eb6b0be0e755ae2d86ef1a7836bca0a97754 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 18 Nov 2025 21:23:36 +0000 Subject: [PATCH 02/21] added gdn --- .../layers/common/normalization/config.py | 11 + .../common/normalization/normalization.py | 42 ++ fast_llm/layers/ssm/config.py | 90 +++++ fast_llm/layers/ssm/gdn.py | 372 ++++++++++++++++++ fast_llm/models/auto.py | 7 +- 5 files changed, 521 insertions(+), 1 deletion(-) create mode 100644 fast_llm/layers/ssm/gdn.py diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index a80a1928..4ecb7a3b 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -127,3 +127,14 @@ def module_class(self): from fast_llm.layers.common.normalization.normalization import RMSNormalization return RMSNormalization + + +@config_class(dynamic_type={NormalizationConfig: "gated_rms_norm"}) +class GatedRMSNormalizationConfig(RMSNormalizationConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import GatedRMSNormalization + + return GatedRMSNormalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index d0a5ab15..ec8a52e2 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,6 +1,7 @@ import abc import torch +import torch.nn.functional as F from fast_llm.config import Configurable from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ @@ -9,6 +10,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.config import ( + GatedRMSNormalizationConfig, LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -33,6 +35,12 @@ _fast_normalization_available = False +try: + from fla.modules.fused_norm_gate import rms_norm_gated # noqa +except ImportError: + rms_norm_gated = None + + _PERSIST_LN_SIZES = ( 1024, 1536, @@ -292,3 +300,37 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) + + +class GatedRMSNormalization[ConfigType: GatedRMSNormalizationConfig](RMSNormalization[ConfigType], torch.nn.Module): + """ + A gated RMS normalization layer. + """ + + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + + if rms_norm_gated is not None: + self._forward = self._forward_fused + else: + self._forward = self._forward + + def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_) + + def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + return rms_norm_gated( + input_, + gate, + self.weight, + None, + activation="silu", + eps=self._config.epsilon, + residual=None, + prenorm=False, + residual_in_fp32=False, + ) + + def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + normalized = self.rmsnorm(input_) + return normalized * F.silu(gate) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index e541341e..35ed6f6a 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -4,7 +4,9 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -16,6 +18,94 @@ from fast_llm.tensor import ParameterMeta +@config_class(dynamic_type={MixerConfig: "gdn"}) +class GatedDeltaNetConfig(MixerConfig): + """ + Configuration for the gated DeltaNet mixer used in Qwen3Next style linear attention blocks. + """ + + _abstract = False + normalization: NormalizationConfig = Field( + desc="Configuration for the block normalization layers.", + hint=FieldHint.architecture, + ) + qkv_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query, key, value and modulation vectors.", + hint=FieldHint.architecture, + ) + ba_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the decay and beta terms.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied to the concatenated QKV streams.", + hint=FieldHint.architecture, + ) + output_layer: AffineLinearConfig = Field( + desc="Output projection applied after the DeltaNet recurrence and gated RMS norm.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet time-step bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the DeltaNet decay rates.", + hint=FieldHint.architecture, + ) + + value_heads: int = Field( + default=16, + desc="Number of value heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_heads: int = Field( + default=8, + desc="Number of key heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + key_head_dim: int = Field( + default=64, + desc="Dimension of each key head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + value_head_dim: int = Field( + default=64, + desc="Dimension of each value head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + norm_epsilon: float = Field( + default=1e-6, + desc="Epsilon used by the gated RMS norm.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + use_qk_l2norm: bool = Field( + default=True, + desc="Apply L2 normalization on query/key vectors inside the Delta rule kernel.", + hint=FieldHint.architecture, + ) + activation: ActivationType = Field( + default=ActivationType.silu, + desc="Activation used after the convolution.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + super()._validate() + Assert.multiple(self.value_heads, self.key_heads) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.gdn import GatedDeltaNet + + return GatedDeltaNet + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py new file mode 100644 index 00000000..a1a62a5a --- /dev/null +++ b/fast_llm/layers/ssm/gdn.py @@ -0,0 +1,372 @@ +import logging +import typing + +import torch +import torch.nn.functional as F + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.utils import div + +logger = logging.getLogger(__name__) + +try: + from fla.ops.gated_delta_rule import chunk_gated_delta_rule +except ImportError: + chunk_gated_delta_rule = None + + +is_fast_path_available = chunk_gated_delta_rule is not None + + +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def torch_recurrent_gated_delta_rule( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + *, + use_qk_l2norm_in_kernel: bool, +) -> torch.Tensor: + """ + Simplified gated Delta rule used during training. + Args expect tensors shaped as (batch, heads, seq, dim) except for g/beta which are (batch, heads, seq). + """ + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query, dim=-1) + key = _l2norm(key, dim=-1) + + query = query.to(torch.float32) + key = key.to(torch.float32) + value = value.to(torch.float32) + beta = beta.to(torch.float32) + g = g.to(torch.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + state = torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, device=key.device, dtype=key.dtype) + outputs = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim, device=value.device, dtype=value.dtype) + + for idx in range(sequence_length): + q_t = query[:, :, idx] + k_t = key[:, :, idx] + v_t = value[:, :, idx] + g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, idx].unsqueeze(-1) + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + outputs[:, :, idx] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return outputs.to(initial_dtype), state + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = ( + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = ( + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + ) + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +# class _GatedRMSNorm(torch.nn.Module): +# def __init__(self, hidden_size: int, eps: float): +# super().__init__() +# self.weight = torch.nn.Parameter(torch.ones(hidden_size)) +# self.eps = eps + +# def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: +# dtype = hidden_states.dtype +# hidden_states = hidden_states.to(torch.float32) +# variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + self.eps) +# hidden_states = self.weight * hidden_states.to(dtype) +# hidden_states = hidden_states * F.silu(gate.to(torch.float32)) +# return hidden_states.to(dtype) + + +class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._value_heads_dim = TensorDim( + "gdn_value_heads", self._config.value_heads, self._parallel_dim if self._config.value_heads > 1 else None + ) + self._key_heads_dim = TensorDim( + "gdn_key_heads", self._config.key_heads, self._parallel_dim if self._config.key_heads > 1 else None + ) + self._value_head_dim = TensorDim("gdn_value_head_dim", self._config.value_head_dim) + self._key_head_dim = TensorDim("gdn_key_head_dim", self._config.key_head_dim) + self._local_value_heads = self._value_heads_dim.size + self._local_key_heads = self._key_heads_dim.size + self._value_heads_per_key = div(self._local_value_heads, max(self._local_key_heads, 1)) + + query_dim = CompositeTensorDim("gdn_query", (self._key_heads_dim, self._key_head_dim)) + key_dim = CompositeTensorDim("gdn_key", (self._key_heads_dim, self._key_head_dim)) + value_dim = CompositeTensorDim("gdn_value", (self._value_heads_dim, self._value_head_dim)) + z_dim = CompositeTensorDim("gdn_z", (self._value_heads_dim, self._value_head_dim)) + qkvz_dim = ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) + ba_dim = ConcatenatedTensorDim( + "gdn_ba", + ( + CompositeTensorDim("gdn_beta", (self._value_heads_dim,)), + CompositeTensorDim("gdn_alpha", (self._value_heads_dim,)), + ), + ) + + qkv_channels_dim = ConcatenatedTensorDim("gdn_qkv", (query_dim, key_dim, value_dim)) + + self.in_proj_qkvz = self._config.qkv_projection_layer.get_layer( + hidden_dim, + qkvz_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.in_proj_ba = self._config.ba_projection_layer.get_layer( + hidden_dim, + ba_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.convolution = self._config.convolution_layer.get_layer( + qkv_channels_dim, + default_add_bias=False, + default_activation=self._config.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.out_proj = self._config.output_layer.get_layer( + value_dim, + hidden_dim, + default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._value_heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(0, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + ) + # _GatedRMSNorm(self._config.value_head_dim, self._config.norm_epsilon) + self._use_qk_l2norm = self._config.use_qk_l2norm + + self._value_dim = value_dim + self._query_dim = query_dim + self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule + + if not is_fast_path_available: + logger.warning( + "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." + ) + + def _reshape_heads(self, tensor: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + batch, seq, _ = tensor.shape + return tensor.view(batch, seq, num_heads, head_dim) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[BlockKwargs.sequence_first] + if sequence_first: + hidden_states = input_.transpose(0, 1) + else: + hidden_states = input_ + + batch_size, sequence_length, _ = hidden_states.shape + qkvz = self.in_proj_qkvz(hidden_states) + ba = self.in_proj_ba(hidden_states) + key_size = self._query_dim.size + value_size = self._value_dim.size + query, key, value, z = torch.split(qkvz, (key_size, key_size, value_size, value_size), dim=-1) + beta, alpha = torch.split(ba, (self._local_value_heads, self._local_value_heads), dim=-1) + + query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) + key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) + value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + z = self._reshape_heads(z, self._local_value_heads, self._config.value_head_dim) + + mixed_qkv = torch.cat( + ( + query.reshape(batch_size, sequence_length, -1), + key.reshape(batch_size, sequence_length, -1), + value.reshape(batch_size, sequence_length, -1), + ), + dim=-1, + ) + mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = self.convolution(mixed_qkv) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + ( + self._local_key_heads * self._config.key_head_dim, + self._local_key_heads * self._config.key_head_dim, + self._local_value_heads * self._config.value_head_dim, + ), + dim=-1, + ) + query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) + key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) + value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + + beta = beta.view(batch_size, sequence_length, self._local_value_heads).sigmoid() + alpha = alpha.view(batch_size, sequence_length, self._local_value_heads) + dt_bias = self.dt_bias.to(hidden_states.dtype) + a_log = self.A_log.to(hidden_states.dtype) + g = -torch.exp(a_log) * F.softplus(alpha + dt_bias) + + if self._value_heads_per_key > 1: + query = query.repeat_interleave(self._value_heads_per_key, dim=2) + key = key.repeat_interleave(self._value_heads_per_key, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + g=g.permute(0, 2, 1), + beta=beta.permute(0, 2, 1), + use_qk_l2norm_in_kernel=self._use_qk_l2norm, + ) + + core_attn_out = core_attn_out.permute(0, 2, 1, 3).reshape( + batch_size, sequence_length, -1, self._config.value_head_dim + ) + z = z.reshape(batch_size, sequence_length, -1, self._config.value_head_dim) + norm_input = core_attn_out.reshape(-1, self._config.value_head_dim) + norm_gate = z.reshape(-1, self._config.value_head_dim) + norm_out = self.norm(norm_input, norm_gate).view(batch_size, sequence_length, -1) + output = self.out_proj(norm_out) + + if sequence_first: + output = output.transpose(0, 1) + return output + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # return ( + # self.in_proj_qkvz.get_compute_usage(input_, config) + # + self.in_proj_ba.get_compute_usage(input_, config) + # + self.out_proj.get_compute_usage(input_, config) + # ) + raise NotImplementedError() diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 3f67fe71..f7c34a97 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -3,7 +3,12 @@ """ from fast_llm.layers.attention.config import AttentionConfig # isort: skip -from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip +from fast_llm.layers.ssm.config import ( + MambaConfig, + Mamba2Config, + DiscreteMamba2Config, + GatedDeltaNetConfig, +) # isort: skip from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.models.multimodal.config import MultiModalModelConfig, MultiModalTrainerConfig # isort: skip From bec22ded712f7b8b721ad3bf1e5a7f14030e8328 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 14:27:59 +0000 Subject: [PATCH 03/21] gdn layer --- fast_llm/layers/ssm/gdn.py | 178 ++++++++++++++----------------------- 1 file changed, 69 insertions(+), 109 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index a1a62a5a..62360acc 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -30,50 +30,6 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) -def torch_recurrent_gated_delta_rule( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - *, - use_qk_l2norm_in_kernel: bool, -) -> torch.Tensor: - """ - Simplified gated Delta rule used during training. - Args expect tensors shaped as (batch, heads, seq, dim) except for g/beta which are (batch, heads, seq). - """ - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = _l2norm(query, dim=-1) - key = _l2norm(key, dim=-1) - - query = query.to(torch.float32) - key = key.to(torch.float32) - value = value.to(torch.float32) - beta = beta.to(torch.float32) - g = g.to(torch.float32) - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - state = torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, device=key.device, dtype=key.dtype) - outputs = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim, device=value.device, dtype=value.dtype) - - for idx in range(sequence_length): - q_t = query[:, :, idx] - k_t = key[:, :, idx] - v_t = value[:, :, idx] - g_t = g[:, :, idx].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, idx].unsqueeze(-1) - state = state * g_t - kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) - outputs[:, :, idx] = (state * q_t.unsqueeze(-1)).sum(dim=-2) - - return outputs.to(initial_dtype), state - - def torch_chunk_gated_delta_rule( query, key, @@ -154,23 +110,11 @@ def torch_chunk_gated_delta_rule( return core_attn_out, last_recurrent_state -# class _GatedRMSNorm(torch.nn.Module): -# def __init__(self, hidden_size: int, eps: float): -# super().__init__() -# self.weight = torch.nn.Parameter(torch.ones(hidden_size)) -# self.eps = eps - -# def forward(self, hidden_states: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: -# dtype = hidden_states.dtype -# hidden_states = hidden_states.to(torch.float32) -# variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) -# hidden_states = hidden_states * torch.rsqrt(variance + self.eps) -# hidden_states = self.weight * hidden_states.to(dtype) -# hidden_states = hidden_states * F.silu(gate.to(torch.float32)) -# return hidden_states.to(dtype) - - class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): + """ + Follows implementation here: https://github.com/huggingface/transformers/blob/a5c903f877fda21e739027eed133e03162eb7712/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L593 + """ + _config: ConfigType def __init__( @@ -265,11 +209,7 @@ def __init__( self.norm = self._config.normalization.get_layer( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - # _GatedRMSNorm(self._config.value_head_dim, self._config.norm_epsilon) - self._use_qk_l2norm = self._config.use_qk_l2norm - self._value_dim = value_dim - self._query_dim = query_dim self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule if not is_fast_path_available: @@ -277,9 +217,41 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) - def _reshape_heads(self, tensor: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: - batch, seq, _ = tensor.shape - return tensor.view(batch, seq, num_heads, head_dim) + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + """ + + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self._local_key_heads, + 2 * self._config.key_head_dim + + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self._local_key_heads, + 2 * self._local_value_heads // self._local_key_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self._config.key_head_dim, + self._config.key_head_dim, + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + ] + split_arg_list_ba = [ + self._local_value_heads // self._local_key_heads, + self._local_value_heads // self._local_key_heads, + ] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + return query, key, value, z, b, a def _forward( self, @@ -289,32 +261,22 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[BlockKwargs.sequence_first] + + # TODO: do we need maksing of padding tokens? if sequence_first: hidden_states = input_.transpose(0, 1) else: hidden_states = input_ - batch_size, sequence_length, _ = hidden_states.shape - qkvz = self.in_proj_qkvz(hidden_states) - ba = self.in_proj_ba(hidden_states) - key_size = self._query_dim.size - value_size = self._value_dim.size - query, key, value, z = torch.split(qkvz, (key_size, key_size, value_size, value_size), dim=-1) - beta, alpha = torch.split(ba, (self._local_value_heads, self._local_value_heads), dim=-1) - - query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) - key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) - value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) - z = self._reshape_heads(z, self._local_value_heads, self._config.value_head_dim) - - mixed_qkv = torch.cat( - ( - query.reshape(batch_size, sequence_length, -1), - key.reshape(batch_size, sequence_length, -1), - value.reshape(batch_size, sequence_length, -1), - ), - dim=-1, + # batch_size, sequence_length, _ = hidden_states.shape + projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) + projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba ) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = self.convolution(mixed_qkv) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -327,37 +289,35 @@ def _forward( ), dim=-1, ) - query = self._reshape_heads(query, self._local_key_heads, self._config.key_head_dim) - key = self._reshape_heads(key, self._local_key_heads, self._config.key_head_dim) - value = self._reshape_heads(value, self._local_value_heads, self._config.value_head_dim) + query = query.reshape(query.shape[0], query.shape[1], -1, self._config.key_head_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self._config.key_head_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) - beta = beta.view(batch_size, sequence_length, self._local_value_heads).sigmoid() - alpha = alpha.view(batch_size, sequence_length, self._local_value_heads) - dt_bias = self.dt_bias.to(hidden_states.dtype) - a_log = self.A_log.to(hidden_states.dtype) - g = -torch.exp(a_log) * F.softplus(alpha + dt_bias) + beta = beta.sigmoid() + g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) if self._value_heads_per_key > 1: query = query.repeat_interleave(self._value_heads_per_key, dim=2) key = key.repeat_interleave(self._value_heads_per_key, dim=2) core_attn_out, _ = self.chunk_gated_delta_rule( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - g=g.permute(0, 2, 1), - beta=beta.permute(0, 2, 1), - use_qk_l2norm_in_kernel=self._use_qk_l2norm, + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, ) - core_attn_out = core_attn_out.permute(0, 2, 1, 3).reshape( - batch_size, sequence_length, -1, self._config.value_head_dim - ) - z = z.reshape(batch_size, sequence_length, -1, self._config.value_head_dim) - norm_input = core_attn_out.reshape(-1, self._config.value_head_dim) - norm_gate = z.reshape(-1, self._config.value_head_dim) - norm_out = self.norm(norm_input, norm_gate).view(batch_size, sequence_length, -1) - output = self.out_proj(norm_out) + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + output = self.out_proj(core_attn_out) if sequence_first: output = output.transpose(0, 1) From 7f7990983c064bb9a0627dbeb93ace5b28d87b58 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 16:44:03 +0000 Subject: [PATCH 04/21] kda --- fast_llm/layers/ssm/config.py | 91 ++++++++++++ fast_llm/layers/ssm/kda.py | 268 ++++++++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 fast_llm/layers/ssm/kda.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 35ed6f6a..95ef9bed 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -106,6 +106,97 @@ def layer_class(self) -> "type": return GatedDeltaNet +@config_class(dynamic_type={MixerConfig: "kda"}) +class KimiDeltaAttentionConfig(MixerConfig): + """ + Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. + """ + + _abstract = False + normalization: NormalizationConfig = Field( + desc="Configuration for the gated normalization applied to the KDA output.", + hint=FieldHint.architecture, + ) + q_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces query vectors.", + hint=FieldHint.architecture, + ) + k_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces key vectors.", + hint=FieldHint.architecture, + ) + v_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces value vectors.", + hint=FieldHint.architecture, + ) + f_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating pre-activation.", + hint=FieldHint.architecture, + ) + f_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the Delta gating expansion.", + hint=FieldHint.architecture, + ) + g_a_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating pre-activation.", + hint=FieldHint.architecture, + ) + g_b_projection_layer: AffineLinearConfig = Field( + desc="Projection used for the output gating expansion.", + hint=FieldHint.architecture, + ) + beta_projection_layer: AffineLinearConfig = Field( + desc="Projection that produces the Beta gate.", + hint=FieldHint.architecture, + ) + output_projection_layer: AffineLinearConfig = Field( + desc="Projection applied after the Delta recurrence and gated normalization.", + hint=FieldHint.architecture, + ) + convolution_layer: CausalConv1dConfig = Field( + desc="Depth-wise convolution applied independently on each Q, K and V stream.", + hint=FieldHint.architecture, + ) + dt_bias_weight: ParameterConfig = Field( + desc="Parameter configuration for the Delta gate bias.", + hint=FieldHint.architecture, + ) + a_log_weight: ParameterConfig = Field( + desc="Parameter configuration for the decay rates.", + hint=FieldHint.architecture, + ) + + heads: int = Field( + default=16, + desc="Number of attention heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + head_dim: int = Field( + default=64, + desc="Dimension of each head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + recurrent_threshold: int = Field( + default=64, + desc="Switch to the fused recurrent kernel below this sequence length.", + hint=FieldHint.performance, + valid=check_field(Assert.gt, 0), + ) + use_qk_l2norm: bool = Field( + default=True, + desc="Apply L2 normalization to query/key vectors inside the Delta kernel.", + hint=FieldHint.architecture, + ) + + @property + def layer_class(self) -> "type": + from fast_llm.layers.ssm.kda import KimiDeltaAttention + + return KimiDeltaAttention + + @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py new file mode 100644 index 00000000..78cabe8d --- /dev/null +++ b/fast_llm/layers/ssm/kda.py @@ -0,0 +1,268 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig +from fast_llm.tensor import ParameterMeta, TensorMeta + +logger = logging.getLogger(__name__) + +try: + from fla.ops.kda import chunk_kda + from fla.ops.kda.gate import fused_kda_gate +except ImportError: + chunk_kda = None + fused_kda_gate = None + + +class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): + """ + Implementation of the Kimi Delta Attention mixer. + Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias + ) + if chunk_kda is None or fused_kda_gate is None: + raise ImportError( + "KimiDeltaAttention requires the `fla-core` package. " + "Please install it with `pip install -U fla-core`." + ) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._heads_dim = TensorDim( + "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None + ) + self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) + self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) + self._local_heads = self._heads_dim.size + self._projection_size = self._projection_dim.size + + init = init_normal_(std=self._hidden_size**-0.5) + self.q_proj = self._config.q_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_proj = self._config.k_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_proj = self._config.v_projection_layer.get_layer( + hidden_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.q_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.k_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.v_conv = self._config.convolution_layer.get_layer( + self._projection_dim, + default_add_bias=False, + default_activation=self._config.convolution_layer.activation, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.f_a_proj = self._config.f_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.f_b_proj = self._config.f_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_a_proj = self._config.g_a_projection_layer.get_layer( + hidden_dim, + self._head_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.g_b_proj = self._config.g_b_projection_layer.get_layer( + self._head_dim, + self._projection_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.beta_proj = self._config.beta_projection_layer.get_layer( + hidden_dim, + self._heads_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.o_proj = self._config.output_projection_layer.get_layer( + self._projection_dim, + hidden_dim, + default_weight_initialization=init, + default_add_bias=False, + sequence_parallel=self._sequence_parallel, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( + (self._projection_dim,), + default_initialization=init_ones_, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.a_log: ParameterMeta = self._config.a_log_weight.get_parameter( + (self._heads_dim,), + default_initialization=LambdaInitializer( + lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() + ), + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.norm = self._config.normalization.get_layer( + self._head_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module) -> torch.Tensor: + """ + Applies convolution. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one kjust uses causal_conv1danyways. + TODO: make sure varlen is supported correctly. + """ + tensor = tensor.transpose(1, 2).contiguous() + tensor = conv(tensor) + return tensor.transpose(1, 2).contiguous() + + def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.contiguous() + # since head_dim is the same vor k,q and v + # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) + + def _get_dt_bias(self) -> torch.Tensor: + return self.dt_bias.view(1, 1, self._local_heads, self._config.head_dim) + + def _get_a_log(self) -> torch.Tensor: + return self.a_log.view(1, 1, self._local_heads, 1) + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO: make sure varlen is supported + # TODO: make sure we dont need to mask padding tokens in training + # TODO: make sure sequence first is handdled correctly + sequence_first = kwargs[BlockKwargs.sequence_first] + hidden_states = input_.transpose(0, 1) if sequence_first else input_ + batch_size, sequence_length, _ = hidden_states.shape + residual_dtype = hidden_states.dtype + + q = self._apply_conv(self.q_proj(hidden_states), self.q_conv) + k = self._apply_conv(self.k_proj(hidden_states), self.k_conv) + v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) + + g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + g_kernel = fused_kda_gate(g_kernel, self._get_a_log(), self._config.head_dim, g_bias=self._get_dt_bias()) + + beta = torch.sigmoid(self.beta_proj(hidden_states).float()) + + q = self._reshape_heads(q) + k = self._reshape_heads(k) + v = self._reshape_heads(v) + # currently on supports Ampere??? + attn_out, _ = chunk_kda( + q=q, + k=k, + v=v, + g=g_kernel, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=self._config.use_qk_l2norm, + cu_seqlens=None, + ) + + attn_out = attn_out.to(residual_dtype) + attn_out = self._reshape_heads(attn_out) + + g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) + + attn_out = attn_out.reshape(-1, self._config.head_dim) + g_out = g_out.reshape(-1, self._config.head_dim) + attn_out = self.norm(attn_out, g_out) + attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) + attn_out = self.o_proj(attn_out) + + if sequence_first: + attn_out = attn_out.transpose(0, 1) + + return attn_out + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + raise NotImplementedError() From 8636f092ba89738a98e3c43e875ca38ce23b7d32 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 19 Nov 2025 21:23:25 +0000 Subject: [PATCH 05/21] wip --- fast_llm/layers/ssm/kda.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 78cabe8d..f6f77654 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -202,12 +202,6 @@ def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) - def _get_dt_bias(self) -> torch.Tensor: - return self.dt_bias.view(1, 1, self._local_heads, self._config.head_dim) - - def _get_a_log(self) -> torch.Tensor: - return self.a_log.view(1, 1, self._local_heads, 1) - def _forward( self, input_: torch.Tensor, @@ -228,14 +222,14 @@ def _forward( v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - g_kernel = fused_kda_gate(g_kernel, self._get_a_log(), self._config.head_dim, g_bias=self._get_dt_bias()) + g_kernel = fused_kda_gate(g_kernel, self.a_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) - # currently on supports Ampere??? + # need to install nightly triton for now attn_out, _ = chunk_kda( q=q, k=k, @@ -244,7 +238,7 @@ def _forward( beta=beta, initial_state=None, output_final_state=False, - use_qk_l2norm_in_kernel=self._config.use_qk_l2norm, + use_qk_l2norm_in_kernel=True, cu_seqlens=None, ) From a20c9586387d5a9605c90ac345f1515c99f93017 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 20 Nov 2025 01:01:18 +0000 Subject: [PATCH 06/21] convertion kda --- fast_llm/layers/ssm/config.py | 31 +++---- fast_llm/layers/ssm/kda.py | 6 +- fast_llm/models/gpt/conversion/apriel.py | 113 ++++++++++++++++++++++- 3 files changed, 127 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 95ef9bed..b8a5a64c 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig -from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig, NormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -84,11 +84,6 @@ class GatedDeltaNetConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - use_qk_l2norm: bool = Field( - default=True, - desc="Apply L2 normalization on query/key vectors inside the Delta rule kernel.", - hint=FieldHint.architecture, - ) activation: ActivationType = Field( default=ActivationType.silu, desc="Activation used after the convolution.", @@ -113,7 +108,7 @@ class KimiDeltaAttentionConfig(MixerConfig): """ _abstract = False - normalization: NormalizationConfig = Field( + normalization: GatedRMSNormalizationConfig = Field( desc="Configuration for the gated normalization applied to the KDA output.", hint=FieldHint.architecture, ) @@ -178,17 +173,6 @@ class KimiDeltaAttentionConfig(MixerConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - recurrent_threshold: int = Field( - default=64, - desc="Switch to the fused recurrent kernel below this sequence length.", - hint=FieldHint.performance, - valid=check_field(Assert.gt, 0), - ) - use_qk_l2norm: bool = Field( - default=True, - desc="Apply L2 normalization to query/key vectors inside the Delta kernel.", - hint=FieldHint.architecture, - ) @property def layer_class(self) -> "type": @@ -196,6 +180,17 @@ def layer_class(self) -> "type": return KimiDeltaAttention + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + @config_class() class SSMConfig(MixerConfig): diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index f6f77654..1ce6bed7 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -172,7 +172,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - self.a_log: ParameterMeta = self._config.a_log_weight.get_parameter( + self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( (self._heads_dim,), default_initialization=LambdaInitializer( lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() @@ -210,8 +210,8 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure we dont need to mask padding tokens in training # TODO: make sure sequence first is handdled correctly + # TODO: make sure we dont need to mask padding tokens in training sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_.transpose(0, 1) if sequence_first else input_ batch_size, sequence_length, _ = hidden_states.shape @@ -222,7 +222,7 @@ def _forward( v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - g_kernel = fused_kda_gate(g_kernel, self.a_log, self._config.head_dim, g_bias=self.dt_bias) + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index e16eac4d..215cc525 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,7 +8,13 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config +from fast_llm.layers.decoder.mlp.config import MLPConfig +from fast_llm.layers.ssm.config import ( + DiscreteMamba2Config, + GatedDeltaNetConfig, + KimiDeltaAttentionConfig, + Mamba2Config, +) from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters @@ -224,7 +230,102 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): +class AprielMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + +class GatedDeltaNetConverter: + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "type": "gated_delta_net", + "value_heads": config["linear_attn_config"]["gdn_value_head_dim"], + "key_heads": config["linear_attn_config"]["gdn_num_key_heads"], + "key_head_dim": config["linear_attn_config"]["gdn_key_head_dim"], + "value_head_dim": config["linear_attn_config"]["value_head_dim"], + "convolution_layer": { + "kernel_size": config["linear_attn_config"]["gdn_linear_conv_kernel_size"], + }, + } + + @classmethod + def export_config(cls, config: GatedDeltaNetConfig) -> dict: + return { + "linear_attn_config": { + "gdn_num_value_heads": config.value_heads, + "gdn_num_key_heads": config.key_heads, + "gdn_key_head_dim": config.key_head_dim, + "gdn_value_head_dim": config.value_head_dim, + "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, + }, + } + + @classmethod + def get_converters( + cls, + config: KimiDeltaAttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_qkvz", + f"{hf_prefix}.in_proj_qkvz", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj_ba", + f"{hf_prefix}.in_proj_ba", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.convolution", + False, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + False, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.dt_bias", + f"{hf_prefix}.dt_bias", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.norm", + f"{hf_prefix}.norm", + False, + drop_on_export=drop_on_export, + ), + ] + + +class AprielBlockConverterBase(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter + + +class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -234,16 +335,24 @@ class AprielMamba2BlockConverter(MistralBlockConverter): hf_mixer_name: typing.ClassVar[str] = "mixer" +class AprielGatedDeltaNetBlockConverter(AprielBlockConverterBase): + mixer_converter_class: typing.ClassVar[type[GatedDeltaNetConverter]] = GatedDeltaNetConverter + hf_mixer_name: typing.ClassVar[str] = "mixer" + + class AprielBlockConverter: layout_names = { AttentionConfig: "t", Mamba2Config: "m2", DiscreteMamba2Config: "m2d", + KimiDeltaAttentionConfig: "kda", + GatedDeltaNetConfig: "gdn", } _converter_classes = { AttentionConfig: MistralBlockConverter, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, + GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } _config_classes = {value: key for key, value in layout_names.items()} From 8ac5167f62fe047ff2004093df8fa72f91b5e19c Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 21 Nov 2025 16:26:07 +0000 Subject: [PATCH 07/21] tp and sequence tp --- fast_llm/layers/ssm/config.py | 15 ++++++++++-- fast_llm/layers/ssm/gdn.py | 21 ++++++++++------- fast_llm/layers/ssm/kda.py | 44 +++++++++++++++++++++++++---------- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index b8a5a64c..29f66c8b 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig -from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig, NormalizationConfig +from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig from fast_llm.utils import Assert @@ -25,7 +25,7 @@ class GatedDeltaNetConfig(MixerConfig): """ _abstract = False - normalization: NormalizationConfig = Field( + normalization: GatedRMSNormalizationConfig = Field( desc="Configuration for the block normalization layers.", hint=FieldHint.architecture, ) @@ -100,6 +100,17 @@ def layer_class(self) -> "type": return GatedDeltaNet + def _validate(self) -> None: + with self._set_implicit_default(): + if "epsilon" not in self.normalization._explicit_fields: + self.normalization.epsilon = 1.0e-5 + if "activation" not in self.convolution_layer._explicit_fields: + self.convolution_layer.activation = "silu" + if "kernel_size" not in self.convolution_layer._explicit_fields: + self.convolution_layer.kernel_size = 4 + + super()._validate() + @config_class(dynamic_type={MixerConfig: "kda"}) class KimiDeltaAttentionConfig(MixerConfig): diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 62360acc..d07cb5e2 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -113,6 +113,9 @@ def torch_chunk_gated_delta_rule( class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): """ Follows implementation here: https://github.com/huggingface/transformers/blob/a5c903f877fda21e739027eed133e03162eb7712/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L593 + - For tensor parallel implementtion (no sequnece prallel): we scatter teh heads accross ranks. + - Sequence Tensor parallel: in_proj_qkvz all reduces across sequence dim. --> each rank performs work on full sequence but only a subset of heads (standrd TP). + """ _config: ConfigType @@ -261,16 +264,18 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: sequence_first = kwargs[BlockKwargs.sequence_first] - - # TODO: do we need maksing of padding tokens? - if sequence_first: - hidden_states = input_.transpose(0, 1) - else: - hidden_states = input_ + # in sequence parallel TP the input here is already scattered across sequence dimension + # TODO: do we need masking of padding tokens? + # TODO: make sure varlen is supported + hidden_states = input_ # batch_size, sequence_length, _ = hidden_states.shape projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + if sequence_first: + projected_states_qkvz = projected_states_qkvz.transpose(0, 1) + projected_states_ba = projected_states_ba.transpose(0, 1) + query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) @@ -317,10 +322,10 @@ def _forward( core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + if sequence_first: + core_attn_out = core_attn_out.transpose(0, 1) output = self.out_proj(core_attn_out) - if sequence_first: - output = output.transpose(0, 1) return output def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 1ce6bed7..9f85cd06 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -115,7 +115,7 @@ def __init__( self._head_dim, default_weight_initialization=init, default_add_bias=False, - sequence_parallel=self._sequence_parallel, + sequence_parallel=False, # self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) @@ -133,7 +133,7 @@ def __init__( self._head_dim, default_weight_initialization=init, default_add_bias=False, - sequence_parallel=self._sequence_parallel, + sequence_parallel=False, # self._sequence_parallel, lr_scale=self._lr_scale, peft=self._peft, ) @@ -210,25 +210,46 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure sequence first is handdled correctly # TODO: make sure we dont need to mask padding tokens in training sequence_first = kwargs[BlockKwargs.sequence_first] - hidden_states = input_.transpose(0, 1) if sequence_first else input_ - batch_size, sequence_length, _ = hidden_states.shape + hidden_states = input_ + residual_dtype = hidden_states.dtype - q = self._apply_conv(self.q_proj(hidden_states), self.q_conv) - k = self._apply_conv(self.k_proj(hidden_states), self.k_conv) - v = self._apply_conv(self.v_proj(hidden_states), self.v_conv) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + if sequence_first: + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + q = self._apply_conv(q, self.q_conv) + k = self._apply_conv(k, self.k_conv) + v = self._apply_conv(v, self.v_conv) + + if sequence_first: + _, batch_size, _ = hidden_states.shape + sequence_length = q.size(1) + # hidden_states = gather_op(hidden_states, self._distributed.tensor_group, dim=0, async_op=False).transpose( + # 0, 1 + # ) + # hidden_states = hidden_states.transpose(0, 1) + else: + batch_size, sequence_length, _ = hidden_states.shape g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) + if sequence_first: + g_kernel = g_kernel.transpose(0, 1) g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) - q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) + if sequence_first: + beta = beta.transpose(0, 1) + # need to install nightly triton for now attn_out, _ = chunk_kda( q=q, @@ -245,16 +266,15 @@ def _forward( attn_out = attn_out.to(residual_dtype) attn_out = self._reshape_heads(attn_out) - g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) + g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) # bs x seq x n_local_heads x head dim attn_out = attn_out.reshape(-1, self._config.head_dim) g_out = g_out.reshape(-1, self._config.head_dim) attn_out = self.norm(attn_out, g_out) attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) - attn_out = self.o_proj(attn_out) - if sequence_first: attn_out = attn_out.transpose(0, 1) + attn_out = self.o_proj(attn_out) return attn_out From f1a51f2754e90bc6982d0ee3edbded8010416460 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 22 Nov 2025 00:19:41 +0000 Subject: [PATCH 08/21] varlen kda --- fast_llm/layers/common/linear/convolution.py | 3 +- fast_llm/layers/ssm/config.py | 6 + fast_llm/layers/ssm/kda.py | 114 ++++++-- tests/test_ssm_varlen.py | 259 +++++++++++++++++++ 4 files changed, 355 insertions(+), 27 deletions(-) create mode 100644 tests/test_ssm_varlen.py diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index 6281348e..2f682c46 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -45,12 +45,13 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: )[..., : input_.size(1)] ) - def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: + def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return _causal_conv1d_fn( input_, self.weight.squeeze(1), self.bias, activation=(None if self._activation == ActivationType.identity else self._activation.value), + **kwargs, ) def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 29f66c8b..2fa90aff 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, LambdaInitializer from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.layers.common.normalization.config import GatedRMSNormalizationConfig from fast_llm.layers.decoder.config import MixerConfig @@ -18,6 +19,11 @@ from fast_llm.tensor import ParameterMeta +class LinearAttentionKwargs(BlockKwargs): + cu_seqlens = "cu_seqlens" + seq_idx = "seq_idx" + + @config_class(dynamic_type={MixerConfig: "gdn"}) class GatedDeltaNetConfig(MixerConfig): """ diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 9f85cd06..b14fd459 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -2,6 +2,7 @@ import typing import torch +from einops import rearrange, repeat from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -10,7 +11,7 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs from fast_llm.tensor import ParameterMeta, TensorMeta logger = logging.getLogger(__name__) @@ -23,6 +24,16 @@ fused_kda_gate = None +def index_first_axis(x, indices): + other_shape = x.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(x, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): """ Implementation of the Kimi Delta Attention mixer. @@ -186,14 +197,16 @@ def __init__( peft=self._peft, ) - def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module) -> torch.Tensor: + def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: """ Applies convolution. - Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one kjust uses causal_conv1danyways. - TODO: make sure varlen is supported correctly. + Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. + Varlen: + - seq. idx are only suppored in channel last layout, i.e. no transpose """ - tensor = tensor.transpose(1, 2).contiguous() - tensor = conv(tensor) + tensor = rearrange(tensor, "b t d -> b d t") + # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) + tensor = conv(tensor, seq_idx=seq_idx) return tensor.transpose(1, 2).contiguous() def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: @@ -210,37 +223,44 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: make sure varlen is supported - # TODO: make sure we dont need to mask padding tokens in training + # TODO: do we need to deal with padding tokens? sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + # TODO: can be made more efficeint by rearranging hidden states directly residual_dtype = hidden_states.dtype q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) + if sequence_first: + # make bs first dim again q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - q = self._apply_conv(q, self.q_conv) - k = self._apply_conv(k, self.k_conv) - v = self._apply_conv(v, self.v_conv) + batch_size, sequence_length, _ = q.size() - if sequence_first: - _, batch_size, _ = hidden_states.shape - sequence_length = q.size(1) - # hidden_states = gather_op(hidden_states, self._distributed.tensor_group, dim=0, async_op=False).transpose( - # 0, 1 - # ) - # hidden_states = hidden_states.transpose(0, 1) - else: - batch_size, sequence_length, _ = hidden_states.shape + # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 + # can this be applied once to hidden state only? pr + q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) + k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) + v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) + + # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) + # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) + q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) + k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) + v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) if sequence_first: g_kernel = g_kernel.transpose(0, 1) + g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) beta = torch.sigmoid(self.beta_proj(hidden_states).float()) @@ -249,8 +269,10 @@ def _forward( v = self._reshape_heads(v) if sequence_first: beta = beta.transpose(0, 1) + beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) - # need to install nightly triton for now + # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md + # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes attn_out, _ = chunk_kda( q=q, k=k, @@ -260,18 +282,19 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, - cu_seqlens=None, + cu_seqlens=cu_seqlens, ) attn_out = attn_out.to(residual_dtype) - attn_out = self._reshape_heads(attn_out) - g_out = self._reshape_heads(self.g_b_proj(self.g_a_proj(hidden_states))) # bs x seq x n_local_heads x head dim + g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim + g_out = self._reshape_heads(g_out) + if sequence_first: + g_out = g_out.transpose(0, 1) - attn_out = attn_out.reshape(-1, self._config.head_dim) - g_out = g_out.reshape(-1, self._config.head_dim) + attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) attn_out = self.norm(attn_out, g_out) - attn_out = attn_out.view(batch_size, sequence_length, self._projection_size) + attn_out = rearrange(attn_out, "b s h d -> b s (h d)") if sequence_first: attn_out = attn_out.transpose(0, 1) attn_out = self.o_proj(attn_out) @@ -280,3 +303,42 @@ def _forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + # TODO: packing is enabled by default, i.e. its always used? + # only get here when cross_document_attention is False + self._preprocess_for_varlen(batch, kwargs) diff --git a/tests/test_ssm_varlen.py b/tests/test_ssm_varlen.py new file mode 100644 index 00000000..9ca491e3 --- /dev/null +++ b/tests/test_ssm_varlen.py @@ -0,0 +1,259 @@ +import inspect +import itertools + +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm import kda as kda_module +from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs + +# from mamba2 import NemotronHMamba2 + + +_mamba_varlen = False +try: + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + + _mamba_available = True + sig = inspect.signature(selective_scan_fn) + if "position_indices" in sig.parameters: + _mamba_varlen = True + else: + _mamba_varlen = False + # for training with packing install https://github.com/jxiw/varlen_mamba + # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md + +except (ImportError, RuntimeError): + _mamba_available = False + + +@pytest.fixture +def distributed_config(): + return DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + + +@pytest.fixture +def distributed(distributed_config): + return Distributed(config=distributed_config) + + +def materialize_meta_tensors(model, tensor_space): + # Materialize parameters that are on meta device + for name, param in model.named_parameters(): + if param.device.type == "meta": + # Check if the parameter is a custom tensor type + if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): + param_data = param.new_empty(param.shape, device="cuda") + # Initialize param_data + param.init_parameter(param_data, tensor_space.distributed) + # Replace the parameter in the module + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + module = model + if module_path is not None: + for part in module_path.split("."): + module = getattr(module, part) + param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation + param.grad = None + param.grad_buffer = torch.empty_like(param) + param.param_grad_is_zero = True + module._parameters[param_name] = param + return model + + +def unpack(packed_hidden_states, cu_seqlens): + batch_size = packed_hidden_states.shape[0] + package_num = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros( + package_num * batch_size, + seq_len, + hidden_dim, + dtype=packed_hidden_states.dtype, + device=packed_hidden_states.device, + ) + for j in range(batch_size): + for i in range(package_num): + line = j * package_num + i + hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ + j, cu_seqlens[i] : cu_seqlens[i + 1], : + ] + return hidden_states + + +def pack(hidden_states, cu_seqlens, batch_size): + package_num, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) + packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) + return packed_hidden_states + + +def generate_random_cu_seqlens(seq_len, packages_num=2): + if packages_num < 1: + raise ValueError("packages_num must be at least 1") + + # base size of each chunk, and how many get an extra token + base, rem = divmod(seq_len, packages_num) + # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] + lengths = [base + 1 if i < rem else base for i in range(packages_num)] + + # split points exclude the final cumulative (seq_len) + split_points = list(itertools.accumulate(lengths))[:-1] + + cu_seqlens = [0] + split_points + [seq_len] + # cu_seqlens = split_points # + [seq_len] + + # index: for each chunk, we emit 0,1,...,length-1 + index = [] + for length in lengths: + index.extend(range(length)) + + # sanity check + assert len(cu_seqlens) - 1 == packages_num + assert sum(lengths) == seq_len + assert len(index) == seq_len + + return cu_seqlens, index + + +def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Materialize meta parameters on the requested device for KDA mixer layers. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param + + +def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: + # ParameterMeta stores grads in grad_buffer; fall back to .grad otherwise. + return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA varlen needs CUDA") +@pytest.mark.skipif( + kda_module.chunk_kda is None or kda_module.fused_kda_gate is None, + reason="KDA fused kernels not available", +) +def test_kda_varlen_stacking_equivalence(distributed_config, distributed): + """ + Check that KDA forward/backward match with and without stacking using the real kernels. + """ + device = torch.device("cuda") + dtype = torch.float16 + heads, head_dim = 2, 16 + hidden_size = heads * head_dim + + config = KimiDeltaAttentionConfig(heads=heads, head_dim=head_dim) + hidden_dim = TensorDim("hidden", hidden_size) + kda_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + kda_packed.setup(distributed) + kda_ref.setup(distributed) + _materialize_kda_tensors(kda_packed, distributed, device) + _materialize_kda_tensors(kda_ref, distributed, device) + kda_ref.load_state_dict(kda_packed.state_dict()) + kda_packed.to(device=device, dtype=dtype) + kda_ref.to(device=device, dtype=dtype) + + batch_size = 2 # cu_seqlens path requires flattened batch + seq_len = 15 + packages_num = torch.randint(2, 5, (1, batch_size))[0] # randomize packages num between 2 and 4 + lengths = [ + torch.tensor( + generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], + device=device, + dtype=torch.long, + ).diff() + for i in range(batch_size) + ] + + # lengths = torch.tensor(cu_seqlens, device=device, dtype=torch.long)#.diff() + # total_tokens = lengths.sum().item() + packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) + + kwargs_packed = { + LinearAttentionKwargs.sequence_lengths: lengths, + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + # BlockKwargs.sequence_q_dim: TensorDim("sequence_q", lengths.sum().item()), + } + # Use the layer's preprocess to construct cu_seqlens/seq_idx the same way as the implementation. + kda_packed.preprocess(packed, kwargs_packed) + + kwargs_ref = { + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + } + + out_packed = kda_packed(packed, kwargs_packed) + # Run reference path separately per sequence without varlen packing, then concatenate. + ref_outs = [] + for b in range(batch_size): + out_batch = [] + length = lengths[b] + ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + for seq in ref_seqs: + kwargs_ref_seq = { + **kwargs_ref, + BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + } + out_batch.append(kda_ref(seq, kwargs_ref_seq)) + ref_outs.append(torch.cat(out_batch, dim=1)) + out_ref = torch.cat(ref_outs, dim=0) + out_ref_packed = out_ref + + assert out_packed.shape == packed.shape + assert out_ref_packed.shape == out_packed.shape + assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) + + out_packed.sum().backward() + out_ref_packed.sum().backward() + + assert _param_grad(kda_packed.q_proj.weight) is not None + assert _param_grad(kda_ref.q_proj.weight) is not None + assert torch.allclose( + _param_grad(kda_packed.q_proj.weight), _param_grad(kda_ref.q_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.k_proj.weight), _param_grad(kda_ref.k_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.v_proj.weight), _param_grad(kda_ref.v_proj.weight), atol=1e-3, rtol=1e-3 + ) + assert torch.allclose( + _param_grad(kda_packed.o_proj.weight), _param_grad(kda_ref.o_proj.weight), atol=1e-3, rtol=1e-3 + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From 3b367d871b658818091469be1502d001ae033db2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 24 Nov 2025 20:54:00 +0000 Subject: [PATCH 09/21] gdn only: varlen test --- fast_llm/layers/ssm/config.py | 91 --------- fast_llm/layers/ssm/gdn.py | 91 +++++++-- fast_llm/layers/ssm/kda.py | 344 ---------------------------------- tests/test_ssm_varlen.py | 108 ++++++----- 4 files changed, 134 insertions(+), 500 deletions(-) delete mode 100644 fast_llm/layers/ssm/kda.py diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2fa90aff..6f36321e 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -118,97 +118,6 @@ def _validate(self) -> None: super()._validate() -@config_class(dynamic_type={MixerConfig: "kda"}) -class KimiDeltaAttentionConfig(MixerConfig): - """ - Configuration for the KimiDeltaAttention mixer inspired by the Kimi Linear models. - """ - - _abstract = False - normalization: GatedRMSNormalizationConfig = Field( - desc="Configuration for the gated normalization applied to the KDA output.", - hint=FieldHint.architecture, - ) - q_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces query vectors.", - hint=FieldHint.architecture, - ) - k_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces key vectors.", - hint=FieldHint.architecture, - ) - v_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces value vectors.", - hint=FieldHint.architecture, - ) - f_a_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the Delta gating pre-activation.", - hint=FieldHint.architecture, - ) - f_b_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the Delta gating expansion.", - hint=FieldHint.architecture, - ) - g_a_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the output gating pre-activation.", - hint=FieldHint.architecture, - ) - g_b_projection_layer: AffineLinearConfig = Field( - desc="Projection used for the output gating expansion.", - hint=FieldHint.architecture, - ) - beta_projection_layer: AffineLinearConfig = Field( - desc="Projection that produces the Beta gate.", - hint=FieldHint.architecture, - ) - output_projection_layer: AffineLinearConfig = Field( - desc="Projection applied after the Delta recurrence and gated normalization.", - hint=FieldHint.architecture, - ) - convolution_layer: CausalConv1dConfig = Field( - desc="Depth-wise convolution applied independently on each Q, K and V stream.", - hint=FieldHint.architecture, - ) - dt_bias_weight: ParameterConfig = Field( - desc="Parameter configuration for the Delta gate bias.", - hint=FieldHint.architecture, - ) - a_log_weight: ParameterConfig = Field( - desc="Parameter configuration for the decay rates.", - hint=FieldHint.architecture, - ) - - heads: int = Field( - default=16, - desc="Number of attention heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - head_dim: int = Field( - default=64, - desc="Dimension of each head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - - @property - def layer_class(self) -> "type": - from fast_llm.layers.ssm.kda import KimiDeltaAttention - - return KimiDeltaAttention - - def _validate(self) -> None: - with self._set_implicit_default(): - if "epsilon" not in self.normalization._explicit_fields: - self.normalization.epsilon = 1.0e-5 - if "activation" not in self.convolution_layer._explicit_fields: - self.convolution_layer.activation = "silu" - if "kernel_size" not in self.convolution_layer._explicit_fields: - self.convolution_layer.kernel_size = 4 - - super()._validate() - - @config_class() class SSMConfig(MixerConfig): # Layers diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index d07cb5e2..6fd86c6e 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from einops import rearrange from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -11,7 +12,7 @@ from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import GatedDeltaNetConfig +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, LinearAttentionKwargs from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import div @@ -263,28 +264,46 @@ def _forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + - we flatten batch + seq + - forward as packed sequence, i.e. BS = 1, cu_seqlens and seq_idx created in the preprocessing step must reflect this (these are None if cross_document_attention is True) + - scatter results back to B x T x D + - note, if there are padding tokens they are note removed, they are assumed to be ignored later in the loss calculation and are assumed to be always ont he right + """ + sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension # TODO: do we need masking of padding tokens? - # TODO: make sure varlen is supported + # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - # batch_size, sequence_length, _ = hidden_states.shape - projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs x seq_len x (qkvz) - projected_states_ba = self.in_proj_ba(hidden_states) # bs x seq_len x (b a) + # these are not + cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) + seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) + projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) if sequence_first: projected_states_qkvz = projected_states_qkvz.transpose(0, 1) projected_states_ba = projected_states_ba.transpose(0, 1) + batch_size, sequence_length = projected_states_qkvz.shape[:2] + + # note: to support var len training (packing) we need to flatten hidden states to batch_size = 1 + # this is does not seem to be required by causal_conv1d_fn, but it it required by chunked_gdn_rule: https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/gated_delta_rule/chunk.py#L299 + # similarly to kimi linear and to SHortCOnv from fla, we pass it flattened tro conv_1d as well, i.e. see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914 query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) mixed_qkv = torch.cat((query, key, value), dim=-1) - mixed_qkv = mixed_qkv.transpose(1, 2) - mixed_qkv = self.convolution(mixed_qkv) - mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d + mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2) + mixed_qkv = self.convolution( + mixed_qkv, seq_idx=seq_idx + ) # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 + mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2) query, key, value = torch.split( mixed_qkv, ( @@ -301,6 +320,9 @@ def _forward( beta = beta.sigmoid() g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) + beta = rearrange(beta, "b s ... -> (b s) ...").unsqueeze(0) + g = rearrange(g, "b s ... -> (b s) ...").unsqueeze(0) + if self._value_heads_per_key > 1: query = query.repeat_interleave(self._value_heads_per_key, dim=2) key = key.repeat_interleave(self._value_heads_per_key, dim=2) @@ -314,9 +336,12 @@ def _forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, ) z_shape_og = z.shape + core_attn_out = rearrange(core_attn_out.squeeze(0), "(b s) ... -> b s ...", b=batch_size, s=sequence_length) + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) @@ -328,10 +353,50 @@ def _forward( return output + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Creates seqlens and cu_seqlens for packed training (varlen). + This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. + + Sets: + - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token + - cu_seqlens to [N+1] tensor, where N is the total number of sequences in the batch, each element is the cumulative sequence length of packed sequences sofar + """ + + sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + if sequence_lengths is None: + raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") + seqlens = torch.cat(sequence_lengths) + cu_seqlens = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=batch.device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), + ) + ) + # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 + # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 + kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens + # seq_idx has to be (bs, seqlen), but bs is forced to 1 + kwargs[LinearAttentionKwargs.seq_idx] = ( + ( + torch.cat( + [ + torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) + for n in (torch.diff(cu_seqlens).to(torch.int32)) + ], + dim=0, + ) + .eq(0) + .cumsum(0) + - 1 + ) + .to(torch.int32) + .unsqueeze(0) + ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if LinearAttentionKwargs.sequence_lengths in kwargs: + self._preprocess_for_varlen(batch, kwargs) + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # return ( - # self.in_proj_qkvz.get_compute_usage(input_, config) - # + self.in_proj_ba.get_compute_usage(input_, config) - # + self.out_proj.get_compute_usage(input_, config) - # ) raise NotImplementedError() diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py deleted file mode 100644 index b14fd459..00000000 --- a/fast_llm/layers/ssm/kda.py +++ /dev/null @@ -1,344 +0,0 @@ -import logging -import typing - -import torch -from einops import rearrange, repeat - -from fast_llm.engine.base_model.config import ResourceUsageConfig -from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.block import BlockWithBias -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta - -logger = logging.getLogger(__name__) - -try: - from fla.ops.kda import chunk_kda - from fla.ops.kda.gate import fused_kda_gate -except ImportError: - chunk_kda = None - fused_kda_gate = None - - -def index_first_axis(x, indices): - other_shape = x.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(x, "b ... -> b (...)"), - 0, - repeat(indices, "z -> z d", d=second_dim), - ).reshape(-1, *other_shape) - - -class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): - """ - Implementation of the Kimi Delta Attention mixer. - Reference Implementation: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/modeling_kimi.py - """ - - _config: ConfigType - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_bias: bool = True, - ): - super().__init__( - config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias - ) - if chunk_kda is None or fused_kda_gate is None: - raise ImportError( - "KimiDeltaAttention requires the `fla-core` package. " - "Please install it with `pip install -U fla-core`." - ) - - self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - self._heads_dim = TensorDim( - "kda_heads", self._config.heads, self._parallel_dim if self._config.heads > 1 else None - ) - self._head_dim = TensorDim("kda_head_dim", self._config.head_dim) - self._projection_dim = CompositeTensorDim("kda_projection", (self._heads_dim, self._head_dim)) - self._local_heads = self._heads_dim.size - self._projection_size = self._projection_dim.size - - init = init_normal_(std=self._hidden_size**-0.5) - self.q_proj = self._config.q_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.k_proj = self._config.k_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.v_proj = self._config.v_projection_layer.get_layer( - hidden_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.q_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.k_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.v_conv = self._config.convolution_layer.get_layer( - self._projection_dim, - default_add_bias=False, - default_activation=self._config.convolution_layer.activation, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.f_a_proj = self._config.f_a_projection_layer.get_layer( - hidden_dim, - self._head_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=False, # self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.f_b_proj = self._config.f_b_projection_layer.get_layer( - self._head_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.g_a_proj = self._config.g_a_projection_layer.get_layer( - hidden_dim, - self._head_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=False, # self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.g_b_proj = self._config.g_b_projection_layer.get_layer( - self._head_dim, - self._projection_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.beta_proj = self._config.beta_projection_layer.get_layer( - hidden_dim, - self._heads_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.o_proj = self._config.output_projection_layer.get_layer( - self._projection_dim, - hidden_dim, - default_weight_initialization=init, - default_add_bias=False, - sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - self.dt_bias: ParameterMeta = self._config.dt_bias_weight.get_parameter( - (self._projection_dim,), - default_initialization=init_ones_, - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.A_log: ParameterMeta = self._config.a_log_weight.get_parameter( - (self._heads_dim,), - default_initialization=LambdaInitializer( - lambda _, tensor, generator: tensor.uniform_(1, 16, generator=generator).log_() - ), - lr_scale=self._lr_scale, - peft=self._peft, - ) - self.norm = self._config.normalization.get_layer( - self._head_dim, - lr_scale=self._lr_scale, - peft=self._peft, - ) - - def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: - """ - Applies convolution. - Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. - Varlen: - - seq. idx are only suppored in channel last layout, i.e. no transpose - """ - tensor = rearrange(tensor, "b t d -> b d t") - # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) - tensor = conv(tensor, seq_idx=seq_idx) - return tensor.transpose(1, 2).contiguous() - - def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: - tensor = tensor.contiguous() - # since head_dim is the same vor k,q and v - # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) - return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) - - def _forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - # TODO: make sure varlen is supported - # TODO: do we need to deal with padding tokens? - sequence_first = kwargs[BlockKwargs.sequence_first] - hidden_states = input_ - - cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) - seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) - # TODO: can be made more efficeint by rearranging hidden states directly - residual_dtype = hidden_states.dtype - - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - if sequence_first: - # make bs first dim again - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - - batch_size, sequence_length, _ = q.size() - - # work with bs = 1 to make sure varlen works correctly, only needed if micro batch size is > 1 - # can this be applied once to hidden state only? pr - q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) - k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) - v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) - - # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) - # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) - q = self._apply_conv(q, self.q_conv, seq_idx=seq_idx) - k = self._apply_conv(k, self.k_conv, seq_idx=seq_idx) - v = self._apply_conv(v, self.v_conv, seq_idx=seq_idx) - - g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - if sequence_first: - g_kernel = g_kernel.transpose(0, 1) - g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) - - g_kernel = fused_kda_gate(g_kernel, self.A_log, self._config.head_dim, g_bias=self.dt_bias) - - beta = torch.sigmoid(self.beta_proj(hidden_states).float()) - q = self._reshape_heads(q) - k = self._reshape_heads(k) - v = self._reshape_heads(v) - if sequence_first: - beta = beta.transpose(0, 1) - beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) - - # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md - # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes - attn_out, _ = chunk_kda( - q=q, - k=k, - v=v, - g=g_kernel, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=cu_seqlens, - ) - - attn_out = attn_out.to(residual_dtype) - - g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim - g_out = self._reshape_heads(g_out) - if sequence_first: - g_out = g_out.transpose(0, 1) - - attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) - attn_out = self.norm(attn_out, g_out) - attn_out = rearrange(attn_out, "b s h d -> b s (h d)") - if sequence_first: - attn_out = attn_out.transpose(0, 1) - attn_out = self.o_proj(attn_out) - - return attn_out - - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - if sequence_lengths is None: - raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=batch.device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), - ) - ) - # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 - # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 - kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens - # seq_idx has to be (bs, seqlen), but bs is forced to 1 - kwargs[LinearAttentionKwargs.seq_idx] = ( - ( - torch.cat( - [ - torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) - for n in (torch.diff(cu_seqlens).to(torch.int32)) - ], - dim=0, - ) - .eq(0) - .cumsum(0) - - 1 - ) - .to(torch.int32) - .unsqueeze(0) - ) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - if LinearAttentionKwargs.sequence_lengths in kwargs: - # TODO: packing is enabled by default, i.e. its always used? - # only get here when cross_document_attention is False - self._preprocess_for_varlen(batch, kwargs) diff --git a/tests/test_ssm_varlen.py b/tests/test_ssm_varlen.py index 9ca491e3..1f7a83e6 100644 --- a/tests/test_ssm_varlen.py +++ b/tests/test_ssm_varlen.py @@ -8,8 +8,9 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig, LinearAttentionKwargs +from fast_llm.layers.decoder.config import MixerConfig +from fast_llm.layers.ssm import gdn as gdn_module +from fast_llm.layers.ssm.config import GatedDeltaNetConfig # from mamba2 import NemotronHMamba2 @@ -71,9 +72,8 @@ def materialize_meta_tensors(model, tensor_space): return model -def unpack(packed_hidden_states, cu_seqlens): +def unpack_and_padd(packed_hidden_states, cu_seqlens, package_num): batch_size = packed_hidden_states.shape[0] - package_num = cu_seqlens.shape[0] - 1 seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() hidden_dim = packed_hidden_states.shape[2] hidden_states = torch.zeros( @@ -132,7 +132,7 @@ def generate_random_cu_seqlens(seq_len, packages_num=2): return cu_seqlens, index -def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: +def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: """ Materialize meta parameters on the requested device for KDA mixer layers. """ @@ -154,41 +154,46 @@ def _materialize_kda_tensors(module: torch.nn.Module, distributed: Distributed, def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: - # ParameterMeta stores grads in grad_buffer; fall back to .grad otherwise. return param.grad_buffer if hasattr(param, "grad_buffer") and param.grad_buffer is not None else param.grad +# TODO: include mamba varlen @pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="KDA varlen needs CUDA") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") @pytest.mark.skipif( - kda_module.chunk_kda is None or kda_module.fused_kda_gate is None, - reason="KDA fused kernels not available", + gdn_module.chunk_gated_delta_rule is None, + reason="Gated Delta Net fused kernels not available", ) -def test_kda_varlen_stacking_equivalence(distributed_config, distributed): +@pytest.mark.parametrize( + "config, sequence_first", + [ + pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), + pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), + # pytest.param(KimiDeltaAttentionConfig) + ], +) +def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): """ - Check that KDA forward/backward match with and without stacking using the real kernels. + Check that Gated Delta Net forward/backward match with and without packing. """ device = torch.device("cuda") dtype = torch.float16 - heads, head_dim = 2, 16 - hidden_size = heads * head_dim - - config = KimiDeltaAttentionConfig(heads=heads, head_dim=head_dim) + hidden_size = 32 hidden_dim = TensorDim("hidden", hidden_size) - kda_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - kda_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) - kda_packed.setup(distributed) - kda_ref.setup(distributed) - _materialize_kda_tensors(kda_packed, distributed, device) - _materialize_kda_tensors(kda_ref, distributed, device) - kda_ref.load_state_dict(kda_packed.state_dict()) - kda_packed.to(device=device, dtype=dtype) - kda_ref.to(device=device, dtype=dtype) + mixer_packed = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + mixer_ref = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + mixer_packed.setup(distributed) + mixer_ref.setup(distributed) + _materialize_mixer_tensors(mixer_packed, distributed, device) + _materialize_mixer_tensors(mixer_ref, distributed, device) + mixer_ref.load_state_dict(mixer_packed.state_dict()) + mixer_packed.to(device=device, dtype=dtype) + mixer_ref.to(device=device, dtype=dtype) batch_size = 2 # cu_seqlens path requires flattened batch seq_len = 15 - packages_num = torch.randint(2, 5, (1, batch_size))[0] # randomize packages num between 2 and 4 - lengths = [ + packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) + sequence_lengths = [ torch.tensor( generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], device=device, @@ -196,63 +201,62 @@ def test_kda_varlen_stacking_equivalence(distributed_config, distributed): ).diff() for i in range(batch_size) ] + seqlens = torch.cat(sequence_lengths) + cu_seqlen = torch.cat( + ( + torch.zeros(1, dtype=torch.long, device=device), + torch.cumsum(seqlens, dim=0, dtype=torch.long).to(device), + ) + ) - # lengths = torch.tensor(cu_seqlens, device=device, dtype=torch.long)#.diff() - # total_tokens = lengths.sum().item() packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) + if sequence_first: + packed = packed.transpose(0, 1) kwargs_packed = { - LinearAttentionKwargs.sequence_lengths: lengths, - BlockKwargs.sequence_first: False, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_first: sequence_first, BlockKwargs.hidden_dims: (hidden_dim,), - # BlockKwargs.sequence_q_dim: TensorDim("sequence_q", lengths.sum().item()), } - # Use the layer's preprocess to construct cu_seqlens/seq_idx the same way as the implementation. - kda_packed.preprocess(packed, kwargs_packed) + mixer_packed.preprocess(packed, kwargs_packed) + assert torch.all(kwargs_packed["cu_seqlens"] == cu_seqlen) kwargs_ref = { BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), } - out_packed = kda_packed(packed, kwargs_packed) + out_packed = mixer_packed(packed, kwargs_packed) + if sequence_first: + out_packed = out_packed.transpose(0, 1) # Run reference path separately per sequence without varlen packing, then concatenate. ref_outs = [] for b in range(batch_size): out_batch = [] - length = lengths[b] - ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + length = sequence_lengths[b] + if sequence_first: + ref_seqs = torch.split(packed[:, b].unsqueeze(0), length.tolist(), dim=1) + else: + ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) for seq in ref_seqs: kwargs_ref_seq = { **kwargs_ref, BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), } - out_batch.append(kda_ref(seq, kwargs_ref_seq)) + out_batch.append(mixer_ref(seq, kwargs_ref_seq)) ref_outs.append(torch.cat(out_batch, dim=1)) out_ref = torch.cat(ref_outs, dim=0) out_ref_packed = out_ref - assert out_packed.shape == packed.shape assert out_ref_packed.shape == out_packed.shape assert torch.allclose(out_packed, out_ref_packed, atol=1e-3, rtol=1e-3) out_packed.sum().backward() out_ref_packed.sum().backward() - assert _param_grad(kda_packed.q_proj.weight) is not None - assert _param_grad(kda_ref.q_proj.weight) is not None - assert torch.allclose( - _param_grad(kda_packed.q_proj.weight), _param_grad(kda_ref.q_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.k_proj.weight), _param_grad(kda_ref.k_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.v_proj.weight), _param_grad(kda_ref.v_proj.weight), atol=1e-3, rtol=1e-3 - ) - assert torch.allclose( - _param_grad(kda_packed.o_proj.weight), _param_grad(kda_ref.o_proj.weight), atol=1e-3, rtol=1e-3 - ) + for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): + if param.requires_grad: + assert torch.allclose(_param_grad(param), _param_grad(param_ref), atol=1e-3, rtol=1e-3) if __name__ == "__main__": From c48d4ee44ccf71fd66db298c3b6a316925412570 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 24 Nov 2025 21:02:42 +0000 Subject: [PATCH 10/21] clean up --- fast_llm/layers/ssm/gdn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 6fd86c6e..af770e38 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -267,13 +267,16 @@ def _forward( """ - we flatten batch + seq - forward as packed sequence, i.e. BS = 1, cu_seqlens and seq_idx created in the preprocessing step must reflect this (these are None if cross_document_attention is True) - - scatter results back to B x T x D - - note, if there are padding tokens they are note removed, they are assumed to be ignored later in the loss calculation and are assumed to be always ont he right + - scatter results back to B/T x T/B x D + - note, if there are padding tokens they are not treated in a special way here. + They are + - assumed to be ignored later in the loss calculation and + - are assumed to be always on the right and, hence, will be reflected in seq_idx and cu_seqlens (i.e. treated as a seperate packed sequence?) + - """ sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension - # TODO: do we need masking of padding tokens? # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ @@ -355,8 +358,9 @@ def _forward( def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ - Creates seqlens and cu_seqlens for packed training (varlen). + Creates seqlens and cu_seqlens for packed forward. This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. + Note: padding tokens are always on the right and get their own entry in LinearAttentionKwargs.sequence_lengths --> they are treated as seperate sequence. Sets: - seq_idx to [1, BS x T] tensor, where each elemnt is the sequence index of the corresponding token From e2bb25cf49168a6ec43603691745ac2930168a62 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 14:35:37 +0000 Subject: [PATCH 11/21] test config --- tests/utils/model_configs.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f7797e3c..7ee095d3 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -817,6 +817,45 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests hybrid with gated delta net mixer. + "llama", + "hybrid_gdn", + updates={ + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "t": copy.deepcopy(_llama_block), + "gdn": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "gdn", + "value_heads": 4, + "key_heads": 2, + "key_head_dim": 16, + "value_head_dim": 16, + }, + }, + }, + "num_blocks": 2, + "pattern": ["t", "gdn"], + }, + }, + megatron_args=None, + checkpoint_format=AprielHybridSSMCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=16, + skip_tests=("sdp", "ms", "stp"), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models") From d4f9b856f534cb7899b3dc5d887909c10c78cf7a Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 19:08:04 +0000 Subject: [PATCH 12/21] wip --- fast_llm/layers/ssm/gdn.py | 106 +++++++++++++------ setup.cfg | 4 +- tests/{test_ssm_varlen.py => test_varlen.py} | 20 ---- tests/utils/model_configs.py | 2 +- 4 files changed, 75 insertions(+), 57 deletions(-) rename tests/{test_ssm_varlen.py => test_varlen.py} (94%) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index af770e38..cba40e48 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -210,9 +210,9 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - self.norm = self._config.normalization.get_layer( - self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft - ) + # self.norm = self._config.normalization.get_layer( + # self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + # ) self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule @@ -221,41 +221,65 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) + # def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + # """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" + # new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + # self._local_key_heads, + # 2 * self._config.key_head_dim + # + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + # ) + # new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + # self._local_key_heads, + # 2 * self._local_value_heads // self._local_key_heads, + # ) + # mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + # mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + # split_arg_list_qkvz = [ + # self._config.key_head_dim, + # self._config.key_head_dim, + # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + # ] + # split_arg_list_ba = [ + # self._local_value_heads // self._local_key_heads, + # self._local_value_heads // self._local_key_heads, + # ] + # query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + # b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + # value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + # z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + # b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + # a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + # return query, key, value, z, b, a + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ + note this must be the right way to split the TP, because TP splits each subdimention of ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) seperately. Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self._local_key_heads, - 2 * self._config.key_head_dim - + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + # Split contiguous q/k/v/z blocks and only then project them into per-head shapes. + local_qkv_sizes = ( + self._local_key_heads * self._config.key_head_dim, + self._local_key_heads * self._config.key_head_dim, + self._local_value_heads * self._config.value_head_dim, + self._local_value_heads * self._config.value_head_dim, ) - new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - self._local_key_heads, - 2 * self._local_value_heads // self._local_key_heads, + query, key, value, z = torch.split(mixed_qkvz, local_qkv_sizes, dim=-1) + query = query.reshape(*query.shape[:-1], self._local_key_heads, self._config.key_head_dim) + key = key.reshape(*key.shape[:-1], self._local_key_heads, self._config.key_head_dim) + value = value.reshape(*value.shape[:-1], self._local_value_heads, self._config.value_head_dim) + z = z.reshape(*z.shape[:-1], self._local_value_heads, self._config.value_head_dim) + + beta, alpha = torch.split( + mixed_ba, + (self._local_value_heads, self._local_value_heads), + dim=-1, ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - split_arg_list_qkvz = [ - self._config.key_head_dim, - self._config.key_head_dim, - (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - ] - split_arg_list_ba = [ - self._local_value_heads // self._local_key_heads, - self._local_value_heads // self._local_key_heads, - ] - query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) - b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) - z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) - b = b.reshape(b.size(0), b.size(1), self._local_value_heads) - a = a.reshape(a.size(0), a.size(1), self._local_value_heads) - return query, key, value, z, b, a + beta = beta.reshape(*beta.shape[:-1], self._local_value_heads) + alpha = alpha.reshape(*alpha.shape[:-1], self._local_value_heads) + return query, key, value, z, beta, alpha def _forward( self, @@ -280,7 +304,6 @@ def _forward( # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ - # these are not cu_seqlens = kwargs.get(LinearAttentionKwargs.cu_seqlens, None) seq_idx = kwargs.get(LinearAttentionKwargs.seq_idx, None) @@ -321,7 +344,7 @@ def _forward( value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) beta = beta.sigmoid() - g = -torch.exp(self.A_log) * F.softplus(alpha + self.dt_bias) + g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) beta = rearrange(beta, "b s ... -> (b s) ...").unsqueeze(0) g = rearrange(g, "b s ... -> (b s) ...").unsqueeze(0) @@ -347,7 +370,7 @@ def _forward( core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) + # core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) if sequence_first: @@ -398,9 +421,24 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A .unsqueeze(0) ) + def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Since forward is packed by default, this is needed for tests to path. + """ + if LinearAttentionKwargs.sequence_lengths in kwargs: + return self._preprocess_for_varlen(batch, kwargs) + bs, sequence_lengths = ( + batch.shape[:2] if not kwargs[BlockKwargs.sequence_first] else (batch.shape[1], batch.shape[0]) + ) + sequence_lengths = [torch.tensor([sequence_lengths] * bs, device=batch.device)] + kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths + self._preprocess_for_varlen(batch, kwargs) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if LinearAttentionKwargs.sequence_lengths in kwargs: self._preprocess_for_varlen(batch, kwargs) + else: + self._preprocess_for_cross_doc_attetion(batch, kwargs) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/setup.cfg b/setup.cfg index 14e9dba2..77664cd0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,8 +52,8 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]>=2.2.4 - cartesia_pytorch>=0.0.2 + mamba_ssm[causal-conv1d]==2.2.4 + flash-linear-attention>=0.4.1 GENERATION = lm_eval>=0.4.9 diff --git a/tests/test_ssm_varlen.py b/tests/test_varlen.py similarity index 94% rename from tests/test_ssm_varlen.py rename to tests/test_varlen.py index 1f7a83e6..0b5a6cac 100644 --- a/tests/test_ssm_varlen.py +++ b/tests/test_varlen.py @@ -1,4 +1,3 @@ -import inspect import itertools import pytest @@ -12,25 +11,6 @@ from fast_llm.layers.ssm import gdn as gdn_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig -# from mamba2 import NemotronHMamba2 - - -_mamba_varlen = False -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa - - _mamba_available = True - sig = inspect.signature(selective_scan_fn) - if "position_indices" in sig.parameters: - _mamba_varlen = True - else: - _mamba_varlen = False - # for training with packing install https://github.com/jxiw/varlen_mamba - # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md - -except (ImportError, RuntimeError): - _mamba_available = False - @pytest.fixture def distributed_config(): diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7ee095d3..ac431f7d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -831,7 +831,7 @@ def _update_and_add_testing_config( "mixer": { "type": "gdn", "value_heads": 4, - "key_heads": 2, + "key_heads": 4, "key_head_dim": 16, "value_head_dim": 16, }, From 8017a80d053585b6577566186db68a8798f5f418 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 20:57:03 +0000 Subject: [PATCH 13/21] gdn tests --- .../common/normalization/normalization.py | 12 +- fast_llm/layers/ssm/gdn.py | 9 +- tests/utils/model_configs.py | 113 ++---------------- 3 files changed, 22 insertions(+), 112 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index ec8a52e2..ae46ee1d 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -311,14 +311,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) if rms_norm_gated is not None: - self._forward = self._forward_fused + self._forward_gated = self._forward_local else: - self._forward = self._forward + self._forward_gated = self._forward_local def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_) + return self._forward_gated(input_.view(-1, *self._normalized_shape), gate).view_as(input_) - def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: return rms_norm_gated( input_, gate, @@ -331,6 +331,6 @@ def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tens residual_in_fp32=False, ) - def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: - normalized = self.rmsnorm(input_) + def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + normalized = self._forward(input_) return normalized * F.silu(gate) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index cba40e48..cb3249b9 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -210,9 +210,9 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - # self.norm = self._config.normalization.get_layer( - # self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft - # ) + self.norm = self._config.normalization.get_layer( + self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft + ) self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule @@ -259,7 +259,6 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. """ - # Split contiguous q/k/v/z blocks and only then project them into per-head shapes. local_qkv_sizes = ( self._local_key_heads * self._config.key_head_dim, self._local_key_heads * self._config.key_head_dim, @@ -370,7 +369,7 @@ def _forward( core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) - # core_attn_out = self.norm(core_attn_out, z) + core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) if sequence_first: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ac431f7d..31e84eec 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -4,6 +4,7 @@ import functools import os import pathlib +import re import typing import pytest @@ -80,7 +81,7 @@ class ModelTestingConfig: groups: dict[ModelTestingGroup, ModelTestingGroupAction] # Scale the comparison thresholds for specific models. compare_factor: float = 1.0 - # Option to skip specific distributed configuration with name containing any of the provided strings. + # Option to skip specific distributed configuration with name matching any of the provided regex patterns. skip_tests: tuple[str] = () get_dataset: typing.Callable[[bool], tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path]] = ( get_model_test_dataset @@ -136,7 +137,7 @@ def base_model_config_class(self): return self.model_config_class.get_base_model_config_class() def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: - return any(key in distributed_config.name for key in self.skip_tests) + return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) def _update_and_add_testing_config( @@ -461,7 +462,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Arg update for cross-entropy splits doesn't work here. - skip_tests=("ce4", "ms"), + skip_tests=(r"ce4", r"ms"), ) _update_and_add_testing_config( @@ -594,7 +595,7 @@ def _update_and_add_testing_config( }, compare_factor=2.0, # Micro-sequence split not supported. - skip_tests=("sdp", "ms"), + skip_tests=(r"sdp", r"ms"), ) _update_and_add_testing_config( @@ -636,8 +637,8 @@ def _update_and_add_testing_config( compare_factor=2.0, # Micro-sequence split not supported. skip_tests=( - "sdp", - "ms", + r"sdp", + r"ms", ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) @@ -721,99 +722,7 @@ def _update_and_add_testing_config( }, compare_factor=6.0, # Micro-sequence split and sequence-first not supported. - # TODO: Gradient accumulation works but comparison is broken. - skip_tests=("sdp", "ms", "bf4", "df"), -) - - -_update_and_add_testing_config( - # Tests apriel2 format with pattern decoder mixing all mixer types. - # This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention. - "llama", - "apriel2", - updates={ - ("model", "base_model", "tied_embedding_weight"): True, - ("model", "base_model", "decoder"): { - "type": "pattern", - "blocks": { - "attn_full": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "add_linear_biases": False, - }, - }, - "mamba": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "mamba_2", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, - }, - }, - "stochastic": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "stochastic", - "mixers": { - "attn": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "add_linear_biases": False, - }, - "mamba": { - "type": "mamba_2", - "d_inner": 512, - "state_size": 16, - "dt_rank": 16, - "d_xb": 256, - "add_linear_biases": False, - }, - }, - "sampling_strategy": "uniform", - "main_mixer_name": "attn", - }, - }, - "attn_swa": { - **copy.deepcopy(_llama_block), - "mixer": { - "type": "attention", - "rotary": {"type": "default", "theta": 10000}, - "heads": 8, - "head_groups": 4, - "head_size": 32, - "window_size": 128, - "add_linear_biases": False, - }, - }, - }, - "pattern": ["attn_full", "mamba", "stochastic", "attn_swa"], - "num_blocks": 4, - }, - }, - megatron_args=None, - checkpoint_format=Apriel2CheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.normal, - ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.normal, - }, - compare_factor=2.0, - # Micro-sequence split not supported for Mamba. - skip_tests=("sdp", "ms"), + skip_tests=(r"sdp", r"ms", r"bf4", r"df"), ) @@ -851,8 +760,10 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=16, - skip_tests=("sdp", "ms", "stp"), + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm) + # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). + # we should be using STP with this model! + skip_tests=(r"sdp", r"ms", r"^tp2$"), ) From 1e016014f88375151f80006ecbd9a093906b8aff Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 21:37:28 +0000 Subject: [PATCH 14/21] tests --- fast_llm/layers/common/normalization/normalization.py | 2 +- tests/utils/model_configs.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index ae46ee1d..651d8e4b 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -311,7 +311,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) if rms_norm_gated is not None: - self._forward_gated = self._forward_local + self._forward_gated = self._forward_fla else: self._forward_gated = self._forward_local diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 31e84eec..1eacb784 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -760,9 +760,9 @@ def _update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm) + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). - # we should be using STP with this model! + # we should be using STP with this model, not TP! skip_tests=(r"sdp", r"ms", r"^tp2$"), ) From ca8cb5cb4514359b48ebacb6013906e01d1d14d0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 22:19:29 +0000 Subject: [PATCH 15/21] tests --- fast_llm/layers/ssm/gdn.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index cb3249b9..3feac971 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -378,7 +378,7 @@ def _forward( return output - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: """ Creates seqlens and cu_seqlens for packed forward. This assumes that forward pass is performed on a fully packed sequence, i.e. where sequences are flattened out into BS = 1. @@ -390,15 +390,21 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A """ sequence_lengths = kwargs[LinearAttentionKwargs.sequence_lengths] + device = kwargs.get("device", None) if sequence_lengths is None: raise ValueError("sequence_lengths must be provided in kwargs for variable-length sequences.") - seqlens = torch.cat(sequence_lengths) - cu_seqlens = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=batch.device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(batch.device), - ) + seqlens = torch.tensor( + [ + 0, + *( + sequence_length + for sequence_lengths in kwargs[LinearAttentionKwargs.sequence_lengths] + for sequence_length in sequence_lengths + ), + ], + dtype=torch.int32, ) + cu_seqlens = seqlens.cumsum_(0).to(device) # this is supposed to be flattened, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303 # also whenever cu_seqlens is used, batchs size must be forced to 1: see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L347 kwargs[LinearAttentionKwargs.cu_seqlens] = cu_seqlens @@ -433,11 +439,8 @@ def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[s kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths self._preprocess_for_varlen(batch, kwargs) - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - if LinearAttentionKwargs.sequence_lengths in kwargs: - self._preprocess_for_varlen(batch, kwargs) - else: - self._preprocess_for_cross_doc_attetion(batch, kwargs) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + self._preprocess_for_varlen(kwargs) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() From 694d2877d05df28a4979295146af65c2335e4e37 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 25 Nov 2025 22:27:37 +0000 Subject: [PATCH 16/21] nvm --- fast_llm/layers/ssm/gdn.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 3feac971..d9db6ad4 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -426,19 +426,6 @@ def _preprocess_for_varlen(self, kwargs: dict[str, typing.Any]) -> None: .unsqueeze(0) ) - def _preprocess_for_cross_doc_attetion(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - """ - Since forward is packed by default, this is needed for tests to path. - """ - if LinearAttentionKwargs.sequence_lengths in kwargs: - return self._preprocess_for_varlen(batch, kwargs) - bs, sequence_lengths = ( - batch.shape[:2] if not kwargs[BlockKwargs.sequence_first] else (batch.shape[1], batch.shape[0]) - ) - sequence_lengths = [torch.tensor([sequence_lengths] * bs, device=batch.device)] - kwargs[LinearAttentionKwargs.sequence_lengths] = sequence_lengths - self._preprocess_for_varlen(batch, kwargs) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._preprocess_for_varlen(kwargs) From d3bd916f77ff47cfe107a8e2207dc5c73b9c4f42 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:43:50 +0000 Subject: [PATCH 17/21] requirements --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 77664cd0..f4b2c904 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ HUGGINGFACE = # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - flash-linear-attention>=0.4.1 + flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main GENERATION = lm_eval>=0.4.9 From 9a53c5b93b5e7cd85fa94f644425ac216084e632 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:53:05 +0000 Subject: [PATCH 18/21] clean up --- fast_llm/models/gpt/conversion/apriel.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 215cc525..d9ddf57d 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -9,12 +9,7 @@ from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.ssm.config import ( - DiscreteMamba2Config, - GatedDeltaNetConfig, - KimiDeltaAttentionConfig, - Mamba2Config, -) +from fast_llm.layers.ssm.config import DiscreteMamba2Config, GatedDeltaNetConfig, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters @@ -268,11 +263,11 @@ def export_config(cls, config: GatedDeltaNetConfig) -> dict: "gdn_linear_conv_kernel_size": config.convolution_layer.kernel_size, }, } - + @classmethod def get_converters( cls, - config: KimiDeltaAttentionConfig, + config: GatedDeltaNetConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, @@ -345,7 +340,6 @@ class AprielBlockConverter: AttentionConfig: "t", Mamba2Config: "m2", DiscreteMamba2Config: "m2d", - KimiDeltaAttentionConfig: "kda", GatedDeltaNetConfig: "gdn", } _converter_classes = { From 80041ce7f7c40a8b6d8382370999df3ed5039c39 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 13:59:29 +0000 Subject: [PATCH 19/21] conversion --- fast_llm/models/gpt/conversion/apriel.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index d9ddf57d..41c444df 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,7 +8,6 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, GatedDeltaNetConfig, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat @@ -225,19 +224,6 @@ def get_converters( ] -class AprielMLPConverter(LlamaMLPConverter): - @classmethod - def import_config(cls, config: dict) -> dict: - config["mlp_bias"] = False - return super().import_config(config) - - @classmethod - def export_config(cls, config: MLPConfig) -> dict: - out = super().export_config(config) - del out["mlp_bias"] - return out - - class GatedDeltaNetConverter: @classmethod def import_config(cls, config: dict) -> dict: @@ -316,11 +302,7 @@ def get_converters( ] -class AprielBlockConverterBase(MistralBlockConverter): - mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter - - -class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): +class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter hf_mixer_name: typing.ClassVar[str] = "mixer" @@ -330,7 +312,7 @@ class AprielMamba2BlockConverter(MistralBlockConverter): hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielGatedDeltaNetBlockConverter(AprielBlockConverterBase): +class AprielGatedDeltaNetBlockConverter(MistralBlockConverter): mixer_converter_class: typing.ClassVar[type[GatedDeltaNetConverter]] = GatedDeltaNetConverter hf_mixer_name: typing.ClassVar[str] = "mixer" From d6677b08f0baf5a79f25e9adbfba8dc111631cb1 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 16:16:35 +0000 Subject: [PATCH 20/21] comments on the layour + HF forward equivalence test --- fast_llm/layers/ssm/gdn.py | 85 ++++++++++------- tests/layers/test_gdn_equivalence.py | 134 +++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 32 deletions(-) create mode 100644 tests/layers/test_gdn_equivalence.py diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index d9db6ad4..9f3a5526 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -117,6 +117,9 @@ class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): - For tensor parallel implementtion (no sequnece prallel): we scatter teh heads accross ranks. - Sequence Tensor parallel: in_proj_qkvz all reduces across sequence dim. --> each rank performs work on full sequence but only a subset of heads (standrd TP). + Note, Qwen3_Next follows a different layout, where gdn_qkvz is assumed to be layed out as [h0: Q,K,V,Z][h1: Q,K,V,Z][h2: Q,K,V,Z] + Here we follow a more natural layout for gdn_qkvz: [Q_all_heads | K_all_heads | V_all_heads | Z_all_heads]. If we want to apply MIL init here it should be easier like this. + """ _config: ConfigType @@ -131,6 +134,7 @@ def __init__( peft: PeftConfig | None, return_bias: bool = True, ): + super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) @@ -141,6 +145,7 @@ def __init__( self._key_heads_dim = TensorDim( "gdn_key_heads", self._config.key_heads, self._parallel_dim if self._config.key_heads > 1 else None ) + self._value_head_dim = TensorDim("gdn_value_head_dim", self._config.value_head_dim) self._key_head_dim = TensorDim("gdn_key_head_dim", self._config.key_head_dim) self._local_value_heads = self._value_heads_dim.size @@ -150,8 +155,18 @@ def __init__( query_dim = CompositeTensorDim("gdn_query", (self._key_heads_dim, self._key_head_dim)) key_dim = CompositeTensorDim("gdn_key", (self._key_heads_dim, self._key_head_dim)) value_dim = CompositeTensorDim("gdn_value", (self._value_heads_dim, self._value_head_dim)) + z_dim = CompositeTensorDim("gdn_z", (self._value_heads_dim, self._value_head_dim)) qkvz_dim = ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) + # for Qwen's layour use soemthing like this instead: + # n_vheads_per_k_head = self._config.value_heads // self._config.key_heads + # head_size = 2 * self._config.key_head_dim + 2 * self._config.value_head_dim * n_vheads_per_k_head + # n_heads = self._config.key_heads + # qkvz_dim = TensorDim(e + # "gdn_qkvz", + # n_heads * head_size, + # self._parallel_dim if n_heads > 1 else None, + # ) ba_dim = ConcatenatedTensorDim( "gdn_ba", ( @@ -159,6 +174,12 @@ def __init__( CompositeTensorDim("gdn_alpha", (self._value_heads_dim,)), ), ) + # for Qwen's layour use something like this instead: + # ba_dim = TensorDim( + # "gdn_ba", + # 2 * self._config.value_heads, + # self._parallel_dim if 2 * self._config.value_heads > 1 else None, + # ) qkv_channels_dim = ConcatenatedTensorDim("gdn_qkv", (query_dim, key_dim, value_dim)) @@ -221,42 +242,42 @@ def __init__( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) - # def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): - # """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" - # new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - # self._local_key_heads, - # 2 * self._config.key_head_dim - # + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, - # ) - # new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - # self._local_key_heads, - # 2 * self._local_value_heads // self._local_key_heads, - # ) - # mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - # mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - # split_arg_list_qkvz = [ - # self._config.key_head_dim, - # self._config.key_head_dim, - # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - # (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), - # ] - # split_arg_list_ba = [ - # self._local_value_heads // self._local_key_heads, - # self._local_value_heads // self._local_key_heads, - # ] - # query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) - # b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) - # # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - # value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) - # z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) - # b = b.reshape(b.size(0), b.size(1), self._local_value_heads) - # a = a.reshape(a.size(0), a.size(1), self._local_value_heads) - # return query, key, value, z, b, a + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """Derives query, key and value tensors from mixed_qkvz and mixed_ba.""" + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self._local_key_heads, + 2 * self._config.key_head_dim + + 2 * self._config.value_head_dim * self._local_value_heads // self._local_key_heads, + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self._local_key_heads, + 2 * self._local_value_heads // self._local_key_heads, + ) + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + split_arg_list_qkvz = [ + self._config.key_head_dim, + self._config.key_head_dim, + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + (self._local_value_heads // self._local_key_heads * self._config.value_head_dim), + ] + split_arg_list_ba = [ + self._local_value_heads // self._local_key_heads, + self._local_value_heads // self._local_key_heads, + ] + query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3) + b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3) + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), value.size(1), -1, self._config.value_head_dim) + z = z.reshape(z.size(0), z.size(1), -1, self._config.value_head_dim) + b = b.reshape(b.size(0), b.size(1), self._local_value_heads) + a = a.reshape(a.size(0), a.size(1), self._local_value_heads) + return query, key, value, z, b, a def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ - note this must be the right way to split the TP, because TP splits each subdimention of ConcatenatedTensorDim("gdn_qkvz", (query_dim, key_dim, value_dim, z_dim)) seperately. Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. + Replaces fix_query_key_value_ordering from Qwen due to layout differences. """ local_qkv_sizes = ( diff --git a/tests/layers/test_gdn_equivalence.py b/tests/layers/test_gdn_equivalence.py new file mode 100644 index 00000000..9886056e --- /dev/null +++ b/tests/layers/test_gdn_equivalence.py @@ -0,0 +1,134 @@ +import pytest +import torch + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.ssm.config import GatedDeltaNetConfig + +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextGatedDeltaNet +except ImportError: + Qwen3NextConfig, Qwen3NextGatedDeltaNet = None, None + + +def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: + """ + Instantiate meta-allocated parameters on the requested device so the layer can run standalone. + """ + for name, param in module.named_parameters(): + if param.device.type != "meta": + continue + param_data = torch.empty_like(param, device=device) + param.init_parameter(param_data, distributed) + module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) + target = module + if module_path is not None: + for part in module_path.split("."): + target = getattr(target, part) + new_param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) + new_param.grad = None + new_param.grad_buffer = torch.zeros_like(param_data) + new_param.param_grad_is_zero = True + target._parameters[param_name] = new_param + + +@pytest.mark.slow +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Varlen test needs CUDA") +@pytest.mark.skipif(Qwen3NextConfig is None, reason="transformers with Qwen3-Next not installed") +def test_fast_llm_gdn_matches_qwen3_next_forward(): + torch.manual_seed(0) + device = torch.device("cuda") + dtype = torch.bfloat16 + + hidden_size = 16 + seq_len = 6 + num_k_heads = 2 + num_v_heads = 4 + head_k_dim = 4 + head_v_dim = 4 + kernel_size = 4 + + hf_config = Qwen3NextConfig( + hidden_size=hidden_size, + linear_num_key_heads=num_k_heads, + linear_num_value_heads=num_v_heads, + linear_key_head_dim=head_k_dim, + linear_value_head_dim=head_v_dim, + linear_conv_kernel_dim=kernel_size, + hidden_act="silu", + rms_norm_eps=1e-6, + dtype=dtype, + ) + hf_layer = Qwen3NextGatedDeltaNet(hf_config, layer_idx=0).to(device=device, dtype=dtype).eval() + + fast_config = GatedDeltaNetConfig( + value_heads=num_v_heads, + key_heads=num_k_heads, + value_head_dim=head_v_dim, + key_head_dim=head_k_dim, + activation=ActivationType.silu, + normalization={"epsilon": hf_config.rms_norm_eps}, + convolution_layer={"kernel_size": kernel_size, "activation": ActivationType.silu}, + ) + distributed_config = DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + sequence_data_parallel=1, + local_world_size=1, + world_size=1, + ) + hidden_dim = TensorDim("hidden", hidden_size) + fast_layer = fast_config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) + distributed = Distributed(config=distributed_config) + fast_layer.setup(distributed) + _materialize_mixer_tensors(fast_layer, distributed, device) + fast_layer.to(device=device, dtype=dtype).eval() + + with torch.no_grad(): + fast_layer.in_proj_qkvz.weight.copy_(hf_layer.in_proj_qkvz.weight) + fast_layer.in_proj_ba.weight.copy_(hf_layer.in_proj_ba.weight) + fast_layer.convolution.weight.copy_(hf_layer.conv1d.weight) + if fast_layer.convolution.bias is not None and hf_layer.conv1d.bias is not None: + fast_layer.convolution.bias.copy_(hf_layer.conv1d.bias) + fast_layer.out_proj.weight.copy_(hf_layer.out_proj.weight) + fast_layer.A_log.copy_(hf_layer.A_log) + fast_layer.dt_bias.copy_(hf_layer.dt_bias) + fast_layer.norm.weight.copy_(hf_layer.norm.weight) + + hidden_states = torch.randn(1, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=False) + + param_map = { + "in_proj_qkvz.weight": "in_proj_qkvz.weight", + "in_proj_ba.weight": "in_proj_ba.weight", + "convolution.weight": "conv1d.weight", + "convolution.bias": "conv1d.bias", + "out_proj.weight": "out_proj.weight", + "A_log": "A_log", + "dt_bias": "dt_bias", + "norm.weight": "norm.weight", + } + for k, p in fast_layer.state_dict().items(): + torch.testing.assert_close(p, hf_layer.state_dict()[param_map[k]], atol=1e-6, rtol=1e-6) + + # need to monkey patch the hf implementation with our fix_query_key_value_ordering due to the layout differences + hf_layer.fix_query_key_value_ordering = fast_layer.fix_query_key_value_ordering + hf_layer._local_key_heads = fast_layer._local_key_heads + hf_layer._local_value_heads = fast_layer._local_value_heads + hf_layer._config = fast_layer._config + + hf_out = hf_layer(hidden_states) + + fast_kwargs = { + BlockKwargs.sequence_first: False, + BlockKwargs.hidden_dims: (hidden_dim,), + } + fast_out = fast_layer(hidden_states, fast_kwargs) + + torch.testing.assert_close(fast_out, hf_out, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__]) From 6e2c1fe90406f1f6a64ceae1f1ed87d4201151dc Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 26 Nov 2025 20:13:48 +0000 Subject: [PATCH 21/21] varlen test --- tests/test_varlen.py | 57 +++++++++++++------------------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/tests/test_varlen.py b/tests/test_varlen.py index 0b5a6cac..126a3e1e 100644 --- a/tests/test_varlen.py +++ b/tests/test_varlen.py @@ -1,5 +1,3 @@ -import itertools - import pytest import torch @@ -84,7 +82,7 @@ def pack(hidden_states, cu_seqlens, batch_size): return packed_hidden_states -def generate_random_cu_seqlens(seq_len, packages_num=2): +def generate_random_seq_len(seq_len, packages_num=2): if packages_num < 1: raise ValueError("packages_num must be at least 1") @@ -92,24 +90,9 @@ def generate_random_cu_seqlens(seq_len, packages_num=2): base, rem = divmod(seq_len, packages_num) # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] lengths = [base + 1 if i < rem else base for i in range(packages_num)] - - # split points exclude the final cumulative (seq_len) - split_points = list(itertools.accumulate(lengths))[:-1] - - cu_seqlens = [0] + split_points + [seq_len] - # cu_seqlens = split_points # + [seq_len] - - # index: for each chunk, we emit 0,1,...,length-1 - index = [] - for length in lengths: - index.extend(range(length)) - - # sanity check - assert len(cu_seqlens) - 1 == packages_num assert sum(lengths) == seq_len - assert len(index) == seq_len - - return cu_seqlens, index + assert len(lengths) == packages_num + return lengths def _materialize_mixer_tensors(module: torch.nn.Module, distributed: Distributed, device: torch.device) -> None: @@ -149,7 +132,6 @@ def _param_grad(param: torch.nn.Parameter) -> torch.Tensor | None: [ pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), False), pytest.param(GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), True), - # pytest.param(KimiDeltaAttentionConfig) ], ) def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: bool, distributed_config, distributed): @@ -174,34 +156,23 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: seq_len = 15 packages_num = torch.tensor([2, 3], device=device, dtype=torch.long) sequence_lengths = [ - torch.tensor( - generate_random_cu_seqlens(seq_len, packages_num=packages_num[i].item())[0], - device=device, - dtype=torch.long, - ).diff() - for i in range(batch_size) + generate_random_seq_len(seq_len, packages_num=packages_num[i].item()) for i in range(batch_size) ] - seqlens = torch.cat(sequence_lengths) - cu_seqlen = torch.cat( - ( - torch.zeros(1, dtype=torch.long, device=device), - torch.cumsum(seqlens, dim=0, dtype=torch.long).to(device), - ) - ) packed = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=True) if sequence_first: packed = packed.transpose(0, 1) kwargs_packed = { + BlockKwargs.device: device, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_first: sequence_first, BlockKwargs.hidden_dims: (hidden_dim,), } - mixer_packed.preprocess(packed, kwargs_packed) - assert torch.all(kwargs_packed["cu_seqlens"] == cu_seqlen) + mixer_packed.preprocess(kwargs_packed) kwargs_ref = { + BlockKwargs.device: device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), } @@ -215,13 +186,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: out_batch = [] length = sequence_lengths[b] if sequence_first: - ref_seqs = torch.split(packed[:, b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[:, b].unsqueeze(0), length, dim=1) else: - ref_seqs = torch.split(packed[b].unsqueeze(0), length.tolist(), dim=1) + ref_seqs = torch.split(packed[b].unsqueeze(0), length, dim=1) for seq in ref_seqs: kwargs_ref_seq = { **kwargs_ref, - BlockKwargs.sequence_q_dim: TensorDim("sequence_q", seq.shape[1]), + BlockKwargs.sequence_lengths: [seq.shape[1]], } out_batch.append(mixer_ref(seq, kwargs_ref_seq)) ref_outs.append(torch.cat(out_batch, dim=1)) @@ -236,7 +207,13 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig, sequence_first: for (name, param), (_, param_ref) in zip(mixer_packed.named_parameters(), mixer_ref.named_parameters()): if param.requires_grad: - assert torch.allclose(_param_grad(param), _param_grad(param_ref), atol=1e-3, rtol=1e-3) + torch.testing.assert_close( + _param_grad(param), + _param_grad(param_ref), + atol=1e-3, + rtol=1e-3, + msg=f"Grad mismatch for parameter {name}", + ) if __name__ == "__main__":