Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# skipping until vLLM implementation issues are resolved
# "pfnet/plamo-2-1b",
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-granite/granite-4.0-tiny-preview",
Expand Down
2 changes: 0 additions & 2 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ def check_available_online(
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
max_transformers_version="4.53",
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
max_transformers_version="4.53",
Expand Down
32 changes: 8 additions & 24 deletions vllm/model_executor/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PretrainedConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
Expand Down Expand Up @@ -72,20 +72,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore
vocab_size: int


class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore

def _init_weights(self, module: torch.nn.Module) -> None:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


def is_mamba(config: Plamo2Config, i: int) -> bool:
assert config.mamba_step > 1

Expand Down Expand Up @@ -588,7 +574,7 @@ def forward(
return hidden_states, residual


class Plamo2Decoder(torch.nn.Module):
class Plamo2Decoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
Expand Down Expand Up @@ -631,10 +617,10 @@ def forward(
return hidden_states, residual


class Plamo2Model(Plamo2PreTrainedModel):
class Plamo2Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config.model_config.hf_config)
super().__init__()

config = vllm_config.model_config.hf_config

Expand All @@ -654,7 +640,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
["hidden_states", "residual"], config.hidden_size))
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand Down Expand Up @@ -700,8 +685,8 @@ def forward(
return hidden_states


class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
IsHybrid, SupportsV0Only):
class Plamo2ForCausalLM(nn.Module, HasInnerState, SupportsPP, IsHybrid,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand All @@ -711,12 +696,13 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()

config = vllm_config.model_config.hf_config
scheduler_config = vllm_config.scheduler_config
assert not vllm_config.cache_config.enable_prefix_caching, \
"PLaMo2 currently does not support prefix caching"

super().__init__(config)
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -750,8 +736,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down