Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from collections.abc import Sequence

from vllm.config import VllmConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
Expand All @@ -16,14 +16,13 @@
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens

from .types import (
EMBED_TASKS,
SparseEmbeddingCompletionRequestMixin,
SparseEmbeddingResponse,
SparseEmbeddingResponseData,
SparseEmbeddingTokenWeight,
)

logger = init_logger(__name__)


class BgeM3SparseEmbeddingsProcessor(
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
Expand All @@ -33,6 +32,22 @@ def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer):
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
self.renderer: BaseRenderer = renderer
self.default_pooling_params = {}
pooler_config: PoolerConfig = vllm_config.model_config.pooler_config
if pooler_config is not None:
for param in ["use_activation", "dimensions"]:
if getattr(pooler_config, param, None) is None:
continue
self.default_pooling_params[param] = getattr(pooler_config, param)
self.embed_dimensions = vllm_config.model_config.embedding_size
self.embed_request_queue: list[EmbedRequestMixin] = []

def __repr__(self) -> str:
return (
f"BgeM3SparseEmbeddingsProcessor("
f"embed_dimensions={self.embed_dimensions}, "
f"default_pooling_params={self.default_pooling_params})"
)

def merge_pooling_params(
self,
Expand All @@ -41,7 +56,57 @@ def merge_pooling_params(
if params is None:
params = PoolingParams()
# refer to PoolingCompletionRequest.to_pooling_params
params.task = "token_classify"
# set and verify pooling params
params.skip_reading_prefix_cache = True

raw_embed_request = self.embed_request_queue.pop(0)
if raw_embed_request.embed_task not in EMBED_TASKS:
raise ValueError(
f"Unsupported task {raw_embed_request}, "
f"Supported tasks are {EMBED_TASKS}"
)
has_dense_embed = True
if raw_embed_request.embed_task == "dense":
params.task = "embed"
params.skip_reading_prefix_cache = False
elif raw_embed_request.embed_task == "sparse":
params.task = "token_classify"
has_dense_embed = False
else:
params.task = "embed&token_classify"
params.use_activation = raw_embed_request.use_activation
if params.use_activation is None:
params.use_activation = True
if not has_dense_embed:
params.dimensions = None
return params

params.dimensions = raw_embed_request.dimensions

model_config: ModelConfig = self.vllm_config.model_config
for param in self.default_pooling_params:
if getattr(params, param, None) is None:
setattr(params, param, self.default_pooling_params[param])

if params.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f"support matryoshka representation, "
f"changing output dimensions will lead to poor results."
)

mds = model_config.matryoshka_dimensions
if mds is not None:
if params.dimensions not in mds:
raise ValueError(
f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will "
f"lead to poor results."
)
elif params.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
return params

def parse_request(
Expand All @@ -61,14 +126,16 @@ def pre_process(
if request_id is not None:
assert request_id not in self.online_requests, "request_id duplicated"
self.online_requests[request_id] = prompt
self.embed_request_queue.extend(prompt.to_embed_requests_online())
else:
self.offline_requests.append(prompt)
self.embed_request_queue.extend(prompt.to_embed_requests_offline())
return prompt.input

def _get_sparse_embedding_request(self, request_id: str | None = None):
if request_id:
return self.online_requests.pop(request_id, None)
return self.offline_requests.pop()
return self.offline_requests.pop(0)

def _build_sparse_embedding_token_weights(
self,
Expand Down Expand Up @@ -100,26 +167,45 @@ def post_process(
) -> SparseEmbeddingResponse:
num_prompt_tokens = 0
response_data = []
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens
raw_request = self._get_sparse_embedding_request(request_id)
has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"]
has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"]
embed_dimensions = 0
if has_dense_embed:
embed_dimensions = (
self.embed_dimensions
if raw_request.dimensions is None
else raw_request.dimensions
)
for idx in range(len(model_output)):
mo = model_output[idx]
sparse_embedding: dict[int, float] = {}
sparse_embedding_dict: dict[int, float] = {}
num_prompt_tokens += len(mo.prompt_token_ids)
if len(mo.prompt_token_ids) != len(mo.outputs.data):
# this is the case that add_special_tokens is True,
# which means first token and last token are special tokens
mo.prompt_token_ids = mo.prompt_token_ids[1:]
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()):
sparse_embedding[token_id] = max(
weight, sparse_embedding.get(token_id, 0.0)
dense_embedding: list[float] | None = None
sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None
if has_dense_embed:
dense_embedding = mo.outputs.data[:embed_dimensions].tolist()
if has_sparse_embed:
sparse_weights = mo.outputs.data[embed_dimensions:].tolist()
if len(mo.prompt_token_ids) != len(sparse_weights):
# this is the case that add_special_tokens is True,
# which means first token and last token are special tokens
mo.prompt_token_ids = mo.prompt_token_ids[1:]
for token_id, weight in zip(mo.prompt_token_ids, sparse_weights):
sparse_embedding_dict[token_id] = max(
weight, sparse_embedding_dict.get(token_id, 0.0)
)
sparse_embedding = self._build_sparse_embedding_token_weights(
sparse_embedding_dict,
raw_request.return_tokens,
)

response_data.append(
SparseEmbeddingResponseData(
index=idx,
sparse_embedding=self._build_sparse_embedding_token_weights(
sparse_embedding,
return_tokens,
),
object=raw_request.embed_task,
sparse_embedding=sparse_embedding,
dense_embedding=dense_embedding,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Literal, get_args

from pydantic import BaseModel, Field

from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin
from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
EmbedRequestMixin,
)

EmbedTask = Literal[
"sparse",
"dense",
"dense&sparse",
]

EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask)


class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin):
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin):
return_tokens: bool | None = Field(
default=None,
description="Whether to return dict shows the mapping of token_id to text."
"`None` or False means not return.",
)
embed_task: EmbedTask = Field(
default="dense&sparse",
description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', "
"default to 'dense&sparse'",
)

def to_embed_requests_offline(self) -> list[EmbedRequestMixin]:
if isinstance(self.input, list):
return [self] * len(self.input)
return [self]

def to_embed_requests_online(self) -> list[EmbedRequestMixin]:
return [self]


class SparseEmbeddingTokenWeight(BaseModel):
Expand All @@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel):

class SparseEmbeddingResponseData(BaseModel):
index: int
object: str = "sparse-embedding"
sparse_embedding: list[SparseEmbeddingTokenWeight]
object: str = "dense&sparse"
sparse_embedding: list[SparseEmbeddingTokenWeight] | None
dense_embedding: list[float] | None


class SparseEmbeddingResponse(BaseModel):
Expand Down
25 changes: 24 additions & 1 deletion tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
),
}

