Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2ed4c33
Add models to registry and test
hmellor Sep 20, 2025
4946646
Add legacy base model names
hmellor Sep 20, 2025
db6db8f
Handle `embedding_size` which can be different to `hidden_size`
hmellor Sep 20, 2025
492f1d5
Switch to using pooler classes
hmellor Sep 22, 2025
25667bd
Fix runner type import
hmellor Sep 24, 2025
56191d9
Fix pooler class resolution
hmellor Sep 24, 2025
cb9cdae
Update tested cheeckpoints
hmellor Sep 24, 2025
8bb99b4
Check `using_transformers_backend` agains the actual class from the r…
hmellor Sep 24, 2025
7edc010
Use Transformers backend for Roberta
hmellor Sep 24, 2025
29f577f
Don't replace Roberta yet
hmellor Sep 24, 2025
76a9040
Fix bad name
hmellor Sep 24, 2025
1725c0f
Allow setting `dtype` in `init_params`
hmellor Sep 24, 2025
66cf2df
Bert/Roberta only weight mapping for now
hmellor Sep 24, 2025
3399476
Don't load `lm_head` for pooling models
hmellor Sep 24, 2025
c4f356d
Partially fix classifier
hmellor Sep 24, 2025
5538654
pre-commit
hmellor Sep 24, 2025
db42a67
fix roberta edge case
Isotr0py Sep 25, 2025
343394d
Move `TransformersPoolingBase` to its own file
hmellor Sep 26, 2025
10ab82a
Fix registry
hmellor Sep 27, 2025
cf98214
Fix pooling class resolution
hmellor Sep 27, 2025
853fc9e
Skip another unwanted head
hmellor Sep 27, 2025
c2cb9a7
Don't use forward hook
hmellor Sep 29, 2025
b8fe0b3
Merge branch 'main' into add-new-encoder-models
hmellor Sep 30, 2025
00f1047
Remove reward class because no model currently supports it
hmellor Sep 30, 2025
3bb4020
Explicitly test pooling classes
hmellor Sep 30, 2025
372e4cc
Don't mutate `self` when checking transformers backend class
hmellor Sep 30, 2025
bf44b84
Don't raise error when checking transformers backend class
hmellor Sep 30, 2025
3c27b92
Fix loading of causal models for pooling
hmellor Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,8 @@ def check_available_online(
}

_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"),
"TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
}
Expand Down
99 changes: 41 additions & 58 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@

from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts
from .registry import HF_EXAMPLE_MODELS
from .utils import check_embeddings_close, check_logprobs_close


def get_model(arch: str) -> str:
model_info = HF_EXAMPLE_MODELS.get_hf_info(arch)
model_info.check_transformers_version(on_fail="skip")
return model_info.default


def check_implementation(
runner_ref: type[Union[HfRunner, VllmRunner]],
runner_test: type[VllmRunner],
Expand Down Expand Up @@ -170,71 +177,47 @@ def test_embed_loading(vllm_runner, model):


@pytest.mark.parametrize(
"model",
[
# Encoder model
"BAAI/bge-base-en-v1.5",
])
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
pytest.skip("Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")

with vllm_runner(model, max_model_len=512,
model_impl="transformers") as vllm_model:
"arch",
["TransformersEmbeddingModel", "TransformersForSequenceClassification"])
def test_pooling(hf_runner, vllm_runner, example_prompts, arch):
model = get_model(arch)

vllm_kwargs = dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add more models to test explicit pooling classes?

max_model_len=None,
model_impl="transformers",
compilation_config=dict(cudagraph_capture_sizes=[8]),
)

hf_kwargs = dict()
if arch == "TransformersEmbeddingModel":
hf_kwargs["is_sentence_transformer"] = True
elif arch == "TransformersForSequenceClassification":
from transformers import AutoModelForSequenceClassification
hf_kwargs["auto_cls"] = AutoModelForSequenceClassification

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]

with (vllm_runner(model, **vllm_kwargs) as
vllm_model, hf_runner(model, **hf_kwargs) as hf_model):
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()

vllm_outputs = vllm_model.embed(example_prompts)

with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
if arch == "TransformersEmbeddingModel":
vllm_outputs = vllm_model.embed(example_prompts)
hf_outputs = hf_model.encode(example_prompts)
elif arch == "TransformersForSequenceClassification":
vllm_outputs = vllm_model.classify(example_prompts)
hf_outputs = hf_model.classify(example_prompts)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)


