Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Keys that should stay on GPU in model_intermediate_buffer to avoid
# CPU-to-GPU round-trips on every decode step.
self.gpu_resident_buffer_keys: set[str] = {
"audio_codes",
"last_talker_hidden",
"tts_pad_embed",
"tailing_text_hidden",
Expand Down Expand Up @@ -650,7 +651,8 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]:
# Stays on GPU - gpu_resident_buffer_keys avoids the CPU round-trip.
if hidden_states.numel() == 0:
return {}
return {"last_talker_hidden": hidden_states[-1, :].detach()}
last = hidden_states[-1, :].detach()
return {"last_talker_hidden": last}

# -------------------- prompt construction helpers --------------------

Expand Down
57 changes: 36 additions & 21 deletions vllm_omni/worker/gpu_ar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,29 @@ def propose_draft_token_ids(sampled_token_ids):
hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
)

# Pre-copy multimodal tensors to CPU once (not per-request) to avoid
# redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
mm_cpu: dict[str, object] = {}
if isinstance(multimodal_outputs, dict) and multimodal_outputs:
for k, v in multimodal_outputs.items():
try:
if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
mm_cpu[k] = v.detach().to("cpu").contiguous()
elif isinstance(v, dict):
sub_dict: dict[str, torch.Tensor] = {}
for sk, sv in v.items():
if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]:
sub_dict[str(sk)] = sv.detach().to("cpu").contiguous()
if sub_dict:
mm_cpu[k] = sub_dict
elif isinstance(v, list):
element = v[0]
if isinstance(element, torch.Tensor):
element = element.detach().to("cpu").contiguous()
mm_cpu[k] = element
except Exception as e:
logger.error(f"Error in merge multimodal outputs: {e}")

pooler_output: list[dict[str, object]] = []
for rid in req_ids_output_copy:
idx = req_id_to_index_output_copy[rid]
Expand All @@ -567,28 +590,20 @@ def propose_draft_token_ids(sampled_token_ids):
end = start + sched
hidden_slice = hidden_states_cpu[start:end]
payload: dict[str, object] = {"hidden": hidden_slice}
if isinstance(multimodal_outputs, dict) and multimodal_outputs:
if mm_cpu:
mm_payload: dict[str, object] = {}
for k, v in multimodal_outputs.items():
try:
if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
mm_payload[k] = v.detach().to("cpu")[start:end].contiguous()
elif isinstance(v, dict):
sub_dict: dict[str, torch.Tensor] = {}
for sk, sv in v.items():
if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]:
sub_dict[str(sk)] = sv.detach().to("cpu")[start:end].contiguous()
if sub_dict:
mm_payload[k] = sub_dict
elif isinstance(v, list):
element = v[0]
if isinstance(element, torch.Tensor):
element = element.detach().to("cpu").contiguous()
mm_payload[k] = element
except Exception as e:
logger.error(f"Error in merge multimodal outputs: {e}")
if mm_payload:
payload.update(mm_payload)
for k, v in mm_cpu.items():
if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
mm_payload[k] = v[start:end].contiguous()
elif isinstance(v, dict):
mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()}
elif isinstance(v, torch.Tensor):
# List-derived tensor payloads are request-invariant; clone to
# avoid accidental cross-request aliasing on downstream mutation.
mm_payload[k] = v.clone()
else:
mm_payload[k] = v
payload.update(mm_payload)
pooler_output.append(payload)
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
if self.model_config.enable_return_routed_experts:
Expand Down
8 changes: 5 additions & 3 deletions vllm_omni/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,14 +1298,16 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te
None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
# update the inputs_embeds and code_predictor_codes
code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous()
# code_predictor_codes stays on GPU here; _update_intermediate_buffer
# keeps it device-resident when the key is in gpu_resident_buffer_keys.
# D2H is deferred to sample_tokens where hidden_states.to("cpu") already
# syncs the stream, avoiding a per-step cudaStreamSynchronize.
out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes")
for idx, req_id in enumerate(decode_req_ids):
req_index = self.input_batch.req_ids.index(req_id)
start_offset = int(self.query_start_loc.cpu[req_index])
inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
update_dict = {out_key: code_predictor_codes_cpu[idx : idx + 1]}
update_dict = {out_key: code_predictor_codes[idx : idx + 1]}
Comment thread
DomBrown marked this conversation as resolved.
self._merge_additional_information_update(req_id, update_dict)

def _model_forward(
Expand Down
Loading