Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
"dataset_name": "random",
"backend": "openai-chat-omni",
"endpoint": "/v1/chat/completions",
"num_prompts": [4, 16, 32, 64],
"max_concurrency": [1, 4, 8, 16],
"num_prompts": [4, 16, 32, 64, 128],
"max_concurrency": [1, 4, 8, 16, 32],
"random_input_len": 2500,
"random_output_len": 900,
"ignore_eos": true,
"percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
"baseline": {
"mean_ttft_ms": [1000, 3000, 5000, 7000],
"mean_audio_ttfp_ms": [1000, 3000, 5000, 7000],
"mean_audio_rtf": [0.2, 0.35, 0.6, 0.85]
"mean_ttft_ms": [1000, 3000, 5000, 7000, 9000],
"mean_audio_ttfp_ms": [1000, 3000, 5000, 7000, 9000],
"mean_audio_rtf": [0.2, 0.35, 0.6, 0.85, 0.9]
}
},
{
Expand Down
9 changes: 7 additions & 2 deletions tests/e2e/online_serving/test_qwen3_omni_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_batch_token_config(default_path):
return modify_stage_config(
default_path,
updates={
"stages": {1: {"max_num_batched_tokens": 64}},
"stages": {0: {"max_num_batched_tokens": 64}, 1: {"max_num_batched_tokens": 64}},
},
)

Expand Down Expand Up @@ -95,7 +95,12 @@ def get_default_config(default_path):

test_token_params = [
pytest.param(
OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path), use_stage_cli=True),
OmniServerParams(
model=model,
stage_config_path=get_batch_token_config(default_path),
use_stage_cli=True,
server_args=["--async-chunk"],
),
id="batch_token_64",
)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, vllm_config: Any):
self.waiting_for_chunk_running_requests: deque[Any] = deque()
self.requests_with_ready_chunks = set()
self.requests_origin_status = {}
self.requests_num_chunks_sent: dict[str, int] = defaultdict(int)

