-
Notifications
You must be signed in to change notification settings - Fork 559
feat: Add vLLM support #3794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: Add vLLM support #3794
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
33176ce
init
noooop 5ab3de0
init
noooop 9ae5c58
ruff
noooop d7e42f4
- vllm_loader
noooop a7f6873
+ TYPE_CHECKING
noooop 0c1c606
Make vLLM exit properly.
noooop 373d638
rename
noooop 880b08b
support rerank
noooop d97b2ce
refine
noooop 9c12067
refine
noooop c501253
Update mteb/models/vllm_wrapper.py
noooop b9d4cf2
refine
noooop 6bac26d
+ docs
noooop 003bd20
+ benchmark
noooop 9b24f24
+ more benchmark
noooop 1b7e209
Update docs/advanced_usage/vllm_wrapper.md
noooop 6ef8943
Update docs/advanced_usage/vllm_wrapper.md
noooop be0c5bb
refine docs
noooop 2d90268
refine docs
noooop 930b782
Merge branch 'main' into mtebXvllm
Samoed e699013
fix typing
Samoed 17a68ec
move type ignore
Samoed 5ca5c40
doc upd
Samoed 577eac5
add test
Samoed e6141e2
Update Makefile
Samoed 1e40fa1
add support for prompts
Samoed 69084e1
add support for prompts
Samoed 870ee87
- demo
noooop f2641fb
make mypy happy
noooop 019aadf
fix typehints
Samoed b2534d1
Merge branch 'main' into mtebXvllm
noooop 1cd85e7
update pyproject
Samoed 98bcdb7
update pyproject
Samoed 0f6db42
update pyproject
Samoed 5134c3f
The pooling + dp fails to run.
noooop 2528ede
fix uv lock
Samoed cfed518
fix docs
Samoed c5b404f
simplify conflicts
Samoed 871a664
Merge branch 'main' into mtebXvllm
Samoed 1fb06db
upd lock
Samoed 33b6558
upd lock
Samoed 4c62e57
Update docs/advanced_usage/vllm_wrapper.md
noooop caa9400
Update docs/advanced_usage/vllm_wrapper.md
noooop 539fcee
Update docs/advanced_usage/vllm_wrapper.md
noooop 3ee7cb4
Update docs/advanced_usage/vllm_wrapper.md
noooop c2ccc59
Apply suggestions from code review
noooop 94cf757
Update docs/advanced_usage/vllm_wrapper.md
noooop af56ad7
Merge branch 'embeddings-benchmark:main' into mtebXvllm
noooop 53b8c45
Apply suggestion from @Samoed
noooop 9c2c8ac
update
noooop ba0189b
update
noooop File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| ## vLLM | ||
|
|
||
| !!! note | ||
| vLLM currently supports only a limited number of models, and many implementations have subtle differences compared to the default implementations in mteb (see the [overview issue](add me) for more information). For the full list of supported models, refer to the [vllm documentation](https://docs.vllm.ai/en/stable/models/supported_models/#pooling-models). | ||
|
|
||
|
|
||
| ## Installation | ||
|
|
||
| If you're using cuda you can run | ||
| === "pip" | ||
| ```bash | ||
| pip install "mteb[vllm]" | ||
| ``` | ||
| === "uv" | ||
| ```bash | ||
| uv pip install "mteb[vllm]" | ||
| ``` | ||
|
|
||
| For other architectures, please refer to the [vllm installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/). | ||
| ## Usage | ||
|
|
||
noooop marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| To use vLLM with MTEB you have to wrap the model with its respective wrapper. | ||
|
|
||
| !!! note | ||
| you must update your Python code to guard usage of vllm behind a if __name__ == '__main__': block. For example, instead of this: | ||
|
|
||
| ```python | ||
| import vllm | ||
|
|
||
| llm = vllm.LLM(...) | ||
| ``` | ||
| try this instead: | ||
| ```python | ||
| if __name__ == '__main__': | ||
| import vllm | ||
|
|
||
| llm = vllm.LLM(...) | ||
| ``` | ||
|
|
||
| See more [troubleshooting](https://docs.vllm.ai/en/latest/usage/troubleshooting/#python-multiprocessing) | ||
|
|
||
| === "Embedding models" | ||
| ```python | ||
| import mteb | ||
| from mteb.models.vllm_wrapper import VllmEncoderWrapper | ||
|
|
||
| def run_vllm_encoder(): | ||
| """Evaluate a model on specified MTEB tasks using vLLM for inference.""" | ||
| encoder = VllmEncoderWrapper(model="intfloat/e5-small") | ||
| return mteb.evaluate( | ||
| encoder, | ||
| mteb.get_task("STS12"), | ||
| ) | ||
|
|
||
| if __name__ == "__main__": | ||
noooop marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| results = run_vllm_encoder() | ||
| print(results) | ||
| ``` | ||
| === "Reranking models" | ||
| ```python | ||
| import mteb | ||
| from mteb.models.vllm_wrapper import VllmCrossEncoderWrapper | ||
|
|
||
| def run_vllm_crossencoder(): | ||
| """Evaluate a model on specified MTEB tasks using vLLM for inference.""" | ||
| cross_encoder = VllmCrossEncoderWrapper(model="cross-encoder/ms-marco-MiniLM-L-6-v2") | ||
| return mteb.evaluate( | ||
| cross_encoder, | ||
| mteb.get_task("AskUbuntuDupQuestions"), | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| results = run_vllm_crossencoder() | ||
| print(results) | ||
| ``` | ||
|
|
||
| ## Why is vLLM fast? | ||
|
|
||
| ### Half-Precision Inference | ||
|
|
||
| By default, vLLM uses Flash Attention, which only supports float16 and bfloat16 but not float32. vLLM does not optimize inference performance for float32. | ||
|
|
||
| <figure markdown="span"> | ||
|  | ||
| <figcaption>The throughput using float16 is approximately four times that of float32. | ||
| ST: using sentence transformers backend | ||
| vLLM: using vLLM backend | ||
| X-axis: Throughput (request/s) | ||
| Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale | ||
| The curve lower right is better ↘ | ||
noooop marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| </figcaption> | ||
| </figure> | ||
|
|
||
|
|
||
|
|
||
| !!! note | ||
|
|
||
| | Format | Bits | Exponent | Fraction | | ||
| |----------|------|----------|----------| | ||
| | float32 | 32 | 8 | 23 | | ||
| | float16 | 16 | 5 | 10 | | ||
| | bfloat16 | 16 | 8 | 7 | | ||
|
|
||
| If the model weights are stored in float32: | ||
|
|
||
| - VLLM uses float16 for inference by default to inference a float32 model, it will keep numerical precision in most cases, for it have retains relatively more Fraction bits. However, due to the smaller Exponent part (only 5 bits), some models (e.g., the Gemma family) may risk producing NaN. VLLM maintains a list models that may cause NaN values and uses bfloat16 for inference by default. | ||
| - Using bfloat16 for inference avoids NaN risks because its Exponent part matches float32 with 8 bits. However, with only 7 Fraction bits, numerical precision decreases noticeably. | ||
| - Using float32 for inference incurs no precision loss but is about four times slower than float16/bfloat16. | ||
|
|
||
| If model weights are stored in float16 or bfloat16, vLLM defaults to using the original dtype for inference. | ||
|
|
||
| Quantization: With the advancement of open-source large models, fine-tuning of larger models for tasks like embedding and reranking is increasing. Exploring quantization methods to accelerate inference and reduce GPU memory usage may become necessary. | ||
|
|
||
|
|
||
|
|
||
| ### Unpadding | ||
|
|
||
| By default, Sentence Transformers (st) pads all inputs in a batch to the length of the longest one, which is undoubtedly very inefficient. VLLM avoids padding entirely during inference. | ||
|
|
||
| <figure markdown="span"> | ||
|  | ||
| <figcaption>X-axis: Throughput (request/s) | ||
| ST: using sentence transformers | ||
| vLLM: using vLLM | ||
| Y-axis: Latency, Time needed for one step (ms) <- logarithmic scale | ||
| The curve lower right is better ↘ | ||
| </figcaption> | ||
| </figure> | ||
|
|
||
|
|
||
| Sentence Transformers (st) suffers a noticeable drop in speed when handling requests with varied input lengths, whereas vLLM does not. | ||
|
|
||
| ### Others | ||
|
|
||
| For models using bidirectional attention, such as BERT, VLLM offers a range of performance optimizations: | ||
|
|
||
| - Optimized CUDA kernels, including FlashAttention and FlashInfer integration | ||
| - CUDA Graphs and `torch.compile` support to reduce overhead and accelerate execution | ||
| - Support for tensor, pipeline, data, and expert parallelism for distributed inference | ||
| - Multiple quantization schemes—GPTQ, AWQ, AutoRound, INT4, INT8, and FP8—for efficient deployment | ||
| - Continuous batching of incoming requests to maximize throughput | ||
|
|
||
| For causal attention models, such as the Qwen3 reranker, the following optimizations are also applicable: | ||
|
|
||
| - Efficient KV cache memory management via PagedAttention | ||
| - Chunked prefill for improved memory handling during long-context processing | ||
| - Prefix caching to accelerate repeated prompt processing | ||
|
|
||
| vLLM’s optimizations are primarily designed for and most effective with causal language models (generative models). For the full list of features, refer to the [vllm documentation](https://docs.vllm.ai/en/latest/features/). | ||
|
|
||
|
|
||
| ## API Reference | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we move the to the API section |
||
|
|
||
|
|
||
| :::mteb.models.vllm_wrapper.VllmWrapperBase | ||
|
|
||
| !!! info | ||
| For all vLLM parameters, please refer to https://docs.vllm.ai/en/latest/configuration/engine_args/. | ||
|
|
||
| :::mteb.models.vllm_wrapper.VllmEncoderWrapper | ||
|
|
||
| :::mteb.models.vllm_wrapper.VllmCrossEncoderWrapper | ||
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to myself for when I rewrite this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve submitted anything I found useful. It’s best to just update this PR to get it ready.