Skip to content

Commit 6f1e7f7

Browse files
NickLucchehmellor
andauthored
[DisaggEverything] Tokens in<>out /generate endpoint (vllm-project#24261)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent d54a18a commit 6f1e7f7

File tree

12 files changed

+822
-9
lines changed

12 files changed

+822
-9
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import httpx
4+
from transformers import AutoTokenizer
5+
6+
GEN_ENDPOINT = "http://localhost:8000/inference/v1/generate"
7+
DUMMY_API_KEY = "empty"
8+
MODEL_NAME = "Qwen/Qwen3-0.6B"
9+
10+
transport = httpx.HTTPTransport()
11+
headers = {"Authorization": f"Bearer {DUMMY_API_KEY}"}
12+
client = httpx.Client(
13+
transport=transport,
14+
base_url=GEN_ENDPOINT,
15+
timeout=600,
16+
headers=headers,
17+
)
18+
messages = [
19+
{"role": "system", "content": "You are a helpful assistant."},
20+
{"role": "user", "content": "How many countries are in the EU?"},
21+
]
22+
23+
24+
def main(client):
25+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26+
token_ids = tokenizer.apply_chat_template(
27+
messages,
28+
add_generation_prompt=True,
29+
enable_thinking=False,
30+
)
31+
payload = {
32+
"model": MODEL_NAME,
33+
"token_ids": token_ids,
34+
"sampling_params": {"max_tokens": 24, "temperature": 0.2, "detokenize": False},
35+
"stream": False,
36+
}
37+
resp = client.post(GEN_ENDPOINT, json=payload)
38+
resp.raise_for_status()
39+
data = resp.json()
40+
print(data)
41+
print("-" * 50)
42+
print("Token generation results:")
43+
res = tokenizer.decode(data["choices"][0]["token_ids"])
44+
print(res)
45+
print("-" * 50)
46+
47+
48+
if __name__ == "__main__":
49+
main(client)

requirements/docs.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ mkdocs-minify-plugin
1010
regex
1111
ruff
1212
pydantic
13+
14+
# For generating argparse docs.
15+
# Adding requirements here should only be used as a last resort.
16+
msgspec # Need for multiple inheritance involving msgspec.Struct
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import httpx
5+
import pytest
6+
import pytest_asyncio
7+
from transformers import AutoTokenizer
8+
9+
from vllm.config import ModelConfig
10+
from vllm.v1.engine.detokenizer import check_stop_strings
11+
12+
from ...utils import RemoteOpenAIServer
13+
14+
MODEL_NAME = "Qwen/Qwen3-0.6B"
15+
GEN_ENDPOINT = "/inference/v1/generate"
16+
17+
18+
def get_vocab_size(model_name):
19+
config = ModelConfig(
20+
model=model_name,
21+
seed=0,
22+
dtype="bfloat16",
23+
)
24+
return config.get_vocab_size()
25+
26+
27+
@pytest.fixture(scope="module")
28+
def tokenizer():
29+
return AutoTokenizer.from_pretrained(MODEL_NAME)
30+
31+
32+
@pytest.fixture(scope="module")
33+
def messages():
34+
return [
35+
{"role": "system", "content": "You are a helpful assistant."},
36+
{"role": "user", "content": "How many countries are in the EU?"},
37+
]
38+
39+
40+
@pytest.fixture(scope="module")
41+
def server(request):
42+
args = [
43+
"--dtype",
44+
"bfloat16",
45+
"--max-model-len",
46+
"1024",
47+
"--enforce-eager",
48+
]
49+
50+
extra_args = getattr(request, "param", None)
51+
if extra_args is not None:
52+
args = args + (
53+
list(extra_args)
54+
if isinstance(extra_args, (list, tuple))
55+
else [str(extra_args)]
56+
)
57+
58+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
59+
yield remote_server
60+
61+
62+
@pytest_asyncio.fixture
63+
async def client(server: RemoteOpenAIServer):
64+
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
65+
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
66+
async with httpx.AsyncClient(
67+
transport=transport,
68+
base_url=server.url_root,
69+
timeout=600,
70+
headers=headers,
71+
) as c:
72+
yield c
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_generate_endpoint(client):
77+
payload = {
78+
"model": MODEL_NAME,
79+
"token_ids": [1, 2, 3],
80+
"sampling_params": {"max_tokens": 5},
81+
"stream": False,
82+
}
83+
resp = await client.post(GEN_ENDPOINT, json=payload)
84+
resp.raise_for_status()
85+
data = resp.json()
86+
assert "choices" in data
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_same_response_as_chat_completions(client, tokenizer, messages):
91+
token_ids = tokenizer.apply_chat_template(
92+
messages,
93+
add_generation_prompt=True,
94+
enable_thinking=False, # default with Qwen3
95+
)
96+
for ignore_eos in [True, False]:
97+
payload = {
98+
"model": MODEL_NAME,
99+
"token_ids": token_ids,
100+
"sampling_params": {
101+
"max_tokens": 24,
102+
"temperature": 0.0,
103+
# NOTE coordinator will set this to skip detokenization
104+
"detokenize": False,
105+
"ignore_eos": ignore_eos,
106+
},
107+
"stream": False,
108+
}
109+
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
110+
generate_data = generate_resp.json()
111+
generate_res = tokenizer.decode(
112+
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
113+
)
114+
115+
payload = {
116+
"model": MODEL_NAME,
117+
"messages": messages,
118+
"max_tokens": 24,
119+
"temperature": 0.0,
120+
"stream": False,
121+
"ignore_eos": ignore_eos,
122+
"chat_template_kwargs": dict(enable_thinking=False),
123+
}
124+
completions_resp = await client.post("/v1/chat/completions", json=payload)
125+
completions_data = completions_resp.json()
126+
completions_res = completions_data["choices"][0]["message"]["content"]
127+
128+
assert generate_res == completions_res
129+
130+
131+
@pytest.mark.asyncio
132+
async def test_stop_string_workflow(client, tokenizer, messages):
133+
token_ids = tokenizer.apply_chat_template(
134+
messages,
135+
add_generation_prompt=True,
136+
enable_thinking=False, # default with Qwen3
137+
)
138+
payload = {
139+
"model": MODEL_NAME,
140+
"token_ids": token_ids,
141+
"sampling_params": {
142+
"max_tokens": 24,
143+
"temperature": 0.0,
144+
"detokenize": False,
145+
# stop strings are only supported when detokenize is True.
146+
"stop": ["27 member"],
147+
},
148+
# TODO stream test is much more interesting
149+
"stream": False,
150+
}
151+
with pytest.raises(httpx.HTTPStatusError):
152+
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
153+
generate_resp.raise_for_status()
154+
155+
payload["sampling_params"]["stop"] = None
156+
generate_resp = await client.post(
157+
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
158+
)
159+
generate_data = generate_resp.json()
160+
generate_res = tokenizer.decode(
161+
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
162+
)
163+
164+
# NOTE This is under the responsibility of the coordinator
165+
# stop_checker = StopChecker(
166+
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
167+
# )
168+
stop_str, truncate_to = check_stop_strings(
169+
generate_res, len(generate_res), ["27 member"], False
170+
)
171+
assert stop_str == "27 member"
172+
# abort request that hit stop string (requires tokens-only mode)
173+
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
174+
# res.raise_for_status()
175+
generate_res = generate_res[:truncate_to]
176+
177+
# Get stop_str response from chat completions
178+
payload = {
179+
"model": MODEL_NAME,
180+
"messages": messages,
181+
"max_tokens": 24,
182+
"temperature": 0.0,
183+
"stream": False,
184+
"stop": ["27 member"],
185+
"chat_template_kwargs": dict(enable_thinking=False),
186+
}
187+
completions_resp = await client.post("/v1/chat/completions", json=payload)
188+
completions_data = completions_resp.json()
189+
completions_res = completions_data["choices"][0]["message"]["content"]
190+
assert generate_res == completions_res
191+
192+
193+
@pytest.mark.asyncio
194+
@pytest.mark.parametrize(
195+
"server",
196+
[
197+
[
198+
"--enable-lora",
199+
"--lora-modules",
200+
"Alice=charent/self_cognition_Alice",
201+
"Bob=charent/self_cognition_Bob",
202+
"--max-lora-rank",
203+
"64",
204+
"--max-cpu-loras",
205+
"2",
206+
]
207+
],
208+
indirect=True,
209+
)
210+
async def test_generate_with_lora_adapter(client, tokenizer, messages):
211+
# Verify adapters are listed
212+
models_resp = await client.get("/v1/models")
213+
models_resp.raise_for_status()
214+
models = {m["id"] for m in models_resp.json().get("data", [])}
215+
assert {"Alice", "Bob"}.issubset(models)
216+
217+
# Generate using a LoRA adapter by specifying its name as the model
218+
payload = {
219+
"model": "Alice",
220+
"token_ids": [1, 2, 3],
221+
"sampling_params": {"max_tokens": 5},
222+
"stream": False,
223+
}
224+
resp = await client.post(GEN_ENDPOINT, json=payload)
225+
resp.raise_for_status()
226+
data = resp.json()
227+
assert "choices" in data
228+
229+
token_ids = tokenizer.apply_chat_template(
230+
messages,
231+
add_generation_prompt=True,
232+
enable_thinking=False, # default with Qwen3
233+
)
234+
payload = {
235+
"model": "Alice",
236+
"token_ids": token_ids,
237+
"sampling_params": {
238+
"max_tokens": 24,
239+
"temperature": 0.0,
240+
"detokenize": False,
241+
},
242+
"stream": False,
243+
}
244+
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
245+
generate_data = generate_resp.json()
246+
generate_res = tokenizer.decode(
247+
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
248+
)
249+
250+
payload = {
251+
"model": "Alice",
252+
"messages": messages,
253+
"max_tokens": 24,
254+
"temperature": 0.0,
255+
"stream": False,
256+
"chat_template_kwargs": dict(enable_thinking=False),
257+
}
258+
completions_resp = await client.post("/v1/chat/completions", json=payload)
259+
completions_data = completions_resp.json()
260+
completions_res = completions_data["choices"][0]["message"]["content"]
261+
262+
assert generate_res == completions_res

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ class EngineArgs:
566566
kv_offloading_backend: KVOffloadingBackend | None = (
567567
CacheConfig.kv_offloading_backend
568568
)
569+
tokens_only: bool = False
569570

570571
def __post_init__(self):
571572
# support `EngineArgs(compilation_config={...})`
@@ -1495,6 +1496,10 @@ def create_engine_config(
14951496
else ParallelConfig.data_parallel_rpc_port
14961497
)
14971498

1499+
if self.tokens_only and not model_config.skip_tokenizer_init:
1500+
model_config.skip_tokenizer_init = True
1501+
logger.info("Skipping tokenizer initialization for tokens-only mode.")
1502+
14981503
# Forward the deprecated CLI args to the EPLB config.
14991504
if self.num_redundant_experts is not None:
15001505
self.eplb_config.num_redundant_experts = self.num_redundant_experts

0 commit comments

Comments
 (0)