Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# suppress tokens by setting their probability to ~1e-9 (finite very small)
self.suppressed_tokens = self._get_talker_suppressed_tokens()
self.requires_raw_input_tokens = True
# Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips
self.gpu_resident_buffer_keys: set[str] = {
"last_talker_hidden",
"trailing_text_hidden",
"tts_pad_embed_projected",
}

elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
Expand Down Expand Up @@ -229,16 +223,14 @@ def embed_multimodal(self, **kwargs):

# ==================== Forward Pass ====================
def _get_talker_suppressed_tokens(self):
"""Return a boolean mask on GPU for suppressed token positions."""
vocab_size = self.config.talker_config.text_config.vocab_size
mask = torch.zeros(vocab_size, dtype=torch.bool)
start = vocab_size - 1024
eos_id = self.config.talker_config.codec_eos_token_id
for i in range(start, vocab_size):
if i != eos_id:
mask[i] = True
# Will be moved to the correct device on first use
return mask
return [
i
for i in range(
self.config.talker_config.text_config.vocab_size - 1024,
self.config.talker_config.text_config.vocab_size,
)
if i != self.config.talker_config.codec_eos_token_id
]

def get_mrope_input_positions(
self,
Expand Down Expand Up @@ -586,7 +578,7 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object):
Postprocess the talker hidden states.
"""
update_dict = {}
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach()
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous()
return update_dict

def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
Expand Down Expand Up @@ -640,9 +632,9 @@ def talker_mtp(
if inputs_embeds.shape[-1] == 2048:
inputs_embeds = self.text_projection(inputs_embeds)
code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
input_ids, inputs_embeds, last_talker_hidden=last_talker_hidden
input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
)
inputs_embeds = summed_embeddings
inputs_embeds = summed_embeddings.clone()
inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size)
return inputs_embeds, code_predictor_codes.squeeze(-1)

Expand Down Expand Up @@ -766,7 +758,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
# compatible with old shape [1,S,D]
rem_tail = trailing_text_hidden.squeeze(0)
if rem_tail.shape[0] > 0:
update_dict["trailing_text_hidden"] = rem_tail.detach()
update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous()
# Also persist projected tts_pad for decode fallback if needed
if isinstance(tts_pad_thinker, torch.Tensor):
pad_in = tts_pad_thinker
Expand All @@ -775,7 +767,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
if pad_in.ndim == 1:
pad_in = pad_in.view(1, 1, -1)
pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker)))
update_dict["tts_pad_embed_projected"] = pad_proj.detach()
update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous()
except Exception:
pass
self._talker_cache_thinker_decode_embeds(info_dict, update_dict)
Expand Down Expand Up @@ -934,7 +926,7 @@ def talker_preprocess_decode(
if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0:
use_vec = q_tail[0:1, :]
new_q_tail = (
q_tail[1:, :].detach()
q_tail[1:, :].detach().to("cpu").contiguous()
if q_tail.shape[0] > 1
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
Expand Down Expand Up @@ -1130,10 +1122,16 @@ def compute_logits(
# implemented by assigning their logits to log(1e-9).

if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor):
# Move mask to device once (lazy), then reuse every step
if self.suppressed_tokens.device != logits.device:
self.suppressed_tokens = self.suppressed_tokens.to(logits.device)
logits.masked_fill_(self.suppressed_tokens.unsqueeze(0), -1e9)
try:
logits_cpu = logits.cpu()
logits_cpu[:, self.suppressed_tokens] = -1e9
logits = logits_cpu.to(logits.device)
Comment on lines +1126 to +1128
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Remove CPU copy from talker logits suppression

This block copies every talker logits tensor to CPU and back on each decode step, then applies the same suppression again on GPU. The extra host round-trip (logits.cpu() + .to(...)) forces synchronization and moves full-vocab logits across PCIe/host memory, causing a significant throughput regression in production decoding while providing no correctness benefit beyond the following in-place GPU assignment.

Useful? React with 👍 / 👎.

except Exception as e:
print(f"Error in logits suppression: {e}")
print(f"logits.shape: {logits.shape}")
print(f"self.suppressed_tokens: {self.suppressed_tokens}")
raise e
logits[:, self.suppressed_tokens] = -1e9
return logits

def sample(
Expand Down
Loading
Loading