diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index bdcad535a6a..de248f0f330 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -396,6 +396,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Keys that should stay on GPU in model_intermediate_buffer to avoid # CPU-to-GPU round-trips on every decode step. self.gpu_resident_buffer_keys: set[str] = { + "audio_codes", "last_talker_hidden", "tts_pad_embed", "tailing_text_hidden", @@ -650,7 +651,8 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]: # Stays on GPU - gpu_resident_buffer_keys avoids the CPU round-trip. if hidden_states.numel() == 0: return {} - return {"last_talker_hidden": hidden_states[-1, :].detach()} + last = hidden_states[-1, :].detach() + return {"last_talker_hidden": last} # -------------------- prompt construction helpers -------------------- diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index d7d45031af4..0e25b5aa472 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -559,6 +559,29 @@ def propose_draft_token_ids(sampled_token_ids): hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output ) + # Pre-copy multimodal tensors to CPU once (not per-request) to avoid + # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + mm_cpu: dict[str, object] = {} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list): + element = v[0] + if isinstance(element, torch.Tensor): + element = element.detach().to("cpu").contiguous() + mm_cpu[k] = element + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: idx = req_id_to_index_output_copy[rid] @@ -567,28 +590,20 @@ def propose_draft_token_ids(sampled_token_ids): end = start + sched hidden_slice = hidden_states_cpu[start:end] payload: dict[str, object] = {"hidden": hidden_slice} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: + if mm_cpu: mm_payload: dict[str, object] = {} - for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_payload[k] = v.detach().to("cpu")[start:end].contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: - sub_dict[str(sk)] = sv.detach().to("cpu")[start:end].contiguous() - if sub_dict: - mm_payload[k] = sub_dict - elif isinstance(v, list): - element = v[0] - if isinstance(element, torch.Tensor): - element = element.detach().to("cpu").contiguous() - mm_payload[k] = element - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") - if mm_payload: - payload.update(mm_payload) + for k, v in mm_cpu.items(): + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + mm_payload[k] = v[start:end].contiguous() + elif isinstance(v, dict): + mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} + elif isinstance(v, torch.Tensor): + # List-derived tensor payloads are request-invariant; clone to + # avoid accidental cross-request aliasing on downstream mutation. + mm_payload[k] = v.clone() + else: + mm_payload[k] = v + payload.update(mm_payload) pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): if self.model_config.enable_return_routed_experts: diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index e17d52bdd53..17acbe8b005 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1298,14 +1298,16 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) - # update the inputs_embeds and code_predictor_codes - code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + # code_predictor_codes stays on GPU here; _update_intermediate_buffer + # keeps it device-resident when the key is in gpu_resident_buffer_keys. + # D2H is deferred to sample_tokens where hidden_states.to("cpu") already + # syncs the stream, avoiding a per-step cudaStreamSynchronize. out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes") for idx, req_id in enumerate(decode_req_ids): req_index = self.input_batch.req_ids.index(req_id) start_offset = int(self.query_start_loc.cpu[req_index]) inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] - update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]} + update_dict = {out_key: code_predictor_codes[idx : idx + 1]} self._merge_additional_information_update(req_id, update_dict) def _model_forward(