-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
add io_process_plugin for sparse embedding #34214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
826ac0b
wip: add io_process_plugin for sparse embedding
staugust 1687d0c
fix bugs for offline mode with array
staugust 2e4dc28
update code with gemeni suggestions
staugust 4589b09
update bge_m3_sparse_plugin with simple code to construct sparse embe…
staugust b2e15fe
add params to determine whether output token_id to token text mapping
staugust 269a8b7
udpate bge_m3_sparse_plugin
staugust 75027a7
add input param for sparse embedding
staugust 3901619
update interface for io_processor_plugin
staugust 4850b7e
add return
staugust 79946b8
fix bugs in post_process
staugust 9c20e6f
fix bugs in post_process
staugust d3c2d8d
make plugin compatible with main branch
staugust 07e2633
make plugin compatible with offline mode
staugust 32bda35
update pooling params for online mode
staugust ac285c3
make code cleaner
staugust 93d754c
use convert_ids_list_to_tokens instead of convert_ids_to_tokens
staugust e3dce21
use convert_ids_list_to_tokens instead of convert_ids_to_tokens
staugust 5c856cb
pass renderer during io_processor init
staugust 5eb5a33
let get_io_processor compatible with previous io_process_plugin
staugust 670d31c
add warnning msg for io_processor_plugin.__init__ api change
staugust 0782705
remove request parameter in merge_pooling_params
staugust dc4ec89
fix bugs in call merge_pooling_params
staugust 0739106
update io_processor_plugins.md as abstract class IOProcessor is updated
staugust a5d518b
remove fallbacks in update io_processor_plugins.md, return correct er…
staugust 49a2cef
Update vllm/plugins/io_processors/__init__.py
staugust 50db326
fix testcase for loading wrong io_processor plugin
staugust ef9065d
Merge branch 'main' into bge-m3-sparse-plugin
staugust 29378ab
add e2e test case for bge_m3_sparse_plugin
staugust ea6e9c1
fix bugs in passing hf_overrides
staugust 759314b
fix bugs in construct prompts for offline mode
staugust 15d6cf5
fix bugs in construct prompts for multi inputs in offline mode
staugust d67346f
update verify logic for bge_m3_sparse_plugin
staugust cb01d53
fix bugs in get pooler_output
staugust a16d521
fix bugs in offline testcase
staugust bee36a2
check embed result
staugust 15f54ba
fix bugs in check offline mode result
staugust 7efcc16
check token is None for return_tokens=False
staugust a29968f
make _check_sparse_embedding compatible for both online serving and o…
staugust 4a18af7
fix test online
staugust a07bed7
fix verify logic for online mode
staugust 43b1f54
update online test case
staugust dd77e52
rename test file for bge_m3_sparse_plugin
staugust add2882
Merge branch 'main' into bge-m3-sparse-plugin
staugust 58ebcc7
add bge_m3_sparse io processor plugin test into .buildkite
staugust 166cef8
fix pre-commit check
staugust 1f4f969
check sparse-embedding weight using loose equality
staugust dfb3663
Merge branch 'main' into bge-m3-sparse-plugin
staugust 3a96af3
Merge branch 'main' into bge-m3-sparse-plugin
staugust File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
|
|
||
| def register_bge_m3_sparse_embeddings_processor(): | ||
| return "bge_m3_sparse_processor.sparse_embeddings_processor.BgeM3SparseEmbeddingsProcessor" # noqa: E501 |
135 changes: 135 additions & 0 deletions
135
tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from collections.abc import Sequence | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.entrypoints.openai.engine.protocol import UsageInfo | ||
| from vllm.inputs.data import PromptType | ||
| from vllm.logger import init_logger | ||
| from vllm.outputs import PoolingRequestOutput | ||
| from vllm.plugins.io_processors.interface import ( | ||
| IOProcessor, | ||
| ) | ||
| from vllm.pooling_params import PoolingParams | ||
| from vllm.renderers import BaseRenderer | ||
| from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens | ||
|
|
||
| from .types import ( | ||
| SparseEmbeddingCompletionRequestMixin, | ||
| SparseEmbeddingResponse, | ||
| SparseEmbeddingResponseData, | ||
| SparseEmbeddingTokenWeight, | ||
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class BgeM3SparseEmbeddingsProcessor( | ||
| IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse] | ||
| ): | ||
| def __init__(self, vllm_config: VllmConfig, renderer: BaseRenderer): | ||
| super().__init__(vllm_config, renderer) | ||
| self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = [] | ||
| self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {} | ||
| self.renderer: BaseRenderer = renderer | ||
|
|
||
| def merge_pooling_params( | ||
| self, | ||
| params: PoolingParams | None = None, | ||
| ) -> PoolingParams: | ||
| if params is None: | ||
| params = PoolingParams() | ||
| # refer to PoolingCompletionRequest.to_pooling_params | ||
| params.task = "token_classify" | ||
| return params | ||
|
|
||
| def parse_request( | ||
| self, request_data: object | ||
| ) -> SparseEmbeddingCompletionRequestMixin: | ||
| # for vllm.entrypoints.llm.LLM, offline mode, calls `encode` directly. | ||
| if isinstance(request_data, dict): | ||
| return SparseEmbeddingCompletionRequestMixin(**request_data) | ||
| raise TypeError("request_data should be a dictionary") | ||
|
|
||
| def pre_process( | ||
| self, | ||
| prompt: SparseEmbeddingCompletionRequestMixin, | ||
| request_id: str | None = None, | ||
| **kwargs, | ||
| ) -> PromptType | Sequence[PromptType]: | ||
| if request_id is not None: | ||
| assert request_id not in self.online_requests, "request_id duplicated" | ||
| self.online_requests[request_id] = prompt | ||
| else: | ||
| self.offline_requests.append(prompt) | ||
| return prompt.input | ||
staugust marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _get_sparse_embedding_request(self, request_id: str | None = None): | ||
| if request_id: | ||
| return self.online_requests.pop(request_id, None) | ||
| return self.offline_requests.pop() | ||
|
|
||
| def _build_sparse_embedding_token_weights( | ||
| self, | ||
| sparse_embedding: dict[int, float], | ||
| return_tokens: bool = False, | ||
| ) -> list[SparseEmbeddingTokenWeight]: | ||
| token_ids = sparse_embedding.keys() | ||
| token_weights = sparse_embedding.values() | ||
| tokens = [None] * len(token_ids) | ||
|
|
||
| if return_tokens and self.renderer is not None: | ||
| tokens = convert_ids_list_to_tokens( | ||
| self.renderer.get_tokenizer(), token_ids | ||
| ) | ||
| sparse_embedding_output: list[SparseEmbeddingTokenWeight] = [] | ||
| for token_id, weight, token in zip(token_ids, token_weights, tokens): | ||
| sparse_embedding_output.append( | ||
| SparseEmbeddingTokenWeight( | ||
| token_id=token_id, weight=weight, token=token | ||
| ) | ||
| ) | ||
| return sparse_embedding_output | ||
|
|
||
| def post_process( | ||
| self, | ||
| model_output: Sequence[PoolingRequestOutput], | ||
| request_id: str | None = None, | ||
| **kwargs, | ||
| ) -> SparseEmbeddingResponse: | ||
| num_prompt_tokens = 0 | ||
| response_data = [] | ||
| return_tokens = self._get_sparse_embedding_request(request_id).return_tokens | ||
| for idx in range(len(model_output)): | ||
| mo = model_output[idx] | ||
| sparse_embedding: dict[int, float] = {} | ||
| num_prompt_tokens += len(mo.prompt_token_ids) | ||
| if len(mo.prompt_token_ids) != len(mo.outputs.data): | ||
| # this is the case that add_special_tokens is True, | ||
| # which means first token and last token are special tokens | ||
| mo.prompt_token_ids = mo.prompt_token_ids[1:] | ||
| for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()): | ||
| sparse_embedding[token_id] = max( | ||
| weight, sparse_embedding.get(token_id, 0.0) | ||
| ) | ||
| response_data.append( | ||
| SparseEmbeddingResponseData( | ||
| index=idx, | ||
| sparse_embedding=self._build_sparse_embedding_token_weights( | ||
| sparse_embedding, | ||
| return_tokens, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| usage = UsageInfo( | ||
| prompt_tokens=num_prompt_tokens, | ||
| total_tokens=num_prompt_tokens, | ||
| ) | ||
| resp = SparseEmbeddingResponse( | ||
| data=response_data, | ||
| usage=usage, | ||
| ) | ||
|
|
||
| return resp | ||
32 changes: 32 additions & 0 deletions
32
tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/types.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from pydantic import BaseModel, Field | ||
|
|
||
| from vllm.entrypoints.openai.engine.protocol import UsageInfo | ||
| from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin | ||
|
|
||
|
|
||
| class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin): | ||
| return_tokens: bool | None = Field( | ||
| default=None, | ||
| description="Whether to return dict shows the mapping of token_id to text." | ||
| "`None` or False means not return.", | ||
| ) | ||
|
|
||
|
|
||
| class SparseEmbeddingTokenWeight(BaseModel): | ||
| token_id: int | ||
| weight: float | ||
| token: str | None | ||
|
|
||
|
|
||
| class SparseEmbeddingResponseData(BaseModel): | ||
| index: int | ||
| object: str = "sparse-embedding" | ||
| sparse_embedding: list[SparseEmbeddingTokenWeight] | ||
|
|
||
|
|
||
| class SparseEmbeddingResponse(BaseModel): | ||
| data: list[SparseEmbeddingResponseData] | ||
| usage: UsageInfo |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from setuptools import setup | ||
|
|
||
| setup( | ||
| name="bge-m3-sparse-plugin", | ||
| version="0.1", | ||
| packages=["bge_m3_sparse_processor"], | ||
| entry_points={ | ||
| "vllm.io_processor_plugins": [ | ||
| "bge_m3_sparse_plugin = bge_m3_sparse_processor:register_bge_m3_sparse_embeddings_processor", # noqa: E501 | ||
| ] | ||
| }, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.