Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
PreTrainedTokenizerFast,
)

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
ASSISTANT_SUFFIX = "Assistant:"

global args
Expand All @@ -51,6 +50,13 @@ def _get_bool_env_var(name: str, default: str = "false") -> bool:
return value.lower() in ("true", "1")


def _create_bench_client_session():
# When the pressure is big, the read buffer could be full before aio thread read
# the content. We increase the read_bufsize from 64K to 10M.
aiohttp_timeout = aiohttp.ClientTimeout(total=6 * 60 * 60)
return aiohttp.ClientSession(timeout=aiohttp_timeout, read_bufsize=10 * 1024**2)


@dataclass
class RequestFuncInput:
prompt: str
Expand Down Expand Up @@ -106,7 +112,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with _create_bench_client_session() as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
Expand Down Expand Up @@ -179,7 +185,7 @@ async def async_request_openai_completions(

prompt = request_func_input.prompt

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"prompt": prompt,
Expand Down Expand Up @@ -261,7 +267,7 @@ async def async_request_truss(

prompt = request_func_input.prompt

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"prompt": prompt,
Expand Down Expand Up @@ -338,7 +344,7 @@ async def async_request_sglang_generate(
api_url = request_func_input.api_url
prompt = request_func_input.prompt

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with _create_bench_client_session() as session:
payload = {
("text" if isinstance(prompt, str) else "input_ids"): prompt,
"sampling_params": {
Expand Down Expand Up @@ -437,7 +443,7 @@ async def async_request_gserver(


async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with _create_bench_client_session() as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
Expand Down
Loading