[Perf] Batched projector for pooling model embed, 1.8% throughput improvement#39533
[Perf] Batched projector for pooling model embed, 1.8% throughput improvement#39533yewentao256 wants to merge 6 commits intomainfrom
Conversation
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
Code Review
This pull request introduces RaggedTokenBatch to optimize token-wise pooling by enabling batch projection and uniform post-processing, refactors TokenEmbeddingPoolerHead to support these ragged batches, and adds corresponding tests. The review feedback recommends adding a defensive check for empty pooling_params in the forward_ragged method to prevent a potential IndexError.
|
BAAI/bge-m3 uses bidirectional self-attention and cannot use --enable-chunked-prefill. I think your test is not very reasonable. |
I think we should always test both models without chunked prefill and those that support it. |
Signed-off-by: yewentao256 <zhyanwentao@126.com>
|
This PR won't affect chunked prefill case if not self.enable_chunked_prefill:
return RaggedTokenBatch.from_lengths(
values=hidden_states,
lengths_cpu=pooling_cursor.num_scheduled_tokens_cpu,
)Benchmark: python - <<'PY'
import json
model = "Qwen/Qwen3-Embedding-0.6B"
repeats = [8, 12, 16, 24, 32, 40, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192]
unit = "retrieval benchmark sentence "
prompts = [(unit * r).strip() for r in repeats]
with open("/tmp/token_embed.json", "w") as f:
json.dump({
"model": model,
"task": "token_embed",
"truncate_prompt_tokens": -1,
"encoding_format": "base64",
"input": prompts,
}, f)
PY
# now
duration_s=120.17
concurrency=4
completed=1622
errors=0
req_per_s=13.50
input_tokens_per_s=54583.71
mean_ms=266.26
p50_ms=264.88
p95_ms=284.75
p99_ms=347.65
# main
duration_s=120.22
concurrency=4
completed=1617
errors=0
req_per_s=13.45
input_tokens_per_s=54391.20
mean_ms=267.22
p50_ms=264.07
p95_ms=296.72
p99_ms=371.64 |
Using --enable-chunked-prefill for BAAI/bge-m3 does not guarantee correct results. |
Signed-off-by: yewentao256 <zhyanwentao@126.com>
There was a problem hiding this comment.
@noooop Thanks! After some consideration, I think we should add support for chunked prefill case as well, code pushed, please take another look.
Test similar to comments above
vllm serve Qwen/Qwen3-Embedding-0.6B --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128
# now
duration_s=120.23
concurrency=4
completed=1649
errors=0
req_per_s=13.71
input_tokens_per_s=55463.08
mean_ms=261.99
p50_ms=260.65
p95_ms=276.62
p99_ms=328.27
(yewentao256) [yewentao256@nm-frk-h200-01-preserve vllm-source]$ python bench.py
duration_s=120.19
concurrency=4
completed=1645
errors=0
req_per_s=13.69
input_tokens_per_s=55347.88
mean_ms=262.20
p50_ms=260.11
p95_ms=281.72
p99_ms=333.30
# main
duration_s=120.22
concurrency=4
completed=1617
errors=0
req_per_s=13.45
input_tokens_per_s=54391.20
mean_ms=267.22
p50_ms=264.07
p95_ms=296.72
p99_ms=371.64| pooling_cursor = pooling_metadata.get_pooling_cursor() | ||
| hidden_states_lst = [ | ||
| hidden_states[first : last + 1] | ||
| for first, last in zip( | ||
| pooling_cursor.first_token_indices_gpu.tolist(), | ||
| pooling_cursor.last_token_indices_gpu.tolist(), | ||
| if self.enable_chunked_prefill: |
Purpose
Optimize pooling model by using ragged tensor so that we can do project in batch
CC: @noooop
Test
Acc
Covered in added unit test, current e2e unit test
Perf
Generate data
vllm serve Qwen/Qwen3-Embedding-0.6B --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128And we get
duration_s=120.23 concurrency=4 completed=1649 errors=0 req_per_s=13.71 input_tokens_per_s=55463.08 mean_ms=261.99 p50_ms=260.65 p95_ms=276.62 p99_ms=328.27 (yewentao256) [yewentao256@nm-frk-h200-01-preserve vllm-source]$ python bench.py duration_s=120.19 concurrency=4 completed=1645 errors=0 req_per_s=13.69 input_tokens_per_s=55347.88 mean_ms=262.20 p50_ms=260.11 p95_ms=281.72 p99_ms=333.30 # main duration_s=120.22 concurrency=4 completed=1617 errors=0 req_per_s=13.45 input_tokens_per_s=54391.20 mean_ms=267.22 p50_ms=264.07 p95_ms=296.72 p99_ms=371.64