diff --git a/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json b/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json index 4f73f2b6a96..98e31174817 100644 --- a/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json +++ b/tests/dfx/perf/tests/test_qwen3_omni_async_chunk.json @@ -10,16 +10,16 @@ "dataset_name": "random", "backend": "openai-chat-omni", "endpoint": "/v1/chat/completions", - "num_prompts": [4, 16, 32, 64], - "max_concurrency": [1, 4, 8, 16], + "num_prompts": [4, 16, 32, 64, 128], + "max_concurrency": [1, 4, 8, 16, 32], "random_input_len": 2500, "random_output_len": 900, "ignore_eos": true, "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration", "baseline": { - "mean_ttft_ms": [1000, 3000, 5000, 7000], - "mean_audio_ttfp_ms": [1000, 3000, 5000, 7000], - "mean_audio_rtf": [0.2, 0.35, 0.6, 0.85] + "mean_ttft_ms": [1000, 3000, 5000, 7000, 9000], + "mean_audio_ttfp_ms": [1000, 3000, 5000, 7000, 9000], + "mean_audio_rtf": [0.2, 0.35, 0.6, 0.85, 0.9] } }, { diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py index aeaf27b31df..bf022dd306e 100644 --- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py +++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py @@ -44,7 +44,7 @@ def get_batch_token_config(default_path): return modify_stage_config( default_path, updates={ - "stages": {1: {"max_num_batched_tokens": 64}}, + "stages": {0: {"max_num_batched_tokens": 64}, 1: {"max_num_batched_tokens": 64}}, }, ) @@ -95,7 +95,12 @@ def get_default_config(default_path): test_token_params = [ pytest.param( - OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path), use_stage_cli=True), + OmniServerParams( + model=model, + stage_config_path=get_batch_token_config(default_path), + use_stage_cli=True, + server_args=["--async-chunk"], + ), id="batch_token_64", ) ] diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 2bdb1136976..bc840c739bf 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -62,6 +62,7 @@ def __init__(self, vllm_config: Any): self.waiting_for_chunk_running_requests: deque[Any] = deque() self.requests_with_ready_chunks = set() self.requests_origin_status = {} + self.requests_num_chunks_sent: dict[str, int] = defaultdict(int) @classmethod def create_connector(cls, model_config: Any): @@ -117,6 +118,17 @@ def save_async( pooling_output: Partial pooling output dictionary request: Request object """ + + # If the request is preempted, skip the already saved chunks. + if request.num_computed_tokens < self.requests_num_chunks_sent.get(request.external_req_id, 0): + logger.warning( + f"Enqueue save_async for request {request.external_req_id}, " + f"request.num_computed_tokens={request.num_computed_tokens}, " + f"previous_chunks_sent={self.requests_num_chunks_sent.get(request.external_req_id, 0)}" + ) + return + + self.requests_num_chunks_sent[request.external_req_id] = request.num_computed_tokens task = { "pooling_output": pooling_output, "request": request, @@ -155,8 +167,7 @@ def _poll_single_request(self, request: Request): meta = payload_data.get("meta", {}) if self.model_mode == "ar": - merged_payload = self._update_request_payload(external_req_id, payload_data) - request.additional_information = merged_payload + request.additional_information = payload_data if meta.get("finished"): self.finished_requests.add(req_id) else: @@ -198,42 +209,6 @@ def _poll_single_request(self, request: Request): return False - def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]: - """Update the stored payload for *req_id* with the latest chunk.""" - if req_id not in self.request_payload: - self.request_payload[req_id] = payload_data - return payload_data - origin = self.request_payload[req_id] - raw_ok = payload_data.get("meta", {}).pop("override_keys", []) - override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok} - - for key, value in payload_data.items(): - if isinstance(value, dict): - origin_sub = origin.get(key) - if not isinstance(origin_sub, dict): - continue - for qual, qval in value.items(): - if key == "meta" and qual == "finished": - continue - if (key, qual) in override_keys: - continue - osv = origin_sub.get(qual) - if isinstance(qval, torch.Tensor) and isinstance(osv, torch.Tensor): - value[qual] = torch.cat([osv, qval], dim=0) - elif isinstance(qval, list) and isinstance(osv, list): - value[qual] = osv + qval - else: - if key in override_keys: - continue - ov = origin.get(key) - if isinstance(value, torch.Tensor) and isinstance(ov, torch.Tensor): - payload_data[key] = torch.cat([ov, value], dim=0) - elif isinstance(value, list) and isinstance(ov, list): - payload_data[key] = ov + value - - self.request_payload[req_id] = payload_data - return payload_data - def _send_single_request(self, task: dict): raw_po = task["pooling_output"] pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po @@ -290,6 +265,7 @@ def _send_single_request(self, task: dict): if is_finished: self.code_prompt_token_ids.pop(external_req_id, None) + self.requests_num_chunks_sent.pop(external_req_id, None) cached_ic = getattr(self, "_cached_ic", None) if cached_ic is not None: cached_ic.pop(external_req_id, None) @@ -327,6 +303,7 @@ def cleanup_sender(self, external_req_id: str) -> None: self.put_req_chunk.pop(external_req_id, None) self.request_payload.pop(external_req_id, None) self.code_prompt_token_ids.pop(external_req_id, None) + self.requests_num_chunks_sent.pop(external_req_id, None) cached_ic = getattr(self, "_cached_ic", None) if cached_ic is not None: @@ -399,6 +376,11 @@ def postprocess_scheduler_output( Add additional info for cached requests and clean up ready chunks from scheduler output. """ + stage_id = self.connector.stage_id + + if stage_id == 0: + return + if requests is not None: self.attach_cached_additional_information(scheduler_output, requests) self._clear_chunk_ready(scheduler_output) @@ -414,6 +396,8 @@ def attach_cached_additional_information(scheduler_output: Any, requests: dict[s request = requests.get(req_id) if req_id else None additional_info = getattr(request, "additional_information", None) if request else None cached_reqs.additional_information[req_id] = additional_info + if request and additional_info: + request.additional_information = None def _process_chunk_queue( self, diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index d7765026524..c0c6cdfbddb 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -982,19 +982,10 @@ def _thinker_decode_to_talker_decode( """ embed = payload.get("embed", {}) meta = payload.get("meta", {}) - ids = payload.get("ids", {}) cached_thinker_decode_embeds = embed.get("cached_decode", None) thinker_decode_embed = embed.get("decode", None) start_index = meta.get("num_processed_tokens", 0) - thinker_output_token_ids = ids.get("output", []) - if start_index >= len(thinker_output_token_ids) - 1: - # When the tokens output by the thinker are exhausted, an EOS token needs to be appended. - # Use the finished_flag to mark that all tokens output by thinker have been consumed. - if meta.get("eos_emitted", False): - return self.tts_pad_embed.to(device) - update_dict.setdefault("meta", {})["eos_emitted"] = True - return self.tts_eos_embed.to(device) if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]: cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(device) @@ -1003,10 +994,20 @@ def _thinker_decode_to_talker_decode( thinker_decode_embed = thinker_decode_embed.to(device) cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0) update_dict.setdefault("embed", {})["cached_decode"] = cached_thinker_decode_embeds - else: + + elif thinker_decode_embed is not None: thinker_embed = thinker_decode_embed if thinker_embed.device != device: thinker_embed = thinker_embed.to(device) + + else: + # When the tokens output by the thinker are exhausted, an EOS token needs to be appended. + # Use the finished_flag to mark that all tokens output by thinker have been consumed. + if meta.get("eos_emitted", False): + return self.tts_pad_embed.to(device) + update_dict.setdefault("meta", {})["eos_emitted"] = True + return self.tts_eos_embed.to(device) + update_dict.setdefault("embed", {})["decode"] = None return self.talker.text_projection(thinker_embed).to(device) diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 63403619e9b..1b1dc0f7740 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -349,27 +349,28 @@ def _maybe_cpu(t: Any) -> torch.Tensor | None: payload.hidden_states.output = torch.cat( (save_payload.get("hidden_states", {}).get("output"), payload.hidden_states.output), dim=0 ) + prefill_shape = payload.embed.prefill.shape[0] + if not is_finished and prefill_shape <= len(prompt_token_ids): + transfer_manager.request_payload[request_id] = to_dict(payload) + return None else: - output_token_ids = _ensure_list(request.output_token_ids) - meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)) - if output_token_ids: - meta.override_keys = [("embed", "decode"), ("ids", "output")] - payload = OmniPayloadStruct( - meta=meta, - embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()), - ids=IdsStruct(output=output_token_ids), - speaker=speaker, - language=language, - ) - else: - # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. - payload = OmniPayloadStruct( - meta=meta, - embed=EmbeddingsStruct(prefill=thinker_emb.detach().cpu()), - hidden_states=HiddenStatesStruct(output=thinker_hid.detach().cpu()), - speaker=speaker, - language=language, + if thinker_emb.shape[0] > 1: + logger.warning( + "Unexpected multiple embeddings in thinker2talker_async_chunk for chunk_id %d: " + "request_id %s, num_computed_tokens%d %s. Expected shape [1, D].", + chunk_id, + request_id, + request.num_computed_tokens, + thinker_emb.shape, ) + return None + meta = MetaStruct(finished=torch.tensor(is_finished, dtype=torch.bool)) + payload = OmniPayloadStruct( + meta=meta, + embed=EmbeddingsStruct(decode=thinker_emb.detach().cpu()), + speaker=speaker, + language=language, + ) return payload