diff --git a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py index 4749d3e81fed..b97f7de13d03 100644 --- a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py +++ b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py @@ -3,10 +3,10 @@ from collections.abc import Sequence -from vllm.config import VllmConfig +from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin from vllm.inputs.data import PromptType -from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors.interface import ( IOProcessor, @@ -16,14 +16,13 @@ from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens from .types import ( + EMBED_TASKS, SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse, SparseEmbeddingResponseData, SparseEmbeddingTokenWeight, ) -logger = init_logger(__name__) - class BgeM3SparseEmbeddingsProcessor( IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse] @@ -33,6 +32,22 @@ def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer): self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = [] self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {} self.renderer: BaseRenderer = renderer + self.default_pooling_params = {} + pooler_config: PoolerConfig = vllm_config.model_config.pooler_config + if pooler_config is not None: + for param in ["use_activation", "dimensions"]: + if getattr(pooler_config, param, None) is None: + continue + self.default_pooling_params[param] = getattr(pooler_config, param) + self.embed_dimensions = vllm_config.model_config.embedding_size + self.embed_request_queue: list[EmbedRequestMixin] = [] + + def __repr__(self) -> str: + return ( + f"BgeM3SparseEmbeddingsProcessor(" + f"embed_dimensions={self.embed_dimensions}, " + f"default_pooling_params={self.default_pooling_params})" + ) def merge_pooling_params( self, @@ -41,7 +56,57 @@ def merge_pooling_params( if params is None: params = PoolingParams() # refer to PoolingCompletionRequest.to_pooling_params - params.task = "token_classify" + # set and verify pooling params + params.skip_reading_prefix_cache = True + + raw_embed_request = self.embed_request_queue.pop(0) + if raw_embed_request.embed_task not in EMBED_TASKS: + raise ValueError( + f"Unsupported task {raw_embed_request}, " + f"Supported tasks are {EMBED_TASKS}" + ) + has_dense_embed = True + if raw_embed_request.embed_task == "dense": + params.task = "embed" + params.skip_reading_prefix_cache = False + elif raw_embed_request.embed_task == "sparse": + params.task = "token_classify" + has_dense_embed = False + else: + params.task = "embed&token_classify" + params.use_activation = raw_embed_request.use_activation + if params.use_activation is None: + params.use_activation = True + if not has_dense_embed: + params.dimensions = None + return params + + params.dimensions = raw_embed_request.dimensions + + model_config: ModelConfig = self.vllm_config.model_config + for param in self.default_pooling_params: + if getattr(params, param, None) is None: + setattr(params, param, self.default_pooling_params[param]) + + if params.dimensions is not None: + if not model_config.is_matryoshka: + raise ValueError( + f'Model "{model_config.served_model_name}" does not ' + f"support matryoshka representation, " + f"changing output dimensions will lead to poor results." + ) + + mds = model_config.matryoshka_dimensions + if mds is not None: + if params.dimensions not in mds: + raise ValueError( + f"Model {model_config.served_model_name!r} " + f"only supports {str(mds)} matryoshka dimensions, " + f"use other output dimensions will " + f"lead to poor results." + ) + elif params.dimensions < 1: + raise ValueError("Dimensions must be greater than 0") return params def parse_request( @@ -61,14 +126,16 @@ def pre_process( if request_id is not None: assert request_id not in self.online_requests, "request_id duplicated" self.online_requests[request_id] = prompt + self.embed_request_queue.extend(prompt.to_embed_requests_online()) else: self.offline_requests.append(prompt) + self.embed_request_queue.extend(prompt.to_embed_requests_offline()) return prompt.input def _get_sparse_embedding_request(self, request_id: str | None = None): if request_id: return self.online_requests.pop(request_id, None) - return self.offline_requests.pop() + return self.offline_requests.pop(0) def _build_sparse_embedding_token_weights( self, @@ -100,26 +167,45 @@ def post_process( ) -> SparseEmbeddingResponse: num_prompt_tokens = 0 response_data = [] - return_tokens = self._get_sparse_embedding_request(request_id).return_tokens + raw_request = self._get_sparse_embedding_request(request_id) + has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"] + has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"] + embed_dimensions = 0 + if has_dense_embed: + embed_dimensions = ( + self.embed_dimensions + if raw_request.dimensions is None + else raw_request.dimensions + ) for idx in range(len(model_output)): mo = model_output[idx] - sparse_embedding: dict[int, float] = {} + sparse_embedding_dict: dict[int, float] = {} num_prompt_tokens += len(mo.prompt_token_ids) - if len(mo.prompt_token_ids) != len(mo.outputs.data): - # this is the case that add_special_tokens is True, - # which means first token and last token are special tokens - mo.prompt_token_ids = mo.prompt_token_ids[1:] - for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()): - sparse_embedding[token_id] = max( - weight, sparse_embedding.get(token_id, 0.0) + dense_embedding: list[float] | None = None + sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None + if has_dense_embed: + dense_embedding = mo.outputs.data[:embed_dimensions].tolist() + if has_sparse_embed: + sparse_weights = mo.outputs.data[embed_dimensions:].tolist() + if len(mo.prompt_token_ids) != len(sparse_weights): + # this is the case that add_special_tokens is True, + # which means first token and last token are special tokens + mo.prompt_token_ids = mo.prompt_token_ids[1:] + for token_id, weight in zip(mo.prompt_token_ids, sparse_weights): + sparse_embedding_dict[token_id] = max( + weight, sparse_embedding_dict.get(token_id, 0.0) + ) + sparse_embedding = self._build_sparse_embedding_token_weights( + sparse_embedding_dict, + raw_request.return_tokens, ) + response_data.append( SparseEmbeddingResponseData( index=idx, - sparse_embedding=self._build_sparse_embedding_token_weights( - sparse_embedding, - return_tokens, - ), + object=raw_request.embed_task, + sparse_embedding=sparse_embedding, + dense_embedding=dense_embedding, ) ) diff --git a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py index 1dcf30a058c9..ba69932f45a7 100644 --- a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py +++ b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py @@ -1,18 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Literal, get_args + from pydantic import BaseModel, Field from vllm.entrypoints.openai.engine.protocol import UsageInfo -from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin +from vllm.entrypoints.pooling.base.protocol import ( + CompletionRequestMixin, + EmbedRequestMixin, +) + +EmbedTask = Literal[ + "sparse", + "dense", + "dense&sparse", +] + +EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask) -class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin): +class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin): return_tokens: bool | None = Field( default=None, description="Whether to return dict shows the mapping of token_id to text." "`None` or False means not return.", ) + embed_task: EmbedTask = Field( + default="dense&sparse", + description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', " + "default to 'dense&sparse'", + ) + + def to_embed_requests_offline(self) -> list[EmbedRequestMixin]: + if isinstance(self.input, list): + return [self] * len(self.input) + return [self] + + def to_embed_requests_online(self) -> list[EmbedRequestMixin]: + return [self] class SparseEmbeddingTokenWeight(BaseModel): @@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel): class SparseEmbeddingResponseData(BaseModel): index: int - object: str = "sparse-embedding" - sparse_embedding: list[SparseEmbeddingTokenWeight] + object: str = "dense&sparse" + sparse_embedding: list[SparseEmbeddingTokenWeight] | None + dense_embedding: list[float] | None class SparseEmbeddingResponse(BaseModel): diff --git a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py index 20c400e59795..85293e55cd81 100644 --- a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py +++ b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py @@ -19,6 +19,12 @@ ), } +dense_embedding_sum = [ + -0.7214539647102356, # "What is the capital of France?" + -0.6926871538162231, # "What is the capital of Germany?" + -0.7129564881324768, # "What is the capital of Spain?" +] + def _float_close(expected: object, result: object): assert isinstance(expected, float) and isinstance(result, float), ( @@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str): return getattr(obj, key, None) +def _check_dense_embedding(data, index=0): + assert _float_close(sum(data), dense_embedding_sum[index]), ( + "dense-embedding result not match" + ) + + def _check_sparse_embedding(data, check_tokens=False): expected_weights = [ {"token_id": 32, "weight": 0.0552978515625, "token": "?"}, @@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online( assert len(_get_attr_or_val(parsed_response, "data")) > 0 data_entry = _get_attr_or_val(parsed_response, "data")[0] - assert _get_attr_or_val(data_entry, "object") == "sparse-embedding" + assert _get_attr_or_val(data_entry, "object") == "dense&sparse" assert _get_attr_or_val(data_entry, "sparse_embedding") # Verify sparse embedding format @@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online( assert isinstance(sparse_embedding, list) _check_sparse_embedding(sparse_embedding, return_tokens) + # Verify dense embedding format + dense_embedding = _get_attr_or_val(data_entry, "dense_embedding") + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding) + # Verify usage information usage = _get_attr_or_val(parsed_response, "usage") assert usage, f"usage not found for {parsed_response}" @@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool): sparse_embedding = output.sparse_embedding assert isinstance(sparse_embedding, list) _check_sparse_embedding(sparse_embedding, return_tokens) + dense_embedding = output.dense_embedding + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding) # Verify usage assert response.usage.prompt_tokens > 0 @@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner): # Each output should have sparse embeddings sparse_embedding = output.sparse_embedding assert isinstance(sparse_embedding, list) + dense_embedding = output.dense_embedding + assert isinstance(dense_embedding, list) + _check_dense_embedding(dense_embedding, i) # Verify usage assert response.usage.prompt_tokens > 0 diff --git a/vllm/model_executor/layers/pooler/special.py b/vllm/model_executor/layers/pooler/special.py index bafa191dbac1..5e0f9ec75597 100644 --- a/vllm/model_executor/layers/pooler/special.py +++ b/vllm/model_executor/layers/pooler/special.py @@ -170,4 +170,42 @@ def forward( return pooled_outputs -__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"] +class BgeM3Pooler(Pooler): + def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None: + super().__init__() + self.token_classify_pooler = token_classify_pooler + self.embed_pooler = embed_pooler + + def forward( + self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata + ) -> PoolerOutput: + embed_outputs = self.embed_pooler(hidden_states, pooling_metadata) + token_classify_outputs = self.token_classify_pooler( + hidden_states, pooling_metadata + ) + pooler_outputs: list[torch.Tensor] = [] + for embed_output, token_classify_output in zip( + embed_outputs, token_classify_outputs + ): + pooler_outputs.append( + torch.cat( + [embed_output.view(-1), token_classify_output.view(-1)], dim=-1 + ) + ) + + return pooler_outputs + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"embed&token_classify"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.embed_pooler.get_pooling_updates( + "embed" + ) | self.token_classify_pooler.get_pooling_updates("token_classify") + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s + + +__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"] diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 5faa64654e7b..46211e6eda02 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.model_executor.layers.pooler import ( + BgeM3Pooler, BOSEOSFilter, DispatchPooler, Pooler, @@ -216,24 +217,29 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: self.colbert_linear = nn.Linear( self.hidden_size, self.hidden_size, dtype=self.head_dtype ) + embed_pooler = pooler_for_embed(pooler_config) + token_classify_pooler = BOSEOSFilter( + pooler_for_token_classify( + pooler_config, + pooling=AllPool(), + classifier=self.sparse_linear, + act_fn=torch.relu, + ), + self.bos_token_id, + self.eos_token_id, + ) return DispatchPooler( { - "embed": pooler_for_embed(pooler_config), + "embed": embed_pooler, "token_embed": BOSEOSFilter( pooler_for_token_embed(pooler_config, self.colbert_linear), self.bos_token_id, # for some reason m3 only filters the bos for colbert vectors ), - "token_classify": BOSEOSFilter( - pooler_for_token_classify( - pooler_config, - pooling=AllPool(), - classifier=self.sparse_linear, - act_fn=torch.relu, - ), - self.bos_token_id, - self.eos_token_id, + "token_classify": token_classify_pooler, + "embed&token_classify": BgeM3Pooler( + token_classify_pooler, embed_pooler ), } ) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 6b85506abf1e..e5e993b75556 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -96,6 +96,10 @@ def verify(self, model_config: ModelConfig) -> None: self.skip_reading_prefix_cache = True return + # skipping verify, let plugins configure and validate pooling params + if self.task not in self.valid_parameters: + return + # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method diff --git a/vllm/tasks.py b/vllm/tasks.py index 950993279dfd..83dd7f85eee0 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -6,7 +6,13 @@ GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) PoolingTask = Literal[ - "embed", "classify", "score", "token_embed", "token_classify", "plugin" + "embed", + "classify", + "score", + "token_embed", + "token_classify", + "plugin", + "embed&token_classify", ] POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)