Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
PoolingParamsUpdate,
)
from vllm.model_executor.layers.pooler.seqwise import (
CLSPool,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import (
pooler_for_token_classify,
Expand Down Expand Up @@ -94,9 +94,9 @@ def forward(


class BertPooler(SequencePooler):
def __init__(self, config: BertConfig):
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
super().__init__(
pooling=CLSPool(),
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head,
)

Expand Down Expand Up @@ -450,7 +450,11 @@ def __init__(
)

config = vllm_config.model_config.hf_config
self.pooler = BertPooler(config)

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.pooler = BertPooler(config, pooler_config)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
Expand Down Expand Up @@ -711,6 +715,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)

# None of vLLM's built-in sequence pooling types are
# applicable so it is overwritten by SPLADESparsePooler
pooling_mode = getattr(self, "_splade_pooling", "max")

cls_id = getattr(cfg, "cls_token_id", None)
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __init__(
add_pooling_layer: bool = False,
):
super().__init__()

self.vllm_config = vllm_config
self.add_pooling_layer = add_pooling_layer
self.config = vllm_config.model_config.hf_config
Expand All @@ -463,7 +464,14 @@ def __init__(
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder",
)
self.pooler = BertPooler(self.config) if add_pooling_layer else None

if add_pooling_layer:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.pooler = BertPooler(self.config, pooler_config)
else:
self.pooler = None

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Expand All @@ -17,6 +17,7 @@
SequencePoolerHeadOutput,
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.models.llama import LlamaForCausalLM
Expand Down Expand Up @@ -177,9 +178,13 @@ def forward(


class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig):
def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig):
super().__init__(
pooling=GritLMMeanPool(model_config),
pooling=(
GritLMMeanPool(model_config)
if pooler_config.seq_pooling_type == "MEAN"
else get_seq_pooling_method(pooler_config.seq_pooling_type)
),
head=self.head,
)

Expand Down Expand Up @@ -235,6 +240,6 @@ def __init__(
self.pooler = DispatchPooler(
{
"token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
"embed": GritLMPooler(vllm_config.model_config, pooler_config),
}
)
16 changes: 12 additions & 4 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.activations import ACT2FN

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.attention.encoder_only_attention import (
EncoderOnlyAttention,
Expand Down Expand Up @@ -282,9 +282,14 @@ def forward(


class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig):
def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig):
hf_pooling_type = config.classifier_pooling.upper()
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
# explicitly or not, so we always use the HF pooling type for now.

super().__init__(
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
pooling=get_seq_pooling_method(hf_pooling_type),
head=self.head,
)

Expand Down Expand Up @@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config

self.config = config
self.model = ModernBertModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
Expand All @@ -324,11 +331,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.num_labels,
dtype=vllm_config.model_config.head_dtype,
)
self.pooling = ModernBertPooler(config)

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None

self.pooling = ModernBertPooler(config, pooler_config)

self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.pooling,
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
Expand Down Expand Up @@ -86,7 +85,7 @@ def __init__(self, model_config: "ModelConfig"):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
# Token extraction has already been applied in `pooler.pooling`
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
Expand Down Expand Up @@ -194,7 +193,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)

Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/models/transformers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling

Expand All @@ -32,7 +31,7 @@


class EmbeddingMixin(VllmModelForPooling):
default_pooling_type = "CLS"
default_seq_pooling_type = "CLS"

def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
Expand All @@ -47,7 +46,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):


class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
default_pooling_type = "CLS"
default_seq_pooling_type = "CLS"

def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
Expand Down Expand Up @@ -85,8 +84,10 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
self.init_parameters(self.classifier, dtype=self.model_config.head_dtype)

class ClassifierWithReshape(self.classifier.__class__):
"""CLSPool has already been applied in `pooling`.
Add dim to match expected input shape of `classifier.forward`."""
"""
Token extraction has already been applied in `pooler.pooling`.
Add dim to match expected input shape of `classifier.forward`.
"""

def forward(self, *args, **kwargs):
if len(args) > 0:
Expand All @@ -97,6 +98,5 @@ def forward(self, *args, **kwargs):

self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)