Skip to content

[Perf] Optimize maxsim scores computation for pooling models, 13.9% E2E throughput improvement#35330

Merged
noooop merged 2 commits intomainfrom
wentao-optimize-maxsim-calculation
Feb 26, 2026
Merged

[Perf] Optimize maxsim scores computation for pooling models, 13.9% E2E throughput improvement#35330
noooop merged 2 commits intomainfrom
wentao-optimize-maxsim-calculation

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Feb 25, 2026

Purpose

Optimize maxsim scores computation for pooling models

Originally it is calculated in CPU, now we calculate it in GPU and using the batched version, so we get a lot of performance improvement

Test

Acc

Covered in unit tests

tests/entrypoints/pooling/score/test_online_colbert.py::TestColBERTOnline::test_score
tests/entrypoints/pooling/score/test_online_colbert.py::TestColBERTOnline::test_rerank
tests/entrypoints/pooling/score/test_online_colbert.py::TestColBERTOnline::test_rerank_top_n
...

Perf

vllm serve --model jinaai/jina-colbert-v2 --runner pooling --port 9256 --enforce-eager --max-model-len 4096 --max-num-batched-tokens 4096 --disable-log-stats --hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' --trust-remote-code

vllm bench serve --model jinaai/jina-colbert-v2 --backend vllm-rerank --endpoint /v1/rerank --host 127.0.0.1 --port 9256 --dataset-name random-rerank --num-prompts 2000 --request-rate inf --max-concurrency 64 --seed 0 --random-input-len 2048 --random-range-ratio 0.5 --percentile-metrics e2el --metric-percentiles 50,95,99

# now
============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  12.41     
Total input tokens:                      4116113   
Request throughput (req/s):              161.14    
Total token throughput (tok/s):          331637.66 
----------------End-to-end Latency----------------
Mean E2EL (ms):                          391.46    
Median E2EL (ms):                        395.45    
P50 E2EL (ms):                           395.45    
P95 E2EL (ms):                           421.21    
P99 E2EL (ms):                           431.76    
==================================================

# main
============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  14.15     
Total input tokens:                      4116113   
Request throughput (req/s):              141.38    
Total token throughput (tok/s):          290961.01 
----------------End-to-end Latency----------------
Mean E2EL (ms):                          446.27    
Median E2EL (ms):                        440.32    
P50 E2EL (ms):                           440.32    
P95 E2EL (ms):                           516.98    
P99 E2EL (ms):                           531.58    
==================================================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from noooop as a code owner February 25, 2026 20:41
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 25, 2026
@mergify mergify Bot added the frontend label Feb 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the maxsim score computation by moving it to the GPU and processing in batches, which is a significant performance improvement. The implementation of the new batched function compute_maxsim_scores is mostly solid. However, I've identified a performance issue within the batching logic itself. The method for determining the batch size to avoid oversized memory allocations is inefficient and can be improved. I've provided a suggestion to refactor this part for better performance.

Comment thread vllm/entrypoints/pooling/score/utils.py
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Copy Markdown
Collaborator

@noooop noooop left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

@noooop noooop enabled auto-merge (squash) February 26, 2026 03:50
@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Feb 26, 2026

(I’m not sure if it will work when the API server count is greater than 1. I have reservations about using the GPU in the API server (or during pre-processing and post-processing stages). There might be a risk of OOM or other weird CUDA errors. However, computing maxsim scores on the GPU is indeed better.)

@noooop noooop merged commit 99c7892 into main Feb 26, 2026
54 of 55 checks passed
@noooop noooop deleted the wentao-optimize-maxsim-calculation branch February 26, 2026 17:14
@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 26, 2026

@yewentao256 can you test this with api-server-count>1 without DP? I have concerns about the API server using GPU resources

@yewentao256
Copy link
Copy Markdown
Member Author

yewentao256 commented Feb 26, 2026

(I’m not sure if it will work when the API server count is greater than 1. I have reservations about using the GPU in the API server (or during pre-processing and post-processing stages). There might be a risk of OOM or other weird CUDA errors. However, computing maxsim scores on the GPU is indeed better.)

No worries, we can fix accordingly if the issue raised. With the current config max_score_matrix_elements=16_000_000 and batch_size=16 , the GPU memory would take up ~40MB to ~200MB, so I think we don't need to worry about OOM

if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")

compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you should be using torch.cuda.is_available() directly. You should use current_platform.is_cuda() at least

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#35427
Should be solved in this pr

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 26, 2026

No worries, we can fix accordingly if the issue raised

Waiting for issues to be reported is not a good testing strategy. We should have raised this PR with other contributors before merging

stakeswky pushed a commit to stakeswky/vllm that referenced this pull request Feb 26, 2026
…2E throughput improvement (vllm-project#35330)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: stakeswky <stakeswky@gmail.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…2E throughput improvement (vllm-project#35330)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
roipony pushed a commit to roipony/vllm that referenced this pull request Apr 20, 2026
…ring

Replaces the vanilla padded-bmm MaxSim (PR vllm-project#35330, vllm-project#38620) with vendored
flash-maxsim Triton kernels for ColBERT/ColPali document scoring.

Three scoring paths, with automatic fallback:
  1. Zero-copy (default): project hidden_states once, rerank kernel
     reads doc slices directly from projected_batch. No torch.cat,
     no score-matrix materialization.
  2. Flash-packed (when zerocopy disabled or params not compatible):
     torch.cat + single fused kernel call, no padding.
  3. Vanilla (CPU, d<16, or no Triton): original padded bmm.

Key results on A100 80GB with ColBERT:
  - Kernel speedup on varlen docs: 100-3000x vs vanilla padded bmm
  - E2E throughput: +15-23% at 500+ docs/req (reranking workloads)
  - P95 latency: 13-24% lower
  - Score parity: max_abs_diff < 0.001 on 5K real docs,
    top-3 rankings identical

Kernel correctness: max_err=4e-6 vs fp32 reference. Falls back to
vanilla for CPU tensors, embedding dim < 16, chunked-prefill, or
when pooling params request matryoshka truncation / activation off.

Addresses: vllm-project#38282
Signed-off-by: roi.pony <roi.pony@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants