Skip to content

Commit e4485a0

Browse files
zhongdaor-nvnv-tusharma
authored andcommitted
fix: Add support for single element arrays for chat and completions prompts (#3482)
Signed-off-by: zhongdaor <[email protected]>
1 parent bb3b1c2 commit e4485a0

File tree

4 files changed

+208
-19
lines changed

4 files changed

+208
-19
lines changed

components/src/dynamo/sglang/multimodal_utils/multimodal_chat_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def multimodal_request_to_sglang(raw_request, tokenizer, chat_template):
3838
sglang_request = {
3939
"model": raw_request.model,
4040
"token_ids": input_ids,
41-
"batch_token_ids": None,
4241
"stop_conditions": {"max_tokens": raw_request.max_tokens or None},
4342
"sampling_options": {"temperature": raw_request.temperature or 0.7},
4443
"eos_token_ids": [tokenizer.eos_token_id],

lib/llm/src/preprocessor.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
pub mod prompt;
1515
pub mod tools;
1616

17-
use anyhow::Result;
17+
use anyhow::{Result, bail};
1818
use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
1919
use futures::Stream;
2020
use futures::stream::{self, StreamExt};
2121
use prompt::OAIPromptFormatter;
22-
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
2322
use std::{collections::HashMap, pin::Pin, sync::Arc};
2423
use tracing;
2524

@@ -282,8 +281,10 @@ impl OpenAIPreprocessor {
282281
if token_batches.len() == 1 {
283282
builder.token_ids(token_batches[0].clone());
284283
} else {
285-
builder.batch_token_ids(Some(token_batches));
286-
builder.token_ids(vec![]);
284+
bail!(
285+
"Batch token input not supported for more than one token in requests (got {})",
286+
token_batches.len()
287+
);
287288
}
288289
}
289290
}
@@ -345,16 +346,15 @@ impl OpenAIPreprocessor {
345346
builder.token_ids(tokens_vec);
346347
}
347348
TextInput::Batch(texts) => {
348-
let token_batches: Vec<Vec<u32>> = texts
349-
.par_iter()
350-
.map(|text| {
351-
self.tokenizer
352-
.encode(text)
353-
.map(|encoded| encoded.token_ids().to_vec())
354-
})
355-
.collect::<Result<Vec<_>>>()?;
356-
builder.batch_token_ids(Some(token_batches));
357-
builder.token_ids(vec![]);
349+
if texts.len() == 1 {
350+
let encoding = self.tokenizer.encode(&texts[0])?;
351+
builder.token_ids(encoding.token_ids().to_vec());
352+
} else {
353+
bail!(
354+
"Batch text input not supported for more than one text in requests (got {})",
355+
texts.len()
356+
);
357+
}
358358
}
359359
}
360360
}

lib/llm/src/protocols/common/preprocessor.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ pub struct PreprocessedRequest {
1818
/// Type of prompt
1919
pub token_ids: Vec<TokenIdType>,
2020

21-
/// Batch Token Ids = for batch completion requests (i.e using ArrayOfIntegerArray type from OpenAI /completions)
22-
#[builder(default)]
23-
pub batch_token_ids: Option<Vec<Vec<TokenIdType>>>,
24-
2521
/// StopConditions are conditions that the inference engine will use to stop generation.
2622
pub stop_conditions: StopConditions,
2723

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from __future__ import annotations
6+
7+
import logging
8+
import os
9+
import shutil
10+
import time
11+
from typing import Any, Dict
12+
13+
import pytest
14+
import requests
15+
16+
from tests.conftest import EtcdServer, NatsServer
17+
from tests.utils.constants import QWEN
18+
from tests.utils.managed_process import ManagedProcess
19+
from tests.utils.payloads import check_models_api
20+
21+
logger = logging.getLogger(__name__)
22+
23+
TEST_MODEL = QWEN
24+
25+
26+
class DynamoFrontendProcess(ManagedProcess):
27+
"""Process manager for Dynamo frontend"""
28+
29+
def __init__(self, request):
30+
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
31+
32+
log_dir = f"{request.node.name}_frontend"
33+
34+
# Clean up any existing log directory from previous runs
35+
try:
36+
shutil.rmtree(log_dir)
37+
logger.info(f"Cleaned up existing log directory: {log_dir}")
38+
except FileNotFoundError:
39+
# Directory doesn't exist, which is fine
40+
pass
41+
42+
super().__init__(
43+
command=command,
44+
display_output=True,
45+
terminate_existing=True,
46+
log_dir=log_dir,
47+
)
48+
49+
50+
class MockWorkerProcess(ManagedProcess):
51+
def __init__(self, request, worker_id: str = "mocker-worker"):
52+
self.worker_id = worker_id
53+
54+
command = [
55+
"python3",
56+
"-m",
57+
"dynamo.mocker",
58+
"--model-path",
59+
TEST_MODEL,
60+
"--speedup-ratio",
61+
"100",
62+
]
63+
64+
env = os.environ.copy()
65+
env["DYN_LOG"] = "debug"
66+
env["DYN_SYSTEM_ENABLED"] = "true"
67+
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
68+
env["DYN_SYSTEM_PORT"] = "8083"
69+
70+
log_dir = f"{request.node.name}_{worker_id}"
71+
72+
try:
73+
shutil.rmtree(log_dir)
74+
except FileNotFoundError:
75+
pass
76+
77+
super().__init__(
78+
command=command,
79+
env=env,
80+
health_check_urls=[
81+
("http://localhost:8000/v1/models", check_models_api),
82+
("http://localhost:8083/health", self.is_ready),
83+
],
84+
timeout=300,
85+
display_output=True,
86+
terminate_existing=False,
87+
stragglers=["VLLM::EngineCore"],
88+
straggler_commands=["-m dynamo.mocker"],
89+
log_dir=log_dir,
90+
)
91+
92+
def is_ready(self, response) -> bool:
93+
try:
94+
status = (response.json() or {}).get("status")
95+
except ValueError:
96+
logger.warning("%s health response is not valid JSON", self.worker_id)
97+
return False
98+
99+
is_ready = status == "ready"
100+
if is_ready:
101+
logger.info("%s status is ready", self.worker_id)
102+
else:
103+
logger.warning("%s status is not ready: %s", self.worker_id, status)
104+
return is_ready
105+
106+
107+
def _send_completion_request(
108+
payload: Dict[str, Any],
109+
timeout: int = 180,
110+
) -> requests.Response:
111+
"""Send a text completion request"""
112+
113+
headers = {"Content-Type": "application/json"}
114+
print(f"Sending request: {time.time()}")
115+
116+
response = requests.post(
117+
"http://localhost:8000/v1/completions",
118+
headers=headers,
119+
json=payload,
120+
timeout=timeout,
121+
)
122+
return response
123+
124+
125+
@pytest.fixture(scope="module")
126+
def runtime_services(request):
127+
"""Module-scoped runtime services for this test file."""
128+
with NatsServer(request) as nats_process:
129+
with EtcdServer(request) as etcd_process:
130+
yield nats_process, etcd_process
131+
132+
133+
@pytest.fixture(scope="module")
134+
def start_services(request, runtime_services):
135+
"""Start frontend and worker processes once for this module's tests."""
136+
with DynamoFrontendProcess(request):
137+
logger.info("Frontend started for tests")
138+
with MockWorkerProcess(request):
139+
logger.info("Worker started for tests")
140+
yield
141+
142+
143+
@pytest.mark.usefixtures("start_services")
144+
@pytest.mark.e2e
145+
@pytest.mark.model(TEST_MODEL)
146+
def test_completion_string_prompt() -> None:
147+
payload: Dict[str, Any] = {
148+
"model": TEST_MODEL,
149+
"prompt": "Tell me about Mars",
150+
"max_tokens": 2000,
151+
}
152+
153+
response = _send_completion_request(payload)
154+
155+
assert response.status_code == 200, (
156+
f"Completion request failed with status "
157+
f"{response.status_code}: {response.text}"
158+
)
159+
160+
161+
@pytest.mark.usefixtures("start_services")
162+
@pytest.mark.e2e
163+
@pytest.mark.model(TEST_MODEL)
164+
def test_completion_single_element_array_prompt() -> None:
165+
payload: Dict[str, Any] = {
166+
"model": TEST_MODEL,
167+
"prompt": ["Tell me about Mars"],
168+
"max_tokens": 2000,
169+
}
170+
171+
response = _send_completion_request(payload)
172+
173+
assert response.status_code == 200, (
174+
f"Completion request failed with status "
175+
f"{response.status_code}: {response.text}"
176+
)
177+
178+
179+
@pytest.mark.usefixtures("start_services")
180+
@pytest.mark.e2e
181+
@pytest.mark.model(TEST_MODEL)
182+
def test_completion_multi_element_array_prompt() -> None:
183+
payload: Dict[str, Any] = {
184+
"model": TEST_MODEL,
185+
"prompt": ["Tell me about Mars", "Tell me about Ceres"],
186+
"max_tokens": 2000,
187+
}
188+
189+
response = _send_completion_request(payload)
190+
191+
# request should fail because we are sending multiple prompts
192+
assert (
193+
response.status_code == 500
194+
), f"Request should fail with code 500; response:{response.text}"

0 commit comments

Comments
 (0)