From 44eb3f9e51bc6ee15c614e2ca49f3037780ca94a Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 22 May 2026 17:38:18 +0800 Subject: [PATCH] [Attention] Mamba attention module refactor - LINEAR Signed-off-by: wangxiyuan --- .../test_attention_backends_selection.py | 19 +- .../layers/mamba/linear/__init__.py | 0 .../mamba/linear/bailing_linear_attn.py | 384 +++++++++++++++ .../layers/mamba/linear/base.py | 66 +++ .../minimax_linear_attn.py} | 86 +--- .../models/bailing_moe_linear.py | 452 +----------------- vllm/model_executor/models/minimax_text_01.py | 49 +- 7 files changed, 505 insertions(+), 551 deletions(-) create mode 100644 vllm/model_executor/layers/mamba/linear/__init__.py create mode 100644 vllm/model_executor/layers/mamba/linear/bailing_linear_attn.py create mode 100644 vllm/model_executor/layers/mamba/linear/base.py rename vllm/model_executor/layers/mamba/{linear_attn.py => linear/minimax_linear_attn.py} (81%) diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py index 4242cc5ff2e2..e3d2e9dc457d 100644 --- a/tests/v1/attention/test_attention_backends_selection.py +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -54,15 +54,14 @@ ( MiniMaxText01LinearAttention, dict( - hidden_size=128, - hidden_inner_size=256, - num_heads=8, - head_dim=32, - max_position=2048, - block_size=64, - num_hidden_layer=12, - layer_idx=0, - linear_layer_idx=0, + config=SimpleNamespace( + hidden_size=256, + num_attention_heads=8, + head_dim=32, + num_hidden_layers=12, + block=64, + ), + prefix="layers.0.self_attn", ), LinearAttentionBackend, MambaAttentionBackendEnum.LINEAR, @@ -88,6 +87,8 @@ def test_mamba_layers_get_attn_backend( expected_mamba_type, ): """Test that Mamba-like layers return the correct attention backend.""" + if layer_class is MiniMaxText01LinearAttention: + init_kwargs["vllm_config"] = default_vllm_config layer = layer_class(**init_kwargs) backend_class = layer.get_attn_backend() diff --git a/vllm/model_executor/layers/mamba/linear/__init__.py b/vllm/model_executor/layers/mamba/linear/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/mamba/linear/bailing_linear_attn.py b/vllm/model_executor/layers/mamba/linear/bailing_linear_attn.py new file mode 100644 index 000000000000..dd963f829d81 --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear/bailing_linear_attn.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + +import torch +import torch.nn.functional as F +from transformers.configuration_utils import PretrainedConfig + +from vllm.config import ( + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.fla.ops.layernorm_guard import ( + RMSNormGated, + layernorm_fn, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.mamba.linear.base import LinearAttention +from vllm.model_executor.layers.mamba.linear.minimax_linear_attn import ( + MiniMaxText01LinearAttention, + MiniMaxText01LinearKernel, + clear_linear_attention_cache_for_new_sequences, + linear_attention_decode, + linear_attention_prefill_and_mix, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + + +def _build_rope_parameters(config: PretrainedConfig) -> dict | None: + rope_parameters = copy.deepcopy(getattr(config, "rope_parameters", None)) or {} + if "rope_theta" not in rope_parameters and hasattr(config, "rope_theta"): + rope_parameters["rope_theta"] = config.rope_theta + if "partial_rotary_factor" not in rope_parameters and hasattr( + config, "partial_rotary_factor" + ): + rope_parameters["partial_rotary_factor"] = config.partial_rotary_factor + + rope_scaling = getattr(config, "rope_scaling", None) + if isinstance(rope_scaling, dict): + rope_scaling = copy.deepcopy(rope_scaling) + if "type" in rope_scaling and "rope_type" not in rope_scaling: + rope_scaling["rope_type"] = rope_scaling.pop("type") + rope_parameters.update(rope_scaling) + + return rope_parameters or None + + +class BailingGroupRMSNormGate(RMSNormGated): + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + super().__init__( + hidden_size, + eps=eps, + group_size=group_size, + norm_before_gate=norm_before_gate, + device=device, + dtype=dtype, + activation="sigmoid", + ) + # Add custom weight loader for TP sharding + self.weight.weight_loader = self._weight_loader + + @staticmethod + def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + """Load weight with TP sharding.""" + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + shard_size = loaded_weight.shape[0] // tp_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard].contiguous()) + + +# --8<-- [start:bailing_moe_linear_attention] +@PluggableLayer.register("bailing_moe_linear_attention") +class BailingMoELinearAttention(LinearAttention): + """Pluggable Bailing MoE Linear Attention layer which allows OOT backends + to add custom implementations. + + This implements the linear attention mechanism from sglang, adapted for + vLLM's v1 engine with MambaBase interface support. + """ + + # --8<-- [end:bailing_moe_linear_attention] + def __init__( + self, + config: PretrainedConfig, + vllm_config: VllmConfig, + prefix: str = "linear_attn", + ): + super().__init__(config, vllm_config, prefix) + + self.scaling = self.head_dim**-0.5 + + self.tp_heads = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = getattr(config, "rope_theta", 600000) + + self.tp_kv_heads = self.num_heads // self.tp_size + self.q_size_per_rank = self.head_dim * self.tp_heads + self.kv_size_per_rank = self.head_dim * self.tp_kv_heads + + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.linear_backend = "minimax" + self.linear_scale = self.linear_backend == "minimax" + self.linear_rope = getattr(config, "linear_rope", True) + if hasattr(config, "use_linear_silu"): + self.linear_silu = config.use_linear_silu + elif hasattr(config, "linear_silu"): + self.linear_silu = config.linear_silu + else: + self.linear_silu = False + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_heads, # MHA: kv_heads = num_heads + bias=(config.use_bias or config.use_qkv_bias), + quant_config=self.quant_config, + prefix=f"{prefix}.query_key_value", + ) + + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.g_proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=self.quant_config, + prefix=f"{prefix}.g_proj", + ) + self.dense = RowParallelLinear( + self.hidden_inner_size, + self.hidden_size, + bias=config.use_bias, + quant_config=self.quant_config, + prefix=f"{prefix}.dense", + reduce_results=True, + ) + + self.group_norm_size = getattr(config, "group_norm_size", 1) + self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5)) + assert self.tp_size <= self.group_norm_size, ( + "tp_size must be <= group_norm_size for local rms norm" + ) + assert self.group_norm_size % self.tp_size == 0, ( + "group_norm_size must be divisible by tp_size" + ) + + # When group_norm_size == 1, group_size equals hidden_size // tp_size + self.g_norm = BailingGroupRMSNormGate( + hidden_size=self.hidden_inner_size // self.tp_size, + eps=self.rms_norm_eps, + group_size=( + self.hidden_inner_size // self.group_norm_size + if self.group_norm_size > 1 + else self.hidden_inner_size // self.tp_size + ), + ) + + # use fp32 rotary embedding + rope_parameters = _build_rope_parameters(config) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + is_neox_style=True, + rope_parameters=rope_parameters or None, + ) + + # Build slope tensor for linear attention decay + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads) + if self.num_hidden_layers <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * ( + 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 + ) + self.tp_slope = self.slope_rate[ + self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads + ].contiguous() + + # Register for compilation + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Load weight for linear attention layers. + + For FP8 quantized parameters, we need to use the weight_loader if available, + as it handles special cases like tensor parallelism sharding. + """ + # Check if param has a weight_loader (for vLLM ModelWeightParameter) + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + # Use the weight_loader which handles TP sharding and quantization + weight_loader(param, loaded_weight) + else: + # Fall back to direct copy for standard tensors + assert param.size() == loaded_weight.size(), ( + f"Shape mismatch: {param.shape} vs {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + ) -> None: + """Forward method called by torch.ops.vllm.linear_attention""" + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + ) -> None: + """Actual forward implementation.""" + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] # type: ignore + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) + else: + num_actual_tokens = hidden_states.shape[0] + + # QKV projection + qkv, _ = self.query_key_value(hidden_states[:num_actual_tokens]) + + # use rotary_emb support fp32 + qkv = qkv.to(torch.float32) + if self.linear_silu: + qkv = F.silu(qkv) + + # Split q, k, v + q, k, v = torch.split( + qkv, + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], + dim=-1, + ) + + # Apply QK norm if needed + if self.use_qk_norm: + q = q.reshape(-1, self.tp_heads, self.head_dim) + k = k.reshape(-1, self.tp_kv_heads, self.head_dim) + q = layernorm_fn( + q, + self.query_layernorm.weight.data, + bias=None, + eps=self.rms_norm_eps, + is_rms_norm=True, + ) + k = layernorm_fn( + k, + self.key_layernorm.weight.data, + bias=None, + eps=self.rms_norm_eps, + is_rms_norm=True, + ) + q = q.reshape(-1, self.q_size_per_rank) + k = k.reshape(-1, self.kv_size_per_rank) + + # Apply rotary embeddings + if self.linear_rope: + q, k = self.rotary_emb(positions[:num_actual_tokens], q, k) + + # Reshape to [batch, heads, seq_len, head_dim] + q = q.view((qkv.shape[0], self.tp_heads, self.head_dim)) + k = k.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) + v = v.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) + + # Apply scaling if using minimax backend + if self.linear_scale: + q = q * self.scaling + + # Get KV cache and state indices + if attn_metadata is not None: + kv_cache = self.kv_cache[0] + state_indices_tensor = attn_metadata.state_indices_tensor + clear_linear_attention_cache_for_new_sequences( + kv_cache, state_indices_tensor, attn_metadata + ) + + # Compute attention + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty( + (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype + ) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + else: + hidden = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + + # Apply group norm and gate (matching SGLang behavior) + gate, _ = self.g_proj(hidden_states[:num_actual_tokens]) + + if self.group_norm_size > 1: + hidden = self.g_norm(hidden, gate) + else: + hidden = self.g_norm(hidden) + hidden = F.sigmoid(gate) * hidden + + hidden = hidden.to(hidden_states.dtype) + + # Output projection + dense_out, _ = self.dense(hidden) + output[:num_actual_tokens] = dense_out + + def _prefill_and_mix_infer( + self, q, k, v, kv_cache, state_indices_tensor, attn_metadata + ): + """Handle prefill (mixed with decode if any).""" + return linear_attention_prefill_and_mix( + q=q, + k=k, + v=v, + kv_cache=kv_cache, + state_indices_tensor=state_indices_tensor, + attn_metadata=attn_metadata, + slope_rate=self.tp_slope, + block_size=self.BLOCK, + decode_fn=self._decode_infer, + prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix, + layer_idx=self.layer_idx, + ) + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): + """Handle decode (single token per sequence).""" + hidden = linear_attention_decode( + q, + k, + v, + kv_cache, + self.tp_slope, + state_indices_tensor, + q_start=0, + q_end=attn_metadata.num_decode_tokens, + slot_start=0, + slot_end=attn_metadata.num_decodes, + block_size=32, + ) + return hidden diff --git a/vllm/model_executor/layers/mamba/linear/base.py b/vllm/model_executor/layers/mamba/linear/base.py new file mode 100644 index 000000000000..73df07187303 --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from transformers import PretrainedConfig + +from vllm.config import ( + VllmConfig, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.models.utils import extract_layer_index +from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum + + +class LinearAttention(PluggableLayer, MambaBase): + """Base class for Linear attention layer.""" + + def __init__( + self, config: PretrainedConfig, vllm_config: VllmConfig, prefix: str = "" + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.prefix = prefix + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.quant_config = vllm_config.quant_config + + self.BLOCK = getattr(config, "block", 256) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // self.num_heads + ) + self.hidden_inner_size = self.head_dim * self.num_heads + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + assert self.num_heads % self.tp_size == 0 + + @property + def mamba_type(self) -> MambaAttentionBackendEnum: + return MambaAttentionBackendEnum.LINEAR + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim + ) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear/minimax_linear_attn.py similarity index 81% rename from vllm/model_executor/layers/mamba/linear_attn.py rename to vllm/model_executor/layers/mamba/linear/minimax_linear_attn.py index 5724e037c661..14c7d3d5f04c 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear/minimax_linear_attn.py @@ -7,30 +7,20 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import nn -from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) +from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.lightning_attn import ( lightning_attention, linear_decode_forward_triton, ) from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) +from vllm.model_executor.layers.mamba.linear.base import LinearAttention from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata -from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum def clear_linear_attention_cache_for_new_sequences( @@ -157,79 +147,39 @@ def jit_linear_forward_prefix( return rearrange(output.squeeze(0), "h n d -> n (h d)") -class MiniMaxText01LinearAttention(nn.Module, MambaBase): - @property - def mamba_type(self) -> MambaAttentionBackendEnum: - return MambaAttentionBackendEnum.LINEAR - - def get_state_dtype(self) -> tuple[torch.dtype]: - assert self.model_config is not None - assert self.cache_config is not None - return MambaStateDtypeCalculator.linear_attention_state_dtype( - self.model_config.dtype, - self.cache_config.mamba_cache_dtype, - ) - - def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: - return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim - ) - +@PluggableLayer.register("minimax_text_01_attention") +class MiniMaxText01LinearAttention(LinearAttention): def __init__( self, - hidden_size: int, - hidden_inner_size: int, - num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, + config, + vllm_config, prefix: str = "linear_attn", ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim - self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size + super().__init__(config, vllm_config, prefix) + + self.tp_heads = self.num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads - self.model_config = model_config - self.cache_config = cache_config - self.prefix = prefix self.qkv_proj = ColumnParallelLinear( - hidden_size, + self.hidden_size, self.hidden_inner_size * 3, bias=False, - quant_config=quant_config, + quant_config=self.quant_config, prefix=f"{prefix}.qkv_proj", ) self.output_gate = ColumnParallelLinear( - hidden_size, + self.hidden_size, self.hidden_inner_size, bias=False, - quant_config=quant_config, + quant_config=self.quant_config, prefix=f"{prefix}.output_gate", ) self.out_proj = RowParallelLinear( self.hidden_inner_size, - hidden_size, + self.hidden_size, bias=False, - quant_config=quant_config, + quant_config=self.quant_config, prefix=f"{prefix}.out_proj", ) self.norm = MiniMaxText01RMSNormTP( @@ -238,11 +188,11 @@ def __init__( ) slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads) - if num_hidden_layer <= 1: + if self.num_hidden_layers <= 1: self.slope_rate = slope_rate * (1 + 1e-5) else: self.slope_rate = slope_rate * ( - 1 - layer_idx / (num_hidden_layer - 1) + 1e-5 + 1 - self.layer_idx / (self.num_hidden_layers - 1) + 1e-5 ) self.tp_slope = self.slope_rate[ self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py index c66ae9102701..3857e993c7c7 100644 --- a/vllm/model_executor/models/bailing_moe_linear.py +++ b/vllm/model_executor/models/bailing_moe_linear.py @@ -9,7 +9,7 @@ from transformers.configuration_utils import PretrainedConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -17,11 +17,6 @@ ) from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.custom_op import PluggableLayer -from vllm.model_executor.layers.fla.ops.layernorm_guard import ( - RMSNormGated, - layernorm_fn, -) from vllm.model_executor.layers.fused_moe import ( FusedMoE, fused_moe_make_expert_params_mapping, @@ -30,25 +25,19 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.linear_attn import ( - MiniMaxText01LinearAttention, - MiniMaxText01LinearKernel, - clear_linear_attention_cache_for_new_sequences, - linear_attention_decode, - linear_attention_prefill_and_mix, +from vllm.model_executor.layers.mamba.linear.bailing_linear_attn import ( + BailingMoELinearAttention, + _build_rope_parameters, ) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFuncCalculator, MambaStateDtypeCalculator, MambaStateShapeCalculator, ) -from vllm.model_executor.layers.minimax_rms_norm import MiniMaxText01RMSNormTP from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -63,8 +52,6 @@ from vllm.model_executor.models.bailing_moe import BailingMLP from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionMetadata -from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata -from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum from .interfaces import HasInnerState, IsHybrid, SupportsPP from .utils import ( @@ -87,25 +74,6 @@ def is_linear_layer(layer_idx, layer_group_size): return False -def _build_rope_parameters(config: PretrainedConfig) -> dict | None: - rope_parameters = copy.deepcopy(getattr(config, "rope_parameters", None)) or {} - if "rope_theta" not in rope_parameters and hasattr(config, "rope_theta"): - rope_parameters["rope_theta"] = config.rope_theta - if "partial_rotary_factor" not in rope_parameters and hasattr( - config, "partial_rotary_factor" - ): - rope_parameters["partial_rotary_factor"] = config.partial_rotary_factor - - rope_scaling = getattr(config, "rope_scaling", None) - if isinstance(rope_scaling, dict): - rope_scaling = copy.deepcopy(rope_scaling) - if "type" in rope_scaling and "rope_type" not in rope_scaling: - rope_scaling["rope_type"] = rope_scaling.pop("type") - rope_parameters.update(rope_scaling) - - return rope_parameters or None - - class BailingMoeV25MLAAttention(nn.Module): """ MLA Attention for BailingMoeV2.5 full attention layers. @@ -397,400 +365,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_size) -BailingRMSNormTP = MiniMaxText01RMSNormTP - - -class BailingGroupRMSNormGate(RMSNormGated): - def __init__( - self, - hidden_size, - eps=1e-5, - group_size=None, - norm_before_gate=True, - device=None, - dtype=None, - ): - super().__init__( - hidden_size, - eps=eps, - group_size=group_size, - norm_before_gate=norm_before_gate, - device=device, - dtype=dtype, - activation="sigmoid", - ) - # Add custom weight loader for TP sharding - self.weight.weight_loader = self._weight_loader - - @staticmethod - def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: - """Load weight with TP sharding.""" - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - shard_size = loaded_weight.shape[0] // tp_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - param.data.copy_(loaded_weight[shard].contiguous()) - - -# --8<-- [start:bailing_moe_linear_attention] -@PluggableLayer.register("bailing_moe_linear_attention") -class BailingMoELinearAttention(PluggableLayer, MambaBase): - """Pluggable Bailing MoE Linear Attention layer which allows OOT backends - to add custom implementations. - - This implements the linear attention mechanism from sglang, adapted for - vLLM's v1 engine with MambaBase interface support. - """ - - # --8<-- [end:bailing_moe_linear_attention] - - @property - def mamba_type(self) -> MambaAttentionBackendEnum: - return MambaAttentionBackendEnum.LINEAR - - def get_state_shape(self) -> tuple[tuple[int, ...], ...]: - """Return state shape for linear attention cache. - - Must match the calculation in get_mamba_state_shape_from_config. - """ - return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.total_num_heads, - tp_size=self.tp_size, - head_dim=self.head_dim, - ) - - def get_state_dtype(self) -> tuple[torch.dtype, ...]: - """Return state dtype for linear attention cache. - - Must match the calculation in get_mamba_state_dtype_from_config. - """ - return MambaStateDtypeCalculator.linear_attention_state_dtype( - self.model_config.dtype, - self.cache_config.mamba_cache_dtype, - ) - - def __init__( - self, - config: PretrainedConfig, - quant_config: QuantizationConfig | None = None, - layer_id: int = 0, - prefix: str = "linear_attn", - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - ): - super().__init__() - - self.layer_id = layer_id - self.hidden_size = config.hidden_size - self.total_num_heads = config.num_attention_heads - self.total_kv_heads = config.num_attention_heads # MHA - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.model_config = model_config - self.cache_config = cache_config - self.prefix = prefix - - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // self.total_num_heads - ) - - self.hidden_inner_size = self.head_dim * self.total_num_heads - self.scaling = self.head_dim**-0.5 - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = getattr(config, "rope_theta", 600000) - - self.tp_kv_heads = self.total_kv_heads // self.tp_size - self.q_size_per_rank = self.head_dim * self.tp_heads - self.kv_size_per_rank = self.head_dim * self.tp_kv_heads - - self.use_qk_norm = getattr(config, "use_qk_norm", False) - self.linear_backend = "minimax" - self.linear_scale = self.linear_backend == "minimax" - self.linear_rope = getattr(config, "linear_rope", True) - if hasattr(config, "use_linear_silu"): - self.linear_silu = config.use_linear_silu - elif hasattr(config, "linear_silu"): - self.linear_silu = config.linear_silu - else: - self.linear_silu = False - - # Block size for lightning attention - self.BLOCK = getattr(config, "block", 256) - - self.query_key_value = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_heads, # MHA: kv_heads = num_heads - bias=(config.use_bias or config.use_qkv_bias), - quant_config=quant_config, - prefix=f"{prefix}.query_key_value", - ) - - if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - - self.g_proj = ColumnParallelLinear( - self.hidden_size, - self.hidden_inner_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.g_proj", - ) - self.dense = RowParallelLinear( - self.hidden_inner_size, - self.hidden_size, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.dense", - reduce_results=True, - ) - - self.group_norm_size = getattr(config, "group_norm_size", 1) - self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5)) - assert self.tp_size <= self.group_norm_size, ( - "tp_size must be <= group_norm_size for local rms norm" - ) - assert self.group_norm_size % self.tp_size == 0, ( - "group_norm_size must be divisible by tp_size" - ) - - # When group_norm_size == 1, group_size equals hidden_size // tp_size - self.g_norm = BailingGroupRMSNormGate( - hidden_size=self.hidden_inner_size // self.tp_size, - eps=self.rms_norm_eps, - group_size=( - self.hidden_inner_size // self.group_norm_size - if self.group_norm_size > 1 - else self.hidden_inner_size // self.tp_size - ), - ) - - # use fp32 rotary embedding - rope_parameters = _build_rope_parameters(config) - - self.rotary_emb = get_rope( - self.head_dim, - max_position=self.max_position_embeddings, - is_neox_style=True, - rope_parameters=rope_parameters or None, - ) - - # Build slope tensor for linear attention decay - num_hidden_layers = config.num_hidden_layers - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.total_num_heads - ) - if num_hidden_layers <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * ( - 1 - layer_id / (num_hidden_layers - 1) + 1e-5 - ) - self.tp_slope = self.slope_rate[ - self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads - ].contiguous() - - # Register for compilation - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - @staticmethod - def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Load weight for linear attention layers. - - For FP8 quantized parameters, we need to use the weight_loader if available, - as it handles special cases like tensor parallelism sharding. - """ - # Check if param has a weight_loader (for vLLM ModelWeightParameter) - weight_loader = getattr(param, "weight_loader", None) - if weight_loader is not None: - # Use the weight_loader which handles TP sharding and quantization - weight_loader(param, loaded_weight) - else: - # Fall back to direct copy for standard tensors - assert param.size() == loaded_weight.size(), ( - f"Shape mismatch: {param.shape} vs {loaded_weight.shape}" - ) - param.data.copy_(loaded_weight) - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor, - ) -> None: - """Forward method called by torch.ops.vllm.linear_attention""" - torch.ops.vllm.linear_attention( - hidden_states, - output, - positions, - self.prefix, - ) - - def _forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor, - ) -> None: - """Actual forward implementation.""" - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) - num_actual_tokens = ( - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - ) - else: - num_actual_tokens = hidden_states.shape[0] - - # QKV projection - qkv, _ = self.query_key_value(hidden_states[:num_actual_tokens]) - - # use rotary_emb support fp32 - qkv = qkv.to(torch.float32) - if self.linear_silu: - qkv = F.silu(qkv) - - # Split q, k, v - q, k, v = torch.split( - qkv, - [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], - dim=-1, - ) - - # Apply QK norm if needed - if self.use_qk_norm: - q = q.reshape(-1, self.tp_heads, self.head_dim) - k = k.reshape(-1, self.tp_kv_heads, self.head_dim) - q = layernorm_fn( - q, - self.query_layernorm.weight.data, - bias=None, - eps=self.rms_norm_eps, - is_rms_norm=True, - ) - k = layernorm_fn( - k, - self.key_layernorm.weight.data, - bias=None, - eps=self.rms_norm_eps, - is_rms_norm=True, - ) - q = q.reshape(-1, self.q_size_per_rank) - k = k.reshape(-1, self.kv_size_per_rank) - - # Apply rotary embeddings - if self.linear_rope: - q, k = self.rotary_emb(positions[:num_actual_tokens], q, k) - - # Reshape to [batch, heads, seq_len, head_dim] - q = q.view((qkv.shape[0], self.tp_heads, self.head_dim)) - k = k.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) - v = v.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) - - # Apply scaling if using minimax backend - if self.linear_scale: - q = q * self.scaling - - # Get KV cache and state indices - if attn_metadata is not None: - kv_cache = self.kv_cache[0] - state_indices_tensor = attn_metadata.state_indices_tensor - clear_linear_attention_cache_for_new_sequences( - kv_cache, state_indices_tensor, attn_metadata - ) - - # Compute attention - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if attn_metadata is None: - hidden = torch.empty( - (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype - ) - else: - if not decode_only: - hidden = self._prefill_and_mix_infer( - q, k, v, kv_cache, state_indices_tensor, attn_metadata - ) - else: - hidden = self._decode_infer( - q, k, v, kv_cache, state_indices_tensor, attn_metadata - ) - - # Apply group norm and gate (matching SGLang behavior) - gate, _ = self.g_proj(hidden_states[:num_actual_tokens]) - - if self.group_norm_size > 1: - hidden = self.g_norm(hidden, gate) - else: - hidden = self.g_norm(hidden) - hidden = F.sigmoid(gate) * hidden - - hidden = hidden.to(hidden_states.dtype) - - # Output projection - dense_out, _ = self.dense(hidden) - output[:num_actual_tokens] = dense_out - - def _prefill_and_mix_infer( - self, q, k, v, kv_cache, state_indices_tensor, attn_metadata - ): - """Handle prefill (mixed with decode if any).""" - return linear_attention_prefill_and_mix( - q=q, - k=k, - v=v, - kv_cache=kv_cache, - state_indices_tensor=state_indices_tensor, - attn_metadata=attn_metadata, - slope_rate=self.tp_slope, - block_size=self.BLOCK, - decode_fn=self._decode_infer, - prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix, - layer_idx=self.layer_id, - ) - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - """Handle decode (single token per sequence).""" - hidden = linear_attention_decode( - q, - k, - v, - kv_cache, - self.tp_slope, - state_indices_tensor, - q_start=0, - q_end=attn_metadata.num_decode_tokens, - slot_start=0, - slot_end=attn_metadata.num_decodes, - block_size=32, - ) - return hidden - - class BailingMoeV25DecoderLayer(nn.Module): """Decoder layer supporting both linear and full attention.""" def __init__( self, config: PretrainedConfig, - quant_config: QuantizationConfig | None = None, - layer_id: int = 0, + vllm_config: VllmConfig, prefix: str = "layer", - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, + layer_id: int = 0, ) -> None: super().__init__() self.layer_id = layer_id @@ -802,19 +385,16 @@ def __init__( if self.attention_type == 0: # Linear attention self.self_attn = BailingMoELinearAttention( config, - quant_config=quant_config, - layer_id=layer_id, + vllm_config, prefix=f"{prefix}.self_attn", - model_config=model_config, - cache_config=cache_config, ) else: # Full attention self.self_attn = BailingMoeV25MLAAttention( config, - quant_config=quant_config, + quant_config=vllm_config.quant_config, layer_id=layer_id, prefix=f"{prefix}.self_attn", - cache_config=cache_config, + cache_config=vllm_config.cache_config, ) # MLP/MoE @@ -825,7 +405,7 @@ def __init__( if is_moe_layer: self.mlp = BailingMoeV25( config, - quant_config=quant_config, + quant_config=vllm_config.quant_config, layer_id=layer_id, prefix=f"{prefix}.mlp", ) @@ -833,7 +413,7 @@ def __init__( self.mlp = BailingMLP( intermediate_size=config.intermediate_size, config=config, - quant_config=quant_config, + quant_config=vllm_config.quant_config, reduce_results=True, prefix=f"{prefix}.mlp", ) @@ -896,10 +476,6 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - quant_config = vllm_config.quant_config - cache_config = vllm_config.cache_config - self.config = config self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size @@ -934,11 +510,9 @@ def layer_fn(prefix): return BailingMoeV25DecoderLayer( config=layer_config, - quant_config=quant_config, - layer_id=layer_idx, + vllm_config=vllm_config, prefix=prefix, - model_config=model_config, - cache_config=cache_config, + layer_id=layer_idx, ) self.start_layer, self.end_layer, self.layers = make_layers( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index c73fbf7009d6..890dbe590ae9 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -15,7 +15,7 @@ from transformers import MiniMaxConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, @@ -35,7 +35,9 @@ RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention +from vllm.model_executor.layers.mamba.linear.minimax_linear_attn import ( + MiniMaxText01LinearAttention, +) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, MambaStateCopyFuncCalculator, @@ -277,9 +279,7 @@ class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, config: MiniMaxConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, + vllm_config: VllmConfig, expert_num: int = 1, layer_id: int = None, linear_layer_id: int | None = None, @@ -303,25 +303,9 @@ def __init__( config.max_position_embeddings, config.max_model_len ) if config.attention_type == 0: - use_headxdim = True - hidden_inner = ( - head_dim * config.num_attention_heads - if use_headxdim - else config.hidden_size - ) self.self_attn = MiniMaxText01LinearAttention( - hidden_size=self.hidden_size, - hidden_inner_size=hidden_inner, - num_heads=config.num_attention_heads, - head_dim=head_dim, - max_position=max_position_embeddings, - block_size=config.block if hasattr(config, "block") else 256, - num_hidden_layer=config.num_hidden_layers, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - layer_idx=self._ilayer, - linear_layer_idx=linear_layer_id, + config, + vllm_config, prefix=prefix, ) elif config.attention_type == 1: @@ -333,9 +317,9 @@ def __init__( max_position=max_position_embeddings, rope_parameters=config.rope_parameters, sliding_window=config.sliding_window, - quant_config=quant_config, + quant_config=vllm_config.quant_config, layer_idx=self._ilayer, - cache_config=cache_config, + cache_config=vllm_config.cache_config, prefix=prefix, ) else: @@ -348,7 +332,7 @@ def __init__( self.mlp = MiniMaxText01MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, - quant_config=quant_config, + quant_config=vllm_config.quant_config, layer_idx=self._ilayer, prefix=prefix, ) @@ -359,7 +343,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, layer_idx=self._ilayer, - quant_config=quant_config, + quant_config=vllm_config.quant_config, prefix=prefix, ) @@ -410,7 +394,7 @@ def __init__( self.shared_mlp = MiniMaxText01MLP( hidden_size=self.hidden_size, intermediate_size=shared_intermediate, - quant_config=quant_config, + quant_config=vllm_config.quant_config, layer_idx=self._ilayer, prefix=prefix, ) @@ -418,7 +402,7 @@ def __init__( self.hidden_size, 1, bias=False, - quant_config=quant_config, + quant_config=vllm_config.quant_config, params_dtype=torch.float32, ) self.coefficient.weight.weight_loader = self.shared_moe_coefficient_loader @@ -496,9 +480,6 @@ class MiniMaxText01Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: MiniMaxConfig = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - quant_config = vllm_config.quant_config - cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config self.config = config self.CONCAT_FFN = True @@ -541,10 +522,8 @@ def layer_fn(prefix): layer_config.layer_idx = layer_idx decoder_kwargs = { - "quant_config": quant_config, "layer_id": layer_idx, - "model_config": model_config, - "cache_config": cache_config, + "vllm_config": vllm_config, } if layer_config.attention_type == 0: