From 6fd7f8ee76f8487f3b993ffd306430ea1ec3e49f Mon Sep 17 00:00:00 2001 From: Hemmi Shinichi Date: Mon, 25 Aug 2025 02:29:33 +0900 Subject: [PATCH 1/3] Remove dependency on transformers.PreTrainedModel Signed-off-by: Hemmi Shinichi --- vllm/model_executor/models/plamo2.py | 32 +++++++--------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e5034b536266..e13c39c797dd 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -6,7 +6,7 @@ import torch from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention @@ -72,20 +72,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore vocab_size: int -class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore - - def _init_weights(self, module: torch.nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -588,7 +574,7 @@ def forward( return hidden_states, residual -class Plamo2Decoder(torch.nn.Module): +class Plamo2Decoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -631,10 +617,10 @@ def forward( return hidden_states, residual -class Plamo2Model(Plamo2PreTrainedModel): +class Plamo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + super().__init__() config = vllm_config.model_config.hf_config @@ -654,7 +640,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size)) self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -700,8 +685,8 @@ def forward( return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, - IsHybrid, SupportsV0Only): +class Plamo2ForCausalLM(nn.Module, HasInnerState, SupportsPP, IsHybrid, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -711,12 +696,13 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config scheduler_config = vllm_config.scheduler_config assert not vllm_config.cache_config.enable_prefix_caching, \ "PLaMo2 currently does not support prefix caching" - super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -750,8 +736,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - # Initialize weights and apply final processing - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) From fc234f42fadf64c4639f67da696cfb2945808b59 Mon Sep 17 00:00:00 2001 From: Hemmi Shinichi Date: Mon, 25 Aug 2025 13:18:47 +0900 Subject: [PATCH 2/3] Drop restriction of transformers version Signed-off-by: Hemmi Shinichi --- tests/models/registry.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 4871ade23104..901e653f0104 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -279,8 +279,6 @@ def check_available_online( "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - max_transformers_version="4.53", - transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", max_transformers_version="4.53", From 970713c73eb75fcb1f391222fe23f7e86ac66aca Mon Sep 17 00:00:00 2001 From: Hemmi Shinichi Date: Mon, 25 Aug 2025 13:19:17 +0900 Subject: [PATCH 3/3] Activate test for plamo2 Signed-off-by: Hemmi Shinichi --- tests/models/language/generation/test_hybrid.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2055c44c83cd..67040c9c2997 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -25,8 +25,7 @@ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # skipping until vLLM implementation issues are resolved - # "pfnet/plamo-2-1b", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview",