Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors


class MyGemma2Embedding(nn.Module):
Expand All @@ -24,7 +23,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
Expand Down Expand Up @@ -54,13 +53,6 @@ def forward(
# Return all-zero embeddings
return torch.zeros_like(hidden_states)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

weights = self.hf_to_vllm_mapper.apply(weights)
Expand Down
123 changes: 79 additions & 44 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Literal, Optional, TypeVar, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from typing_extensions import assert_never

from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingTask = Literal["encode", "embed", "classify", "score"]


class PoolingType(IntEnum):
Expand Down Expand Up @@ -64,6 +67,48 @@
)


class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""

@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> "Pooler":
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)

if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)

return SimplePooler.from_config(resolved_config)

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intended use of get_pooling_params()? Will it get called from serving_embedding.py somehow?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be called by:

  • LLMEngine (and its async version) to validate that the request is supported by the model.
  • The model runner, in order to get information such as use_cross_encoder and logits_processing_needs_token_ids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task will be set by our code at API level

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

  • Score API: We set task="score"
  • LLMEngine: Call get_pooling_params with the task to see if it's supported
  • Model runner: Call get_pooling_params to pass use_cross_encoder to the pooler.

This abstraction lets each model define how to handle each task, instead of having static logic at the API level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is good, we're starting to accumulate too much logic at the entrypoint level.

Just to understand the last detail: is EmbeddingCompetionRequest.to_pooling_params() going to be replaced with something like EmbeddingCompetionRequest.to_pooling_task()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, since we still have some parameters (e.g. dimensions) that need to be forwarded. I will add a task attribute to PoolingParams so that the task can be set in to_pooling_params

"""
Construct the pooling parameters to use for a task,
or `None` if the task is not support.
"""
return None

@abstractmethod
def forward(
self,
hidden_states: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError


def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
Expand Down Expand Up @@ -104,17 +149,6 @@
return PoolerOutput(outputs=all_outputs)


class BasePooler(nn.Module):

@abstractmethod
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError


class PoolingMethod(nn.Module, ABC):

@staticmethod
Expand Down Expand Up @@ -424,21 +458,17 @@
return self.activation(pooled_data)


class SimplePooler(BasePooler):
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.

This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.

Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""

@classmethod
def from_config_with_defaults(

Check failure on line 471 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "from_config_with_defaults" incompatible with supertype "Pooler" [override]

Check failure on line 471 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "from_config_with_defaults" incompatible with supertype "Pooler" [override]

Check failure on line 471 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "from_config_with_defaults" incompatible with supertype "Pooler" [override]

Check failure on line 471 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "from_config_with_defaults" incompatible with supertype "Pooler" [override]
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
Expand Down Expand Up @@ -471,6 +501,17 @@
self.pooling = pooling
self.head = head

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task in ("embed", "classify", "score"):
if isinstance(self.pooling, (LastPool, CLSPool, MeanPool)):
return PoolingParams()

return None

assert_never(task)

Check failure on line 513 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

Check failure on line 513 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

Check failure on line 513 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand All @@ -481,7 +522,7 @@
return build_output(pooled_data)


class StepPooler(BasePooler):
class StepPooler(Pooler):

@classmethod
def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler":
Expand Down Expand Up @@ -543,6 +584,14 @@

return pooled_data

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)
if task in ("embed", "classify", "score"):
return None

assert_never(task)

Check failure on line 593 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

Check failure on line 593 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

Check failure on line 593 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "assert_never" has incompatible type "Literal['embed', 'classify', 'score']"; expected "Never" [arg-type]

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand All @@ -553,32 +602,6 @@
return build_output(pooled_data)


class Pooler(nn.Module):

@staticmethod
def from_config_with_defaults(
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[list[int]] = None,
) -> BasePooler:
resolved_config = ResolvedPoolingConfig.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=pooling_type,
normalize=normalize,
softmax=softmax,
step_tag_id=step_tag_id,
returned_token_ids=returned_token_ids,
)

if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)

return SimplePooler.from_config(resolved_config)


PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]]
Expand Down Expand Up @@ -618,6 +641,18 @@
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task == "embed":
return None
if task == "classify":
return PoolingParams()
if task == "score":
return PoolingParams(use_cross_encoder=True)

assert_never(task)

def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
Expand Down
27 changes: 5 additions & 22 deletions vllm/model_executor/models/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -42,8 +42,7 @@ def _create_pooling_model_cls(
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.layers.pooler import Pooler

from .utils import AutoWeightsLoader, WeightsMapper

Expand Down Expand Up @@ -73,20 +72,13 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking

Expand Down Expand Up @@ -171,10 +163,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
PoolingType, SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors

from .utils import maybe_prefix
Expand Down Expand Up @@ -213,7 +203,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
softmax=True,
)

self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
Expand All @@ -234,13 +224,6 @@ def forward(
return super().forward(input_ids, positions, intermediate_tensors,
inputs_embeds)

def pooler(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
method = getattr(self.config, "method", None)
Expand Down
20 changes: 3 additions & 17 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
Expand Down Expand Up @@ -408,7 +408,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
self.pooler = self._build_pooler(pooler_config)

def forward(
self,
Expand All @@ -422,13 +422,6 @@ def forward(
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)

Expand Down Expand Up @@ -476,7 +469,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
Expand All @@ -487,13 +480,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loaded_params = loader.load_weights(weights)
return loaded_params

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down
Loading