@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
)
@pytest.mark.parametrize("dtype", ["float"])
def test_classify(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
import torch
from transformers import AutoModelForSequenceClassification

with vllm_runner(model,
max_model_len=512,
dtype=dtype,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()

vllm_outputs = vllm_model.classify(example_prompts)

with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)

assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)
33 changes: 27 additions & 6 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand All @@ -40,7 +41,6 @@
import vllm.model_executor.models as me_models
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.scheduler import RunnerType
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.v1.sample.logits_processor import LogitsProcessor
else:
Expand All @@ -52,13 +52,12 @@
"vllm.model_executor.models")
LoadConfig = Any
ParallelConfig = Any
RunnerType = Any
QuantizationMethods = Any
LogitsProcessor = Any

logger = init_logger(__name__)

RunnerOption = Literal["auto", "generate", "pooling", "draft"]
RunnerOption = Literal["auto", RunnerType]
ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
Expand Down Expand Up @@ -668,8 +667,28 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if getattr(self, "runner_type", self.runner) == "pooling":
return "TransformersModel"
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
if defaults := try_match_architecture_defaults(self.architectures[0]):
_, (runner, convert) = defaults
# Overwrite with user-specified values
if self.runner != "auto":
runner = self.runner
if self.convert not in {"auto", "none"}:
convert = self.convert
# Fall back to default values if still not set
if runner is None:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
Expand All @@ -678,7 +697,9 @@ def _get_transformers_backend_cls(self) -> str:

def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
return self.architecture == self._get_transformers_backend_cls()
used_cls = self._model_info.architecture
transformers_backend_cls = self._get_transformers_backend_cls()
return used_cls == transformers_backend_cls

@property
def registry(self):
Expand Down
13 changes: 13 additions & 0 deletions vllm/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import inspect
import textwrap
from collections.abc import Iterable
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
from typing import TYPE_CHECKING, Any, Protocol, TypeVar

Expand Down Expand Up @@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.")


def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
"""
for name in names:
if hasattr(object, name):
return getattr(object, name)
return default


def contains_object_print(text: str) -> bool:
"""
Check if the text looks like a printed Python object, e.g.
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,10 @@
}

_TRANSFORMERS_BACKEND_MODELS = {
"TransformersModel": ("transformers", "TransformersModel"),
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
}
# yapf: enable

Expand Down
81 changes: 10 additions & 71 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
Expand Down Expand Up @@ -486,10 +487,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
names = ("embedding_size", "hidden_size")
embedding_dim = getattr_iter(self.text_config, names, None)
assert embedding_dim is not None
self.model.set_input_embeddings(
VocabParallelEmbedding(
self.text_config.vocab_size,
self.text_config.hidden_size,
embedding_dim=embedding_dim,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
))
Expand Down Expand Up @@ -645,7 +649,9 @@ def create_attention_instances(
attn_type=attn_type)
return attention_instances

def init_parameters(self, module: nn.Module):
def init_parameters(self,
module: nn.Module,
dtype: Optional[torch.dtype] = None):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
Expand All @@ -659,11 +665,11 @@ def init_parameters(self, module: nn.Module):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
dtype=self.model_config.dtype,
dtype=dtype or self.model_config.dtype,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
self.init_parameters(child)
self.init_parameters(child, dtype)

def forward(
self,
Expand Down Expand Up @@ -712,73 +718,6 @@ def load_weights(self, weights: Iterable[tuple[str,
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Handle BERT-like models
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Pooling adapters will be adjacent to `model`
"model.pooler": "pooler",
"model.score": "score",
# Classifier adapter's classifier layer is renamed to score
"model.classifier": "score",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

# After creating a pooling model, `pooler` will be duplicated.
# The one inside `model` comes from the Transformers modelling code.
# The one after `model` is an adapter from vLLM.
# We want to use the adapter so we nullify the original pooler.
if getattr(self.model, "pooler", None) is not None:
self.skip_prefixes.append("pooler.")
self.model.pooler = torch.nn.Identity()

# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")

# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")

def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY

# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")

return super().create_attention_instances(attn_type)


@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):

Expand Down
Loading