Skip to content

[Perf] Batched projector for pooling model embed, 1.8% throughput improvement#39533

Open
yewentao256 wants to merge 6 commits intomainfrom
wentao-optimize-pooling-by-ragged-tensor
Open

[Perf] Batched projector for pooling model embed, 1.8% throughput improvement#39533
yewentao256 wants to merge 6 commits intomainfrom
wentao-optimize-pooling-by-ragged-tensor

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Apr 10, 2026

Purpose

Optimize pooling model by using ragged tensor so that we can do project in batch

CC: @noooop

Test

Acc

Covered in added unit test, current e2e unit test

Perf

Generate data

python - <<'PY'
import json
model = "Qwen/Qwen3-Embedding-0.6B"
repeats = [8, 12, 16, 24, 32, 40, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192]
unit = "retrieval benchmark sentence "
prompts = [(unit * r).strip() for r in repeats]
with open("/tmp/token_embed.json", "w") as f:
    json.dump({
        "model": model,
        "task": "token_embed",
        "truncate_prompt_tokens": -1,
        "encoding_format": "base64",
        "input": prompts,
    }, f)
PY

vllm serve Qwen/Qwen3-Embedding-0.6B --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128

import http.client, threading, time, statistics, math, json

HOST, PORT, PATH = "127.0.0.1", 9256, "/pooling"
BODY = open("/tmp/token_embed.json", "rb").read()
HEADERS = {"Content-Type": "application/json", "Content-Length": str(len(BODY))}
CONCURRENCY = 4
DURATION = 120

def pct(xs, p):
    xs = sorted(xs)
    k = max(0, min(len(xs) - 1, math.ceil(len(xs) * p / 100) - 1))
    return xs[k]

all_lat = [[] for _ in range(CONCURRENCY)]
ok = [0] * CONCURRENCY
err = [0] * CONCURRENCY
prompt_tokens = [0] * CONCURRENCY

deadline = time.perf_counter() + DURATION

def worker(i):
    while time.perf_counter() < deadline:
        conn = None
        try:
            conn = http.client.HTTPConnection(HOST, PORT, timeout=180)
            t0 = time.perf_counter()
            conn.request("POST", PATH, body=BODY, headers=HEADERS)
            resp = conn.getresponse()
            raw = resp.read()
            dt = (time.perf_counter() - t0) * 1000

            if resp.status == 200:
                ok[i] += 1
                all_lat[i].append(dt)
                data = json.loads(raw)
                prompt_tokens[i] += data.get("usage", {}).get("prompt_tokens", 0)
            else:
                err[i] += 1
                print(f"[worker {i}] HTTP {resp.status}")
        except Exception as e:
            err[i] += 1
            print(f"[worker {i}] {type(e).__name__}: {e}")
        finally:
            if conn is not None:
                try:
                    conn.close()
                except Exception:
                    pass

threads = [threading.Thread(target=worker, args=(i,)) for i in range(CONCURRENCY)]
t0 = time.perf_counter()
for t in threads:
    t.start()
for t in threads:
    t.join()
elapsed = time.perf_counter() - t0

lat = [x for xs in all_lat for x in xs]
total_ok = sum(ok)
total_err = sum(err)
total_prompt_tokens = sum(prompt_tokens)

print(f"duration_s={elapsed:.2f}")
print(f"concurrency={CONCURRENCY}")
print(f"completed={total_ok}")
print(f"errors={total_err}")
print(f"req_per_s={total_ok / elapsed:.2f}")
print(f"input_tokens_per_s={total_prompt_tokens / elapsed:.2f}")
if lat:
    print(f"mean_ms={statistics.mean(lat):.2f}")
    print(f"p50_ms={pct(lat, 50):.2f}")
    print(f"p95_ms={pct(lat, 95):.2f}")
    print(f"p99_ms={pct(lat, 99):.2f}")

And we get

duration_s=120.23
concurrency=4
completed=1649
errors=0
req_per_s=13.71
input_tokens_per_s=55463.08
mean_ms=261.99
p50_ms=260.65
p95_ms=276.62
p99_ms=328.27
(yewentao256) [yewentao256@nm-frk-h200-01-preserve vllm-source]$ python bench.py 
duration_s=120.19
concurrency=4
completed=1645
errors=0
req_per_s=13.69
input_tokens_per_s=55347.88
mean_ms=262.20
p50_ms=260.11
p95_ms=281.72
p99_ms=333.30

