Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# suppress tokens by setting their probability to ~1e-9 (finite very small)
self.suppressed_tokens = self._get_talker_suppressed_tokens()
self.requires_raw_input_tokens = True
# Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips
self.gpu_resident_buffer_keys: set[str] = {
"last_talker_hidden",
"trailing_text_hidden",
"tts_pad_embed_projected",
Comment on lines +162 to +166
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep code_predictor_codes in the GPU-resident buffer set

_talker_mtp_forward() writes the decode result under the default talker_mtp_output_key (code_predictor_codes), but _update_intermediate_buffer() only skips the .to("cpu") round-trip for keys listed in gpu_resident_buffer_keys (vllm_omni/worker/gpu_model_runner.py:1349-1354). Because this set omits code_predictor_codes, every decode step still synchronizes on a device-to-host copy of the MTP output, so the advertised hot-path CPU round-trip elimination never actually applies to Qwen3-Omni's codec codes.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code_predictor_codes will not be readed from gpu_resident_buffer_keys . no need to add to gpu_resident_buffer_keys

}

elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
Expand Down Expand Up @@ -223,14 +229,16 @@ def embed_multimodal(self, **kwargs):

# ==================== Forward Pass ====================
def _get_talker_suppressed_tokens(self):
return [
i
for i in range(
self.config.talker_config.text_config.vocab_size - 1024,
self.config.talker_config.text_config.vocab_size,
)
if i != self.config.talker_config.codec_eos_token_id
]
"""Return a boolean mask on GPU for suppressed token positions."""
vocab_size = self.config.talker_config.text_config.vocab_size
mask = torch.zeros(vocab_size, dtype=torch.bool)
start = vocab_size - 1024
eos_id = self.config.talker_config.codec_eos_token_id
for i in range(start, vocab_size):
if i != eos_id:
mask[i] = True
# Will be moved to the correct device on first use
return mask

def get_mrope_input_positions(
self,
Expand Down Expand Up @@ -578,7 +586,7 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object):
Postprocess the talker hidden states.
"""
update_dict = {}
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous()
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach()
return update_dict

def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
Expand Down Expand Up @@ -632,9 +640,11 @@ def talker_mtp(
if inputs_embeds.shape[-1] == 2048:
inputs_embeds = self.text_projection(inputs_embeds)
code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
input_ids, inputs_embeds, last_talker_hidden=last_talker_hidden
)
inputs_embeds = summed_embeddings.clone()
# summed_embeddings is [B, seq_len, H] (3D) while text_step is [B, H] (2D).
# Flatten to 2D first to avoid wrong broadcasting: [B,1,H]+[B,H] → [B,B,H]
inputs_embeds = summed_embeddings.reshape(-1, self.talker_config.text_config.hidden_size)
inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size)
return inputs_embeds, code_predictor_codes.squeeze(-1)

Expand Down Expand Up @@ -758,7 +768,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
# compatible with old shape [1,S,D]
rem_tail = trailing_text_hidden.squeeze(0)
if rem_tail.shape[0] > 0:
update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous()
update_dict["trailing_text_hidden"] = rem_tail.detach()
# Also persist projected tts_pad for decode fallback if needed
if isinstance(tts_pad_thinker, torch.Tensor):
pad_in = tts_pad_thinker
Expand All @@ -767,7 +777,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
if pad_in.ndim == 1:
pad_in = pad_in.view(1, 1, -1)
pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker)))
update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous()
update_dict["tts_pad_embed_projected"] = pad_proj.detach()
except Exception:
pass
self._talker_cache_thinker_decode_embeds(info_dict, update_dict)
Expand Down Expand Up @@ -926,7 +936,7 @@ def talker_preprocess_decode(
if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0:
use_vec = q_tail[0:1, :]
new_q_tail = (
q_tail[1:, :].detach().to("cpu").contiguous()
q_tail[1:, :].detach()
if q_tail.shape[0] > 1
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
Expand Down Expand Up @@ -1122,16 +1132,10 @@ def compute_logits(
# implemented by assigning their logits to log(1e-9).

if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor):
try:
logits_cpu = logits.cpu()
logits_cpu[:, self.suppressed_tokens] = -1e9
logits = logits_cpu.to(logits.device)
except Exception as e:
print(f"Error in logits suppression: {e}")
print(f"logits.shape: {logits.shape}")
print(f"self.suppressed_tokens: {self.suppressed_tokens}")
raise e
logits[:, self.suppressed_tokens] = -1e9
# Move mask to device once (lazy), then reuse every step
if self.suppressed_tokens.device != logits.device:
self.suppressed_tokens = self.suppressed_tokens.to(logits.device)
logits.masked_fill_(self.suppressed_tokens.unsqueeze(0), -1e9)
return logits

def sample(
Expand Down
Loading
Loading