|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 | from collections.abc import Iterable |
5 | | -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union |
| 5 | +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import torch.nn as nn |
9 | 9 |
|
| 10 | +from vllm.model_executor.models.config import VerifyAndUpdateConfig |
| 11 | + |
10 | 12 | from .interfaces_base import VllmModelForPooling, is_pooling_model |
11 | 13 |
|
12 | 14 | if TYPE_CHECKING: |
| 15 | + from vllm.config import VllmConfig |
13 | 16 | from vllm.model_executor.layers.pooler import PoolingType |
14 | 17 |
|
15 | 18 | _T = TypeVar("_T", bound=type[nn.Module]) |
@@ -39,7 +42,6 @@ def _create_pooling_model_cls( |
39 | 42 | default_softmax: bool, |
40 | 43 | ) -> _T: |
41 | 44 | # Lazy import |
42 | | - from vllm.config import VllmConfig |
43 | 45 | from vllm.model_executor.layers.pooler import Pooler, PoolerOutput |
44 | 46 | from vllm.model_executor.pooling_metadata import PoolingMetadata |
45 | 47 |
|
@@ -162,7 +164,6 @@ def as_seq_cls_model(cls: _T) -> _T: |
162 | 164 | return cls |
163 | 165 |
|
164 | 166 | # Lazy import |
165 | | - from vllm.config import VllmConfig |
166 | 167 | from vllm.model_executor.layers.linear import RowParallelLinear |
167 | 168 | from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType |
168 | 169 | from vllm.model_executor.models.interfaces import SupportsCrossEncoding |
@@ -193,6 +194,7 @@ def __init__( |
193 | 194 | config = vllm_config.model_config.hf_config |
194 | 195 | quant_config = vllm_config.quant_config |
195 | 196 |
|
| 197 | + self.vllm_config = vllm_config |
196 | 198 | self.task = vllm_config.model_config.task |
197 | 199 | self.pooling_type = ( |
198 | 200 | vllm_config.model_config.pooler_config.pooling_type) |
@@ -242,6 +244,17 @@ def get_logits(hidden_states): |
242 | 244 | ] |
243 | 245 | return PoolerOutput(outputs=pooled_outputs) |
244 | 246 |
|
| 247 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 248 | + tokens = getattr(self.config, "classifier_from_token", None) |
| 249 | + method = getattr(self.config, "method", None) |
| 250 | + |
| 251 | + if tokens is None and method is None: |
| 252 | + return super().load_weights(weights) |
| 253 | + else: |
| 254 | + # Online convert ForCausalLM into |
| 255 | + # ForSequenceClassification model. |
| 256 | + return seq_cls_model_loader(self, weights) |
| 257 | + |
245 | 258 |
|
246 | 259 | ModelForSequenceClassification.__name__ = \ |
247 | 260 | _get_pooling_model_name(cls.__name__, "ForSequenceClassification") |
@@ -277,3 +290,86 @@ def as_reward_model(cls: _T) -> _T: |
277 | 290 | _get_pooling_model_name(cls.__name__, "ForReward") |
278 | 291 |
|
279 | 292 | return ModelForReward # type: ignore |
| 293 | + |
| 294 | + |
| 295 | +class SequenceClassificationConfig(VerifyAndUpdateConfig): |
| 296 | + |
| 297 | + @staticmethod |
| 298 | + def verify_and_update_config(vllm_config: "VllmConfig") -> None: |
| 299 | + config = vllm_config.model_config.hf_config |
| 300 | + method = getattr(config, "method", None) |
| 301 | + tokens = getattr(config, "classifier_from_token", None) |
| 302 | + |
| 303 | + if method is None: |
| 304 | + return |
| 305 | + |
| 306 | + assert tokens is not None |
| 307 | + assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" |
| 308 | + |
| 309 | + if method == "from_2_way_softmax": |
| 310 | + assert len(tokens) == 2 |
| 311 | + config.num_labels = 1 |
| 312 | + else: |
| 313 | + config.num_labels = len(tokens) |
| 314 | + |
| 315 | + |
| 316 | +def load_weights_using_from_2_way_softmax( |
| 317 | + model, weights: Iterable[tuple[str, torch.Tensor]]): |
| 318 | + # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 |
| 319 | + from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 320 | + ParallelLMHead) |
| 321 | + from vllm.model_executor.models.utils import AutoWeightsLoader |
| 322 | + |
| 323 | + model_config = model.vllm_config.model_config |
| 324 | + tokens = getattr(model.config, "classifier_from_token", []) |
| 325 | + tokens = cast(list[int], tokens) |
| 326 | + assert len(tokens) == 2 |
| 327 | + |
| 328 | + device = model.score.weight.device |
| 329 | + |
| 330 | + if model.config.tie_word_embeddings: |
| 331 | + model.lm_head = model.model.embed_tokens |
| 332 | + else: |
| 333 | + model.lm_head = ParallelLMHead(model.config.vocab_size, |
| 334 | + model.config.hidden_size, |
| 335 | + quant_config=model.quant_config) |
| 336 | + |
| 337 | + loader = AutoWeightsLoader(model) |
| 338 | + loaded_weights = loader.load_weights(weights) |
| 339 | + |
| 340 | + from vllm.transformers_utils.tokenizer import get_tokenizer |
| 341 | + tokenizer = get_tokenizer(model_config.tokenizer, |
| 342 | + revision=model_config.tokenizer_revision, |
| 343 | + tokenizer_mode=model_config.tokenizer_mode, |
| 344 | + trust_remote_code=model_config.trust_remote_code) |
| 345 | + |
| 346 | + false_id = tokenizer.convert_tokens_to_ids(tokens[0]) |
| 347 | + true_id = tokenizer.convert_tokens_to_ids(tokens[1]) |
| 348 | + weight = model.lm_head.weight.data[true_id].to(device).to( |
| 349 | + torch.float32) - model.lm_head.weight.data[false_id].to(device).to( |
| 350 | + torch.float32) |
| 351 | + model.score.weight.data.copy_(weight) |
| 352 | + |
| 353 | + del model.lm_head |
| 354 | + loaded_weights.add("score.weight") |
| 355 | + loaded_weights.discard("lm_head.weight") |
| 356 | + return loaded_weights |
| 357 | + |
| 358 | + |
| 359 | +SEQ_CLS_LOAD_METHODS = { |
| 360 | + "from_2_way_softmax": load_weights_using_from_2_way_softmax, |
| 361 | +} |
| 362 | + |
| 363 | + |
| 364 | +def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): |
| 365 | + # Online convert ForCausalLM into ForSequenceClassification model. |
| 366 | + # - from_2_way_softmax: |
| 367 | + # - Qwen3ForCausalLM |
| 368 | + # - Qwen3-Reranker |
| 369 | + # - Qwen2ForCausalLM |
| 370 | + # - mxbai-rerank-v2 |
| 371 | + |
| 372 | + config = model.vllm_config.model_config.hf_config |
| 373 | + method = getattr(config, "method", None) |
| 374 | + assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" |
| 375 | + return SEQ_CLS_LOAD_METHODS[method](model, weights) |
0 commit comments