# main
duration_s=120.22
concurrency=4
completed=1617
errors=0
req_per_s=13.45
input_tokens_per_s=54391.20
mean_ms=267.22
p50_ms=264.07
p95_ms=296.72
p99_ms=371.64

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from noooop as a code owner April 10, 2026 20:45
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces RaggedTokenBatch to optimize token-wise pooling by enabling batch projection and uniform post-processing, refactors TokenEmbeddingPoolerHead to support these ragged batches, and adds corresponding tests. The review feedback recommends adding a defensive check for empty pooling_params in the forward_ragged method to prevent a potential IndexError.

Comment thread vllm/model_executor/layers/pooler/tokwise/heads.py
Comment thread vllm/model_executor/layers/pooler/tokwise/methods.py
Comment thread vllm/model_executor/layers/pooler/tokwise/heads.py Outdated
@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Apr 11, 2026

BAAI/bge-m3 uses bidirectional self-attention and cannot use --enable-chunked-prefill. I think your test is not very reasonable.

try https://huggingface.co/Qwen/Qwen3-Embedding-0.6B

@DarkLight1337
Copy link
Copy Markdown
Member

BAAI/bge-m3 uses bidirectional self-attention and cannot use --enable-chunked-prefill. I think your test is not very reasonable.

try https://huggingface.co/Qwen/Qwen3-Embedding-0.6B

I think we should always test both models without chunked prefill and those that support it.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256
Copy link
Copy Markdown
Member Author

@noooop

This PR won't affect chunked prefill case

if not self.enable_chunked_prefill:
            return RaggedTokenBatch.from_lengths(
                values=hidden_states,
                lengths_cpu=pooling_cursor.num_scheduled_tokens_cpu,
            )

Benchmark:

python - <<'PY'
import json
model = "Qwen/Qwen3-Embedding-0.6B"
repeats = [8, 12, 16, 24, 32, 40, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192]
unit = "retrieval benchmark sentence "
prompts = [(unit * r).strip() for r in repeats]
with open("/tmp/token_embed.json", "w") as f:
    json.dump({
        "model": model,
        "task": "token_embed",
        "truncate_prompt_tokens": -1,
        "encoding_format": "base64",
        "input": prompts,
    }, f)
PY

vllm serve Qwen/Qwen3-Embedding-0.6B --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128

python bench.py

# now
duration_s=120.17
concurrency=4
completed=1622
errors=0
req_per_s=13.50
input_tokens_per_s=54583.71
mean_ms=266.26
p50_ms=264.88
p95_ms=284.75
p99_ms=347.65
# main
duration_s=120.22
concurrency=4
completed=1617
errors=0
req_per_s=13.45
input_tokens_per_s=54391.20
mean_ms=267.22
p50_ms=264.07
p95_ms=296.72
p99_ms=371.64

@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Apr 13, 2026

vllm serve BAAI/bge-m3 --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128

Using --enable-chunked-prefill for BAAI/bge-m3 does not guarantee correct results.

Copy link
Copy Markdown
Member Author

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noooop Thanks! After some consideration, I think we should add support for chunked prefill case as well, code pushed, please take another look.

Test similar to comments above

vllm serve Qwen/Qwen3-Embedding-0.6B --runner pooling --port 9256 --enable-chunked-prefill --max-num-batched-tokens 512 --long-prefill-token-threshold 128

# now
duration_s=120.23
concurrency=4
completed=1649
errors=0
req_per_s=13.71
input_tokens_per_s=55463.08
mean_ms=261.99
p50_ms=260.65
p95_ms=276.62
p99_ms=328.27
(yewentao256) [yewentao256@nm-frk-h200-01-preserve vllm-source]$ python bench.py 
duration_s=120.19
concurrency=4
completed=1645
errors=0
req_per_s=13.69
input_tokens_per_s=55347.88
mean_ms=262.20
p50_ms=260.11
p95_ms=281.72
p99_ms=333.30

# main
duration_s=120.22
concurrency=4
completed=1617
errors=0
req_per_s=13.45
input_tokens_per_s=54391.20
mean_ms=267.22
p50_ms=264.07
p95_ms=296.72
p99_ms=371.64

@yewentao256 yewentao256 changed the title [Perf] Batched projector for pooling model embed, 7.4% throughput improvement [Perf] Batched projector for pooling model embed, 1.8% throughput improvement Apr 14, 2026
Comment on lines 128 to +129
pooling_cursor = pooling_metadata.get_pooling_cursor()
hidden_states_lst = [
hidden_states[first : last + 1]
for first, last in zip(
pooling_cursor.first_token_indices_gpu.tolist(),
pooling_cursor.last_token_indices_gpu.tolist(),
if self.enable_chunked_prefill:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you looking to optimize this code?

Please wait for #37861 landing. It will disable pooling multi-task support.

After that, we can be sure that only requests for the same task will be batched together, and we can revert #36614.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants