Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ class CompilationConfig:
"vllm::short_conv",
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention_core",
"vllm::qwen_gdn_attention_core",
"vllm::gdn_attention_core_xpu",
"vllm::olmo_hybrid_gdn_full_forward",
"vllm::kda_attention",
Expand Down
Empty file.
58 changes: 58 additions & 0 deletions vllm/model_executor/layers/mamba/gdn/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from transformers import PretrainedConfig

from vllm.config import (
VllmConfig,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum


class GatedDeltaNetAttention(PluggableLayer, MambaBase):
"""Base class for GatedDeltaNet attention layer."""

def __init__(
self,
config: PretrainedConfig,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
self.prefix = prefix
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
self.activation = config.hidden_act
self.layer_norm_epsilon = config.rms_norm_eps
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.speculative_config = vllm_config.speculative_config
self.num_spec = (
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
)

@property
def mamba_type(self) -> MambaAttentionBackendEnum:
return MambaAttentionBackendEnum.GDN_ATTN

def get_state_dtype(self) -> tuple[torch.dtype, ...]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,37 @@
from einops import rearrange
from torch import nn

from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.mamba.gdn.base import GatedDeltaNetAttention
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum

from .fla.ops.kda import (
from ...fla.ops.kda import (
FusedRMSNormGated,
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from .linear import (
from ...linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from .mamba.abstract import MambaBase
from .mamba.mamba_utils import (
from ..mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
is_conv_state_dim_first,
)
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig
from ..ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update

logger = init_logger(__name__)

Expand Down Expand Up @@ -83,11 +81,8 @@ def kda_attention_fake(
)


class KimiDeltaAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> MambaAttentionBackendEnum:
return MambaAttentionBackendEnum.GDN_ATTN

@PluggableLayer.register("kimi_gated_delta_net_attention")
class KimiGatedDeltaNetAttention(GatedDeltaNetAttention):
def get_state_dtype(
self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
Expand All @@ -106,28 +101,16 @@ def get_state_shape(

def __init__(
self,
layer_idx: int,
hidden_size: int,
quant_config: QuantizationConfig | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
rms_norm_eps: float = 1e-5,
config: KimiLinearConfig,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = hidden_size
self.model_config = model_config
self.cache_config = cache_config
if model_config is None:
raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config # type: ignore[attr-defined]
super().__init__(config, vllm_config, prefix)

kda_config = config.linear_attn_config # type: ignore[attr-defined]
assert kda_config is not None, "linear_attn_config must be set"
self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)

Expand All @@ -138,37 +121,37 @@ def __init__(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.v_proj",
)

self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.f_a_proj",
)

self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
Expand All @@ -181,7 +164,7 @@ def __init__(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.b_proj",
)

Expand Down Expand Up @@ -223,24 +206,22 @@ def __init__(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_norm = FusedRMSNormGated(self.head_dim, activation="sigmoid")
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
quant_config=self.quant_config,
prefix=f"{prefix}.o_proj",
)

Expand Down
Loading
Loading