Skip to content
174 changes: 99 additions & 75 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ReqState:

# For streaming output
last_output_offset: int = 0
last_text_offset: int = 0

# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
Expand Down Expand Up @@ -1147,90 +1148,110 @@ async def _wait_one_response(
)
continue

# Drain all pending outputs atomically. For streaming, every
# chunk must be yielded to avoid dropping token deltas. For
# non-streaming only the latest cumulative output matters.
pending = state.out_list if is_stream else state.out_list[-1:]
# Drain all pending outputs atomically.
# With incremental streaming output, each chunk carries only a
# delta, so every queued chunk must be yielded to avoid dropping
# token ids. Without it, outputs are cumulative and only the
# latest chunk contains the full result, so we can safely skip
# intermediate ones.
incremental_stream = (
is_stream and self.server_args.incremental_streaming_output
)
out_list = state.out_list
state.out_list = []
finished = state.finished
state.event.clear()

for i, out in enumerate(pending):
is_last = i == len(pending) - 1

if finished and is_last:
# For non-streaming cases, response has not been sent yet (`response_sent_to_client_time` has not been set yet).
# Record response sent time right before we log finished results and metrics.
if not state.time_stats.response_sent_to_client_time:
state.time_stats.set_response_sent_to_client_time()
out["meta_info"][
"response_sent_to_client_ts"
] = state.time_stats.get_response_sent_to_client_realtime()
self.request_logger.log_finished_request(
obj,
out,
is_multimodal_gen=self.model_config.is_multimodal_gen,
request=request,
if incremental_stream and len(out_list) > 1:
if len(out_list) >= 20:
logger.warning(
"Streaming backlog: rid=%s, coalescing %d queued chunks into one. "
"This may inflate P99 ITL for affected requests.",
obj.rid,
len(out_list),
)
# Coalesce all deltas into a single chunk. Both text and
# output_ids are incremental, so we concatenate them; all
# other fields (meta_info, etc.) are taken from the last chunk.
out = dict(out_list[-1])
if "output_ids" in out:
out["output_ids"] = [
id for chunk in out_list for id in chunk["output_ids"]
]
if "text" in out:
out["text"] = "".join(chunk["text"] for chunk in out_list)
else:
out = out_list[-1]

if self.request_metrics_exporter_manager.exporter_enabled():
# Asynchronously write metrics for this request using the exporter manager.
asyncio.create_task(
self.request_metrics_exporter_manager.write_record(obj, out)
)
if finished:
# For non-streaming cases, response has not been sent yet (`response_sent_to_client_time` has not been set yet).
# Record response sent time right before we log finished results and metrics.
if not state.time_stats.response_sent_to_client_time:
state.time_stats.set_response_sent_to_client_time()
out["meta_info"][
"response_sent_to_client_ts"
] = state.time_stats.get_response_sent_to_client_realtime()
self.request_logger.log_finished_request(
obj,
out,
is_multimodal_gen=self.model_config.is_multimodal_gen,
request=request,
)

# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code")
== HTTPStatus.BAD_REQUEST
):
if not is_stream:
raise ValueError(finish_reason["message"])
else:
yield out
break

if finish_reason.get("type") == "abort" and finish_reason.get(
"status_code"
) in (
HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.INTERNAL_SERVER_ERROR,
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
if state.obj.rid in self.rid_to_state:
del self.rid_to_state[state.obj.rid]

# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
await self.lora_registry.release(state.obj.lora_id)
if not is_stream:
raise fastapi.HTTPException(
status_code=finish_reason["status_code"],
detail=finish_reason["message"],
)
else:
yield out
break
yield out
break

if is_stream:
# Record response sent time right before we send response.
if not state.time_stats.response_sent_to_client_time:
state.time_stats.set_response_sent_to_client_time()
out["meta_info"][
"response_sent_to_client_ts"
] = state.time_stats.get_response_sent_to_client_realtime()
yield out
if self.request_metrics_exporter_manager.exporter_enabled():
# Asynchronously write metrics for this request using the exporter manager.
asyncio.create_task(
self.request_metrics_exporter_manager.write_record(obj, out)
)

if finished:
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
):
if not is_stream:
raise ValueError(finish_reason["message"])
else:
yield out
break

if finish_reason.get("type") == "abort" and finish_reason.get(
"status_code"
) in (
HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.INTERNAL_SERVER_ERROR,
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
if state.obj.rid in self.rid_to_state:
del self.rid_to_state[state.obj.rid]

# Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path:
await self.lora_registry.release(state.obj.lora_id)
if not is_stream:
raise fastapi.HTTPException(
status_code=finish_reason["status_code"],
detail=finish_reason["message"],
)
else:
yield out
break
yield out
break

if is_stream:
# Record response sent time right before we send response.
if not state.time_stats.response_sent_to_client_time:
state.time_stats.set_response_sent_to_client_time()
out["meta_info"][
"response_sent_to_client_ts"
] = state.time_stats.get_response_sent_to_client_realtime()
yield out

if not is_stream:
if (
request is not None
Expand Down Expand Up @@ -1589,12 +1610,15 @@ def _handle_batch_output(
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids[state.last_output_offset :]
state.last_output_offset = len(state.output_ids)
output_text = state.text[state.last_text_offset :]
state.last_text_offset = len(state.text)
else:
state.output_ids.extend(recv_obj.output_ids[i])
output_token_ids = state.output_ids.copy()
output_text = state.text

out_dict = {
"text": state.text,
"text": output_text,
"output_ids": output_token_ids,
"meta_info": meta_info,
}
Expand Down
2 changes: 1 addition & 1 deletion test/registered/spec/eagle/test_eagle_infer_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_gsm8k(self):
if speculative_eagle_topk == 1:
self.assertGreater(avg_spec_accept_length, 2.5)
else:
self.assertGreater(avg_spec_accept_length, 3.49)
self.assertGreater(avg_spec_accept_length, 3.47)

# Wait a little bit so that the memory check happens.
time.sleep(4)
Expand Down
Loading