diff --git a/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json b/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json index 98e31174817..4f73f2b6a96 100644 --- a/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json +++ b/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json @@ -10,16 +10,16 @@ "dataset_name": "random", "backend": "openai-chat-omni", "endpoint": "/v1/chat/completions", - "num_prompts": [4, 16, 32, 64, 128], - "max_concurrency": [1, 4, 8, 16, 32], + "num_prompts": [4, 16, 32, 64], + "max_concurrency": [1, 4, 8, 16], "random_input_len": 2500, "random_output_len": 900, "ignore_eos": true, "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration", "baseline": { - "mean_ttft_ms": [1000, 3000, 5000, 7000, 9000], - "mean_audio_ttfp_ms": [1000, 3000, 5000, 7000, 9000], - "mean_audio_rtf": [0.2, 0.35, 0.6, 0.85, 0.9] + "mean_ttft_ms": [1000, 3000, 5000, 7000], + "mean_audio_ttfp_ms": [1000, 3000, 5000, 7000], + "mean_audio_rtf": [0.2, 0.35, 0.6, 0.85] } }, { diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py index bf022dd306e..aeaf27b31df 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -44,7 +44,7 @@ def get_batch_token_config(default_path): return modify_stage_config( default_path, updates={ - "stages": {0: {"max_num_batched_tokens": 64}, 1: {"max_num_batched_tokens": 64}}, + "stages": {1: {"max_num_batched_tokens": 64}}, }, ) @@ -95,12 +95,7 @@ def get_default_config(default_path): test_token_params = [ pytest.param( - OmniServerParams( - model=model, - stage_config_path=get_batch_token_config(default_path), - use_stage_cli=True, - server_args=["--async-chunk"], - ), + OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path), use_stage_cli=True), id="batch_token_64", ) ] diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index bc840c739bf..2bdb1136976 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -62,7 +62,6 @@ def __init__(self, vllm_config: Any): self.waiting_for_chunk_running_requests: deque[Any] = deque() self.requests_with_ready_chunks = set() self.requests_origin_status = {} - self.requests_num_chunks_sent: dict[str, int] = defaultdict(int) @classmethod def create_connector(cls, model_config: Any): @@ -118,17 +117,6 @@ def save_async( pooling_output: Partial pooling output dictionary request: Request object """ - - # If the request is preempted, skip the already saved chunks. - if request.num_computed_tokens < self.requests_num_chunks_sent.get(request.external_req_id, 0): - logger.warning( - f"Enqueue save_async for request {request.external_req_id}, " - f"request.num_computed_tokens={request.num_computed_tokens}, " - f"previous_chunks_sent={self.requests_num_chunks_sent.get(request.external_req_id, 0)}" - ) - return - - self.requests_num_chunks_sent[request.external_req_id] = request.num_computed_tokens task = { "pooling_output": pooling_output, "request": request, @@ -167,7 +155,8 @@ def _poll_single_request(self, request: Request): meta = payload_data.get("meta", {}) if self.model_mode == "ar": - request.additional_information = payload_data + merged_payload = self._update_request_payload(external_req_id, payload_data) + request.additional_information = merged_payload if meta.get("finished"): self.finished_requests.add(req_id) else: @@ -209,6 +198,42 @@ def _poll_single_request(self, request: Request): return False + def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: + """Update the stored payload for *req_id* with the latest chunk.""" + if req_id not in self.request_payload: + self.request_payload[req_id] = payload_data + return payload_data + origin = self.request_payload[req_id] + raw_ok = payload_data.get("meta", {}).pop("override_keys", []) + override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} + + for key, value in payload_data.items(): + if isinstance(value, dict): + origin_sub = origin.get(key) + if not isinstance(origin_sub, dict): + continue + for qual, qval in value.items(): + if key == "meta" and qual == "finished": + continue + if (key, qual) in override_keys: + continue + osv = origin_sub.get(qual) + if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): + value[qual] = torch.cat([osv, qval], dim=0) + elif isinstance(qval, list) and isinstance(osv, list): + value[qual] = osv + qval + else: + if key in override_keys: + continue + ov = origin.get(key) + if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): + payload_data[key] = torch.cat([ov, value], dim=0) + elif isinstance(value, list) and isinstance(ov, list): + payload_data[key] = ov + value + + self.request_payload[req_id] = payload_data + return payload_data + def _send_single_request(self, task: dict): raw_po = task["pooling_output"] pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po @@ -265,7 +290,6 @@ def _send_single_request(self, task: dict): if is_finished: self.code_prompt_token_ids.pop(external_req_id, None) - self.requests_num_chunks_sent.pop(external_req_id, None) cached_ic = getattr(self, "_cached_ic", None) if cached_ic is not None: cached_ic.pop(external_req_id, None) @@ -303,7 +327,6 @@ def cleanup_sender(self, external_req_id: str) -> None: self.put_req_chunk.pop(external_req_id, None) self.request_payload.pop(external_req_id, None) self.code_prompt_token_ids.pop(external_req_id, None) - self.requests_num_chunks_sent.pop(external_req_id, None) cached_ic = getattr(self, "_cached_ic", None) if cached_ic is not None: @@ -376,11 +399,6 @@ def postprocess_scheduler_output( Add additional info for cached requests and clean up ready chunks from scheduler output. """ - stage_id = self.connector.stage_id - - if stage_id == 0: - return - if requests is not None: self.attach_cached_additional_information(scheduler_output, requests) self._clear_chunk_ready(scheduler_output) @@ -396,8 +414,6 @@ def attach_cached_additional_information(scheduler_output: Any, requests: dict[s request = requests.get(req_id) if req_id else None additional_info = getattr(request, "additional_information", None) if request else None cached_reqs.additional_information[req_id] = additional_info - if request and additional_info: - request.additional_information = None def _process_chunk_queue( self, diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 3c658fc3f29..0bdca2fd297 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -990,10 +990,19 @@ def _thinker_decode_to_talker_decode( """ embed = payload.get("embed", {}) meta = payload.get("meta", {}) + ids = payload.get("ids", {}) cached_thinker_decode_embeds = embed.get("cached_decode", None) thinker_decode_embed = embed.get("decode", None) start_index = meta.get("num_processed_tokens", 0) + thinker_output_token_ids = ids.get("output", []) + if start_index >= len(thinker_output_token_ids) - 1: + # When the tokens output by the thinker are exhausted, an EOS token needs to be appended. + # Use the finished_flag to mark that all tokens output by thinker have been consumed. + if meta.get("eos_emitted", False): + return self.tts_pad_embed.to(device) + update_dict.setdefault("meta", {})["eos_emitted"] = True + return self.tts_eos_embed.to(device) if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) @@ -1002,20 +1011,10 @@ def _thinker_decode_to_talker_decode( thinker_decode_embed = thinker_decode_embed.to(device) cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) update_dict.setdefault("embed", {})["cached_decode"] = cached_thinker_decode_embeds - - elif thinker_decode_embed is not None: + else: thinker_embed = thinker_decode_embed if thinker_embed.device != device: thinker_embed = thinker_embed.to(device) - - else: - # When the tokens output by the thinker are exhausted, an EOS token needs to be appended. - # Use the finished_flag to mark that all tokens output by thinker have been consumed. - if meta.get("eos_emitted", False): - return self.tts_pad_embed.to(device) - update_dict.setdefault("meta", {})["eos_emitted"] = True - return self.tts_eos_embed.to(device) - update_dict.setdefault("embed", {})["decode"] = None return self.talker.text_projection(thinker_embed).to(device) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index fd7cfd2aa60..b1672612bf3 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -431,28 +431,27 @@ def _maybe_cpu(t: Any) -> torch.Tensor | None: payload.hidden_states.output = torch.cat( (save_payload.get("hidden_states", {}).get("output"), payload.hidden_states.output), dim=0 ) - prefill_shape = payload.embed.prefill.shape[0] - if not is_finished and prefill_shape <= len(prompt_token_ids): - transfer_manager.request_payload[request_id] = to_dict(payload) - return None else: - if thinker_emb.shape[0] > 1: - logger.warning( - "Unexpected multiple embeddings in thinker2talker_async_chunk for chunk_id %d: " - "request_id %s, num_computed_tokens%d %s. Expected shape [1, D].", - chunk_id, - request_id, - request.num_computed_tokens, - thinker_emb.shape, - ) - return None + output_token_ids = _ensure_list(request.output_token_ids) meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)) - payload = OmniPayloadStruct( - meta=meta, - embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()), - speaker=speaker, - language=language, - ) + if output_token_ids: + meta.override_keys = [("embed", "decode"), ("ids", "output")] + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()), + ids=IdsStruct(output=output_token_ids), + speaker=speaker, + language=language, + ) + else: + # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(prefill=thinker_emb.detach().cpu()), + hidden_states=HiddenStatesStruct(output=thinker_hid.detach().cpu()), + speaker=speaker, + language=language, + ) return payload