@classmethod
def create_connector(cls, model_config: Any):
Expand Down Expand Up @@ -117,6 +118,17 @@ def save_async(
pooling_output: Partial pooling output dictionary
request: Request object
"""

# If the request is preempted, skip the already saved chunks.
if request.num_computed_tokens < self.requests_num_chunks_sent.get(request.external_req_id, 0):
logger.warning(
f"Enqueue save_async for request {request.external_req_id}, "
f"request.num_computed_tokens={request.num_computed_tokens}, "
f"previous_chunks_sent={self.requests_num_chunks_sent.get(request.external_req_id, 0)}"
)
return

self.requests_num_chunks_sent[request.external_req_id] = request.num_computed_tokens
task = {
"pooling_output": pooling_output,
"request": request,
Expand Down Expand Up @@ -155,8 +167,7 @@ def _poll_single_request(self, request: Request):

meta = payload_data.get("meta", {})
if self.model_mode == "ar":
merged_payload = self._update_request_payload(external_req_id, payload_data)
request.additional_information = merged_payload
request.additional_information = payload_data
if meta.get("finished"):
self.finished_requests.add(req_id)
else:
Expand Down Expand Up @@ -198,42 +209,6 @@ def _poll_single_request(self, request: Request):

return False

def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]:
"""Update the stored payload for *req_id* with the latest chunk."""
if req_id not in self.request_payload:
self.request_payload[req_id] = payload_data
return payload_data
origin = self.request_payload[req_id]
raw_ok = payload_data.get("meta", {}).pop("override_keys", [])
override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok}

for key, value in payload_data.items():
if isinstance(value, dict):
origin_sub = origin.get(key)
if not isinstance(origin_sub, dict):
continue
for qual, qval in value.items():
if key == "meta" and qual == "finished":
continue
if (key, qual) in override_keys:
continue
osv = origin_sub.get(qual)
if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor):
value[qual] = torch.cat([osv, qval], dim=0)
elif isinstance(qval, list) and isinstance(osv, list):
value[qual] = osv + qval
else:
if key in override_keys:
continue
ov = origin.get(key)
if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor):
payload_data[key] = torch.cat([ov, value], dim=0)
elif isinstance(value, list) and isinstance(ov, list):
payload_data[key] = ov + value

self.request_payload[req_id] = payload_data
return payload_data

def _send_single_request(self, task: dict):
raw_po = task["pooling_output"]
pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po
Expand Down Expand Up @@ -290,6 +265,7 @@ def _send_single_request(self, task: dict):

if is_finished:
self.code_prompt_token_ids.pop(external_req_id, None)
self.requests_num_chunks_sent.pop(external_req_id, None)
cached_ic = getattr(self, "_cached_ic", None)
if cached_ic is not None:
cached_ic.pop(external_req_id, None)
Expand Down Expand Up @@ -327,6 +303,7 @@ def cleanup_sender(self, external_req_id: str) -> None:
self.put_req_chunk.pop(external_req_id, None)
self.request_payload.pop(external_req_id, None)
self.code_prompt_token_ids.pop(external_req_id, None)
self.requests_num_chunks_sent.pop(external_req_id, None)

cached_ic = getattr(self, "_cached_ic", None)
if cached_ic is not None:
Expand Down Expand Up @@ -399,6 +376,11 @@ def postprocess_scheduler_output(
Add additional info for cached requests and
clean up ready chunks from scheduler output.
"""
stage_id = self.connector.stage_id

if stage_id == 0:
return

if requests is not None:
self.attach_cached_additional_information(scheduler_output, requests)
self._clear_chunk_ready(scheduler_output)
Expand All @@ -414,6 +396,8 @@ def attach_cached_additional_information(scheduler_output: Any, requests: dict[s
request = requests.get(req_id) if req_id else None
additional_info = getattr(request, "additional_information", None) if request else None
cached_reqs.additional_information[req_id] = additional_info
if request and additional_info:
request.additional_information = None

def _process_chunk_queue(
self,
Expand Down
21 changes: 11 additions & 10 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,19 +982,10 @@ def _thinker_decode_to_talker_decode(
"""
embed = payload.get("embed", {})
meta = payload.get("meta", {})
ids = payload.get("ids", {})

cached_thinker_decode_embeds = embed.get("cached_decode", None)
thinker_decode_embed = embed.get("decode", None)
start_index = meta.get("num_processed_tokens", 0)
thinker_output_token_ids = ids.get("output", [])
if start_index >= len(thinker_output_token_ids) - 1:
# When the tokens output by the thinker are exhausted, an EOS token needs to be appended.
# Use the finished_flag to mark that all tokens output by thinker have been consumed.
if meta.get("eos_emitted", False):
return self.tts_pad_embed.to(device)
update_dict.setdefault("meta", {})["eos_emitted"] = True
return self.tts_eos_embed.to(device)

if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device)
Expand All @@ -1003,10 +994,20 @@ def _thinker_decode_to_talker_decode(
thinker_decode_embed = thinker_decode_embed.to(device)
cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0)
update_dict.setdefault("embed", {})["cached_decode"] = cached_thinker_decode_embeds
else:

elif thinker_decode_embed is not None:
thinker_embed = thinker_decode_embed
if thinker_embed.device != device:
thinker_embed = thinker_embed.to(device)

else:
# When the tokens output by the thinker are exhausted, an EOS token needs to be appended.
# Use the finished_flag to mark that all tokens output by thinker have been consumed.
if meta.get("eos_emitted", False):
return self.tts_pad_embed.to(device)
update_dict.setdefault("meta", {})["eos_emitted"] = True
return self.tts_eos_embed.to(device)

update_dict.setdefault("embed", {})["decode"] = None
return self.talker.text_projection(thinker_embed).to(device)

Expand Down
39 changes: 20 additions & 19 deletions vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,27 +349,28 @@ def _maybe_cpu(t: Any) -> torch.Tensor | None:
payload.hidden_states.output = torch.cat(
(save_payload.get("hidden_states", {}).get("output"), payload.hidden_states.output), dim=0
)
prefill_shape = payload.embed.prefill.shape[0]
if not is_finished and prefill_shape <= len(prompt_token_ids):
transfer_manager.request_payload[request_id] = to_dict(payload)
return None
else:
output_token_ids = _ensure_list(request.output_token_ids)
meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool))
if output_token_ids:
meta.override_keys = [("embed", "decode"), ("ids", "output")]
payload = OmniPayloadStruct(
meta=meta,
embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()),
ids=IdsStruct(output=output_token_ids),
speaker=speaker,
language=language,
)
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
payload = OmniPayloadStruct(
meta=meta,
embed=EmbeddingsStruct(prefill=thinker_emb.detach().cpu()),
hidden_states=HiddenStatesStruct(output=thinker_hid.detach().cpu()),
speaker=speaker,
language=language,
if thinker_emb.shape[0] > 1:
logger.warning(
"Unexpected multiple embeddings in thinker2talker_async_chunk for chunk_id %d: "
"request_id %s, num_computed_tokens%d %s. Expected shape [1, D].",
chunk_id,
request_id,
request.num_computed_tokens,
thinker_emb.shape,
)
return None
meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool))
payload = OmniPayloadStruct(
meta=meta,
embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()),
speaker=speaker,
language=language,
)
return payload


Expand Down
Loading