Skip to content
50 changes: 26 additions & 24 deletions vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# suppress tokens by setting their probability to ~1e-9 (finite very small)
self.suppressed_tokens = self._get_talker_suppressed_tokens()
self.requires_raw_input_tokens = True
# Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips
self.gpu_resident_buffer_keys: set[str] = {
"last_talker_hidden",
"trailing_text_hidden",
"tts_pad_embed_projected",
}

elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
Expand Down Expand Up @@ -223,14 +229,16 @@ def embed_multimodal(self, **kwargs):

# ==================== Forward Pass ====================
def _get_talker_suppressed_tokens(self):
return [
i
for i in range(
self.config.talker_config.text_config.vocab_size - 1024,
self.config.talker_config.text_config.vocab_size,
)
if i != self.config.talker_config.codec_eos_token_id
]
"""Return a boolean mask on GPU for suppressed token positions."""
vocab_size = self.config.talker_config.text_config.vocab_size
mask = torch.zeros(vocab_size, dtype=torch.bool)
start = vocab_size - 1024
eos_id = self.config.talker_config.codec_eos_token_id
for i in range(start, vocab_size):
if i != eos_id:
mask[i] = True
# Will be moved to the correct device on first use
return mask

def get_mrope_input_positions(
self,
Expand Down Expand Up @@ -578,7 +586,7 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object):
Postprocess the talker hidden states.
"""
update_dict = {}
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach().to("cpu").contiguous()
update_dict["last_talker_hidden"] = hidden_states[-1, :].detach()
return update_dict

def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
Expand Down Expand Up @@ -632,9 +640,9 @@ def talker_mtp(
if inputs_embeds.shape[-1] == 2048:
inputs_embeds = self.text_projection(inputs_embeds)
code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
input_ids, inputs_embeds, last_talker_hidden=last_talker_hidden
)
inputs_embeds = summed_embeddings.clone()
inputs_embeds = summed_embeddings
inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size)
return inputs_embeds, code_predictor_codes.squeeze(-1)

Expand Down Expand Up @@ -758,7 +766,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
# compatible with old shape [1,S,D]
rem_tail = trailing_text_hidden.squeeze(0)
if rem_tail.shape[0] > 0:
update_dict["trailing_text_hidden"] = rem_tail.detach().to("cpu").contiguous()
update_dict["trailing_text_hidden"] = rem_tail.detach()
# Also persist projected tts_pad for decode fallback if needed
if isinstance(tts_pad_thinker, torch.Tensor):
pad_in = tts_pad_thinker
Expand All @@ -767,7 +775,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
if pad_in.ndim == 1:
pad_in = pad_in.view(1, 1, -1)
pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker)))
update_dict["tts_pad_embed_projected"] = pad_proj.detach().to("cpu").contiguous()
update_dict["tts_pad_embed_projected"] = pad_proj.detach()
except Exception:
pass
self._talker_cache_thinker_decode_embeds(info_dict, update_dict)
Expand Down Expand Up @@ -926,7 +934,7 @@ def talker_preprocess_decode(
if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0:
use_vec = q_tail[0:1, :]
new_q_tail = (
q_tail[1:, :].detach().to("cpu").contiguous()
q_tail[1:, :].detach()
if q_tail.shape[0] > 1
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
Expand Down Expand Up @@ -1122,16 +1130,10 @@ def compute_logits(
# implemented by assigning their logits to log(1e-9).

if getattr(self, "model_stage", None) == "talker" and isinstance(logits, torch.Tensor):
try:
logits_cpu = logits.cpu()
logits_cpu[:, self.suppressed_tokens] = -1e9
logits = logits_cpu.to(logits.device)
except Exception as e:
print(f"Error in logits suppression: {e}")
print(f"logits.shape: {logits.shape}")
print(f"self.suppressed_tokens: {self.suppressed_tokens}")
raise e
logits[:, self.suppressed_tokens] = -1e9
# Move mask to device once (lazy), then reuse every step
if self.suppressed_tokens.device != logits.device:
self.suppressed_tokens = self.suppressed_tokens.to(logits.device)
logits.masked_fill_(self.suppressed_tokens.unsqueeze(0), -1e9)
return logits

def sample(
Expand Down
Loading