dense_embedding_sum = [
-0.7214539647102356, # "What is the capital of France?"
-0.6926871538162231, # "What is the capital of Germany?"
-0.7129564881324768, # "What is the capital of Spain?"
]


def _float_close(expected: object, result: object):
assert isinstance(expected, float) and isinstance(result, float), (
Expand All @@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str):
return getattr(obj, key, None)


def _check_dense_embedding(data, index=0):
assert _float_close(sum(data), dense_embedding_sum[index]), (
"dense-embedding result not match"
)


def _check_sparse_embedding(data, check_tokens=False):
expected_weights = [
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
Expand Down Expand Up @@ -109,14 +121,19 @@ async def test_bge_m3_sparse_plugin_online(
assert len(_get_attr_or_val(parsed_response, "data")) > 0

data_entry = _get_attr_or_val(parsed_response, "data")[0]
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding"
assert _get_attr_or_val(data_entry, "object") == "dense&sparse"
assert _get_attr_or_val(data_entry, "sparse_embedding")

# Verify sparse embedding format
sparse_embedding = _get_attr_or_val(data_entry, "sparse_embedding")
assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens)

# Verify dense embedding format
dense_embedding = _get_attr_or_val(data_entry, "dense_embedding")
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)

# Verify usage information
usage = _get_attr_or_val(parsed_response, "usage")
assert usage, f"usage not found for {parsed_response}"
Expand Down Expand Up @@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)

# Verify usage
assert response.usage.prompt_tokens > 0
Expand Down Expand Up @@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
# Each output should have sparse embeddings
sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding, i)

# Verify usage
assert response.usage.prompt_tokens > 0
Expand Down
40 changes: 39 additions & 1 deletion vllm/model_executor/layers/pooler/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,42 @@ def forward(
return pooled_outputs


__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
class BgeM3Pooler(Pooler):
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
super().__init__()
self.token_classify_pooler = token_classify_pooler
self.embed_pooler = embed_pooler

def forward(
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput:
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
token_classify_outputs = self.token_classify_pooler(
hidden_states, pooling_metadata
)
pooler_outputs: list[torch.Tensor] = []
for embed_output, token_classify_output in zip(
embed_outputs, token_classify_outputs
):
pooler_outputs.append(
torch.cat(
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
)
)

return pooler_outputs

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed&token_classify"}

def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.embed_pooler.get_pooling_updates(
"embed"
) | self.token_classify_pooler.get_pooling_updates("token_classify")

def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s


__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]
Loading
Loading