From ecb0656c9d588648a910173d013d93696bd75ad0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 14:13:26 +0000 Subject: [PATCH 1/7] [Model[ Avoid hardcoding pooling type Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 14 ++++++++++---- vllm/model_executor/models/bert_with_rope.py | 8 +++++++- vllm/model_executor/models/gritlm.py | 7 ++++++- vllm/model_executor/models/modernbert.py | 15 ++++++++++++--- vllm/model_executor/models/roberta.py | 4 +--- .../model_executor/models/transformers/pooling.py | 8 ++++---- 6 files changed, 40 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index b09e76015e6f..f2803122d5d5 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -25,11 +25,11 @@ PoolingParamsUpdate, ) from vllm.model_executor.layers.pooler.seqwise import ( - CLSPool, SequencePooler, SequencePoolerHeadOutput, SequencePoolerOutput, SequencePoolingMethodOutput, + get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import ( pooler_for_token_classify, @@ -94,9 +94,9 @@ def forward( class BertPooler(SequencePooler): - def __init__(self, config: BertConfig): + def __init__(self, config: BertConfig, seq_pooling_type: str): super().__init__( - pooling=CLSPool(), + pooling=get_seq_pooling_method(seq_pooling_type), head=self.head, ) @@ -450,7 +450,11 @@ def __init__( ) config = vllm_config.model_config.hf_config - self.pooler = BertPooler(config) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = BertPooler(config, pooler_config.seq_pooling_type) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) @@ -711,6 +715,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), ) + # None of vLLM's built-in sequence pooling types are + # applicable so it is overwritten by SPLADESparsePooler pooling_mode = getattr(self, "_splade_pooling", "max") cls_id = getattr(cfg, "cls_token_id", None) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index cfe350db1aa4..2fb30b299d2b 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -453,6 +453,7 @@ def __init__( add_pooling_layer: bool = False, ): super().__init__() + self.vllm_config = vllm_config self.add_pooling_layer = add_pooling_layer self.config = vllm_config.model_config.hf_config @@ -463,7 +464,12 @@ def __init__( rotary_kwargs=self.config.rotary_kwargs, prefix=f"{prefix}.encoder", ) - self.pooler = BertPooler(self.config) if add_pooling_layer else None + + if add_pooling_layer: + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = BertPooler(self.config, pooler_config.seq_pooling_type) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 34d7e5c9286f..ac30f1a3ecc0 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -17,6 +17,7 @@ SequencePoolerHeadOutput, SequencePoolingMethod, SequencePoolingMethodOutput, + pooler_for_embed, ) from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.models.llama import LlamaForCausalLM @@ -235,6 +236,10 @@ def __init__( self.pooler = DispatchPooler( { "token_embed": pooler_for_token_embed(pooler_config), - "embed": GritLMPooler(vllm_config.model_config), + "embed": ( + GritLMPooler(vllm_config.model_config) + if pooler_config.seq_pooling_type == "MEAN" + else pooler_for_embed(pooler_config) + ), } ) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index b80258daf375..879530df4606 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -282,9 +282,9 @@ def forward( class ModernBertPooler(SequencePooler): - def __init__(self, config: ModernBertConfig): + def __init__(self, config: ModernBertConfig, seq_pooling_type: str): super().__init__( - pooling=get_seq_pooling_method(config.classifier_pooling.upper()), + pooling=get_seq_pooling_method(seq_pooling_type), head=self.head, ) @@ -314,7 +314,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + self.config = config self.model = ModernBertModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") @@ -324,11 +326,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_labels, dtype=vllm_config.model_config.head_dtype, ) - self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None + hf_pooling_type = config.classifier_pooling.upper() + vllm_pooling_type = pooler_config.seq_pooling_type + assert hf_pooling_type == vllm_pooling_type, ( + f"Found inconsistent sequence pooling type: {hf_pooling_type=!r} " + f"vs. {vllm_pooling_type=!r}" + ) + self.pooling = ModernBertPooler(config, vllm_pooling_type) + self.pooler = DispatchPooler.for_seq_cls( pooler_config, pooling=self.pooling, diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index f52123901827..7bf9a68824d8 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,7 +9,6 @@ from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler -from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.bert import ( TOKEN_TYPE_SHIFT, @@ -86,7 +85,7 @@ def __init__(self, model_config: "ModelConfig"): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # CLSPool has already been applied in `pooling` + # Token extraction has already been applied in `pooler.pooling` x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) @@ -194,7 +193,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler.for_seq_cls( pooler_config, - pooling=CLSPool(), classifier=self.classifier, ) diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 470ca48eedb7..0b169cfff10c 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -23,7 +23,6 @@ from vllm.config.utils import getattr_iter from vllm.model_executor.layers.pooler import DispatchPooler -from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces_base import VllmModelForPooling @@ -85,8 +84,10 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) class ClassifierWithReshape(self.classifier.__class__): - """CLSPool has already been applied in `pooling`. - Add dim to match expected input shape of `classifier.forward`.""" + """ + Token extraction has already been applied in `pooler.pooling`. + Add dim to match expected input shape of `classifier.forward`. + """ def forward(self, *args, **kwargs): if len(args) > 0: @@ -97,6 +98,5 @@ def forward(self, *args, **kwargs): self.pooler = DispatchPooler.for_seq_cls( pooler_config, - pooling=CLSPool(), classifier=self.classifier, ) From 7c25d15dcbddfd9b16e3b8a3da807b2dfc135fb8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 14:16:57 +0000 Subject: [PATCH 2/7] Standardize Signed-off-by: DarkLight1337 --- vllm/model_executor/models/gritlm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ac30f1a3ecc0..69c1227de4a5 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -17,7 +17,7 @@ SequencePoolerHeadOutput, SequencePoolingMethod, SequencePoolingMethodOutput, - pooler_for_embed, + get_seq_pooling_method, ) from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.models.llama import LlamaForCausalLM @@ -178,9 +178,13 @@ def forward( class GritLMPooler(SequencePooler): - def __init__(self, model_config: ModelConfig): + def __init__(self, model_config: ModelConfig, seq_pooling_type: str): super().__init__( - pooling=GritLMMeanPool(model_config), + pooling=( + GritLMMeanPool(model_config) + if seq_pooling_type == "MEAN" + else get_seq_pooling_method(seq_pooling_type) + ), head=self.head, ) @@ -236,10 +240,9 @@ def __init__( self.pooler = DispatchPooler( { "token_embed": pooler_for_token_embed(pooler_config), - "embed": ( - GritLMPooler(vllm_config.model_config) - if pooler_config.seq_pooling_type == "MEAN" - else pooler_for_embed(pooler_config) + "embed": GritLMPooler( + vllm_config.model_config, + pooler_config.seq_pooling_type, ), } ) From 5d5064e1e13388b8bbc2390bfff84ac433f79598 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 14:21:30 +0000 Subject: [PATCH 3/7] Fallback Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert_with_rope.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 2fb30b299d2b..4e564db3df02 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -470,6 +470,8 @@ def __init__( assert pooler_config is not None self.pooler = BertPooler(self.config, pooler_config.seq_pooling_type) + else: + self.pooler = None def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) From 7847f34389993354951cb0a62ed2c6913e6013e1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 14:27:17 +0000 Subject: [PATCH 4/7] Handle inside Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 4 ++-- vllm/model_executor/models/bert_with_rope.py | 2 +- vllm/model_executor/models/gritlm.py | 13 +++++------- vllm/model_executor/models/modernbert.py | 21 ++++++++++---------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index f2803122d5d5..c821b8b4d139 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -94,9 +94,9 @@ def forward( class BertPooler(SequencePooler): - def __init__(self, config: BertConfig, seq_pooling_type: str): + def __init__(self, config: BertConfig, pooler_config: PoolerConfig): super().__init__( - pooling=get_seq_pooling_method(seq_pooling_type), + pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), head=self.head, ) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 4e564db3df02..8f96170627c0 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -469,7 +469,7 @@ def __init__( pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = BertPooler(self.config, pooler_config.seq_pooling_type) + self.pooler = BertPooler(self.config, pooler_config) else: self.pooler = None diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 69c1227de4a5..08ace0c8ea81 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -5,7 +5,7 @@ import numpy as np import torch -from vllm.config import ModelConfig, VllmConfig +from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import ( DispatchPooler, @@ -178,12 +178,12 @@ def forward( class GritLMPooler(SequencePooler): - def __init__(self, model_config: ModelConfig, seq_pooling_type: str): + def __init__(self, model_config: ModelConfig, pooler_config: PoolerConfig): super().__init__( pooling=( GritLMMeanPool(model_config) - if seq_pooling_type == "MEAN" - else get_seq_pooling_method(seq_pooling_type) + if pooler_config.seq_pooling_type == "MEAN" + else get_seq_pooling_method(pooler_config.seq_pooling_type) ), head=self.head, ) @@ -240,9 +240,6 @@ def __init__( self.pooler = DispatchPooler( { "token_embed": pooler_for_token_embed(pooler_config), - "embed": GritLMPooler( - vllm_config.model_config, - pooler_config.seq_pooling_type, - ), + "embed": GritLMPooler(vllm_config.model_config, pooler_config), } ) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 879530df4606..11b7ff882083 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -8,7 +8,7 @@ from transformers.activations import ACT2FN from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.attention.encoder_only_attention import ( EncoderOnlyAttention, @@ -282,9 +282,16 @@ def forward( class ModernBertPooler(SequencePooler): - def __init__(self, config: ModernBertConfig, seq_pooling_type: str): + def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig): + hf_pooling_type = config.classifier_pooling.upper() + vllm_pooling_type = pooler_config.seq_pooling_type + assert hf_pooling_type == vllm_pooling_type, ( + f"Found inconsistent sequence pooling type: {hf_pooling_type=!r} " + f"vs. {vllm_pooling_type=!r}" + ) + super().__init__( - pooling=get_seq_pooling_method(seq_pooling_type), + pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), head=self.head, ) @@ -330,13 +337,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - hf_pooling_type = config.classifier_pooling.upper() - vllm_pooling_type = pooler_config.seq_pooling_type - assert hf_pooling_type == vllm_pooling_type, ( - f"Found inconsistent sequence pooling type: {hf_pooling_type=!r} " - f"vs. {vllm_pooling_type=!r}" - ) - self.pooling = ModernBertPooler(config, vllm_pooling_type) + self.pooling = ModernBertPooler(config, pooler_config) self.pooler = DispatchPooler.for_seq_cls( pooler_config, From 79939080bd005f3091352b568c8d7347e286fe62 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 14:27:43 +0000 Subject: [PATCH 5/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c821b8b4d139..59e7688531c0 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -454,7 +454,7 @@ def __init__( pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = BertPooler(config, pooler_config.seq_pooling_type) + self.pooler = BertPooler(config, pooler_config) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) From ce4cef5ef0581d6b212ad0c5f790e0d53f3ef6bb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 15:14:32 +0000 Subject: [PATCH 6/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/modernbert.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 11b7ff882083..2b56540e6a0e 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -284,14 +284,12 @@ def forward( class ModernBertPooler(SequencePooler): def __init__(self, config: ModernBertConfig, pooler_config: PoolerConfig): hf_pooling_type = config.classifier_pooling.upper() - vllm_pooling_type = pooler_config.seq_pooling_type - assert hf_pooling_type == vllm_pooling_type, ( - f"Found inconsistent sequence pooling type: {hf_pooling_type=!r} " - f"vs. {vllm_pooling_type=!r}" - ) + # vllm_pooling_type = pooler_config.seq_pooling_type + # Currently we don't have a way to see if the user set the pooling type + # explicitly or not, so we always use the HF pooling type for now. super().__init__( - pooling=get_seq_pooling_method(pooler_config.seq_pooling_type), + pooling=get_seq_pooling_method(hf_pooling_type), head=self.head, ) From 69c8fde4aea45239d303afbd532d84a1506d7e20 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 11 Jan 2026 15:42:08 +0000 Subject: [PATCH 7/7] Fix Signed-off-by: DarkLight1337 --- vllm/model_executor/models/transformers/pooling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 0b169cfff10c..8f3173c33e4c 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -31,7 +31,7 @@ class EmbeddingMixin(VllmModelForPooling): - default_pooling_type = "CLS" + default_seq_pooling_type = "CLS" def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): # Skip VllmModelForPooling.__init__ and call the next class in MRO @@ -46,7 +46,7 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): - default_pooling_type = "CLS" + default_seq_pooling_type = "CLS" def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): # Skip VllmModelForPooling.__init__ and call the next class in MRO