From d69c929c8e0d44daba438f5350ed2190be066432 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi Date: Tue, 10 Jun 2025 19:54:26 +0900 Subject: [PATCH 01/15] Replace custom rms_norm implementation Signed-off-by: Shinichi Hemmi --- vllm/model_executor/models/plamo2.py | 38 +++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 670576c68efd..b850349a66f8 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -97,20 +97,6 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: return (i % config.mamba_step) != (config.mamba_step // 2) -# TODO(Shinichi): Replace this with RMSNorm. -def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, - eps: float) -> torch.Tensor: - input_shape = hidden_states.shape - hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape) - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + eps) - hidden_states = hidden_states.to(input_dtype) - hidden_states = weight * hidden_states - return hidden_states.reshape(input_shape) - - def _swiglu(h: torch.Tensor) -> torch.Tensor: h0, h1 = h.chunk(2, dim=-1) return torch.nn.functional.silu(h0) * h1 @@ -405,10 +391,18 @@ def __init__(self, base=self.rope_theta, rope_scaling=self.rope_scaling, ) - self.q_weight = torch.nn.Parameter( + self.q_norm = RMSNorm(config.hidden_size_per_head, + eps=config.rms_norm_eps) + self.q_norm.weight = torch.nn.Parameter( torch.ones((self.num_heads, config.hidden_size_per_head))) - self.k_weight = torch.nn.Parameter( + set_weight_attrs(self.q_norm.weight, + {"weight_loader": sharded_weight_loader(0)}) + self.k_norm = RMSNorm(config.hidden_size_per_head, + eps=config.rms_norm_eps) + self.k_norm.weight = torch.nn.Parameter( torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + set_weight_attrs(self.k_norm.weight, + {"weight_loader": sharded_weight_loader(0)}) self.attn = Attention( self.num_heads, @@ -428,8 +422,14 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = _rms_norm(q, self.q_weight, 1e-6) - k = _rms_norm(k, self.k_weight, 1e-6) + + q_shape = q.shape + q = q.reshape(q_shape[:-1] + self.q_norm.weight.shape) + q = self.q_norm.forward_native(q).reshape(q_shape) + k_shape = k.shape + k = k.reshape(k_shape[:-1] + self.k_norm.weight.shape) + k = self.k_norm.forward_native(k).reshape(k_shape) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -703,6 +703,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ".B_norm_weight": ".B_norm.weight", ".C_norm_weight": ".C_norm.weight", ".dt_norm_weight": ".dt_norm.weight", + ".q_weight": ".q_norm.weight", + ".k_weight": ".k_norm.weight", } # Apply replacements based on the defined mappings for old, new in replacements.items(): From a5554db19d53addb12f973cba0b2b1c4ba4bc9e4 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 03:26:10 +0900 Subject: [PATCH 02/15] Drop get_initial_dt_bias Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b850349a66f8..6f107c3f3f55 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" -import math from collections.abc import Iterable from typing import Optional @@ -77,17 +76,6 @@ def _init_weights(self, module: torch.nn.Module) -> None: module.weight.data[module.padding_idx].zero_() -def get_initial_dt_bias(num_heads: int) -> torch.Tensor: - dt_min = 0.001 - dt_max = 0.1 - dt = torch.exp( - torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + - math.log(dt_min)) - dt = torch.clamp(dt, 1e-4) - inv_dt = dt + torch.log(-torch.expm1(-dt)) - return inv_dt - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -158,7 +146,6 @@ def __init__(self, bias=False, prefix=f"{prefix}.dt_proj", ) - self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) tp_size = get_tensor_model_parallel_world_size() self.A = nn.Parameter( @@ -168,11 +155,14 @@ def __init__(self, dtype=torch.float32, )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( sharded_weight_loader(0), lambda x: -torch.exp(x.float())) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( self.intermediate_size, From 2cb34771dca04d4ed273516317b16e1566f35a9a Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 03:29:34 +0900 Subject: [PATCH 03/15] Use vLLM native SiluAndMul Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 6f107c3f3f55..207bb78f33a0 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,11 +86,6 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: return (i % config.mamba_step) != (config.mamba_step // 2) -def _swiglu(h: torch.Tensor) -> torch.Tensor: - h0, h1 = h.chunk(2, dim=-1) - return torch.nn.functional.silu(h0) * h1 - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class Plamo2MambaMixer(nn.Module): # TODO(Shinichi): Rebase on Mamba2 implementation. @@ -312,6 +308,7 @@ def __init__( bias=False, prefix=f"{prefix}.gate_up_proj", quant_config=quant_config) + self.act = SiluAndMul() self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=False, @@ -320,7 +317,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: h = self.gate_up_proj(hidden_states)[0] - h = _swiglu(h) + h = self.act(h) output, _ = self.down_proj(h) return output # type: ignore From 1caf6901eaa73b3b2928d12697f47bc7da06ec50 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 03:43:13 +0900 Subject: [PATCH 04/15] Propagate configs through VLLMConfig Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 65 ++++++++++++++-------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 207bb78f33a0..eec1e2d6379a 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -91,21 +91,21 @@ class Plamo2MambaMixer(nn.Module): # TODO(Shinichi): Rebase on Mamba2 implementation. def __init__(self, - config: Plamo2Config, - cache_config: CacheConfig, - quant_config: QuantizationConfig, - max_model_len: int, + vllm_config: VllmConfig, + *, prefix: str = "", **kwargs) -> None: super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.intermediate_size = (config.mamba_num_heads * - config.hidden_size_per_head) - self.hidden_size_per_head = config.hidden_size_per_head - self.num_heads = config.mamba_num_heads + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.hidden_size = self.config.hidden_size + self.ssm_state_size = self.config.mamba_d_state + self.conv_kernel_size = self.config.mamba_d_conv + self.intermediate_size = (self.config.mamba_num_heads * + self.config.hidden_size_per_head) + self.head_dim = self.config.hidden_size_per_head + self.hidden_size_per_head = self.config.hidden_size_per_head + self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) self.use_conv_bias = False self.use_bias = False @@ -170,9 +170,12 @@ def __init__(self, # The activation function is fixed to SiLU. self.activation = "silu" - self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) - self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.dt_norm = RMSNorm(self.time_step_rank, + eps=self.config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, + eps=self.config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, + eps=self.config.rms_norm_eps) def forward( self, @@ -325,13 +328,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Plamo2AttentionMixer(nn.Module): def __init__(self, - config: Plamo2Config, - cache_config: CacheConfig, - quant_config: QuantizationConfig, - max_model_len: int | None = None, + *, + vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -369,12 +373,16 @@ def __init__(self, "rope_theta") else 10000 self.rope_scaling = config.rope_scaling if hasattr( config, "rope_scaling") else None + max_position = config.max_position_embeddings + if hasattr(vllm_config.model_config, "max_model_len") and isinstance( + vllm_config.model_config.max_model_len, int): + max_position = min(max_position, + vllm_config.model_config.max_model_len) - assert max_model_len is not None, "max_model_len must be provided" self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position=max_model_len, + max_position=max_position, base=self.rope_theta, rope_scaling=self.rope_scaling, ) @@ -428,27 +436,18 @@ class Plamo2DecoderLayer(nn.Module): def __init__(self, vllm_config: VllmConfig, layer_idx: int, - max_model_len: int | None = None, prefix: str = "", **kwargs) -> None: super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - max_model_len = vllm_config.scheduler_config.max_model_len self.is_mamba = is_mamba(config, layer_idx) if self.is_mamba: - self.mixer = Plamo2MambaMixer(config=config, - cache_config=cache_config, - quant_config=quant_config, - max_model_len=max_model_len, + self.mixer = Plamo2MambaMixer(vllm_config=vllm_config, prefix=f"{prefix}.mixer") else: - self.mixer = Plamo2AttentionMixer(config=config, - cache_config=cache_config, - quant_config=quant_config, - max_model_len=max_model_len, + self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config, prefix=f"{prefix}.mixer") self.mlp = DenseMLP(config=config, From 90cb67d541a9ccf7ee151ad6c94aa225a56872a8 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 03:33:19 +0900 Subject: [PATCH 05/15] Split sample method Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index eec1e2d6379a..dbba2598c8f9 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -26,6 +26,7 @@ selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -615,7 +616,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) - + self.sampler = get_sampler() # Initialize weights and apply final processing self.post_init() @@ -670,6 +671,14 @@ def compute_logits( sampling_metadata) return logits + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: From b35f827876beabe163734078e55b867af7c168c6 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 04:00:59 +0900 Subject: [PATCH 06/15] Support pipeline parallel Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 64 +++++++++++++++++++++------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index dbba2598c8f9..e570a205f757 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -12,6 +12,7 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -32,10 +33,12 @@ from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsV0Only) + SupportsPP, SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.models.utils import ( + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -494,14 +497,18 @@ class Plamo2Decoder(torch.nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + config = vllm_config.model_config.hf_config + extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + return Plamo2DecoderLayer(vllm_config=vllm_config, + layer_idx=layer_idx, + prefix=prefix, + **extra_kwargs) - self.layers = nn.ModuleList([ - Plamo2DecoderLayer(vllm_config=vllm_config, - layer_idx=i, - prefix=f"{prefix}.layers.{i}") - for i in range(num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") def forward( self, @@ -511,7 +518,7 @@ def forward( mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: mamba_cache_index = 0 - for layer in self.layers: + for layer in self.layers[self.start_layer:self.end_layer]: layer_mamba_cache_params = None if layer.is_mamba: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( @@ -544,10 +551,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -556,21 +569,33 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO(Shinichi): Implement pipeline parallelism. - hidden_states = self.embed_tokens(input_ids) - residual = None + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, mamba_cache_params=mamba_cache_params) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, - SupportsV0Only): +class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, + IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -617,9 +642,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -729,6 +759,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): elif "model.norm.weight" in name: loaded_weight += 1.0 + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 6a782cad3899fbb0b6a5245772bbb608c9932dbe Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 13:30:52 +0900 Subject: [PATCH 07/15] Properly set prefix and quant_config Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e570a205f757..902d35718443 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -117,6 +117,7 @@ def __init__(self, input_size=self.conv_kernel_size, output_size=self.intermediate_size, bias=self.use_conv_bias, + prefix=f"{prefix}.conv1d", ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -128,6 +129,7 @@ def __init__(self, self.hidden_size, [self.intermediate_size] * 2, bias=self.use_bias, + quant_config=self.quant_config, prefix=f"{prefix}.in_proj", ) # selective projection used to make dt, B and C input dependent @@ -135,6 +137,7 @@ def __init__(self, self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False, + quant_config=self.quant_config, prefix=f"{prefix}.bcdt_proj", ) # time step projection (discretization) - @@ -144,6 +147,7 @@ def __init__(self, self.time_step_rank, self.num_heads, bias=False, + quant_config=self.quant_config, prefix=f"{prefix}.dt_proj", ) @@ -169,6 +173,7 @@ def __init__(self, self.hidden_size, bias=self.use_bias, input_is_parallel=True, + quant_config=self.quant_config, prefix=f"{prefix}.out_proj", ) # The activation function is fixed to SiLU. From 51ca55fbcdfd6e9ad0368f801364e02e83bbe2b3 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 14:12:34 +0900 Subject: [PATCH 08/15] Modify bias handling Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 902d35718443..289417a59c16 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -111,13 +111,13 @@ def __init__(self, self.hidden_size_per_head = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) - self.use_conv_bias = False self.use_bias = False self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, - bias=self.use_conv_bias, + bias=False, prefix=f"{prefix}.conv1d", + return_bias=False, ) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in @@ -139,6 +139,7 @@ def __init__(self, bias=False, quant_config=self.quant_config, prefix=f"{prefix}.bcdt_proj", + return_bias=False, ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, @@ -149,6 +150,7 @@ def __init__(self, bias=False, quant_config=self.quant_config, prefix=f"{prefix}.dt_proj", + return_bias=False, ) tp_size = get_tensor_model_parallel_world_size() @@ -240,7 +242,7 @@ def forward( # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0] + ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1)) # Splitting the ssm_parameters as in modeling_plamo.py. B, C, time_step = torch.split( @@ -252,7 +254,7 @@ def forward( B = self.B_norm(B.contiguous()) C = self.C_norm(C.contiguous()) - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + discrete_time_step = self.dt_proj(time_step).transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = (self.dt_bias.float() if hasattr( self.dt_proj, "bias") else None) @@ -319,19 +321,21 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, prefix=f"{prefix}.gate_up_proj", - quant_config=quant_config) + quant_config=quant_config, + return_bias=False, + ) self.act = SiluAndMul() self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=False, prefix=f"{prefix}.down_proj", - quant_config=quant_config) + quant_config=quant_config, + return_bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - h = self.gate_up_proj(hidden_states)[0] + h = self.gate_up_proj(hidden_states) h = self.act(h) - output, _ = self.down_proj(h) - return output # type: ignore + return self.down_proj(h) class Plamo2AttentionMixer(nn.Module): From 7c423f137a319dfca7cada89e9d219216c44b74e Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 16:14:06 +0900 Subject: [PATCH 09/15] Support tensor parallel Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger --- vllm/model_executor/models/plamo2.py | 72 ++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 289417a59c16..6684ef3fe761 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -11,7 +11,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -107,11 +107,11 @@ def __init__(self, self.conv_kernel_size = self.config.mamba_d_conv self.intermediate_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) + self.tp_size = get_tensor_model_parallel_world_size() self.head_dim = self.config.hidden_size_per_head self.hidden_size_per_head = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) - self.use_bias = False self.conv1d = ColumnParallelLinear( input_size=self.conv_kernel_size, output_size=self.intermediate_size, @@ -128,9 +128,10 @@ def __init__(self, self.in_proj = MergedColumnParallelLinear( self.hidden_size, [self.intermediate_size] * 2, - bias=self.use_bias, + bias=False, quant_config=self.quant_config, prefix=f"{prefix}.in_proj", + return_bias=False, ) # selective projection used to make dt, B and C input dependent self.bcdt_proj = RowParallelLinear( @@ -161,7 +162,8 @@ def __init__(self, dtype=torch.float32, )) self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + self.dt_bias = nn.Parameter( + torch.ones(divide(self.num_heads, self.tp_size))) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( @@ -173,10 +175,11 @@ def __init__(self, self.out_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, - bias=self.use_bias, + bias=False, input_is_parallel=True, quant_config=self.quant_config, prefix=f"{prefix}.out_proj", + return_bias=False, ) # The activation function is fixed to SiLU. self.activation = "silu" @@ -198,16 +201,8 @@ def forward( attn_metadata: AttentionMetadata = get_forward_context().attn_metadata # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0] - # Reshaping the projected states as in modeling_plamo.py. - length = len(hidden_states) - projected_states = projected_states.reshape(length, self.num_heads, -1) - gate, hidden_states = torch.split( - projected_states, - [self.hidden_size_per_head, self.hidden_size_per_head], - dim=-1) - hidden_states = hidden_states.reshape(length, -1).transpose(0, 1) - gate = gate.reshape(length, -1).transpose(0, 1) + projected_states = self.in_proj(hidden_states) + gate, hidden_states = projected_states.transpose(0, 1).chunk(2, dim=-2) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), @@ -263,11 +258,11 @@ def forward( discrete_time_step = discrete_time_step.transpose( 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) discrete_time_step = discrete_time_step.reshape( - -1, self.intermediate_size).transpose(0, 1) + -1, self.intermediate_size // self.tp_size).transpose(0, 1) time_proj_bias = time_proj_bias[..., None].expand(-1, self.hidden_size_per_head) - time_proj_bias = time_proj_bias.reshape(self.intermediate_size) + time_proj_bias = time_proj_bias.reshape(-1) if attn_metadata.query_start_loc is not None \ and attn_metadata.context_lens_tensor is not None: @@ -301,8 +296,7 @@ def forward( scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] + contextualized_states = self.out_proj(scan_outputs.transpose(-2, -1)) return contextualized_states @@ -409,8 +403,12 @@ def __init__(self, eps=config.rms_norm_eps) self.k_norm.weight = torch.nn.Parameter( torch.ones((self.num_kv_heads, config.hidden_size_per_head))) - set_weight_attrs(self.k_norm.weight, - {"weight_loader": sharded_weight_loader(0)}) + # Tensor-parallelism shards the K norm weights to the tp ranks + # in a head-wise manner. This approach does not work if there is only + # a single KV head, as is the case for PLaMo 2-1B. + if self.total_num_kv_heads != 1: + set_weight_attrs(self.k_norm.weight, + {"weight_loader": sharded_weight_loader(0)}) self.attn = Attention( self.num_heads, @@ -756,6 +754,38 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loaded_weight = loaded_weight[:, None].expand( -1, self.config.hidden_size_per_head) loaded_weight = loaded_weight.reshape(-1) + # Reshape the in_proj weights to match the shape expected + # by MergedColumnParallelLinear. + # This works both for unquantized weights and + # for quantized weights. + # In the quantized case, the weights are already transposed. + # Also, in addition to the quantized weights, + # the zero points and scales have to be reshaped as well. + # Packing should not be affected by this. + if ".mixer.in_proj.weight" in name \ + or "mixer.in_proj.qweight" in name \ + or "mixer.in_proj.scales" in name \ + or "mixer.in_proj.qzeros" in name: + if "mixer.in_proj.weight" in name: + loaded_weight = loaded_weight.transpose(0, 1) + # for weight: + # loaded_weight.shape[0] == self.config.hidden_size + # for qweight: + # loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa + # for scales and qzeros: + # loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa + loaded_weight = loaded_weight.reshape( + loaded_weight.shape[0], self.config.mamba_num_heads, -1) + gate_weight, hidden_states_weight = loaded_weight.chunk(2, + dim=-1) + gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1) + hidden_states_weight = hidden_states_weight.reshape( + loaded_weight.shape[0], -1) + loaded_weight = torch.cat([gate_weight, hidden_states_weight], + dim=-1) + if "mixer.in_proj.weight" in name: + loaded_weight = loaded_weight.transpose(0, 1) + # Offset parameter with vllm's RMSNorm haven't been supported yet. if ".pre_mixer_norm" in name: loaded_weight += 1.0 From 8e848d4f0e86005b9558ce22afc481af369aec8c Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 15 Jun 2025 16:55:07 +0900 Subject: [PATCH 10/15] Use mamba2 kernel Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Sixue Wang --- vllm/model_executor/models/plamo2.py | 281 +++++++++++++++++---------- 1 file changed, 180 insertions(+), 101 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 6684ef3fe761..0b7d28f28f7f 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -21,10 +21,14 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -90,9 +94,10 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: return (i % config.mamba_step) != (config.mamba_step // 2) -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +# Adapted from: +# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 +# transformers.models.mamba.modeling_mamba.MambaMixer class Plamo2MambaMixer(nn.Module): - # TODO(Shinichi): Rebase on Mamba2 implementation. def __init__(self, vllm_config: VllmConfig, @@ -108,8 +113,9 @@ def __init__(self, self.intermediate_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) self.tp_size = get_tensor_model_parallel_world_size() + self.intermediate_size_per_tp_worker = \ + self.intermediate_size // self.tp_size self.head_dim = self.config.hidden_size_per_head - self.hidden_size_per_head = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) self.conv1d = ColumnParallelLinear( @@ -154,14 +160,12 @@ def __init__(self, return_bias=False, ) - tp_size = get_tensor_model_parallel_world_size() self.A = nn.Parameter( torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, + divide(self.num_heads, self.tp_size), dtype=torch.float32, )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) self.dt_bias = nn.Parameter( torch.ones(divide(self.num_heads, self.tp_size))) @@ -191,113 +195,184 @@ def __init__(self, self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + def _project_ssm_parameters(self, hidden_states): + ssm_parameters = self.bcdt_proj(hidden_states) + B, C, time_step = torch.split( + ssm_parameters, + [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], + dim=-1, + ) + + # vllm._custom_ops.rms_norm requires contiguous input tensors. + time_step = self.dt_norm(time_step.contiguous()) + B = self.B_norm(B.contiguous()) + C = self.C_norm(C.contiguous()) + dt = self.dt_proj(time_step) + return B, C, dt + def forward( self, hidden_states: torch.Tensor, mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, **kwargs, ) -> torch.Tensor: + # mamba2_metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) - gate, hidden_states = projected_states.transpose(0, 1).chunk(2, dim=-2) + gate, hidden_states = projected_states.chunk(2, dim=-1) # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: + # Separate prefill and decode by splitting varlen input + # Split along token dimension + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], + dim=0) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + mamba_cache_params.state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] + if has_prefill else None) + + ssd_output_list = [] + + if has_prefill: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, + hidden_states_p = causal_conv1d_fn( + hidden_states_p.transpose(0, 1), conv_weights, self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), + has_initial_state=mamba2_metadata.has_initial_states, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + hidden_states_p = hidden_states_p.transpose(0, 1) + hidden_states_p = hidden_states_p[:num_prefill_tokens] + # In some instances, the following `bcdt_proj` op + # requires contiguous inputs + # (e.g. if the Marlin kernel is used). + hidden_states_p = hidden_states_p.contiguous() + + B, C, dt = self._project_ssm_parameters(hidden_states_p) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + initial_states = None + if (mamba2_metadata.has_initial_states is not None + and mamba2_metadata.prep_initial_states): + # making a copy of the states + initial_states = torch.where( + mamba2_metadata.has_initial_states[:, None, None, None], + mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, + self.head_dim), + dt.unsqueeze(0), + self.A, + B.view(1, num_prefill_tokens, 1, -1), + C.view(1, num_prefill_tokens, 1, -1), + chunk_size=mamba2_metadata.chunk_size, + D=self.D, + z=gate_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, self.head_dim), + dt_bias=self.dt_bias, + seq_idx=mamba2_metadata.seq_idx, + chunk_indices=mamba2_metadata.chunk_indices, + chunk_offsets=mamba2_metadata.chunk_offsets, + cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (batch, nheads, headdim, dstate) tensor + mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + + # - reshape + ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) + + # Process decode requests + if has_decode: + hidden_states_d = causal_conv1d_update( + hidden_states_d, mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1)) - - # Splitting the ssm_parameters as in modeling_plamo.py. - B, C, time_step = torch.split( - ssm_parameters, - [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], - dim=-1, - ) - time_step = self.dt_norm(time_step.contiguous()) - B = self.B_norm(B.contiguous()) - C = self.C_norm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step).transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_bias.float() if hasattr( - self.dt_proj, "bias") else None) - - # Broadcasting as in modeling_plamo.py. - discrete_time_step = discrete_time_step.transpose( - 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) - discrete_time_step = discrete_time_step.reshape( - -1, self.intermediate_size // self.tp_size).transpose(0, 1) - time_proj_bias = time_proj_bias[..., - None].expand(-1, - self.hidden_size_per_head) - time_proj_bias = time_proj_bias.reshape(-1) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = selective_state_update( + conv_state_indices=state_indices_tensor_d) + + B, C, dt = self._project_ssm_parameters(hidden_states_d) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + A = self.A[:, None, ...][:, :, + None].expand(-1, self.head_dim, + self.config.mamba_d_state) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.unsqueeze(1) + C = C.unsqueeze(1) + hidden_states_d = hidden_states_d.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using state_indices_tensor_d + + hidden_states_d = selective_state_update( mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, + hidden_states_d, + dt, + A, B, C, - self.D, - gate.transpose(0, 1), - time_proj_bias, + D, + z=gate_d.reshape(num_decodes, -1, self.head_dim), + dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) - scan_outputs = scan_outputs.transpose(0, 1) + state_batch_indices=state_indices_tensor_d, + ) + ssd_output_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + + # Merge prefill and decode outputs before passing to MLP + hidden_states = torch.vstack(ssd_output_list) # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, -1)) - return contextualized_states + out = self.out_proj(hidden_states) + return out class DenseMLP(nn.Module): @@ -312,12 +387,13 @@ def __init__( self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_up_proj = MergedColumnParallelLinear( - self.hidden_size, [self.intermediate_size] * 2, + self.hidden_size, + [self.intermediate_size] * 2, bias=False, prefix=f"{prefix}.gate_up_proj", quant_config=quant_config, return_bias=False, - ) + ) self.act = SiluAndMul() self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, @@ -423,7 +499,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) @@ -479,6 +554,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, **kwargs, ): if residual is None: @@ -488,10 +564,12 @@ def forward( hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) - hidden_states = self.mixer(positions=positions, - hidden_states=hidden_states, - residual=residual, - mamba_cache_params=mamba_cache_params) + hidden_states = self.mixer( + positions=positions, + hidden_states=hidden_states, + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -523,6 +601,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, ) -> torch.Tensor: mamba_cache_index = 0 for layer in self.layers[self.start_layer:self.end_layer]: @@ -536,7 +615,9 @@ def forward( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=layer_mamba_cache_params) + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) return hidden_states, residual @@ -587,11 +668,19 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + hidden_states, residual = self.layers( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params) + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -685,7 +774,7 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: + self) -> tuple[tuple[int, int], tuple[int, int, int]]: world_size = get_tensor_model_parallel_world_size() hidden_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) @@ -694,7 +783,8 @@ def _get_mamba_cache_shape( self.config.mamba_d_conv - 1, ) temporal_state_shape = ( - hidden_size // world_size, + divide(self.config.mamba_num_heads, world_size), + self.config.hidden_size_per_head, self.config.mamba_d_state, ) return conv_state_shape, temporal_state_shape @@ -743,17 +833,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if old in name: name = name.replace(old, new) - # Broadcast the loaded weight to match the model's parameter shape. - if ".A" in name: - loaded_weight = loaded_weight[:, None, None].expand( - -1, self.config.hidden_size_per_head, - self.config.mamba_d_state) - loaded_weight = loaded_weight.reshape( - -1, self.config.mamba_d_state) - elif ".D" in name: - loaded_weight = loaded_weight[:, None].expand( - -1, self.config.hidden_size_per_head) - loaded_weight = loaded_weight.reshape(-1) # Reshape the in_proj weights to match the shape expected # by MergedColumnParallelLinear. # This works both for unquantized weights and From 53181094abe80f384aad515a5a131120fa7de372 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 16 Jun 2025 07:57:50 +0900 Subject: [PATCH 11/15] Note PP support in the docs Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a435c59a3042..5518a1f2c3b3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -373,7 +373,7 @@ Specified using `--task generate`. | `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | From 5cf171aeb0e84a831635d228d0493cd9f46e914a Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 16 Jun 2025 15:16:25 +0900 Subject: [PATCH 12/15] Support torch compile Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger --- vllm/model_executor/models/plamo2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 0b7d28f28f7f..5565e2216394 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -408,6 +409,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.down_proj(h) +@support_torch_compile class Plamo2AttentionMixer(nn.Module): def __init__(self, From 45e6d6d0d2af8a1e0d8ad5e5d497355d762b4023 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 17 Jun 2025 17:40:46 +0900 Subject: [PATCH 13/15] Activate tests for plamo2 Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/distributed/test_pipeline_parallel.py | 1 + tests/quantization/test_experts_int8.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 7d569fd83821..fcedd85d8870 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -174,6 +174,7 @@ def iter_params(self, model_id: str): "internlm/internlm2-chat-7b": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), + "pfnet/plamo-2-1b": PPTestSettings.fast(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), # Tests TransformersForCausalLM "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 50179b9a904d..84a656a3b9da 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -9,7 +9,7 @@ from tests.quantization.utils import is_quant_method_supported -MODELS = ["ai21labs/Jamba-tiny-random"] +MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] @pytest.mark.skipif(not is_quant_method_supported("experts_int8"), From ed09215e1803e27c1e79742e7c5f3f4d7323cf28 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:31:26 +0900 Subject: [PATCH 14/15] Update plamo mamba mixer comments Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger --- vllm/model_executor/models/plamo2.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 5565e2216394..8a61563115ac 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -259,13 +259,11 @@ def forward( ssd_output_list = [] + # Process prefill requests if has_prefill: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| + # 2. Convolution sequence transformation + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" hidden_states_p = causal_conv1d_fn( hidden_states_p.transpose(0, 1), conv_weights, @@ -284,7 +282,7 @@ def forward( B, C, dt = self._project_ssm_parameters(hidden_states_p) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) + # 3. State Space Model sequence transformation initial_states = None if (mamba2_metadata.has_initial_states is not None and mamba2_metadata.prep_initial_states): @@ -325,6 +323,7 @@ def forward( # Process decode requests if has_decode: + # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, mamba_cache_params.conv_state, @@ -335,7 +334,7 @@ def forward( B, C, dt = self._project_ssm_parameters(hidden_states_d) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) + # 3. State Space Model sequence transformation A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.config.mamba_d_state) From cc367efd0114b77501ddacadee47dd4077576f8c Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:26:03 +0900 Subject: [PATCH 15/15] Ensure num_heads is divisible by tp_size Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8a61563115ac..9bc577cfe3a3 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -363,6 +363,7 @@ def forward( dt_softplus=True, state_batch_indices=state_indices_tensor_d, ) + assert self.num_heads % self.tp_size == 0 ssd_output_list.append( hidden_states_d.view(-1, (self.num_heads // self.tp_size) * self.head_dim))